Merge pull request #4059 from quic-go/ecn

add ECN support
This commit is contained in:
Marten Seemann 2023-09-11 22:04:31 +07:00 committed by GitHub
commit 1f25153884
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
51 changed files with 1616 additions and 369 deletions

View file

@ -41,6 +41,11 @@ jobs:
env:
QUIC_GO_DISABLE_GSO: true
run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=1 ${{ env.QLOGFLAG }}
- name: Run self tests, with ECN disabled
if: ${{ matrix.os == 'ubuntu' && (success() || failure()) }} # run this step even if the previous one failed
env:
QUIC_GO_DISABLE_ECN: true
run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=1 ${{ env.QLOGFLAG }}
- name: Run tests (32 bit)
if: ${{ matrix.os != 'macos' && (success() || failure()) }} # run this step even if the previous one failed
env:

View file

@ -278,6 +278,7 @@ var newConnection = func(
getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats,
clientAddressValidated,
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,
@ -385,7 +386,8 @@ var newClientConnection = func(
initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats,
false, /* has no effect */
false, // has no effect
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,
@ -905,6 +907,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc
KeyPhase: keyPhase,
},
p.Size(),
p.ecn,
frames,
)
}
@ -1190,7 +1193,7 @@ func (s *connection) handleUnpackedLongHeaderPacket(
var log func([]logging.Frame)
if s.tracer != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames)
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames)
}
}
isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log)
@ -1830,9 +1833,10 @@ func (s *connection) sendPackets(now time.Time) error {
if err != nil {
return err
}
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, now)
s.sendQueue.Send(buf, 0)
ecn := s.sentPacketHandler.ECNMode(true)
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, ecn, now)
s.sendQueue.Send(buf, 0, ecn)
// This is kind of a hack. We need to trigger sending again somehow.
s.pacingDeadline = deadlineSendImmediately
return nil
@ -1852,7 +1856,7 @@ func (s *connection) sendPackets(now time.Time) error {
return err
}
s.sentFirstPacket = true
if err := s.sendPackedCoalescedPacket(packet, now); err != nil {
if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now); err != nil {
return err
}
sendMode := s.sentPacketHandler.SendMode(now)
@ -1873,7 +1877,8 @@ func (s *connection) sendPackets(now time.Time) error {
func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for {
buf := getPacketBuffer()
if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil {
ecn := s.sentPacketHandler.ECNMode(true)
if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil {
if err == errNothingToPack {
buf.Release()
return nil
@ -1881,7 +1886,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
return err
}
s.sendQueue.Send(buf, 0)
s.sendQueue.Send(buf, 0, ecn)
if s.sendQueue.WouldBlock() {
return nil
@ -1906,9 +1911,10 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer()
maxSize := s.mtuDiscoverer.CurrentSize()
ecn := s.sentPacketHandler.ECNMode(true)
for {
var dontSendMore bool
size, err := s.appendPacket(buf, maxSize, now)
size, err := s.appendOneShortHeaderPacket(buf, maxSize, ecn, now)
if err != nil {
if err != errNothingToPack {
return err
@ -1930,15 +1936,19 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error {
}
}
// Don't send more packets in this batch if they require a different ECN marking than the previous ones.
nextECN := s.sentPacketHandler.ECNMode(true)
// Append another packet if
// 1. The congestion controller and pacer allow sending more
// 2. The last packet appended was a full-size packet
// 3. We still have enough space for another full-size packet in the buffer
if !dontSendMore && size == maxSize && buf.Len()+maxSize <= buf.Cap() {
// 3. The next packet will have the same ECN marking
// 4. We still have enough space for another full-size packet in the buffer
if !dontSendMore && size == maxSize && nextECN == ecn && buf.Len()+maxSize <= buf.Cap() {
continue
}
s.sendQueue.Send(buf, uint16(maxSize))
s.sendQueue.Send(buf, uint16(maxSize), ecn)
if dontSendMore {
return nil
@ -1967,6 +1977,7 @@ func (s *connection) resetPacingDeadline() {
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if !s.handshakeConfirmed {
ecn := s.sentPacketHandler.ECNMode(false)
packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil {
return err
@ -1974,9 +1985,10 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if packet == nil {
return nil
}
return s.sendPackedCoalescedPacket(packet, time.Now())
return s.sendPackedCoalescedPacket(packet, ecn, time.Now())
}
ecn := s.sentPacketHandler.ECNMode(true)
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil {
if err == errNothingToPack {
@ -1984,9 +1996,9 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
}
return err
}
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, now)
s.sendQueue.Send(buf, 0)
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, ecn, now)
s.sendQueue.Send(buf, 0, ecn)
return nil
}
@ -2018,24 +2030,24 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel)
}
return s.sendPackedCoalescedPacket(packet, now)
return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now)
}
// appendPacket appends a new packet to the given packetBuffer.
// appendOneShortHeaderPacket appends a new packet to the given packetBuffer.
// If there was nothing to pack, the returned size is 0.
func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) {
func (s *connection) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) {
startLen := buf.Len()
p, err := s.packer.AppendPacket(buf, maxSize, s.version)
if err != nil {
return 0, err
}
size := buf.Len() - startLen
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false)
s.registerPackedShortHeaderPacket(p, now)
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, size, false)
s.registerPackedShortHeaderPacket(p, ecn, now)
return size, nil
}
func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) {
func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) {
s.firstAckElicitingPacketAfterIdleSentTime = now
}
@ -2044,12 +2056,12 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket)
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket)
s.connIDManager.SentPacket()
}
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error {
s.logCoalescedPacket(packet)
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now time.Time) error {
s.logCoalescedPacket(packet, ecn)
for _, p := range packet.longHdrPackets {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now
@ -2058,7 +2070,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
if p.ack != nil {
largestAcked = p.ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false)
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false)
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake {
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
// See Section 4.9.1 of RFC 9001.
@ -2075,10 +2087,10 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket)
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket)
}
s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer, 0)
s.sendQueue.Send(packet.buffer, 0, ecn)
return nil
}
@ -2100,11 +2112,12 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
if err != nil {
return nil, err
}
s.logCoalescedPacket(packet)
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0)
ecn := s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket())
s.logCoalescedPacket(packet, ecn)
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn)
}
func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if s.logger.Debug() {
p.header.Log(s.logger)
@ -2132,7 +2145,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
if p.ack != nil {
ack = logutils.ConvertAckFrame(p.ack)
}
s.tracer.SentLongHeaderPacket(p.header, p.length, ack, frames)
s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames)
}
}
@ -2144,11 +2157,12 @@ func (s *connection) logShortHeaderPacket(
pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit,
ecn protocol.ECN,
size protocol.ByteCount,
isCoalesced bool,
) {
if s.logger.Debug() && !isCoalesced {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT", pn, size, s.logID)
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn)
}
// quic-go logging
if s.logger.Debug() {
@ -2185,13 +2199,14 @@ func (s *connection) logShortHeaderPacket(
KeyPhase: kp,
},
size,
ecn,
ack,
fs,
)
}
}
func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
if s.logger.Debug() {
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
// during which we might call PackCoalescedPacket but just pack a short header packet.
@ -2204,6 +2219,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase,
ecn,
packet.shortHdrPacket.Length,
false,
)
@ -2216,10 +2232,10 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
}
}
for _, p := range packet.longHdrPackets {
s.logLongHeaderPacket(p)
s.logLongHeaderPacket(p, ecn)
}
if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true)
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true)
}
}

View file

@ -454,7 +454,7 @@ var _ = Describe("Connection", func() {
Expect(e.ErrorMessage).To(BeEmpty())
return &coalescedPacket{buffer: buffer}, nil
})
mconn.EXPECT().Write([]byte("connection close"), gomock.Any())
mconn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
var appErr *ApplicationError
@ -475,7 +475,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
@ -494,7 +494,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
gomock.InOrder(
tracer.EXPECT().ClosedConnection(expectedErr),
tracer.EXPECT().Close(),
@ -516,7 +516,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
gomock.InOrder(
tracer.EXPECT().ClosedConnection(expectedErr),
tracer.EXPECT().Close(),
@ -565,7 +565,7 @@ var _ = Describe("Connection", func() {
close(returned)
}()
Consistently(returned).ShouldNot(BeClosed())
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
@ -590,7 +590,7 @@ var _ = Describe("Connection", func() {
return 3, protocol.PacketNumberLen2, protocol.KeyPhaseOne, b, nil
})
gomock.InOrder(
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()),
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
tracer.EXPECT().ClosedConnection(gomock.Any()),
tracer.EXPECT().Close(),
)
@ -609,14 +609,15 @@ var _ = Describe("Connection", func() {
conn.handshakeConfirmed = true
sconn := NewMockSendConn(mockCtrl)
sconn.EXPECT().capabilities().AnyTimes()
sconn.EXPECT().Write(gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes()
sconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes()
conn.sendQueue = newSendQueue(sconn)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes()
sph.EXPECT().ECNMode(true).Return(protocol.ECT1).AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
// only expect a single SentPacket() call
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
streamManager.EXPECT().CloseWithError(gomock.Any())
@ -777,7 +778,7 @@ var _ = Describe("Connection", func() {
conn.receivedPacketHandler = rph
packet.rcvTime = rcvTime
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), []logging.Frame{})
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECNCE, []logging.Frame{})
Expect(conn.handlePacketImpl(packet)).To(BeTrue())
})
@ -795,7 +796,12 @@ var _ = Describe("Connection", func() {
)
conn.receivedPacketHandler = rph
packet.rcvTime = rcvTime
tracer.EXPECT().ReceivedShortHeaderPacket(&logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}})
tracer.EXPECT().ReceivedShortHeaderPacket(
&logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero},
protocol.ByteCount(len(packet.data)),
logging.ECT1,
[]logging.Frame{&logging.PingFrame{}},
)
Expect(conn.handlePacketImpl(packet)).To(BeTrue())
})
@ -837,7 +843,7 @@ var _ = Describe("Connection", func() {
// make the go routine return
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.closeLocal(errors.New("close"))
Eventually(conn.Context().Done()).Should(BeClosed())
})
@ -849,8 +855,7 @@ var _ = Describe("Connection", func() {
pn++
return pn, protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil
}).Times(3)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
}).Times(3)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version) // only expect a single call
for i := 0; i < 3; i++ {
@ -872,7 +877,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed()
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.closeLocal(errors.New("close"))
Eventually(conn.Context().Done()).Should(BeClosed())
})
@ -885,8 +890,7 @@ var _ = Describe("Connection", func() {
pn++
return pn, protocol.PacketNumberLen4, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil
}).Times(3)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
}).Times(3)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Times(3)
for i := 0; i < 3; i++ {
@ -908,7 +912,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed()
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.closeLocal(errors.New("close"))
Eventually(conn.Context().Done()).Should(BeClosed())
})
@ -930,7 +934,7 @@ var _ = Describe("Connection", func() {
close(done)
}()
expectReplaceWithClosed()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
packet := getShortHeaderPacket(srcConnID, 0x42, nil)
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
@ -958,7 +962,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.shutdown()
Eventually(conn.Context().Done()).Should(BeClosed())
})
@ -980,7 +984,7 @@ var _ = Describe("Connection", func() {
close(done)
}()
expectReplaceWithClosed()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil))
@ -1020,7 +1024,7 @@ var _ = Describe("Connection", func() {
}, nil)
p1 := getLongHeaderPacket(hdr1, nil)
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any(), gomock.Any())
Expect(conn.handlePacketImpl(p1)).To(BeTrue())
// The next packet has to be ignored, since the source connection ID doesn't match.
p2 := getLongHeaderPacket(hdr2, nil)
@ -1053,7 +1057,7 @@ var _ = Describe("Connection", func() {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* one PADDING frame */, nil)
packet := getShortHeaderPacket(srcConnID, 0x42, nil)
packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any(), gomock.Any())
Expect(conn.handlePacketImpl(packet)).To(BeTrue())
})
})
@ -1093,12 +1097,13 @@ var _ = Describe("Connection", func() {
})
cryptoSetup.EXPECT().DiscardInitialKeys()
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial)
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any(), gomock.Any())
Expect(conn.handlePacketImpl(packet)).To(BeTrue())
})
It("handles coalesced packets", func() {
hdrLen1, packet1 := getPacketWithLength(srcConnID, 456)
packet1.ecn = protocol.ECT1
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) {
Expect(data).To(HaveLen(hdrLen1 + 456 - 3))
return &unpackedPacket{
@ -1125,8 +1130,8 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial).AnyTimes()
cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes()
gomock.InOrder(
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), logging.ECT1, gomock.Any()),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.ECT1, gomock.Any()),
)
packet1.data = append(packet1.data, packet2.data...)
Expect(conn.handlePacketImpl(packet1)).To(BeTrue())
@ -1151,7 +1156,7 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes()
gomock.InOrder(
tracer.EXPECT().BufferedPacket(gomock.Any(), protocol.ByteCount(len(packet1.data))),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any(), gomock.Any()),
)
packet1.data = append(packet1.data, packet2.data...)
Expect(conn.handlePacketImpl(packet1)).To(BeTrue())
@ -1177,7 +1182,7 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes()
// don't EXPECT any more calls to unpacker.UnpackLongHeader()
gomock.InOrder(
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any(), gomock.Any()),
tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID),
)
packet1.data = append(packet1.data, packet2.data...)
@ -1190,6 +1195,7 @@ var _ = Describe("Connection", func() {
var (
connDone chan struct{}
sender *MockSender
sph *mockackhandler.MockSentPacketHandler
)
BeforeEach(func() {
@ -1198,14 +1204,17 @@ var _ = Describe("Connection", func() {
sender.EXPECT().WouldBlock().AnyTimes()
conn.sendQueue = sender
connDone = make(chan struct{})
sph = mockackhandler.NewMockSentPacketHandler(mockCtrl)
conn.sentPacketHandler = sph
})
AfterEach(func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNCE).MaxTimes(1)
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
sender.EXPECT().Close()
@ -1226,12 +1235,11 @@ var _ = Describe("Connection", func() {
It("sends packets", func() {
conn.handshakeConfirmed = true
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.sentPacketHandler = sph
sph.EXPECT().ECNMode(true).Return(protocol.ECNNon).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
runConn()
p := shortHeaderPacket{
DestConnID: protocol.ParseConnectionID([]byte{1, 2, 3}),
@ -1243,19 +1251,22 @@ var _ = Describe("Connection", func() {
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
sent := make(chan struct{})
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) })
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) })
tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{
DestConnectionID: p.DestConnID,
PacketNumber: p.PacketNumber,
PacketNumberLen: p.PacketNumberLen,
KeyPhase: p.KeyPhase,
}, gomock.Any(), nil, []logging.Frame{})
}, gomock.Any(), gomock.Any(), nil, []logging.Frame{})
conn.scheduleSending()
Eventually(sent).Should(BeClosed())
})
It("doesn't send packets if there's nothing to send", func() {
conn.handshakeConfirmed = true
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(true).AnyTimes()
runConn()
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true)
@ -1264,13 +1275,12 @@ var _ = Describe("Connection", func() {
})
It("sends ACK only packets", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
done := make(chan struct{})
packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) })
conn.sentPacketHandler = sph
runConn()
conn.scheduleSending()
Eventually(done).Should(BeClosed())
@ -1278,12 +1288,11 @@ var _ = Describe("Connection", func() {
It("adds a BLOCKED frame when it is connection-level flow control blocked", func() {
conn.handshakeConfirmed = true
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.sentPacketHandler = sph
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
fc := mocks.NewMockConnectionFlowController(mockCtrl)
fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337))
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 13}, []byte("foobar"))
@ -1291,8 +1300,8 @@ var _ = Describe("Connection", func() {
conn.connFlowController = fc
runConn()
sent := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) })
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{})
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) })
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), nil, []logging.Frame{})
conn.scheduleSending()
Eventually(sent).Should(BeClosed())
frames, _ := conn.framer.AppendControlFrames(nil, 1000, protocol.Version1)
@ -1300,11 +1309,9 @@ var _ = Describe("Connection", func() {
})
It("doesn't send when the SentPacketHandler doesn't allow it", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone).AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
conn.sentPacketHandler = sph
runConn()
conn.scheduleSending()
time.Sleep(50 * time.Millisecond)
@ -1333,50 +1340,45 @@ var _ = Describe("Connection", func() {
})
It("sends a probe packet", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().QueueProbePacket(encLevel)
sph.EXPECT().ECNMode(gomock.Any())
p := getCoalescedPacket(123, enc != protocol.Encryption1RTT)
packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) {
Expect(pn).To(Equal(protocol.PacketNumber(123)))
})
sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.sentPacketHandler = sph
runConn()
sent := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) })
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) })
if enc == protocol.Encryption1RTT {
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any(), gomock.Any())
} else {
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any())
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any(), gomock.Any())
}
conn.scheduleSending()
Eventually(sent).Should(BeClosed())
})
It("sends a PING as a probe packet", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT0)
sph.EXPECT().QueueProbePacket(encLevel).Return(false)
p := getCoalescedPacket(123, enc != protocol.Encryption1RTT)
packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) {
Expect(pn).To(Equal(protocol.PacketNumber(123)))
})
conn.sentPacketHandler = sph
sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
runConn()
sent := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) })
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) })
if enc == protocol.Encryption1RTT {
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, logging.ECT0, gomock.Any(), gomock.Any())
} else {
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any())
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, logging.ECT0, gomock.Any(), gomock.Any())
}
conn.scheduleSending()
Eventually(sent).Should(BeClosed())
@ -1395,7 +1397,7 @@ var _ = Describe("Connection", func() {
)
BeforeEach(func() {
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph = mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
conn.handshakeConfirmed = true
@ -1409,10 +1411,11 @@ var _ = Describe("Connection", func() {
AfterEach(func() {
// make the go routine return
sph.EXPECT().ECNMode(gomock.Any()).MaxTimes(1)
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
sender.EXPECT().Close()
@ -1421,17 +1424,18 @@ var _ = Describe("Connection", func() {
})
It("sends multiple packets one by one immediately", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
sph.EXPECT().ECNMode(gomock.Any()).Times(2)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited)
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour))
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10"))
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, []byte("packet11"))
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) {
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal([]byte("packet10")))
})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) {
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal([]byte("packet11")))
})
go func() {
@ -1446,7 +1450,8 @@ var _ = Describe("Connection", func() {
It("sends multiple packets one by one immediately, with GSO", func() {
enableGSO()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(4)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
rand.Read(payload1)
@ -1456,7 +1461,7 @@ var _ = Describe("Connection", func() {
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...)))
})
go func() {
@ -1471,19 +1476,59 @@ var _ = Describe("Connection", func() {
It("stops appending packets when a smaller packet is packed, with GSO", func() {
enableGSO()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode(true).Times(4)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
rand.Read(payload1)
payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize()-1)
rand.Read(payload2)
payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize())
rand.Read(payload3)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 12}, payload3)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...)))
})
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(payload3))
})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
time.Sleep(50 * time.Millisecond) // make sure that only 2 packets are sent
})
It("stops appending packets when the ECN marking changes, with GSO", func() {
enableGSO()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(2)
sph.EXPECT().ECNMode(true).Return(protocol.ECT0).Times(2)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
rand.Read(payload1)
payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize())
rand.Read(payload2)
payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize())
rand.Read(payload3)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload3)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...)))
})
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(payload3))
})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1495,12 +1540,13 @@ var _ = Describe("Connection", func() {
})
It("sends multiple packets, when the pacer allows immediate sending", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
sph.EXPECT().ECNMode(gomock.Any()).Times(2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any())
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1512,13 +1558,14 @@ var _ = Describe("Connection", func() {
})
It("allows an ACK to be sent when pacing limited", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour))
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited)
sph.EXPECT().ECNMode(gomock.Any())
packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{PacketNumber: 123}, getPacketBuffer(), nil)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any())
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1532,12 +1579,13 @@ var _ = Describe("Connection", func() {
// when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck
// we shouldn't send the ACK in the same run
It("doesn't send an ACK right after becoming congestion limited", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
sph.EXPECT().ECNMode(gomock.Any()).Times(2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100"))
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any())
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1552,19 +1600,21 @@ var _ = Describe("Connection", func() {
pacingDelay := scaleDuration(100 * time.Millisecond)
gomock.InOrder(
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().ECNMode(gomock.Any()),
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().ECNMode(gomock.Any()),
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 101}, []byte("packet101")),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)),
)
written := make(chan struct{}, 2)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(2)
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(2)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1578,8 +1628,9 @@ var _ = Describe("Connection", func() {
})
It("sends multiple packets at once", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
sph.EXPECT().ECNMode(gomock.Any()).Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited)
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour))
for pn := protocol.PacketNumber(1000); pn < 1003; pn++ {
@ -1587,7 +1638,7 @@ var _ = Describe("Connection", func() {
}
written := make(chan struct{}, 3)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(3)
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(3)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1618,11 +1669,12 @@ var _ = Describe("Connection", func() {
written := make(chan struct{})
sender.EXPECT().WouldBlock().AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) })
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) })
available <- struct{}{}
Eventually(written).Should(BeClosed())
})
@ -1639,14 +1691,15 @@ var _ = Describe("Connection", func() {
written := make(chan struct{})
sender.EXPECT().WouldBlock().AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ByteCount, bool) {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ECN, protocol.ByteCount, bool) {
sph.EXPECT().ReceivedBytes(gomock.Any())
conn.handlePacket(receivedPacket{buffer: getPacketBuffer()})
})
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) })
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) })
conn.scheduleSending()
time.Sleep(scaleDuration(50 * time.Millisecond))
@ -1655,13 +1708,14 @@ var _ = Describe("Connection", func() {
})
It("stops sending when the send queue is full", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny)
sph.EXPECT().ECNMode(gomock.Any())
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000"))
written := make(chan struct{}, 1)
sender.EXPECT().WouldBlock()
sender.EXPECT().WouldBlock().Return(true).Times(2)
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} })
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} })
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1675,12 +1729,13 @@ var _ = Describe("Connection", func() {
time.Sleep(scaleDuration(50 * time.Millisecond))
// now make room in the send queue
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
sender.EXPECT().WouldBlock().AnyTimes()
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} })
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} })
available <- struct{}{}
Eventually(written).Should(Receive())
@ -1691,6 +1746,7 @@ var _ = Describe("Connection", func() {
It("doesn't set a pacing timer when there is no data to send", func() {
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
sender.EXPECT().WouldBlock().AnyTimes()
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
// don't EXPECT any calls to mconn.Write()
@ -1708,12 +1764,13 @@ var _ = Describe("Connection", func() {
mtuDiscoverer := NewMockMTUDiscoverer(mockCtrl)
conn.mtuDiscoverer = mtuDiscoverer
conn.config.DisablePathMTUDiscovery = false
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny)
sph.EXPECT().ECNMode(true)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
written := make(chan struct{}, 1)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} })
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} })
mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true)
ping := ackhandler.Frame{Frame: &wire.PingFrame{}}
mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234))
@ -1747,7 +1804,7 @@ var _ = Describe("Connection", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
sender.EXPECT().Close()
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
@ -1760,8 +1817,9 @@ var _ = Describe("Connection", func() {
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.sentPacketHandler = sph
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1}, []byte("packet1"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
@ -1776,8 +1834,8 @@ var _ = Describe("Connection", func() {
time.Sleep(50 * time.Millisecond)
// only EXPECT calls after scheduleSending is called
written := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) })
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) })
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
conn.scheduleSending()
Eventually(written).Should(BeClosed())
})
@ -1788,9 +1846,8 @@ var _ = Describe("Connection", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) {
Expect(pn).To(Equal(protocol.PacketNumber(1234)))
})
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(1234), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.sentPacketHandler = sph
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond))
@ -1799,8 +1856,8 @@ var _ = Describe("Connection", func() {
conn.receivedPacketHandler = rph
written := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) })
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) })
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1841,30 +1898,23 @@ var _ = Describe("Connection", func() {
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(false).Return(protocol.ECT1).AnyTimes()
sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes()
gomock.InOrder(
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) {
Expect(encLevel).To(Equal(protocol.EncryptionInitial))
Expect(pn).To(Equal(protocol.PacketNumber(13)))
Expect(size).To(BeEquivalentTo(123))
}),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) {
Expect(encLevel).To(Equal(protocol.EncryptionHandshake))
Expect(pn).To(Equal(protocol.PacketNumber(37)))
Expect(size).To(BeEquivalentTo(1234))
}),
sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(13), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionInitial, protocol.ECT1, protocol.ByteCount(123), gomock.Any()),
sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(37), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionHandshake, protocol.ECT1, protocol.ByteCount(1234), gomock.Any()),
)
gomock.InOrder(
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) {
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ logging.ECN, _ *wire.AckFrame, _ []logging.Frame) {
Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial))
}),
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) {
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ logging.ECN, _ *wire.AckFrame, _ []logging.Frame) {
Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake))
}),
)
sent := make(chan struct{})
mconn.EXPECT().Write([]byte("foobar"), uint16(0)).Do(func([]byte, uint16) { close(sent) })
mconn.EXPECT().Write([]byte("foobar"), uint16(0), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(sent) })
go func() {
defer GinkgoRecover()
@ -1881,7 +1931,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
@ -1952,7 +2002,7 @@ var _ = Describe("Connection", func() {
}()
handshakeCtx := conn.HandshakeComplete()
Consistently(handshakeCtx).ShouldNot(BeClosed())
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.closeLocal(errors.New("handshake error"))
Consistently(handshakeCtx).ShouldNot(BeClosed())
Eventually(conn.Context().Done()).Should(BeClosed())
@ -1961,12 +2011,13 @@ var _ = Describe("Connection", func() {
It("sends a HANDSHAKE_DONE frame when the handshake completes", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SetHandshakeConfirmed()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.sentPacketHandler = sph
done := make(chan struct{})
connRunner.EXPECT().Retire(clientDestConnID)
@ -1987,7 +2038,7 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handleHandshakeComplete()).To(Succeed())
conn.run()
}()
@ -2016,7 +2067,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
@ -2043,7 +2094,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed())
@ -2102,7 +2153,7 @@ var _ = Describe("Connection", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
@ -2255,7 +2306,7 @@ var _ = Describe("Connection", func() {
// make the go routine return
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.shutdown()
Eventually(conn.Context().Done()).Should(BeClosed())
})
@ -2338,7 +2389,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
@ -2541,7 +2592,7 @@ var _ = Describe("Client Connection", func() {
},
PacketNumberLen: protocol.PacketNumberLen2,
}, []byte("foobar"))
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), p.Size(), []logging.Frame{})
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), p.Size(), gomock.Any(), []logging.Frame{})
Expect(conn.handlePacketImpl(p)).To(BeTrue())
go func() {
defer GinkgoRecover()
@ -2554,7 +2605,7 @@ var _ = Describe("Client Connection", func() {
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close()
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any()).MaxTimes(1)
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
@ -2585,7 +2636,7 @@ var _ = Describe("Client Connection", func() {
DestConnectionID: srcConnID,
SrcConnectionID: destConnID,
}
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handleLongHeaderPacket(receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue())
})
@ -2851,7 +2902,7 @@ var _ = Describe("Client Connection", func() {
packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1)
}
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()),
tracer.EXPECT().Close(),
@ -3097,7 +3148,7 @@ var _ = Describe("Client Connection", func() {
hdr: hdr1,
data: []byte{0}, // one PADDING frame
}, nil)
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue())
// The next packet has to be ignored, since the source connection ID doesn't match.
tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any())
@ -3124,7 +3175,7 @@ var _ = Describe("Client Connection", func() {
It("fails on Initial-level ACK for unsent packet", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{ack})
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse())
})
@ -3136,7 +3187,7 @@ var _ = Describe("Client Connection", func() {
ReasonPhrase: "mitm attacker",
}
initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{connCloseFrame})
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue())
})

View file

@ -42,11 +42,13 @@ type keyUpdateConnTracer struct {
logging.NullConnectionTracer
}
func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ *logging.AckFrame, _ []logging.Frame) {
var _ logging.ConnectionTracer = &keyUpdateConnTracer{}
func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) {
sentHeaders = append(sentHeaders, hdr)
}
func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, frames []logging.Frame) {
func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) {
receivedHeaders = append(receivedHeaders, hdr)
}

View file

@ -265,19 +265,21 @@ type packetTracer struct {
rcvdLongHdr []packet
}
var _ logging.ConnectionTracer = &packetTracer{}
func newPacketTracer() *packetTracer {
return &packetTracer{closed: make(chan struct{})}
}
func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, frames []logging.Frame) {
func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
}
func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, frames []logging.Frame) {
func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
}
func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, ack *wire.AckFrame, frames []logging.Frame) {
func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) {
if ack != nil {
frames = append(frames, ack)
}

View file

@ -14,10 +14,11 @@ func NewAckHandler(
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective,
tracer logging.ConnectionTracer,
logger utils.Logger,
) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger)
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
return sph, newReceivedPacketHandler(sph, rttStats, logger)
}

267
internal/ackhandler/ecn.go Normal file
View file

@ -0,0 +1,267 @@
package ackhandler
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
type ecnState uint8
const (
ecnStateInitial ecnState = iota
ecnStateTesting
ecnStateUnknown
ecnStateCapable
ecnStateFailed
)
// must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type
const numECNTestingPackets = 10
type ecnHandler interface {
SentPacket(protocol.PacketNumber, protocol.ECN)
Mode() protocol.ECN
HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool)
LostPacket(protocol.PacketNumber)
}
// The ecnTracker performs ECN validation of a path.
// Once failed, it doesn't do any re-validation of the path.
// It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces.
// In order to avoid revealing any internal state to on-path observers,
// callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent.
// The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4.
type ecnTracker struct {
state ecnState
numSentTesting, numLostTesting uint8
firstTestingPacket protocol.PacketNumber
lastTestingPacket protocol.PacketNumber
firstCapablePacket protocol.PacketNumber
numSentECT0, numSentECT1 int64
numAckedECT0, numAckedECT1, numAckedECNCE int64
tracer logging.ConnectionTracer
logger utils.Logger
}
var _ ecnHandler = &ecnTracker{}
func newECNTracker(logger utils.Logger, tracer logging.ConnectionTracer) *ecnTracker {
return &ecnTracker{
firstTestingPacket: protocol.InvalidPacketNumber,
lastTestingPacket: protocol.InvalidPacketNumber,
firstCapablePacket: protocol.InvalidPacketNumber,
state: ecnStateInitial,
logger: logger,
tracer: tracer,
}
}
func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) {
//nolint:exhaustive // These are the only ones we need to take care of.
switch ecn {
case protocol.ECNNon:
return
case protocol.ECT0:
e.numSentECT0++
case protocol.ECT1:
e.numSentECT1++
case protocol.ECNUnsupported:
if e.state != ecnStateFailed {
panic("didn't expect ECN to be unsupported")
}
default:
panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn))
}
if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber {
e.firstCapablePacket = pn
}
if e.state != ecnStateTesting {
return
}
e.numSentTesting++
if e.firstTestingPacket == protocol.InvalidPacketNumber {
e.firstTestingPacket = pn
}
if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets {
if e.tracer != nil {
e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateUnknown
e.lastTestingPacket = pn
}
}
func (e *ecnTracker) Mode() protocol.ECN {
switch e.state {
case ecnStateInitial:
if e.tracer != nil {
e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateTesting
return e.Mode()
case ecnStateTesting, ecnStateCapable:
return protocol.ECT0
case ecnStateUnknown, ecnStateFailed:
return protocol.ECNNon
default:
panic(fmt.Sprintf("unknown ECN state: %d", e.state))
}
}
func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) {
if e.state != ecnStateTesting && e.state != ecnStateUnknown {
return
}
if !e.isTestingPacket(pn) {
return
}
e.numLostTesting++
if e.numLostTesting >= e.numSentTesting {
e.logger.Debugf("Disabling ECN. All testing packets were lost.")
if e.tracer != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets)
}
e.state = ecnStateFailed
}
}
// HandleNewlyAcked handles the ECN counts on an ACK frame.
// It must only be called for ACK frames that increase the largest acknowledged packet number,
// see section 13.4.2.1 of RFC 9000.
func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) {
if e.state == ecnStateFailed {
return false
}
// ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds
// the total number of packets sent with each corresponding ECT codepoint.
if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 {
e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.")
if e.tracer != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent)
}
e.state = ecnStateFailed
return false
}
// Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged.
var ackedECT0, ackedECT1 int64
for _, p := range packets {
//nolint:exhaustive // We only ever send ECT(0) and ECT(1).
switch e.ecnMarking(p.PacketNumber) {
case protocol.ECT0:
ackedECT0++
case protocol.ECT1:
ackedECT1++
}
}
// If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1)
// codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame.
// This check detects:
// * paths that bleach all ECN marks, and
// * peers that don't report any ECN counts
if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 {
e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.")
if e.tracer != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts)
}
e.state = ecnStateFailed
return false
}
// Determine the increase in ECT0, ECT1 and ECNCE marks
newECT0 := ect0 - e.numAckedECT0
newECT1 := ect1 - e.numAckedECT1
newECNCE := ecnce - e.numAckedECNCE
// We're only processing ACKs that increase the Largest Acked.
// Therefore, the ECN counters should only ever increase.
// Any decrease means that the peer's counting logic is broken.
if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 {
e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.")
if e.tracer != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts)
}
e.state = ecnStateFailed
return false
}
// ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number
// of newly acknowledged packets that were originally sent with an ECT(0) marking.
// This could be the result of (partial) bleaching.
if newECT0+newECNCE < ackedECT0 {
e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).")
if e.tracer != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
}
e.state = ecnStateFailed
return false
}
// Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than
// the number of newly acknowledged packets sent with an ECT(1) marking.
if newECT1+newECNCE < ackedECT1 {
e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).")
if e.tracer != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
}
e.state = ecnStateFailed
return false
}
if e.state == ecnStateTesting || e.state == ecnStateUnknown {
var ackedTestingPacket bool
for _, p := range packets {
if e.isTestingPacket(p.PacketNumber) {
ackedTestingPacket = true
break
}
}
// This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE).
if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) {
e.logger.Debugf("ECN capability confirmed.")
if e.tracer != nil {
e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateCapable
}
}
// update our counters
e.numAckedECT0 = ect0
e.numAckedECT1 = ect1
e.numAckedECNCE = ecnce
return newECNCE > 0
}
func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN {
if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber {
return protocol.ECNNon
}
if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber {
return protocol.ECT0
}
if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber {
return protocol.ECNNon
}
// We don't need to deal with the case when ECN validation fails,
// since we're ignoring any ECN counts reported in ACK frames in that case.
return protocol.ECT0
}
func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool {
if e.firstTestingPacket == protocol.InvalidPacketNumber {
return false
}
return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber)
}

View file

@ -0,0 +1,192 @@
package ackhandler
import (
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("ECN tracker", func() {
var ecnTracker *ecnTracker
var tracer *mocklogging.MockConnectionTracer
getAckedPackets := func(pns ...protocol.PacketNumber) []*packet {
var packets []*packet
for _, p := range pns {
packets = append(packets, &packet{PacketNumber: p})
}
return packets
}
BeforeEach(func() {
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
ecnTracker = newECNTracker(utils.DefaultLogger, tracer)
})
It("sends exactly 10 testing packets", func() {
tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger)
for i := 0; i < 9; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0))
// Do this twice to make sure only sent packets are counted
Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0))
ecnTracker.SentPacket(protocol.PacketNumber(10+i), protocol.ECT0)
}
Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0))
tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger)
ecnTracker.SentPacket(20, protocol.ECT0)
// In unknown state, packets shouldn't be ECN-marked.
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
})
sendAllTestingPackets := func() {
tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger)
tracer.EXPECT().ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger)
for i := 0; i < 10; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0)
}
}
It("fails ECN validation if all ECN testing packets are lost", func() {
sendAllTestingPackets()
for i := 10; i < 20; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon)
}
for i := 0; i < 9; i++ {
ecnTracker.LostPacket(protocol.PacketNumber(i))
}
// We don't care about the loss of non-testing packets
ecnTracker.LostPacket(15)
// Now lose the last testing packet.
tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets)
ecnTracker.LostPacket(9)
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
// We still don't care about more non-testing packets being lost
ecnTracker.LostPacket(16)
})
It("passes ECN validation when a testing packet is acknowledged, while still in testing state", func() {
tracer.EXPECT().ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger)
for i := 0; i < 5; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0)
}
tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(3), 1, 0, 0)).To(BeFalse())
// make sure we continue sending ECT(0) packets
for i := 5; i < 100; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECT0))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECT0)
}
})
It("passes ECN validation when a testing packet is acknowledged, while in unknown state", func() {
sendAllTestingPackets()
for i := 10; i < 20; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon)
}
// Lose some packets to make sure this doesn't influence the outcome.
for i := 0; i < 5; i++ {
ecnTracker.LostPacket(protocol.PacketNumber(i))
}
tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
Expect(ecnTracker.HandleNewlyAcked([]*packet{{PacketNumber: 7}}, 1, 0, 0)).To(BeFalse())
})
It("fails ECN validation when the ACK contains more ECN counts than we sent packets", func() {
sendAllTestingPackets()
for i := 10; i < 20; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon)
}
// only 10 ECT(0) packets were sent, but the ACK claims to have received 12 of them
tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 12, 0, 0)).To(BeFalse())
})
It("fails ECN validation when the ACK contains ECN counts for the wrong code point", func() {
sendAllTestingPackets()
for i := 10; i < 20; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon)
}
// We sent ECT(0), but this ACK acknowledges ECT(1).
tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 0, 1, 0)).To(BeFalse())
})
It("fails ECN validation when the ACK doesn't contain ECN counts", func() {
sendAllTestingPackets()
for i := 10; i < 20; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon)
}
// First only acknowledge packets sent without ECN marks.
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(12, 13, 14), 0, 0, 0)).To(BeFalse())
// Now acknowledge some packets sent with ECN marks.
tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 15), 0, 0, 0)).To(BeFalse())
})
It("fails ECN validation when an ACK decreases ECN counts", func() {
sendAllTestingPackets()
for i := 10; i < 20; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon)
}
tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 3, 0, 0)).To(BeFalse())
// Now acknowledge some more packets, but decrease the ECN counts. Obviously, this doesn't make any sense.
tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 13), 2, 0, 0)).To(BeFalse())
// make sure that new ACKs are ignored
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 14), 5, 0, 0)).To(BeFalse())
})
// This can happen if ACK are lost / reordered.
It("doesn't fail validation if the ACK contains more ECN counts than it acknowledges packets", func() {
sendAllTestingPackets()
for i := 10; i < 20; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon)
}
tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 8, 0, 0)).To(BeFalse())
})
It("fails ECN validation when the ACK doesn't contain enough ECN counts", func() {
sendAllTestingPackets()
for i := 10; i < 20; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon)
}
// First only acknowledge some packets sent with ECN marks.
tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 2, 0, 1)).To(BeTrue())
// Now acknowledge some more packets sent with ECN marks, but don't increase the counters enough.
// This ACK acknowledges 3 more ECN-marked packets, but the counters only increase by 2.
tracer.EXPECT().ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 15), 3, 0, 2)).To(BeFalse())
})
It("declares congestion", func() {
sendAllTestingPackets()
for i := 10; i < 20; i++ {
Expect(ecnTracker.Mode()).To(Equal(protocol.ECNNon))
ecnTracker.SentPacket(protocol.PacketNumber(i), protocol.ECNNon)
}
// Receive one CE count.
tracer.EXPECT().ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(1, 2, 3, 12), 2, 0, 1)).To(BeTrue())
// No increase in CE. No congestion.
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(4, 5, 6, 13), 5, 0, 1)).To(BeFalse())
// Increase in CE. More congestion.
Expect(ecnTracker.HandleNewlyAcked(getAckedPackets(7, 8, 9, 14), 7, 0, 2)).To(BeTrue())
})
})

View file

@ -10,7 +10,7 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool)
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool)
// ReceivedAck processes an ACK frame.
// It does not store a copy of the frame.
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error)
@ -29,6 +29,7 @@ type SentPacketHandler interface {
// only to be called once the handshake is complete
QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */
ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber

View file

@ -0,0 +1,87 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/quic-go/quic-go/internal/ackhandler (interfaces: ECNHandler)
// Package ackhandler is a generated GoMock package.
package ackhandler
import (
reflect "reflect"
protocol "github.com/quic-go/quic-go/internal/protocol"
gomock "go.uber.org/mock/gomock"
)
// MockECNHandler is a mock of ECNHandler interface.
type MockECNHandler struct {
ctrl *gomock.Controller
recorder *MockECNHandlerMockRecorder
}
// MockECNHandlerMockRecorder is the mock recorder for MockECNHandler.
type MockECNHandlerMockRecorder struct {
mock *MockECNHandler
}
// NewMockECNHandler creates a new mock instance.
func NewMockECNHandler(ctrl *gomock.Controller) *MockECNHandler {
mock := &MockECNHandler{ctrl: ctrl}
mock.recorder = &MockECNHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockECNHandler) EXPECT() *MockECNHandlerMockRecorder {
return m.recorder
}
// HandleNewlyAcked mocks base method.
func (m *MockECNHandler) HandleNewlyAcked(arg0 []*packet, arg1, arg2, arg3 int64) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandleNewlyAcked", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(bool)
return ret0
}
// HandleNewlyAcked indicates an expected call of HandleNewlyAcked.
func (mr *MockECNHandlerMockRecorder) HandleNewlyAcked(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleNewlyAcked", reflect.TypeOf((*MockECNHandler)(nil).HandleNewlyAcked), arg0, arg1, arg2, arg3)
}
// LostPacket mocks base method.
func (m *MockECNHandler) LostPacket(arg0 protocol.PacketNumber) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "LostPacket", arg0)
}
// LostPacket indicates an expected call of LostPacket.
func (mr *MockECNHandlerMockRecorder) LostPacket(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockECNHandler)(nil).LostPacket), arg0)
}
// Mode mocks base method.
func (m *MockECNHandler) Mode() protocol.ECN {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Mode")
ret0, _ := ret[0].(protocol.ECN)
return ret0
}
// Mode indicates an expected call of Mode.
func (mr *MockECNHandlerMockRecorder) Mode() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mode", reflect.TypeOf((*MockECNHandler)(nil).Mode))
}
// SentPacket mocks base method.
func (m *MockECNHandler) SentPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentPacket", arg0, arg1)
}
// SentPacket indicates an expected call of SentPacket.
func (mr *MockECNHandlerMockRecorder) SentPacket(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockECNHandler)(nil).SentPacket), arg0, arg1)
}

View file

@ -4,3 +4,6 @@ package ackhandler
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker"
type SentPacketTracker = sentPacketTracker
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler"
type ECNHandler = ecnHandler

View file

@ -62,8 +62,8 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
if ackEliciting {
h.maybeQueueACK(pn, rcvTime, isMissing)
}
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE.
switch ecn {
case protocol.ECNNon:
case protocol.ECT0:
h.ect0++
case protocol.ECT1:

View file

@ -92,6 +92,9 @@ type sentPacketHandler struct {
// The alarm timeout
alarm time.Time
enableECN bool
ecnTracker ecnHandler
perspective protocol.Perspective
tracer logging.ConnectionTracer
@ -110,6 +113,7 @@ func newSentPacketHandler(
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective,
tracer logging.ConnectionTracer,
logger utils.Logger,
@ -122,7 +126,7 @@ func newSentPacketHandler(
tracer,
)
return &sentPacketHandler{
h := &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
initialPackets: newPacketNumberSpace(initialPN, false),
@ -134,6 +138,11 @@ func newSentPacketHandler(
tracer: tracer,
logger: logger,
}
if enableECN {
h.enableECN = true
h.ecnTracker = newECNTracker(logger, tracer)
}
return h
}
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
@ -228,6 +237,7 @@ func (h *sentPacketHandler) SentPacket(
streamFrames []StreamFrame,
frames []Frame,
encLevel protocol.EncryptionLevel,
ecn protocol.ECN,
size protocol.ByteCount,
isPathMTUProbePacket bool,
) {
@ -252,6 +262,10 @@ func (h *sentPacketHandler) SentPacket(
}
h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting)
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
h.ecnTracker.SentPacket(pn, ecn)
}
if !isAckEliciting {
pnSpace.history.SentNonAckElicitingPacket(pn)
if !h.peerCompletedAddressValidation {
@ -302,8 +316,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
}
}
pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked)
// Servers complete address validation when a protected packet is received.
if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation &&
(encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) {
@ -333,6 +345,17 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
h.congestion.MaybeExitSlowStart()
}
}
// Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked.
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked {
congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE))
if congested {
h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight)
}
}
pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked)
if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
return false, err
}
@ -635,7 +658,10 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
if !p.IsPathMTUProbePacket {
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight)
}
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
h.ecnTracker.LostPacket(p.PacketNumber)
}
}
}
@ -712,6 +738,16 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN {
if !h.enableECN {
return protocol.ECNUnsupported
}
if !isShortHeaderPacket {
return protocol.ECNNon
}
return h.ecnTracker.Mode()
}
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel)
pn := pnSpace.pns.Peek()

View file

@ -44,7 +44,7 @@ var _ = Describe("SentPacketHandler", func() {
JustBeforeEach(func() {
lostPackets = nil
rttStats := utils.NewRTTStats()
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, perspective, nil, utils.DefaultLogger)
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, false, perspective, nil, utils.DefaultLogger)
streamFrame = wire.StreamFrame{
StreamID: 5,
Data: []byte{0x13, 0x37},
@ -106,7 +106,7 @@ var _ = Describe("SentPacketHandler", func() {
}
sentPacket := func(p *packet) {
handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, p.Length, p.IsPathMTUProbePacket)
handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, protocol.ECNNon, p.Length, p.IsPathMTUProbePacket)
}
expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) {
@ -563,7 +563,7 @@ var _ = Describe("SentPacketHandler", func() {
// lose packet 1
gomock.InOrder(
cong.EXPECT().MaybeExitSlowStart(),
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)),
cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()),
)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
@ -575,7 +575,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(err).ToNot(HaveOccurred())
})
It("doesn't call OnPacketLost when a Path MTU probe packet is lost", func() {
It("doesn't call OnCongestionEvent when a Path MTU probe packet is lost", func() {
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
var mtuPacketDeclaredLost bool
sentPacket(ackElicitingPacket(&packet{
@ -590,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() {
},
}))
sentPacket(ackElicitingPacket(&packet{PacketNumber: 2}))
// lose packet 1, but don't EXPECT any calls to OnPacketLost()
// lose packet 1, but don't EXPECT any calls to OnCongestionEvent()
gomock.InOrder(
cong.EXPECT().MaybeExitSlowStart(),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()),
@ -602,7 +602,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.bytesInFlight).To(BeZero())
})
It("calls OnPacketAcked and OnPacketLost with the right bytes_in_flight value", func() {
It("calls OnPacketAcked and OnCongestionEvent with the right bytes_in_flight value", func() {
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(4)
sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
sentPacket(ackElicitingPacket(&packet{PacketNumber: 2, SendTime: time.Now().Add(-30 * time.Minute)}))
@ -611,7 +611,7 @@ var _ = Describe("SentPacketHandler", func() {
// receive the first ACK
gomock.InOrder(
cong.EXPECT().MaybeExitSlowStart(),
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)),
cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()),
)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
@ -620,7 +620,7 @@ var _ = Describe("SentPacketHandler", func() {
// receive the second ACK
gomock.InOrder(
cong.EXPECT().MaybeExitSlowStart(),
cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)),
cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()),
)
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}}
@ -984,7 +984,7 @@ var _ = Describe("SentPacketHandler", func() {
Context("amplification limit, for the server, with validated address", func() {
JustBeforeEach(func() {
rttStats := utils.NewRTTStats()
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, perspective, nil, utils.DefaultLogger)
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, false, perspective, nil, utils.DefaultLogger)
})
It("do not limits the window", func() {
@ -1429,4 +1429,151 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.rttStats.SmoothedRTT()).To(BeZero())
})
})
Context("ECN handling", func() {
var ecnHandler *MockECNHandler
var cong *mocks.MockSendAlgorithmWithDebugInfos
JustBeforeEach(func() {
cong = mocks.NewMockSendAlgorithmWithDebugInfos(mockCtrl)
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
cong.EXPECT().OnPacketAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
cong.EXPECT().MaybeExitSlowStart().AnyTimes()
ecnHandler = NewMockECNHandler(mockCtrl)
lostPackets = nil
rttStats := utils.NewRTTStats()
rttStats.UpdateRTT(time.Hour, 0, time.Now())
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, false, perspective, nil, utils.DefaultLogger)
handler.ecnTracker = ecnHandler
handler.congestion = cong
})
It("informs about sent packets", func() {
// Check that only 1-RTT packets are reported
handler.SentPacket(time.Now(), 100, -1, nil, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false)
handler.SentPacket(time.Now(), 101, -1, nil, nil, protocol.EncryptionHandshake, protocol.ECT0, 1200, false)
handler.SentPacket(time.Now(), 102, -1, nil, nil, protocol.Encryption0RTT, protocol.ECNCE, 1200, false)
ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(103), protocol.ECT1)
handler.SentPacket(time.Now(), 103, -1, nil, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false)
})
It("informs about sent packets", func() {
// Check that only 1-RTT packets are reported
handler.SentPacket(time.Now(), 100, -1, nil, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false)
handler.SentPacket(time.Now(), 101, -1, nil, nil, protocol.EncryptionHandshake, protocol.ECT0, 1200, false)
handler.SentPacket(time.Now(), 102, -1, nil, nil, protocol.Encryption0RTT, protocol.ECNCE, 1200, false)
ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(103), protocol.ECT1)
handler.SentPacket(time.Now(), 103, -1, nil, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false)
})
It("informs about lost packets", func() {
for i := 10; i < 20; i++ {
ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1)
handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false)
}
cong.EXPECT().OnCongestionEvent(gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(10))
ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(11))
ecnHandler.EXPECT().LostPacket(protocol.PacketNumber(12))
ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 16, Smallest: 13}}}, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
})
It("processes ACKs", func() {
// Check that we only care about 1-RTT packets.
handler.SentPacket(time.Now(), 100, -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.EncryptionInitial, protocol.ECT1, 1200, false)
_, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 100}}}, protocol.EncryptionInitial, time.Now())
Expect(err).ToNot(HaveOccurred())
for i := 10; i < 20; i++ {
ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1)
handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false)
}
ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool {
Expect(packets).To(HaveLen(5))
Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(10)))
Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(11)))
Expect(packets[2].PacketNumber).To(Equal(protocol.PacketNumber(12)))
Expect(packets[3].PacketNumber).To(Equal(protocol.PacketNumber(14)))
Expect(packets[4].PacketNumber).To(Equal(protocol.PacketNumber(15)))
return false
})
_, err = handler.ReceivedAck(&wire.AckFrame{
AckRanges: []wire.AckRange{
{Largest: 15, Smallest: 14},
{Largest: 12, Smallest: 10},
},
ECT0: 1,
ECT1: 2,
ECNCE: 3,
}, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
})
It("ignores reordered ACKs", func() {
for i := 10; i < 20; i++ {
ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1)
handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false)
}
ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool {
Expect(packets).To(HaveLen(2))
Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(11)))
Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(12)))
return false
})
_, err := handler.ReceivedAck(&wire.AckFrame{
AckRanges: []wire.AckRange{{Largest: 12, Smallest: 11}},
ECT0: 1,
ECT1: 2,
ECNCE: 3,
}, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
// acknowledge packet 10 now, but don't increase the largest acked
_, err = handler.ReceivedAck(&wire.AckFrame{
AckRanges: []wire.AckRange{{Largest: 12, Smallest: 10}},
ECT0: 1,
ECNCE: 3,
}, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
})
It("ignores ACKs that don't increase the largest acked", func() {
for i := 10; i < 20; i++ {
ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT1)
handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT1, 1200, false)
}
ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(1), int64(2), int64(3)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool {
Expect(packets).To(HaveLen(1))
Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(11)))
return false
})
_, err := handler.ReceivedAck(&wire.AckFrame{
AckRanges: []wire.AckRange{{Largest: 11, Smallest: 11}},
ECT0: 1,
ECT1: 2,
ECNCE: 3,
}, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
_, err = handler.ReceivedAck(&wire.AckFrame{
AckRanges: []wire.AckRange{{Largest: 11, Smallest: 10}},
ECT0: 1,
ECNCE: 3,
}, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
})
It("informs the congestion controller about CE events", func() {
for i := 10; i < 20; i++ {
ecnHandler.EXPECT().SentPacket(protocol.PacketNumber(i), protocol.ECT0)
handler.SentPacket(time.Now(), protocol.PacketNumber(i), -1, []StreamFrame{{Frame: &streamFrame}}, nil, protocol.Encryption1RTT, protocol.ECT0, 1200, false)
}
ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(0), int64(0), int64(0)).Return(true)
cong.EXPECT().OnCongestionEvent(protocol.PacketNumber(15), gomock.Any(), gomock.Any())
_, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 15, Smallest: 10}}}, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
})
})
})

View file

@ -188,7 +188,7 @@ func (c *cubicSender) OnPacketAcked(
}
}
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {

View file

@ -80,14 +80,14 @@ var _ = Describe("Cubic Sender", func() {
LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) {
for i := 0; i < n; i++ {
ackedPacketNumber++
sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight)
sender.OnCongestionEvent(ackedPacketNumber, packetLength, bytesInFlight)
}
bytesInFlight -= protocol.ByteCount(n) * packetLength
}
// Does not increment acked_packet_number_.
LosePacket := func(number protocol.PacketNumber) {
sender.OnPacketLost(number, maxDatagramSize, bytesInFlight)
sender.OnCongestionEvent(number, maxDatagramSize, bytesInFlight)
bytesInFlight -= maxDatagramSize
}

View file

@ -14,7 +14,7 @@ type SendAlgorithm interface {
CanSend(bytesInFlight protocol.ByteCount) bool
MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
OnRetransmissionTimeout(packetsRetransmitted bool)
SetMaxDatagramSize(protocol.ByteCount)
}

View file

@ -49,6 +49,20 @@ func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0)
}
// ECNMode mocks base method.
func (m *MockSentPacketHandler) ECNMode(arg0 bool) protocol.ECN {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ECNMode", arg0)
ret0, _ := ret[0].(protocol.ECN)
return ret0
}
// ECNMode indicates an expected call of ECNMode.
func (mr *MockSentPacketHandlerMockRecorder) ECNMode(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode), arg0)
}
// GetLossDetectionTimeout mocks base method.
func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time {
m.ctrl.T.Helper()
@ -176,15 +190,15 @@ func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 interface{}) *gomock.
}
// SentPacket mocks base method.
func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ByteCount, arg7 bool) {
func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ECN, arg7 protocol.ByteCount, arg8 bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7)
m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8)
}
// SentPacket indicates an expected call of SentPacket.
func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7 interface{}) *gomock.Call {
func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8)
}
// SetHandshakeConfirmed mocks base method.

View file

@ -117,6 +117,18 @@ func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart))
}
// OnCongestionEvent mocks base method.
func (m *MockSendAlgorithmWithDebugInfos) OnCongestionEvent(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnCongestionEvent", arg0, arg1, arg2)
}
// OnCongestionEvent indicates an expected call of OnCongestionEvent.
func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnCongestionEvent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnCongestionEvent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnCongestionEvent), arg0, arg1, arg2)
}
// OnPacketAcked mocks base method.
func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount, arg3 time.Time) {
m.ctrl.T.Helper()
@ -129,18 +141,6 @@ func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3)
}
// OnPacketLost mocks base method.
func (m *MockSendAlgorithmWithDebugInfos) OnPacketLost(arg0 protocol.PacketNumber, arg1, arg2 protocol.ByteCount) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnPacketLost", arg0, arg1, arg2)
}
// OnPacketLost indicates an expected call of OnPacketLost.
func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketLost(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketLost", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketLost), arg0, arg1, arg2)
}
// OnPacketSent mocks base method.
func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(arg0 time.Time, arg1 protocol.ByteCount, arg2 protocol.PacketNumber, arg3 protocol.ByteCount, arg4 bool) {
m.ctrl.T.Helper()

View file

@ -135,6 +135,18 @@ func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 inter
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2)
}
// ECNStateUpdated mocks base method.
func (m *MockConnectionTracer) ECNStateUpdated(arg0 logging.ECNState, arg1 logging.ECNStateTrigger) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1)
}
// ECNStateUpdated indicates an expected call of ECNStateUpdated.
func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1)
}
// LossTimerCanceled mocks base method.
func (m *MockConnectionTracer) LossTimerCanceled() {
m.ctrl.T.Helper()
@ -184,15 +196,15 @@ func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 i
}
// ReceivedLongHeaderPacket mocks base method.
func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) {
func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2)
m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3)
}
// ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3)
}
// ReceivedRetry mocks base method.
@ -208,15 +220,15 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom
}
// ReceivedShortHeaderPacket mocks base method.
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) {
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2)
m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3)
}
// ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3)
}
// ReceivedTransportParameters mocks base method.
@ -256,27 +268,27 @@ func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 int
}
// SentLongHeaderPacket mocks base method.
func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) {
func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3)
m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4)
}
// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4)
}
// SentShortHeaderPacket mocks base method.
func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) {
func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3)
m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4)
}
// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4)
}
// SentTransportParameters mocks base method.

View file

@ -37,14 +37,48 @@ func (t PacketType) String() string {
type ECN uint8
const (
ECNNon ECN = iota // 00
ECT1 // 01
ECT0 // 10
ECNCE // 11
ECNUnsupported ECN = iota
ECNNon // 00
ECT1 // 01
ECT0 // 10
ECNCE // 11
)
func ParseECNHeaderBits(bits byte) ECN {
switch bits {
case 0:
return ECNNon
case 0b00000010:
return ECT0
case 0b00000001:
return ECT1
case 0b00000011:
return ECNCE
default:
panic("invalid ECN bits")
}
}
func (e ECN) ToHeaderBits() byte {
//nolint:exhaustive // There are only 4 values.
switch e {
case ECNNon:
return 0
case ECT0:
return 0b00000010
case ECT1:
return 0b00000001
case ECNCE:
return 0b00000011
default:
panic("ECN unsupported")
}
}
func (e ECN) String() string {
switch e {
case ECNUnsupported:
return "ECN unsupported"
case ECNNon:
return "Not-ECT"
case ECT1:

View file

@ -17,13 +17,22 @@ var _ = Describe("Protocol", func() {
})
It("converts ECN bits from the IP header wire to the correct types", func() {
Expect(ECN(0)).To(Equal(ECNNon))
Expect(ECN(0b00000010)).To(Equal(ECT0))
Expect(ECN(0b00000001)).To(Equal(ECT1))
Expect(ECN(0b00000011)).To(Equal(ECNCE))
Expect(ParseECNHeaderBits(0)).To(Equal(ECNNon))
Expect(ParseECNHeaderBits(0b00000010)).To(Equal(ECT0))
Expect(ParseECNHeaderBits(0b00000001)).To(Equal(ECT1))
Expect(ParseECNHeaderBits(0b00000011)).To(Equal(ECNCE))
Expect(func() { ParseECNHeaderBits(0b1010101) }).To(Panic())
})
It("converts to IP header bits", func() {
for _, v := range [...]ECN{ECNNon, ECT0, ECT1, ECNCE} {
Expect(ParseECNHeaderBits(v.ToHeaderBits())).To(Equal(v))
}
Expect(func() { ECN(42).ToHeaderBits() }).To(Panic())
})
It("has a string representation for ECN", func() {
Expect(ECNUnsupported.String()).To(Equal("ECN unsupported"))
Expect(ECNNon.String()).To(Equal("Not-ECT"))
Expect(ECT0.String()).To(Equal("ECT(0)"))
Expect(ECT1.String()).To(Equal("ECT(1)"))

View file

@ -15,6 +15,8 @@ import (
type (
// A ByteCount is used to count bytes.
ByteCount = protocol.ByteCount
// ECN is the ECN value
ECN = protocol.ECN
// A ConnectionID is a QUIC Connection ID.
ConnectionID = protocol.ConnectionID
// An ArbitraryLenConnectionID is a QUIC Connection ID that can be up to 255 bytes long.
@ -58,6 +60,19 @@ type (
RTTStats = utils.RTTStats
)
const (
// ECNUnsupported means that no ECN value was set / received
ECNUnsupported = protocol.ECNUnsupported
// ECTNot is Not-ECT
ECTNot = protocol.ECNNon
// ECT0 is ECT(0)
ECT0 = protocol.ECT0
// ECT1 is ECT(1)
ECT1 = protocol.ECT1
// ECNCE is CE
ECNCE = protocol.ECNCE
)
const (
// KeyPhaseZero is key phase bit 0
KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero
@ -113,12 +128,12 @@ type ConnectionTracer interface {
SentTransportParameters(*TransportParameters)
ReceivedTransportParameters(*TransportParameters)
RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT
SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame)
SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame)
SentLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame)
SentShortHeaderPacket(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame)
ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber)
ReceivedRetry(*Header)
ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame)
ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame)
ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, []Frame)
ReceivedShortHeaderPacket(*ShortHeader, ByteCount, ECN, []Frame)
BufferedPacket(PacketType, ByteCount)
DroppedPacket(PacketType, ByteCount, PacketDropReason)
UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int)
@ -133,6 +148,7 @@ type ConnectionTracer interface {
SetLossTimer(TimerType, EncryptionLevel, time.Time)
LossTimerExpired(TimerType, EncryptionLevel)
LossTimerCanceled()
ECNStateUpdated(state ECNState, trigger ECNStateTrigger)
// Close is called when the connection is closed.
Close()
Debug(name, msg string)

View file

@ -134,6 +134,18 @@ func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 inter
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2)
}
// ECNStateUpdated mocks base method.
func (m *MockConnectionTracer) ECNStateUpdated(arg0 ECNState, arg1 ECNStateTrigger) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1)
}
// ECNStateUpdated indicates an expected call of ECNStateUpdated.
func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1)
}
// LossTimerCanceled mocks base method.
func (m *MockConnectionTracer) LossTimerCanceled() {
m.ctrl.T.Helper()
@ -183,15 +195,15 @@ func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 i
}
// ReceivedLongHeaderPacket mocks base method.
func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []Frame) {
func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2)
m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3)
}
// ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3)
}
// ReceivedRetry mocks base method.
@ -207,15 +219,15 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom
}
// ReceivedShortHeaderPacket mocks base method.
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 []Frame) {
func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2)
m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3)
}
// ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3)
}
// ReceivedTransportParameters mocks base method.
@ -255,27 +267,27 @@ func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 int
}
// SentLongHeaderPacket mocks base method.
func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) {
func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3)
m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4)
}
// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4)
}
// SentShortHeaderPacket mocks base method.
func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) {
func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3)
m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4)
}
// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4)
}
// SentTransportParameters mocks base method.

View file

@ -93,15 +93,15 @@ func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParamet
}
}
func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) {
func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range m.tracers {
t.SentLongHeaderPacket(hdr, size, ack, frames)
t.SentLongHeaderPacket(hdr, size, ecn, ack, frames)
}
}
func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) {
func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range m.tracers {
t.SentShortHeaderPacket(hdr, size, ack, frames)
t.SentShortHeaderPacket(hdr, size, ecn, ack, frames)
}
}
@ -117,15 +117,15 @@ func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) {
}
}
func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) {
func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range m.tracers {
t.ReceivedLongHeaderPacket(hdr, size, frames)
t.ReceivedLongHeaderPacket(hdr, size, ecn, frames)
}
}
func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) {
func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range m.tracers {
t.ReceivedShortHeaderPacket(hdr, size, frames)
t.ReceivedShortHeaderPacket(hdr, size, ecn, frames)
}
}
@ -213,6 +213,12 @@ func (m *connTracerMultiplexer) LossTimerCanceled() {
}
}
func (m *connTracerMultiplexer) ECNStateUpdated(state ECNState, trigger ECNStateTrigger) {
for _, t := range m.tracers {
t.ECNStateUpdated(state, trigger)
}
}
func (m *connTracerMultiplexer) Debug(name, msg string) {
for _, t := range m.tracers {
t.Debug(name, msg)

View file

@ -119,18 +119,18 @@ var _ = Describe("Tracing", func() {
hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}}
ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}}
ping := &PingFrame{}
tr1.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tr2.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tracer.SentLongHeaderPacket(hdr, 1337, ack, []Frame{ping})
tr1.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ECTNot, ack, []Frame{ping})
tr2.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ECTNot, ack, []Frame{ping})
tracer.SentLongHeaderPacket(hdr, 1337, ECTNot, ack, []Frame{ping})
})
It("traces the SentShortHeaderPacket event", func() {
hdr := &ShortHeader{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}
ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}}
ping := &PingFrame{}
tr1.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tr2.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tracer.SentShortHeaderPacket(hdr, 1337, ack, []Frame{ping})
tr1.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ECNCE, ack, []Frame{ping})
tr2.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ECNCE, ack, []Frame{ping})
tracer.SentShortHeaderPacket(hdr, 1337, ECNCE, ack, []Frame{ping})
})
It("traces the ReceivedVersionNegotiationPacket event", func() {
@ -151,17 +151,17 @@ var _ = Describe("Tracing", func() {
It("traces the ReceivedLongHeaderPacket event", func() {
hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}}
ping := &PingFrame{}
tr1.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), []Frame{ping})
tr2.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), []Frame{ping})
tracer.ReceivedLongHeaderPacket(hdr, 1337, []Frame{ping})
tr1.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), ECT1, []Frame{ping})
tr2.EXPECT().ReceivedLongHeaderPacket(hdr, ByteCount(1337), ECT1, []Frame{ping})
tracer.ReceivedLongHeaderPacket(hdr, 1337, ECT1, []Frame{ping})
})
It("traces the ReceivedShortHeaderPacket event", func() {
hdr := &ShortHeader{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}
ping := &PingFrame{}
tr1.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), []Frame{ping})
tr2.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), []Frame{ping})
tracer.ReceivedShortHeaderPacket(hdr, 1337, []Frame{ping})
tr1.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), ECT0, []Frame{ping})
tr2.EXPECT().ReceivedShortHeaderPacket(hdr, ByteCount(1337), ECT0, []Frame{ping})
tracer.ReceivedShortHeaderPacket(hdr, 1337, ECT0, []Frame{ping})
})
It("traces the BufferedPacket event", func() {

View file

@ -27,32 +27,37 @@ func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnI
func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) {
}
func (n NullConnectionTracer) ClosedConnection(err error) {}
func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {}
func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, *AckFrame, []Frame) {}
func (n NullConnectionTracer) ClosedConnection(err error) {}
func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) {
}
func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) {
}
func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) {
}
func (n NullConnectionTracer) ReceivedRetry(*Header) {}
func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, []Frame) {}
func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, []Frame) {}
func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {}
func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {}
func (n NullConnectionTracer) ReceivedRetry(*Header) {}
func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, []Frame) {}
func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, ECN, []Frame) {}
func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {}
func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {}
func (n NullConnectionTracer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) {
}
func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {}
func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {}
func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {}
func (n NullConnectionTracer) UpdatedPTOCount(uint32) {}
func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {}
func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {}
func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {}
func (n NullConnectionTracer) DroppedKey(KeyPhase) {}
func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {}
func (n NullConnectionTracer) LossTimerExpired(timerType TimerType, level EncryptionLevel) {}
func (n NullConnectionTracer) LossTimerCanceled() {}
func (n NullConnectionTracer) Close() {}
func (n NullConnectionTracer) Debug(name, msg string) {}
func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {}
func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {}
func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {}
func (n NullConnectionTracer) UpdatedPTOCount(uint32) {}
func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {}
func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {}
func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {}
func (n NullConnectionTracer) DroppedKey(KeyPhase) {}
func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {}
func (n NullConnectionTracer) LossTimerExpired(TimerType, EncryptionLevel) {}
func (n NullConnectionTracer) LossTimerCanceled() {}
func (n NullConnectionTracer) ECNStateUpdated(ECNState, ECNStateTrigger) {}
func (n NullConnectionTracer) Close() {}
func (n NullConnectionTracer) Debug(name, msg string) {}

View file

@ -92,3 +92,35 @@ const (
// CongestionStateApplicationLimited means that the congestion controller is application limited
CongestionStateApplicationLimited
)
// ECNState is the state of the ECN state machine (see Appendix A.4 of RFC 9000)
type ECNState uint8
const (
// ECNStateTesting is the testing state
ECNStateTesting ECNState = 1 + iota
// ECNStateUnknown is the unknown state
ECNStateUnknown
// ECNStateFailed is the failed state
ECNStateFailed
// ECNStateCapable is the capable state
ECNStateCapable
)
// ECNStateTrigger is a trigger for an ECN state transition.
type ECNStateTrigger uint8
const (
ECNTriggerNoTrigger ECNStateTrigger = iota
// ECNFailedNoECNCounts is emitted when an ACK acknowledges ECN-marked packets,
// but doesn't contain any ECN counts
ECNFailedNoECNCounts
// ECNFailedDecreasedECNCounts is emitted when an ACK frame decreases ECN counts
ECNFailedDecreasedECNCounts
// ECNFailedLostAllTestingPackets is emitted when all ECN testing packets are declared lost
ECNFailedLostAllTestingPackets
// ECNFailedMoreECNCountsThanSent is emitted when an ACK contains more ECN counts than ECN-marked packets were sent
ECNFailedMoreECNCountsThanSent
// ECNFailedTooFewECNCounts is emitted when contains fewer ECT(0) / ECT(1) counts than it acknowledges packets
ECNFailedTooFewECNCounts
)

View file

@ -9,6 +9,7 @@ import (
reflect "reflect"
time "time"
protocol "github.com/quic-go/quic-go/internal/protocol"
gomock "go.uber.org/mock/gomock"
)
@ -93,18 +94,18 @@ func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Cal
}
// WritePacket mocks base method.
func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16) (int, error) {
func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16, arg4 protocol.ECN) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3)
ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// WritePacket indicates an expected call of WritePacket.
func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3, arg4)
}
// capabilities mocks base method.

View file

@ -8,6 +8,7 @@ import (
net "net"
reflect "reflect"
protocol "github.com/quic-go/quic-go/internal/protocol"
gomock "go.uber.org/mock/gomock"
)
@ -77,17 +78,17 @@ func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call {
}
// Write mocks base method.
func (m *MockSendConn) Write(arg0 []byte, arg1 uint16) error {
func (m *MockSendConn) Write(arg0 []byte, arg1 uint16, arg2 protocol.ECN) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0, arg1)
ret := m.ctrl.Call(m, "Write", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// Write indicates an expected call of Write.
func (mr *MockSendConnMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockSendConnMockRecorder) Write(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1, arg2)
}
// capabilities mocks base method.

View file

@ -7,6 +7,7 @@ package quic
import (
reflect "reflect"
protocol "github.com/quic-go/quic-go/internal/protocol"
gomock "go.uber.org/mock/gomock"
)
@ -74,15 +75,15 @@ func (mr *MockSenderMockRecorder) Run() *gomock.Call {
}
// Send mocks base method.
func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16) {
func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16, arg2 protocol.ECN) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Send", arg0, arg1)
m.ctrl.Call(m, "Send", arg0, arg1, arg2)
}
// Send indicates an expected call of Send.
func (mr *MockSenderMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockSenderMockRecorder) Send(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1, arg2)
}
// WouldBlock mocks base method.

View file

@ -21,6 +21,8 @@ type connCapabilities struct {
DF bool
// GSO (Generic Segmentation Offload) supported
GSO bool
// ECN (Explicit Congestion Notifications) supported
ECN bool
}
// rawConn is a connection that allow reading of a receivedPackeh.
@ -29,7 +31,7 @@ type rawConn interface {
// WritePacket writes a packet on the wire.
// gsoSize is the size of a single packet, or 0 to disable GSO.
// It is invalid to set gsoSize if capabilities.GSO is not set.
WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error)
WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error)
LocalAddr() net.Addr
SetReadDeadline(time.Time) error
io.Closer

View file

@ -71,6 +71,11 @@ type coalescedPacket struct {
shortHdrPacket *shortHeaderPacket
}
// IsOnlyShortHeaderPacket says if this packet only contains a short header packet (and no long header packets).
func (p *coalescedPacket) IsOnlyShortHeaderPacket() bool {
return len(p.longHdrPackets) == 0 && p.shortHdrPacket != nil
}
func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel {
//nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data).
switch p.header.Type {

View file

@ -655,6 +655,7 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(packet).ToNot(BeNil())
Expect(packet.longHdrPackets).To(HaveLen(1))
Expect(packet.IsOnlyShortHeaderPacket()).To(BeFalse())
// cut off the tag that the mock sealer added
// packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())]
hdr, _, _, err := wire.ParsePacket(packet.buffer.Data)
@ -874,6 +875,7 @@ var _ = Describe("Packet packer", func() {
p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.IsOnlyShortHeaderPacket()).To(BeFalse())
parsePacket(p.buffer.Data)
})
@ -1047,6 +1049,7 @@ var _ = Describe("Packet packer", func() {
packer.retransmissionQueue.addAppData(&wire.PingFrame{})
p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(p.IsOnlyShortHeaderPacket()).To(BeFalse())
Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize))
Expect(p.longHdrPackets).To(HaveLen(1))
Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
@ -1422,6 +1425,7 @@ var _ = Describe("Packet packer", func() {
p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.IsOnlyShortHeaderPacket()).To(BeTrue())
Expect(p.longHdrPackets).To(BeEmpty())
Expect(p.shortHdrPacket).ToNot(BeNil())
packet := p.shortHdrPacket
@ -1448,6 +1452,7 @@ var _ = Describe("Packet packer", func() {
p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT, maxPacketSize, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.IsOnlyShortHeaderPacket()).To(BeTrue())
Expect(p.longHdrPackets).To(BeEmpty())
Expect(p.shortHdrPacket).ToNot(BeNil())
packet := p.shortHdrPacket

View file

@ -158,6 +158,7 @@ type eventPacketSent struct {
PayloadLength logging.ByteCount
Frames frames
IsCoalesced bool
ECN logging.ECN
Trigger string
}
@ -172,6 +173,9 @@ func (e eventPacketSent) MarshalJSONObject(enc *gojay.Encoder) {
enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength})
enc.ArrayKeyOmitEmpty("frames", e.Frames)
enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced)
if e.ECN != logging.ECNUnsupported {
enc.StringKey("ecn", ecn(e.ECN).String())
}
enc.StringKeyOmitEmpty("trigger", e.Trigger)
}
@ -180,6 +184,7 @@ type eventPacketReceived struct {
Length logging.ByteCount
PayloadLength logging.ByteCount
Frames frames
ECN logging.ECN
IsCoalesced bool
Trigger string
}
@ -195,6 +200,9 @@ func (e eventPacketReceived) MarshalJSONObject(enc *gojay.Encoder) {
enc.ObjectKey("raw", rawInfo{Length: e.Length, PayloadLength: e.PayloadLength})
enc.ArrayKeyOmitEmpty("frames", e.Frames)
enc.BoolKeyOmitEmpty("is_coalesced", e.IsCoalesced)
if e.ECN != logging.ECNUnsupported {
enc.StringKey("ecn", ecn(e.ECN).String())
}
enc.StringKeyOmitEmpty("trigger", e.Trigger)
}
@ -516,6 +524,20 @@ func (e eventCongestionStateUpdated) MarshalJSONObject(enc *gojay.Encoder) {
enc.StringKey("new", e.state.String())
}
type eventECNStateUpdated struct {
state logging.ECNState
trigger logging.ECNStateTrigger
}
func (e eventECNStateUpdated) Category() category { return categoryRecovery }
func (e eventECNStateUpdated) Name() string { return "ecn_state_updated" }
func (e eventECNStateUpdated) IsNil() bool { return false }
func (e eventECNStateUpdated) MarshalJSONObject(enc *gojay.Encoder) {
enc.StringKey("new", ecnState(e.state).String())
enc.StringKeyOmitEmpty("trigger", ecnStateTrigger(e.trigger).String())
}
type eventGeneric struct {
name string
msg string

View file

@ -253,15 +253,33 @@ func (t *connectionTracer) toTransportParameters(tp *wire.TransportParameters) *
}
}
func (t *connectionTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
t.sentPacket(*transformLongHeader(hdr), packetSize, hdr.Length, ack, frames)
func (t *connectionTracer) SentLongHeaderPacket(
hdr *logging.ExtendedHeader,
size logging.ByteCount,
ecn logging.ECN,
ack *logging.AckFrame,
frames []logging.Frame,
) {
t.sentPacket(*transformLongHeader(hdr), size, hdr.Length, ecn, ack, frames)
}
func (t *connectionTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
t.sentPacket(*transformShortHeader(hdr), packetSize, 0, ack, frames)
func (t *connectionTracer) SentShortHeaderPacket(
hdr *logging.ShortHeader,
size logging.ByteCount,
ecn logging.ECN,
ack *logging.AckFrame,
frames []logging.Frame,
) {
t.sentPacket(*transformShortHeader(hdr), size, 0, ecn, ack, frames)
}
func (t *connectionTracer) sentPacket(hdr gojay.MarshalerJSONObject, packetSize, payloadLen logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
func (t *connectionTracer) sentPacket(
hdr gojay.MarshalerJSONObject,
size, payloadLen logging.ByteCount,
ecn logging.ECN,
ack *logging.AckFrame,
frames []logging.Frame,
) {
numFrames := len(frames)
if ack != nil {
numFrames++
@ -276,14 +294,15 @@ func (t *connectionTracer) sentPacket(hdr gojay.MarshalerJSONObject, packetSize,
t.mutex.Lock()
t.recordEvent(time.Now(), &eventPacketSent{
Header: hdr,
Length: packetSize,
Length: size,
PayloadLength: payloadLen,
ECN: ecn,
Frames: fs,
})
t.mutex.Unlock()
}
func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, packetSize logging.ByteCount, frames []logging.Frame) {
func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
fs := make([]frame, len(frames))
for i, f := range frames {
fs[i] = frame{Frame: f}
@ -292,14 +311,15 @@ func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader,
t.mutex.Lock()
t.recordEvent(time.Now(), &eventPacketReceived{
Header: header,
Length: packetSize,
Length: size,
PayloadLength: hdr.Length,
ECN: ecn,
Frames: fs,
})
t.mutex.Unlock()
}
func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, packetSize logging.ByteCount, frames []logging.Frame) {
func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
fs := make([]frame, len(frames))
for i, f := range frames {
fs[i] = frame{Frame: f}
@ -308,8 +328,9 @@ func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, p
t.mutex.Lock()
t.recordEvent(time.Now(), &eventPacketReceived{
Header: header,
Length: packetSize,
PayloadLength: packetSize - wire.ShortHeaderLen(hdr.DestConnectionID, hdr.PacketNumberLen),
Length: size,
PayloadLength: size - wire.ShortHeaderLen(hdr.DestConnectionID, hdr.PacketNumberLen),
ECN: ecn,
Frames: fs,
})
t.mutex.Unlock()
@ -482,6 +503,12 @@ func (t *connectionTracer) LossTimerCanceled() {
t.mutex.Unlock()
}
func (t *connectionTracer) ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) {
t.mutex.Lock()
t.recordEvent(time.Now(), &eventECNStateUpdated{state: state, trigger: trigger})
t.mutex.Unlock()
}
func (t *connectionTracer) Debug(name, msg string) {
t.mutex.Lock()
t.recordEvent(time.Now(), &eventGeneric{

View file

@ -419,6 +419,7 @@ var _ = Describe("Tracing", func() {
PacketNumber: 1337,
},
987,
logging.ECNCE,
nil,
[]logging.Frame{
&logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987},
@ -439,6 +440,7 @@ var _ = Describe("Tracing", func() {
Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337)))
Expect(hdr).To(HaveKeyWithValue("scid", "04030201"))
Expect(ev).To(HaveKey("frames"))
Expect(ev).To(HaveKeyWithValue("ecn", "CE"))
frames := ev["frames"].([]interface{})
Expect(frames).To(HaveLen(2))
Expect(frames[0].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "max_stream_data"))
@ -452,6 +454,7 @@ var _ = Describe("Tracing", func() {
PacketNumber: 1337,
},
123,
logging.ECNUnsupported,
&logging.AckFrame{AckRanges: []logging.AckRange{{Smallest: 1, Largest: 10}}},
[]logging.Frame{&logging.MaxDataFrame{MaximumData: 987}},
)
@ -461,6 +464,7 @@ var _ = Describe("Tracing", func() {
Expect(raw).To(HaveKeyWithValue("length", float64(123)))
Expect(raw).ToNot(HaveKey("payload_length"))
Expect(ev).To(HaveKey("header"))
Expect(ev).ToNot(HaveKey("ecn"))
hdr := ev["header"].(map[string]interface{})
Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT"))
Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337)))
@ -485,6 +489,7 @@ var _ = Describe("Tracing", func() {
PacketNumber: 1337,
},
789,
logging.ECT0,
[]logging.Frame{
&logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987},
&logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true},
@ -498,6 +503,7 @@ var _ = Describe("Tracing", func() {
raw := ev["raw"].(map[string]interface{})
Expect(raw).To(HaveKeyWithValue("length", float64(789)))
Expect(raw).To(HaveKeyWithValue("payload_length", float64(1234)))
Expect(ev).To(HaveKeyWithValue("ecn", "ECT(0)"))
Expect(ev).To(HaveKey("header"))
hdr := ev["header"].(map[string]interface{})
Expect(hdr).To(HaveKeyWithValue("packet_type", "initial"))
@ -520,6 +526,7 @@ var _ = Describe("Tracing", func() {
tracer.ReceivedShortHeaderPacket(
shdr,
789,
logging.ECT1,
[]logging.Frame{
&logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987},
&logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true},
@ -533,6 +540,7 @@ var _ = Describe("Tracing", func() {
raw := ev["raw"].(map[string]interface{})
Expect(raw).To(HaveKeyWithValue("length", float64(789)))
Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-(1+8+3))))
Expect(ev).To(HaveKeyWithValue("ecn", "ECT(1)"))
Expect(ev).To(HaveKey("header"))
hdr := ev["header"].(map[string]interface{})
Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT"))
@ -857,6 +865,27 @@ var _ = Describe("Tracing", func() {
Expect(ev).To(HaveKeyWithValue("event_type", "cancelled"))
})
It("records an ECN state transition, without a trigger", func() {
tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger)
entry := exportAndParseSingle()
Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond)))
Expect(entry.Name).To(Equal("recovery:ecn_state_updated"))
ev := entry.Event
Expect(ev).To(HaveLen(1))
Expect(ev).To(HaveKeyWithValue("new", "unknown"))
})
It("records an ECN state transition, with a trigger", func() {
tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts)
entry := exportAndParseSingle()
Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond)))
Expect(entry.Name).To(Equal("recovery:ecn_state_updated"))
ev := entry.Event
Expect(ev).To(HaveLen(2))
Expect(ev).To(HaveKeyWithValue("new", "failed"))
Expect(ev).To(HaveKeyWithValue("trigger", "ACK doesn't contain ECN marks"))
})
It("records a generic event", func() {
tracer.Debug("foo", "bar")
entry := exportAndParseSingle()

View file

@ -312,3 +312,59 @@ func (s congestionState) String() string {
return "unknown congestion state"
}
}
type ecn logging.ECN
func (e ecn) String() string {
//nolint:exhaustive // The unsupported value is never logged.
switch logging.ECN(e) {
case logging.ECTNot:
return "Not-ECT"
case logging.ECT0:
return "ECT(0)"
case logging.ECT1:
return "ECT(1)"
case logging.ECNCE:
return "CE"
default:
return "unknown ECN"
}
}
type ecnState logging.ECNState
func (e ecnState) String() string {
switch logging.ECNState(e) {
case logging.ECNStateTesting:
return "testing"
case logging.ECNStateUnknown:
return "unknown"
case logging.ECNStateCapable:
return "capable"
case logging.ECNStateFailed:
return "failed"
default:
return "unknown ECN state"
}
}
type ecnStateTrigger logging.ECNStateTrigger
func (e ecnStateTrigger) String() string {
switch logging.ECNStateTrigger(e) {
case logging.ECNTriggerNoTrigger:
return ""
case logging.ECNFailedNoECNCounts:
return "ACK doesn't contain ECN marks"
case logging.ECNFailedDecreasedECNCounts:
return "ACK decreases ECN counts"
case logging.ECNFailedLostAllTestingPackets:
return "all ECN testing packets declared lost"
case logging.ECNFailedMoreECNCountsThanSent:
return "ACK contains more ECN counts than ECN-marked packets sent"
case logging.ECNFailedTooFewECNCounts:
return "ACK contains fewer new ECN counts than acknowledged ECN-marked packets"
default:
return "unknown ECN state trigger"
}
}

View file

@ -127,4 +127,30 @@ var _ = Describe("Types", func() {
Expect(congestionState(logging.CongestionStateApplicationLimited).String()).To(Equal("application_limited"))
Expect(congestionState(logging.CongestionStateRecovery).String()).To(Equal("recovery"))
})
It("has a string representation for the ECN bits", func() {
Expect(ecn(logging.ECT0).String()).To(Equal("ECT(0)"))
Expect(ecn(logging.ECT1).String()).To(Equal("ECT(1)"))
Expect(ecn(logging.ECNCE).String()).To(Equal("CE"))
Expect(ecn(logging.ECTNot).String()).To(Equal("Not-ECT"))
Expect(ecn(42).String()).To(Equal("unknown ECN"))
})
It("has a string representation for the ECN state", func() {
Expect(ecnState(logging.ECNStateTesting).String()).To(Equal("testing"))
Expect(ecnState(logging.ECNStateUnknown).String()).To(Equal("unknown"))
Expect(ecnState(logging.ECNStateFailed).String()).To(Equal("failed"))
Expect(ecnState(logging.ECNStateCapable).String()).To(Equal("capable"))
Expect(ecnState(42).String()).To(Equal("unknown ECN state"))
})
It("has a string representation for the ECN state trigger", func() {
Expect(ecnStateTrigger(logging.ECNTriggerNoTrigger).String()).To(Equal(""))
Expect(ecnStateTrigger(logging.ECNFailedNoECNCounts).String()).To(Equal("ACK doesn't contain ECN marks"))
Expect(ecnStateTrigger(logging.ECNFailedDecreasedECNCounts).String()).To(Equal("ACK decreases ECN counts"))
Expect(ecnStateTrigger(logging.ECNFailedLostAllTestingPackets).String()).To(Equal("all ECN testing packets declared lost"))
Expect(ecnStateTrigger(logging.ECNFailedMoreECNCountsThanSent).String()).To(Equal("ACK contains more ECN counts than ECN-marked packets sent"))
Expect(ecnStateTrigger(logging.ECNFailedTooFewECNCounts).String()).To(Equal("ACK contains fewer new ECN counts than acknowledged ECN-marked packets"))
Expect(ecnStateTrigger(42).String()).To(Equal("unknown ECN state trigger"))
})
})

View file

@ -3,12 +3,13 @@ package quic
import (
"net"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
// A sendConn allows sending using a simple Write() on a non-connected packet conn.
type sendConn interface {
Write(b []byte, gsoSize uint16) error
Write(b []byte, gsoSize uint16, ecn protocol.ECN) error
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
@ -42,10 +43,9 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge
}
oob := info.OOB()
// add 32 bytes, so we can add the UDP_SEGMENT msg
// increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating
l := len(oob)
oob = append(oob, make([]byte, 32)...)
oob = oob[:l]
oob = append(oob, make([]byte, 64)...)[:l]
return &sconn{
rawConn: c,
localAddr: localAddr,
@ -55,8 +55,8 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge
}
}
func (c *sconn) Write(p []byte, gsoSize uint16) error {
_, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize)
func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error {
_, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn)
if err != nil && isGSOError(err) {
// disable GSO for future calls
c.gotGSOError = true
@ -69,7 +69,7 @@ func (c *sconn) Write(p []byte, gsoSize uint16) error {
if l > int(gsoSize) {
l = int(gsoSize)
}
if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0); err != nil {
if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil {
return err
}
p = p[l:]

View file

@ -5,6 +5,7 @@ import (
"net/netip"
"runtime"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
. "github.com/onsi/ginkgo/v2"
@ -45,8 +46,8 @@ var _ = Describe("Connection (for sending packets)", func() {
pi := packetInfo{addr: netip.IPv6Loopback()}
Expect(pi.OOB()).ToNot(BeEmpty())
c := newSendConn(rawConn, remoteAddr, pi, utils.DefaultLogger)
rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0))
Expect(c.Write([]byte("foobar"), 0)).To(Succeed())
rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0), protocol.ECT1)
Expect(c.Write([]byte("foobar"), 0, protocol.ECT1)).To(Succeed())
})
}
@ -55,8 +56,8 @@ var _ = Describe("Connection (for sending packets)", func() {
rawConn.EXPECT().LocalAddr()
rawConn.EXPECT().capabilities().AnyTimes()
c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger)
rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3))
Expect(c.Write([]byte("foobar"), 3)).To(Succeed())
rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3), protocol.ECNCE)
Expect(c.Write([]byte("foobar"), 3, protocol.ECNCE)).To(Succeed())
})
if platformSupportsGSO {
@ -67,11 +68,11 @@ var _ = Describe("Connection (for sending packets)", func() {
c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger)
Expect(c.capabilities().GSO).To(BeTrue())
gomock.InOrder(
rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4)).Return(0, errGSO),
rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0)).Return(4, nil),
rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0)).Return(2, nil),
rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4), protocol.ECNCE).Return(0, errGSO),
rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(4, nil),
rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(2, nil),
)
Expect(c.Write([]byte("foobar"), 4)).To(Succeed())
Expect(c.Write([]byte("foobar"), 4, protocol.ECNCE)).To(Succeed())
Expect(c.capabilities().GSO).To(BeFalse())
})
}

View file

@ -1,8 +1,9 @@
package quic
import "github.com/quic-go/quic-go/internal/protocol"
type sender interface {
// Send sends a packet. GSO is only used if gsoSize > 0.
Send(p *packetBuffer, gsoSize uint16)
Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN)
Run() error
WouldBlock() bool
Available() <-chan struct{}
@ -12,6 +13,7 @@ type sender interface {
type queueEntry struct {
buf *packetBuffer
gsoSize uint16
ecn protocol.ECN
}
type sendQueue struct {
@ -39,9 +41,9 @@ func newSendQueue(conn sendConn) sender {
// Send sends out a packet. It's guaranteed to not block.
// Callers need to make sure that there's actually space in the send queue by calling WouldBlock.
// Otherwise Send will panic.
func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16) {
func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) {
select {
case h.queue <- queueEntry{buf: p, gsoSize: gsoSize}:
case h.queue <- queueEntry{buf: p, gsoSize: gsoSize, ecn: ecn}:
// clear available channel if we've reached capacity
if len(h.queue) == sendQueueCapacity {
select {
@ -76,7 +78,7 @@ func (h *sendQueue) Run() error {
// make sure that all queued packets are actually sent out
shouldClose = true
case e := <-h.queue:
if err := h.conn.Write(e.buf.Data, e.gsoSize); err != nil {
if err := h.conn.Write(e.buf.Data, e.gsoSize, e.ecn); err != nil {
// This additional check enables:
// 1. Checking for "datagram too large" message from the kernel, as such,
// 2. Path MTU discovery,and

View file

@ -3,6 +3,8 @@ package quic
import (
"errors"
"github.com/quic-go/quic-go/internal/protocol"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"go.uber.org/mock/gomock"
@ -26,10 +28,10 @@ var _ = Describe("Send Queue", func() {
It("sends a packet", func() {
p := getPacket([]byte("foobar"))
q.Send(p, 10) // make sure the packet size is passed through to the conn
q.Send(p, 10, protocol.ECT1) // make sure the packet size is passed through to the conn
written := make(chan struct{})
c.EXPECT().Write([]byte("foobar"), uint16(10)).Do(func([]byte, uint16) { close(written) })
c.EXPECT().Write([]byte("foobar"), uint16(10), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(written) })
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -45,19 +47,19 @@ var _ = Describe("Send Queue", func() {
It("panics when Send() is called although there's no space in the queue", func() {
for i := 0; i < sendQueueCapacity; i++ {
Expect(q.WouldBlock()).To(BeFalse())
q.Send(getPacket([]byte("foobar")), 6)
q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
}
Expect(q.WouldBlock()).To(BeTrue())
Expect(func() { q.Send(getPacket([]byte("raboof")), 6) }).To(Panic())
Expect(func() { q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon) }).To(Panic())
})
It("signals when sending is possible again", func() {
Expect(q.WouldBlock()).To(BeFalse())
q.Send(getPacket([]byte("foobar1")), 6)
q.Send(getPacket([]byte("foobar1")), 6, protocol.ECNNon)
Consistently(q.Available()).ShouldNot(Receive())
// now start sending out packets. This should free up queue space.
c.EXPECT().Write(gomock.Any(), gomock.Any()).MinTimes(1).MaxTimes(2)
c.EXPECT().Write(gomock.Any(), gomock.Any(), protocol.ECNNon).MinTimes(1).MaxTimes(2)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -67,7 +69,7 @@ var _ = Describe("Send Queue", func() {
Eventually(q.Available()).Should(Receive())
Expect(q.WouldBlock()).To(BeFalse())
Expect(func() { q.Send(getPacket([]byte("foobar2")), 7) }).ToNot(Panic())
Expect(func() { q.Send(getPacket([]byte("foobar2")), 7, protocol.ECNNon) }).ToNot(Panic())
q.Close()
Eventually(done).Should(BeClosed())
@ -77,7 +79,7 @@ var _ = Describe("Send Queue", func() {
write := make(chan struct{}, 1)
written := make(chan struct{}, 100)
// now start sending out packets. This should free up queue space.
c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16) error {
c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16, protocol.ECN) error {
written <- struct{}{}
<-write
return nil
@ -92,19 +94,19 @@ var _ = Describe("Send Queue", func() {
close(done)
}()
q.Send(getPacket([]byte("foobar")), 6)
q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
<-written
// now fill up the send queue
for i := 0; i < sendQueueCapacity; i++ {
Expect(q.WouldBlock()).To(BeFalse())
q.Send(getPacket([]byte("foobar")), 6)
q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
}
// One more packet is queued when it's picked up by Run and written to the connection.
// In this test, it's blocked on write channel in the mocked Write call.
<-written
Eventually(q.WouldBlock()).Should(BeFalse())
q.Send(getPacket([]byte("foobar")), 6)
q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
Expect(q.WouldBlock()).To(BeTrue())
Consistently(q.Available()).ShouldNot(Receive())
@ -130,15 +132,15 @@ var _ = Describe("Send Queue", func() {
// the run loop exits if there is a write error
testErr := errors.New("test error")
c.EXPECT().Write(gomock.Any(), gomock.Any()).Return(testErr)
q.Send(getPacket([]byte("foobar")), 6)
c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(testErr)
q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
Eventually(done).Should(BeClosed())
sent := make(chan struct{})
go func() {
defer GinkgoRecover()
q.Send(getPacket([]byte("raboof")), 6)
q.Send(getPacket([]byte("quux")), 4)
q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon)
q.Send(getPacket([]byte("quux")), 4, protocol.ECNNon)
close(sent)
}()
@ -147,7 +149,7 @@ var _ = Describe("Send Queue", func() {
It("blocks Close() until the packet has been sent out", func() {
written := make(chan []byte)
c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16) { written <- p })
c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16, _ protocol.ECN) { written <- p })
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -155,7 +157,7 @@ var _ = Describe("Send Queue", func() {
close(done)
}()
q.Send(getPacket([]byte("foobar")), 6)
q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
closed := make(chan struct{})
go func() {

View file

@ -745,7 +745,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe
if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
}
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0)
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
return err
}
@ -844,7 +844,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
}
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0)
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
return err
}
@ -882,7 +882,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
if s.tracer != nil {
s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
}
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil {
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)
}
}

View file

@ -104,10 +104,13 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) {
}, nil
}
func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16) (n int, err error) {
func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, ecn protocol.ECN) (n int, err error) {
if gsoSize != 0 {
panic("cannot use GSO with a basicConn")
}
if ecn != protocol.ECNUnsupported {
panic("cannot use ECN with a basicConn")
}
return c.PacketConn.WriteTo(b, addr)
}

View file

@ -15,6 +15,8 @@ const (
ipv4PKTINFO = unix.IP_RECVPKTINFO
)
const ecnIPv4DataLen = 4
// ReadBatch only returns a single packet on OSX,
// see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch.
const batchSize = 1

View file

@ -14,6 +14,8 @@ const (
ipv4PKTINFO = 0x7
)
const ecnIPv4DataLen = 4
const batchSize = 8
func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) {

View file

@ -19,6 +19,8 @@ const (
ipv4PKTINFO = unix.IP_PKTINFO
)
const ecnIPv4DataLen = 4
const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed)
func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error {

View file

@ -8,9 +8,12 @@ import (
"log"
"net"
"net/netip"
"os"
"strconv"
"sync"
"syscall"
"time"
"unsafe"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@ -56,6 +59,11 @@ func inspectWriteBuffer(c syscall.RawConn) (int, error) {
return size, serr
}
func isECNDisabled() bool {
disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_ECN"))
return err == nil && disabled
}
type oobConn struct {
OOBCapablePacketConn
batchConn batchConn
@ -140,6 +148,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
cap: connCapabilities{
DF: supportsDF,
GSO: isGSOSupported(rawConn),
ECN: !isECNDisabled(),
},
}
for i := 0; i < batchSize; i++ {
@ -188,7 +197,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
if hdr.Level == unix.IPPROTO_IP {
switch hdr.Type {
case msgTypeIPTOS:
p.ecn = protocol.ECN(body[0] & ecnMask)
p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask)
case ipv4PKTINFO:
ip, ifIndex, ok := parseIPv4PktInfo(body)
if ok {
@ -205,7 +214,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
if hdr.Level == unix.IPPROTO_IPV6 {
switch hdr.Type {
case unix.IPV6_TCLASS:
p.ecn = protocol.ECN(body[0] & ecnMask)
p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask)
case unix.IPV6_PKTINFO:
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
@ -228,7 +237,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
}
// WritePacket writes a new packet.
func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) {
func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) {
oob := packetInfoOOB
if gsoSize > 0 {
if !c.capabilities().GSO {
@ -236,6 +245,18 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso
}
oob = appendUDPSegmentSizeMsg(oob, gsoSize)
}
if ecn != protocol.ECNUnsupported {
if !c.capabilities().ECN {
panic("tried to send a ECN-marked packet although ECN is disabled")
}
if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok {
if remoteUDPAddr.IP.To4() != nil {
oob = appendIPv4ECNMsg(oob, ecn)
} else {
oob = appendIPv6ECNMsg(oob, ecn)
}
}
}
n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
return n, err
}
@ -279,3 +300,32 @@ func (info *packetInfo) OOB() []byte {
}
return nil
}
func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte {
startLen := len(b)
b = append(b, make([]byte, unix.CmsgSpace(ecnIPv4DataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_IP
h.Type = unix.IP_TOS
h.SetLen(unix.CmsgLen(ecnIPv4DataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
b[offset] = val.ToHeaderBits()
return b
}
func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte {
startLen := len(b)
const dataLen = 4
b = append(b, make([]byte, unix.CmsgSpace(dataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_IPV6
h.Type = unix.IPV6_TCLASS
h.SetLen(unix.CmsgLen(dataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
b[offset] = val.ToHeaderBits()
return b
}

View file

@ -53,7 +53,7 @@ var _ = Describe("OOB Conn Test", func() {
return udpConn, packetChan
}
Context("ECN conn", func() {
Context("reading ECN-marked packets", func() {
sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr {
conn, err := net.DialUDP(network, nil, addr)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
@ -139,6 +139,42 @@ var _ = Describe("OOB Conn Test", func() {
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse())
Expect(p.ecn).To(Equal(protocol.ECT1))
})
It("sends packets with ECN on IPv4", func() {
conn, packetChan := runServer("udp4", "localhost:0")
defer conn.Close()
c, err := net.ListenUDP("udp4", nil)
Expect(err).ToNot(HaveOccurred())
defer c.Close()
for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} {
_, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv4ECNMsg([]byte{}, val), conn.LocalAddr().(*net.UDPAddr))
Expect(err).ToNot(HaveOccurred())
var p receivedPacket
Eventually(packetChan).Should(Receive(&p))
Expect(p.data).To(Equal([]byte("foobar")))
Expect(p.ecn).To(Equal(val))
}
})
It("sends packets with ECN on IPv6", func() {
conn, packetChan := runServer("udp6", "[::1]:0")
defer conn.Close()
c, err := net.ListenUDP("udp6", nil)
Expect(err).ToNot(HaveOccurred())
defer c.Close()
for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} {
_, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv6ECNMsg([]byte{}, val), conn.LocalAddr().(*net.UDPAddr))
Expect(err).ToNot(HaveOccurred())
var p receivedPacket
Eventually(packetChan).Should(Receive(&p))
Expect(p.data).To(Equal([]byte("foobar")))
Expect(p.ecn).To(Equal(val))
}
})
})
Context("Packet Info conn", func() {
@ -253,6 +289,27 @@ var _ = Describe("OOB Conn Test", func() {
})
})
Context("sending ECN-marked packets", func() {
It("sets the ECN control message", func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
c := &oobRecordingConn{UDPConn: udpConn}
oobConn, err := newConn(c, true)
Expect(err).ToNot(HaveOccurred())
oob := make([]byte, 0, 123)
oobConn.WritePacket([]byte("foobar"), addr, oob, 0, protocol.ECNCE)
Expect(c.oobs).To(HaveLen(1))
oobMsg := c.oobs[0]
Expect(oobMsg).ToNot(BeEmpty())
Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob
expected := appendIPv4ECNMsg([]byte{}, protocol.ECNCE)
Expect(oobMsg).To(Equal(expected))
})
})
if platformSupportsGSO {
Context("GSO", func() {
It("appends the GSO control message", func() {
@ -265,14 +322,15 @@ var _ = Describe("OOB Conn Test", func() {
Expect(err).ToNot(HaveOccurred())
Expect(oobConn.capabilities().GSO).To(BeTrue())
oob := make([]byte, 0, 42)
oobConn.WritePacket([]byte("foobar"), addr, oob, 3)
oob := make([]byte, 0, 123)
oobConn.WritePacket([]byte("foobar"), addr, oob, 3, protocol.ECNCE)
Expect(c.oobs).To(HaveLen(1))
oobMsg := c.oobs[0]
Expect(oobMsg).ToNot(BeEmpty())
Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob
expected := appendUDPSegmentSizeMsg([]byte{}, 3)
Expect(oobMsg).To(Equal(expected))
// Check that the first control message is the OOB control message.
Expect(oobMsg[:len(expected)]).To(Equal(expected))
})
})
}

View file

@ -228,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil {
return 0, err
}
return t.conn.WritePacket(b, addr, nil, 0)
return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
}
func (t *Transport) enqueueClosePacket(p closePacket) {
@ -246,7 +246,7 @@ func (t *Transport) runSendQueue() {
case <-t.listening:
return
case p := <-t.closeQueue:
t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0)
t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported)
case p := <-t.statelessResetQueue:
t.sendStatelessReset(p)
}
@ -414,7 +414,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) {
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil {
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
}
}