send and track packets with ECN markings

This commit is contained in:
Marten Seemann 2023-08-12 10:08:40 +08:00
parent f919473598
commit 5dd6d91c11
21 changed files with 264 additions and 206 deletions

View file

@ -1830,9 +1830,10 @@ func (s *connection) sendPackets(now time.Time) error {
if err != nil { if err != nil {
return err return err
} }
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) ecn := s.sentPacketHandler.ECNMode()
s.registerPackedShortHeaderPacket(p, now) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
s.sendQueue.Send(buf, 0) s.registerPackedShortHeaderPacket(p, ecn, now)
s.sendQueue.Send(buf, 0, ecn)
// This is kind of a hack. We need to trigger sending again somehow. // This is kind of a hack. We need to trigger sending again somehow.
s.pacingDeadline = deadlineSendImmediately s.pacingDeadline = deadlineSendImmediately
return nil return nil
@ -1852,7 +1853,7 @@ func (s *connection) sendPackets(now time.Time) error {
return err return err
} }
s.sentFirstPacket = true s.sentFirstPacket = true
if err := s.sendPackedCoalescedPacket(packet, now); err != nil { if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now); err != nil {
return err return err
} }
sendMode := s.sentPacketHandler.SendMode(now) sendMode := s.sentPacketHandler.SendMode(now)
@ -1873,7 +1874,8 @@ func (s *connection) sendPackets(now time.Time) error {
func (s *connection) sendPacketsWithoutGSO(now time.Time) error { func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for { for {
buf := getPacketBuffer() buf := getPacketBuffer()
if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil { ecn := s.sentPacketHandler.ECNMode()
if _, err := s.appendOnePacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil {
if err == errNothingToPack { if err == errNothingToPack {
buf.Release() buf.Release()
return nil return nil
@ -1881,7 +1883,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
return err return err
} }
s.sendQueue.Send(buf, 0) s.sendQueue.Send(buf, 0, ecn)
if s.sendQueue.WouldBlock() { if s.sendQueue.WouldBlock() {
return nil return nil
@ -1908,7 +1910,8 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error {
for { for {
var dontSendMore bool var dontSendMore bool
size, err := s.appendPacket(buf, maxSize, now) ecn := s.sentPacketHandler.ECNMode()
size, err := s.appendOnePacket(buf, maxSize, ecn, now)
if err != nil { if err != nil {
if err != errNothingToPack { if err != errNothingToPack {
return err return err
@ -1938,7 +1941,7 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error {
continue continue
} }
s.sendQueue.Send(buf, uint16(maxSize)) s.sendQueue.Send(buf, uint16(maxSize), ecn)
if dontSendMore { if dontSendMore {
return nil return nil
@ -1966,6 +1969,7 @@ func (s *connection) resetPacingDeadline() {
} }
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
ecn := s.sentPacketHandler.ECNMode()
if !s.handshakeConfirmed { if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil { if err != nil {
@ -1974,7 +1978,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if packet == nil { if packet == nil {
return nil return nil
} }
return s.sendPackedCoalescedPacket(packet, time.Now()) return s.sendPackedCoalescedPacket(packet, ecn, time.Now())
} }
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
@ -1984,9 +1988,9 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
} }
return err return err
} }
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, now) s.registerPackedShortHeaderPacket(p, ecn, now)
s.sendQueue.Send(buf, 0) s.sendQueue.Send(buf, 0, ecn)
return nil return nil
} }
@ -2018,24 +2022,24 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel)
} }
return s.sendPackedCoalescedPacket(packet, now) return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(), now)
} }
// appendPacket appends a new packet to the given packetBuffer. // appendOnePacket appends a new packet to the given packetBuffer.
// If there was nothing to pack, the returned size is 0. // If there was nothing to pack, the returned size is 0.
func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) { func (s *connection) appendOnePacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) {
startLen := buf.Len() startLen := buf.Len()
p, err := s.packer.AppendPacket(buf, maxSize, s.version) p, err := s.packer.AppendPacket(buf, maxSize, s.version)
if err != nil { if err != nil {
return 0, err return 0, err
} }
size := buf.Len() - startLen size := buf.Len() - startLen
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, size, false)
s.registerPackedShortHeaderPacket(p, now) s.registerPackedShortHeaderPacket(p, ecn, now)
return size, nil return size, nil
} }
func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) { func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
} }
@ -2044,12 +2048,12 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti
if p.Ack != nil { if p.Ack != nil {
largestAcked = p.Ack.LargestAcked() largestAcked = p.Ack.LargestAcked()
} }
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket)
s.connIDManager.SentPacket() s.connIDManager.SentPacket()
} }
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error { func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now time.Time) error {
s.logCoalescedPacket(packet) s.logCoalescedPacket(packet, ecn)
for _, p := range packet.longHdrPackets { for _, p := range packet.longHdrPackets {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
@ -2058,7 +2062,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
if p.ack != nil { if p.ack != nil {
largestAcked = p.ack.LargestAcked() largestAcked = p.ack.LargestAcked()
} }
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false) s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false)
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake {
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
// See Section 4.9.1 of RFC 9001. // See Section 4.9.1 of RFC 9001.
@ -2075,10 +2079,10 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
if p.Ack != nil { if p.Ack != nil {
largestAcked = p.Ack.LargestAcked() largestAcked = p.Ack.LargestAcked()
} }
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket)
} }
s.connIDManager.SentPacket() s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer, 0) s.sendQueue.Send(packet.buffer, 0, ecn)
return nil return nil
} }
@ -2100,8 +2104,9 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.logCoalescedPacket(packet) ecn := s.sentPacketHandler.ECNMode()
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0) s.logCoalescedPacket(packet, ecn)
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn)
} }
func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
@ -2144,11 +2149,12 @@ func (s *connection) logShortHeaderPacket(
pn protocol.PacketNumber, pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen, pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit, kp protocol.KeyPhaseBit,
ecn protocol.ECN,
size protocol.ByteCount, size protocol.ByteCount,
isCoalesced bool, isCoalesced bool,
) { ) {
if s.logger.Debug() && !isCoalesced { if s.logger.Debug() && !isCoalesced {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT", pn, size, s.logID) s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn)
} }
// quic-go logging // quic-go logging
if s.logger.Debug() { if s.logger.Debug() {
@ -2191,7 +2197,7 @@ func (s *connection) logShortHeaderPacket(
} }
} }
func (s *connection) logCoalescedPacket(packet *coalescedPacket) { func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
if s.logger.Debug() { if s.logger.Debug() {
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake, // There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
// during which we might call PackCoalescedPacket but just pack a short header packet. // during which we might call PackCoalescedPacket but just pack a short header packet.
@ -2204,6 +2210,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase, packet.shortHdrPacket.KeyPhase,
ecn,
packet.shortHdrPacket.Length, packet.shortHdrPacket.Length,
false, false,
) )
@ -2219,7 +2226,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
s.logLongHeaderPacket(p) s.logLongHeaderPacket(p)
} }
if p := packet.shortHdrPacket; p != nil { if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true)
} }
} }

View file

@ -454,7 +454,7 @@ var _ = Describe("Connection", func() {
Expect(e.ErrorMessage).To(BeEmpty()) Expect(e.ErrorMessage).To(BeEmpty())
return &coalescedPacket{buffer: buffer}, nil return &coalescedPacket{buffer: buffer}, nil
}) })
mconn.EXPECT().Write([]byte("connection close"), gomock.Any()) mconn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
gomock.InOrder( gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
var appErr *ApplicationError var appErr *ApplicationError
@ -475,7 +475,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed() expectReplaceWithClosed()
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
conn.shutdown() conn.shutdown()
@ -494,7 +494,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed() expectReplaceWithClosed()
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
gomock.InOrder( gomock.InOrder(
tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().ClosedConnection(expectedErr),
tracer.EXPECT().Close(), tracer.EXPECT().Close(),
@ -516,7 +516,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed() expectReplaceWithClosed()
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
gomock.InOrder( gomock.InOrder(
tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().ClosedConnection(expectedErr),
tracer.EXPECT().Close(), tracer.EXPECT().Close(),
@ -565,7 +565,7 @@ var _ = Describe("Connection", func() {
close(returned) close(returned)
}() }()
Consistently(returned).ShouldNot(BeClosed()) Consistently(returned).ShouldNot(BeClosed())
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
conn.shutdown() conn.shutdown()
@ -609,13 +609,14 @@ var _ = Describe("Connection", func() {
conn.handshakeConfirmed = true conn.handshakeConfirmed = true
sconn := NewMockSendConn(mockCtrl) sconn := NewMockSendConn(mockCtrl)
sconn.EXPECT().capabilities().AnyTimes() sconn.EXPECT().capabilities().AnyTimes()
sconn.EXPECT().Write(gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() sconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes()
conn.sendQueue = newSendQueue(sconn) conn.sendQueue = newSendQueue(sconn)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes()
sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
// only expect a single SentPacket() call // only expect a single SentPacket() call
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
@ -837,7 +838,7 @@ var _ = Describe("Connection", func() {
// make the go routine return // make the go routine return
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.closeLocal(errors.New("close")) conn.closeLocal(errors.New("close"))
Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed())
}) })
@ -872,7 +873,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed() expectReplaceWithClosed()
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.closeLocal(errors.New("close")) conn.closeLocal(errors.New("close"))
Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed())
}) })
@ -908,7 +909,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed() expectReplaceWithClosed()
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.closeLocal(errors.New("close")) conn.closeLocal(errors.New("close"))
Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed())
}) })
@ -930,7 +931,7 @@ var _ = Describe("Connection", func() {
close(done) close(done)
}() }()
expectReplaceWithClosed() expectReplaceWithClosed()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
packet := getShortHeaderPacket(srcConnID, 0x42, nil) packet := getShortHeaderPacket(srcConnID, 0x42, nil)
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
@ -958,7 +959,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.shutdown() conn.shutdown()
Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed())
}) })
@ -980,7 +981,7 @@ var _ = Describe("Connection", func() {
close(done) close(done)
}() }()
expectReplaceWithClosed() expectReplaceWithClosed()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil))
@ -1190,6 +1191,7 @@ var _ = Describe("Connection", func() {
var ( var (
connDone chan struct{} connDone chan struct{}
sender *MockSender sender *MockSender
sph *mockackhandler.MockSentPacketHandler
) )
BeforeEach(func() { BeforeEach(func() {
@ -1198,14 +1200,17 @@ var _ = Describe("Connection", func() {
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
conn.sendQueue = sender conn.sendQueue = sender
connDone = make(chan struct{}) connDone = make(chan struct{})
sph = mockackhandler.NewMockSentPacketHandler(mockCtrl)
conn.sentPacketHandler = sph
}) })
AfterEach(func() { AfterEach(func() {
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sph.EXPECT().ECNMode().Return(protocol.ECNCE).MaxTimes(1)
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
expectReplaceWithClosed() expectReplaceWithClosed()
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
sender.EXPECT().Close() sender.EXPECT().Close()
@ -1226,12 +1231,11 @@ var _ = Describe("Connection", func() {
It("sends packets", func() { It("sends packets", func() {
conn.handshakeConfirmed = true conn.handshakeConfirmed = true
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().ECNMode().Return(protocol.ECNNon).AnyTimes()
conn.sentPacketHandler = sph sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
runConn() runConn()
p := shortHeaderPacket{ p := shortHeaderPacket{
DestConnID: protocol.ParseConnectionID([]byte{1, 2, 3}), DestConnID: protocol.ParseConnectionID([]byte{1, 2, 3}),
@ -1243,7 +1247,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
sent := make(chan struct{}) sent := make(chan struct{})
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) })
tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{ tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{
DestConnectionID: p.DestConnID, DestConnectionID: p.DestConnID,
PacketNumber: p.PacketNumber, PacketNumber: p.PacketNumber,
@ -1256,6 +1260,9 @@ var _ = Describe("Connection", func() {
It("doesn't send packets if there's nothing to send", func() { It("doesn't send packets if there's nothing to send", func() {
conn.handshakeConfirmed = true conn.handshakeConfirmed = true
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode().AnyTimes()
runConn() runConn()
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true)
@ -1264,13 +1271,12 @@ var _ = Describe("Connection", func() {
}) })
It("sends ACK only packets", func() { It("sends ACK only packets", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes()
done := make(chan struct{}) done := make(chan struct{})
packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) }) packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) })
conn.sentPacketHandler = sph
runConn() runConn()
conn.scheduleSending() conn.scheduleSending()
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -1278,12 +1284,11 @@ var _ = Describe("Connection", func() {
It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { It("adds a BLOCKED frame when it is connection-level flow control blocked", func() {
conn.handshakeConfirmed = true conn.handshakeConfirmed = true
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().ECNMode().AnyTimes()
conn.sentPacketHandler = sph sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
fc := mocks.NewMockConnectionFlowController(mockCtrl) fc := mocks.NewMockConnectionFlowController(mockCtrl)
fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337))
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 13}, []byte("foobar")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 13}, []byte("foobar"))
@ -1291,7 +1296,7 @@ var _ = Describe("Connection", func() {
conn.connFlowController = fc conn.connFlowController = fc
runConn() runConn()
sent := make(chan struct{}) sent := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) })
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{}) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{})
conn.scheduleSending() conn.scheduleSending()
Eventually(sent).Should(BeClosed()) Eventually(sent).Should(BeClosed())
@ -1300,11 +1305,9 @@ var _ = Describe("Connection", func() {
}) })
It("doesn't send when the SentPacketHandler doesn't allow it", func() { It("doesn't send when the SentPacketHandler doesn't allow it", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone).AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
conn.sentPacketHandler = sph
runConn() runConn()
conn.scheduleSending() conn.scheduleSending()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@ -1333,21 +1336,19 @@ var _ = Describe("Connection", func() {
}) })
It("sends a probe packet", func() { It("sends a probe packet", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().QueueProbePacket(encLevel) sph.EXPECT().QueueProbePacket(encLevel)
sph.EXPECT().ECNMode()
p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT)
packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
Expect(pn).To(Equal(protocol.PacketNumber(123)))
})
conn.sentPacketHandler = sph conn.sentPacketHandler = sph
runConn() runConn()
sent := make(chan struct{}) sent := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) })
if enc == protocol.Encryption1RTT { if enc == protocol.Encryption1RTT {
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any())
} else { } else {
@ -1358,21 +1359,18 @@ var _ = Describe("Connection", func() {
}) })
It("sends a PING as a probe packet", func() { It("sends a PING as a probe packet", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(sendMode) sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode()
sph.EXPECT().QueueProbePacket(encLevel).Return(false) sph.EXPECT().QueueProbePacket(encLevel).Return(false)
p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) p := getCoalescedPacket(123, enc != protocol.Encryption1RTT)
packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
Expect(pn).To(Equal(protocol.PacketNumber(123)))
})
conn.sentPacketHandler = sph
runConn() runConn()
sent := make(chan struct{}) sent := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) })
if enc == protocol.Encryption1RTT { if enc == protocol.Encryption1RTT {
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any())
} else { } else {
@ -1409,10 +1407,11 @@ var _ = Describe("Connection", func() {
AfterEach(func() { AfterEach(func() {
// make the go routine return // make the go routine return
sph.EXPECT().ECNMode().MaxTimes(1)
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
expectReplaceWithClosed() expectReplaceWithClosed()
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
sender.EXPECT().Close() sender.EXPECT().Close()
@ -1421,17 +1420,18 @@ var _ = Describe("Connection", func() {
}) })
It("sends multiple packets one by one immediately", func() { It("sends multiple packets one by one immediately", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
sph.EXPECT().ECNMode().Times(2)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited)
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour))
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10"))
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, []byte("packet11")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, []byte("packet11"))
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal([]byte("packet10"))) Expect(b.Data).To(Equal([]byte("packet10")))
}) })
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal([]byte("packet11"))) Expect(b.Data).To(Equal([]byte("packet11")))
}) })
go func() { go func() {
@ -1446,7 +1446,8 @@ var _ = Describe("Connection", func() {
It("sends multiple packets one by one immediately, with GSO", func() { It("sends multiple packets one by one immediately, with GSO", func() {
enableGSO() enableGSO()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().ECNMode().Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
rand.Read(payload1) rand.Read(payload1)
@ -1456,7 +1457,7 @@ var _ = Describe("Connection", func() {
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...))) Expect(b.Data).To(Equal(append(payload1, payload2...)))
}) })
go func() { go func() {
@ -1471,8 +1472,9 @@ var _ = Describe("Connection", func() {
It("stops appending packets when a smaller packet is packed, with GSO", func() { It("stops appending packets when a smaller packet is packed, with GSO", func() {
enableGSO() enableGSO()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
sph.EXPECT().ECNMode().Times(2)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
rand.Read(payload1) rand.Read(payload1)
@ -1481,7 +1483,7 @@ var _ = Describe("Connection", func() {
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...))) Expect(b.Data).To(Equal(append(payload1, payload2...)))
}) })
go func() { go func() {
@ -1495,12 +1497,13 @@ var _ = Describe("Connection", func() {
}) })
It("sends multiple packets, when the pacer allows immediate sending", func() { It("sends multiple packets, when the pacer allows immediate sending", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
sph.EXPECT().ECNMode().Times(2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1512,13 +1515,14 @@ var _ = Describe("Connection", func() {
}) })
It("allows an ACK to be sent when pacing limited", func() { It("allows an ACK to be sent when pacing limited", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour))
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited)
sph.EXPECT().ECNMode()
packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{PacketNumber: 123}, getPacketBuffer(), nil) packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{PacketNumber: 123}, getPacketBuffer(), nil)
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1532,12 +1536,13 @@ var _ = Describe("Connection", func() {
// when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck
// we shouldn't send the ACK in the same run // we shouldn't send the ACK in the same run
It("doesn't send an ACK right after becoming congestion limited", func() { It("doesn't send an ACK right after becoming congestion limited", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
sph.EXPECT().ECNMode().Times(2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100"))
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1552,19 +1557,21 @@ var _ = Describe("Connection", func() {
pacingDelay := scaleDuration(100 * time.Millisecond) pacingDelay := scaleDuration(100 * time.Millisecond)
gomock.InOrder( gomock.InOrder(
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().ECNMode(),
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 100}, []byte("packet100")),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().ECNMode(),
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 101}, []byte("packet101")), expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 101}, []byte("packet101")),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited), sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)),
) )
written := make(chan struct{}, 2) written := make(chan struct{}, 2)
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(2) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(2)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1578,8 +1585,9 @@ var _ = Describe("Connection", func() {
}) })
It("sends multiple packets at once", func() { It("sends multiple packets at once", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
sph.EXPECT().ECNMode().Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited)
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour))
for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { for pn := protocol.PacketNumber(1000); pn < 1003; pn++ {
@ -1587,7 +1595,7 @@ var _ = Describe("Connection", func() {
} }
written := make(chan struct{}, 3) written := make(chan struct{}, 3)
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(3) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} }).Times(3)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1618,11 +1626,12 @@ var _ = Describe("Connection", func() {
written := make(chan struct{}) written := make(chan struct{})
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode().AnyTimes()
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) })
available <- struct{}{} available <- struct{}{}
Eventually(written).Should(BeClosed()) Eventually(written).Should(BeClosed())
}) })
@ -1639,14 +1648,15 @@ var _ = Describe("Connection", func() {
written := make(chan struct{}) written := make(chan struct{})
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ByteCount, bool) { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ECN, protocol.ByteCount, bool) {
sph.EXPECT().ReceivedBytes(gomock.Any()) sph.EXPECT().ReceivedBytes(gomock.Any())
conn.handlePacket(receivedPacket{buffer: getPacketBuffer()}) conn.handlePacket(receivedPacket{buffer: getPacketBuffer()})
}) })
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode().AnyTimes()
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { close(written) })
conn.scheduleSending() conn.scheduleSending()
time.Sleep(scaleDuration(50 * time.Millisecond)) time.Sleep(scaleDuration(50 * time.Millisecond))
@ -1655,13 +1665,14 @@ var _ = Describe("Connection", func() {
}) })
It("stops sending when the send queue is full", func() { It("stops sending when the send queue is full", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny)
sph.EXPECT().ECNMode()
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000"))
written := make(chan struct{}, 1) written := make(chan struct{}, 1)
sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock()
sender.EXPECT().WouldBlock().Return(true).Times(2) sender.EXPECT().WouldBlock().Return(true).Times(2)
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} })
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
@ -1675,12 +1686,13 @@ var _ = Describe("Connection", func() {
time.Sleep(scaleDuration(50 * time.Millisecond)) time.Sleep(scaleDuration(50 * time.Millisecond))
// now make room in the send queue // now make room in the send queue
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode().AnyTimes()
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} })
available <- struct{}{} available <- struct{}{}
Eventually(written).Should(Receive()) Eventually(written).Should(Receive())
@ -1691,6 +1703,7 @@ var _ = Describe("Connection", func() {
It("doesn't set a pacing timer when there is no data to send", func() { It("doesn't set a pacing timer when there is no data to send", func() {
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode().AnyTimes()
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
// don't EXPECT any calls to mconn.Write() // don't EXPECT any calls to mconn.Write()
@ -1708,12 +1721,13 @@ var _ = Describe("Connection", func() {
mtuDiscoverer := NewMockMTUDiscoverer(mockCtrl) mtuDiscoverer := NewMockMTUDiscoverer(mockCtrl)
conn.mtuDiscoverer = mtuDiscoverer conn.mtuDiscoverer = mtuDiscoverer
conn.config.DisablePathMTUDiscovery = false conn.config.DisablePathMTUDiscovery = false
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny)
sph.EXPECT().ECNMode()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
written := make(chan struct{}, 1) written := make(chan struct{}, 1)
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16, protocol.ECN) { written <- struct{}{} })
mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true)
ping := ackhandler.Frame{Frame: &wire.PingFrame{}} ping := ackhandler.Frame{Frame: &wire.PingFrame{}}
mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234))
@ -1747,7 +1761,7 @@ var _ = Describe("Connection", func() {
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
sender.EXPECT().Close() sender.EXPECT().Close()
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
@ -1760,8 +1774,9 @@ var _ = Describe("Connection", func() {
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode().AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.sentPacketHandler = sph conn.sentPacketHandler = sph
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1}, []byte("packet1")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1}, []byte("packet1"))
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack)
@ -1776,7 +1791,7 @@ var _ = Describe("Connection", func() {
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
// only EXPECT calls after scheduleSending is called // only EXPECT calls after scheduleSending is called
written := make(chan struct{}) written := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) })
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
conn.scheduleSending() conn.scheduleSending()
Eventually(written).Should(BeClosed()) Eventually(written).Should(BeClosed())
@ -1788,9 +1803,8 @@ var _ = Describe("Connection", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, _ protocol.EncryptionLevel, _ protocol.ByteCount, _ bool) { sph.EXPECT().ECNMode().AnyTimes()
Expect(pn).To(Equal(protocol.PacketNumber(1234))) sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(1234), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
})
conn.sentPacketHandler = sph conn.sentPacketHandler = sph
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond))
@ -1799,7 +1813,7 @@ var _ = Describe("Connection", func() {
conn.receivedPacketHandler = rph conn.receivedPacketHandler = rph
written := make(chan struct{}) written := make(chan struct{})
sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(written) })
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -1841,18 +1855,11 @@ var _ = Describe("Connection", func() {
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode().Return(protocol.ECT1).AnyTimes()
sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes() sph.EXPECT().TimeUntilSend().Return(time.Now()).AnyTimes()
gomock.InOrder( gomock.InOrder(
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) { sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(13), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionInitial, protocol.ECT1, protocol.ByteCount(123), gomock.Any()),
Expect(encLevel).To(Equal(protocol.EncryptionInitial)) sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(37), gomock.Any(), gomock.Any(), gomock.Any(), protocol.EncryptionHandshake, protocol.ECT1, protocol.ByteCount(1234), gomock.Any()),
Expect(pn).To(Equal(protocol.PacketNumber(13)))
Expect(size).To(BeEquivalentTo(123))
}),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ time.Time, pn, _ protocol.PacketNumber, _ []ackhandler.StreamFrame, _ []ackhandler.Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, _ bool) {
Expect(encLevel).To(Equal(protocol.EncryptionHandshake))
Expect(pn).To(Equal(protocol.PacketNumber(37)))
Expect(size).To(BeEquivalentTo(1234))
}),
) )
gomock.InOrder( gomock.InOrder(
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) { tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) {
@ -1864,7 +1871,7 @@ var _ = Describe("Connection", func() {
) )
sent := make(chan struct{}) sent := make(chan struct{})
mconn.EXPECT().Write([]byte("foobar"), uint16(0)).Do(func([]byte, uint16) { close(sent) }) mconn.EXPECT().Write([]byte("foobar"), uint16(0), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(sent) })
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -1881,7 +1888,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed() expectReplaceWithClosed()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
conn.shutdown() conn.shutdown()
@ -1952,7 +1959,7 @@ var _ = Describe("Connection", func() {
}() }()
handshakeCtx := conn.HandshakeComplete() handshakeCtx := conn.HandshakeComplete()
Consistently(handshakeCtx).ShouldNot(BeClosed()) Consistently(handshakeCtx).ShouldNot(BeClosed())
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.closeLocal(errors.New("handshake error")) conn.closeLocal(errors.New("handshake error"))
Consistently(handshakeCtx).ShouldNot(BeClosed()) Consistently(handshakeCtx).ShouldNot(BeClosed())
Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed())
@ -1961,11 +1968,12 @@ var _ = Describe("Connection", func() {
It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SetHandshakeConfirmed() sph.EXPECT().SetHandshakeConfirmed()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.sentPacketHandler = sph conn.sentPacketHandler = sph
done := make(chan struct{}) done := make(chan struct{})
@ -1987,7 +1995,7 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket() cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handleHandshakeComplete()).To(Succeed()) Expect(conn.handleHandshakeComplete()).To(Succeed())
conn.run() conn.run()
}() }()
@ -2016,7 +2024,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed() expectReplaceWithClosed()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
conn.shutdown() conn.shutdown()
@ -2043,7 +2051,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed() expectReplaceWithClosed()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed()) Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed())
@ -2102,7 +2110,7 @@ var _ = Describe("Connection", func() {
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
conn.shutdown() conn.shutdown()
@ -2255,7 +2263,7 @@ var _ = Describe("Connection", func() {
// make the go routine return // make the go routine return
expectReplaceWithClosed() expectReplaceWithClosed()
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.shutdown() conn.shutdown()
Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed())
}) })
@ -2338,7 +2346,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
expectReplaceWithClosed() expectReplaceWithClosed()
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
conn.shutdown() conn.shutdown()
@ -2554,7 +2562,7 @@ var _ = Describe("Client Connection", func() {
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any()).MaxTimes(1) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close() tracer.EXPECT().Close()
conn.shutdown() conn.shutdown()
@ -2851,7 +2859,7 @@ var _ = Describe("Client Connection", func() {
packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1)
} }
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
gomock.InOrder( gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()),
tracer.EXPECT().Close(), tracer.EXPECT().Close(),

View file

@ -10,7 +10,7 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet // SentPacket may modify the packet
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool) SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool)
// ReceivedAck processes an ACK frame. // ReceivedAck processes an ACK frame.
// It does not store a copy of the frame. // It does not store a copy of the frame.
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error) ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error)
@ -29,6 +29,8 @@ type SentPacketHandler interface {
// only to be called once the handshake is complete // only to be called once the handshake is complete
QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */
ECNMode() protocol.ECN
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber

View file

@ -228,6 +228,7 @@ func (h *sentPacketHandler) SentPacket(
streamFrames []StreamFrame, streamFrames []StreamFrame,
frames []Frame, frames []Frame,
encLevel protocol.EncryptionLevel, encLevel protocol.EncryptionLevel,
_ protocol.ECN,
size protocol.ByteCount, size protocol.ByteCount,
isPathMTUProbePacket bool, isPathMTUProbePacket bool,
) { ) {
@ -712,6 +713,11 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
return h.alarm return h.alarm
} }
func (h *sentPacketHandler) ECNMode() protocol.ECN {
// TODO: implement ECN logic
return protocol.ECNNon
}
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
pn := pnSpace.pns.Peek() pn := pnSpace.pns.Peek()

View file

@ -106,7 +106,7 @@ var _ = Describe("SentPacketHandler", func() {
} }
sentPacket := func(p *packet) { sentPacket := func(p *packet) {
handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, p.Length, p.IsPathMTUProbePacket) handler.SentPacket(p.SendTime, p.PacketNumber, p.LargestAcked, p.StreamFrames, p.Frames, p.EncryptionLevel, protocol.ECNNon, p.Length, p.IsPathMTUProbePacket)
} }
expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) {

View file

@ -49,6 +49,20 @@ func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0)
} }
// ECNMode mocks base method.
func (m *MockSentPacketHandler) ECNMode() protocol.ECN {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ECNMode")
ret0, _ := ret[0].(protocol.ECN)
return ret0
}
// ECNMode indicates an expected call of ECNMode.
func (mr *MockSentPacketHandlerMockRecorder) ECNMode() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode))
}
// GetLossDetectionTimeout mocks base method. // GetLossDetectionTimeout mocks base method.
func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time { func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -176,15 +190,15 @@ func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 interface{}) *gomock.
} }
// SentPacket mocks base method. // SentPacket mocks base method.
func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ByteCount, arg7 bool) { func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.PacketNumber, arg3 []ackhandler.StreamFrame, arg4 []ackhandler.Frame, arg5 protocol.EncryptionLevel, arg6 protocol.ECN, arg7 protocol.ByteCount, arg8 bool) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8)
} }
// SentPacket indicates an expected call of SentPacket. // SentPacket indicates an expected call of SentPacket.
func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7 interface{}) *gomock.Call { func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8)
} }
// SetHandshakeConfirmed mocks base method. // SetHandshakeConfirmed mocks base method.

View file

@ -9,6 +9,7 @@ import (
reflect "reflect" reflect "reflect"
time "time" time "time"
protocol "github.com/quic-go/quic-go/internal/protocol"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
@ -93,18 +94,18 @@ func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Cal
} }
// WritePacket mocks base method. // WritePacket mocks base method.
func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16) (int, error) { func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16, arg4 protocol.ECN) (int, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(int) ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// WritePacket indicates an expected call of WritePacket. // WritePacket indicates an expected call of WritePacket.
func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3, arg4)
} }
// capabilities mocks base method. // capabilities mocks base method.

View file

@ -8,6 +8,7 @@ import (
net "net" net "net"
reflect "reflect" reflect "reflect"
protocol "github.com/quic-go/quic-go/internal/protocol"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
@ -77,17 +78,17 @@ func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call {
} }
// Write mocks base method. // Write mocks base method.
func (m *MockSendConn) Write(arg0 []byte, arg1 uint16) error { func (m *MockSendConn) Write(arg0 []byte, arg1 uint16, arg2 protocol.ECN) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0, arg1) ret := m.ctrl.Call(m, "Write", arg0, arg1, arg2)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Write indicates an expected call of Write. // Write indicates an expected call of Write.
func (mr *MockSendConnMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call { func (mr *MockSendConnMockRecorder) Write(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1, arg2)
} }
// capabilities mocks base method. // capabilities mocks base method.

View file

@ -7,6 +7,7 @@ package quic
import ( import (
reflect "reflect" reflect "reflect"
protocol "github.com/quic-go/quic-go/internal/protocol"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
@ -74,15 +75,15 @@ func (mr *MockSenderMockRecorder) Run() *gomock.Call {
} }
// Send mocks base method. // Send mocks base method.
func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16) { func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16, arg2 protocol.ECN) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Send", arg0, arg1) m.ctrl.Call(m, "Send", arg0, arg1, arg2)
} }
// Send indicates an expected call of Send. // Send indicates an expected call of Send.
func (mr *MockSenderMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call { func (mr *MockSenderMockRecorder) Send(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1, arg2)
} }
// WouldBlock mocks base method. // WouldBlock mocks base method.

View file

@ -29,7 +29,7 @@ type rawConn interface {
// WritePacket writes a packet on the wire. // WritePacket writes a packet on the wire.
// gsoSize is the size of a single packet, or 0 to disable GSO. // gsoSize is the size of a single packet, or 0 to disable GSO.
// It is invalid to set gsoSize if capabilities.GSO is not set. // It is invalid to set gsoSize if capabilities.GSO is not set.
WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error)
LocalAddr() net.Addr LocalAddr() net.Addr
SetReadDeadline(time.Time) error SetReadDeadline(time.Time) error
io.Closer io.Closer

View file

@ -9,7 +9,7 @@ import (
// A sendConn allows sending using a simple Write() on a non-connected packet conn. // A sendConn allows sending using a simple Write() on a non-connected packet conn.
type sendConn interface { type sendConn interface {
Write(b []byte, gsoSize uint16) error Write(b []byte, gsoSize uint16, ecn protocol.ECN) error
Close() error Close() error
LocalAddr() net.Addr LocalAddr() net.Addr
RemoteAddr() net.Addr RemoteAddr() net.Addr
@ -43,13 +43,6 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge
} }
oob := info.OOB() oob := info.OOB()
if remoteUDPAddr, ok := remote.(*net.UDPAddr); ok {
if remoteUDPAddr.IP.To4() != nil {
oob = appendIPv4ECNMsg(oob, protocol.ECT1)
} else {
oob = appendIPv6ECNMsg(oob, protocol.ECT1)
}
}
// increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating
l := len(oob) l := len(oob)
oob = append(oob, make([]byte, 64)...)[:l] oob = append(oob, make([]byte, 64)...)[:l]
@ -62,8 +55,8 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge
} }
} }
func (c *sconn) Write(p []byte, gsoSize uint16) error { func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error {
_, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize) _, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn)
if err != nil && isGSOError(err) { if err != nil && isGSOError(err) {
// disable GSO for future calls // disable GSO for future calls
c.gotGSOError = true c.gotGSOError = true
@ -76,7 +69,7 @@ func (c *sconn) Write(p []byte, gsoSize uint16) error {
if l > int(gsoSize) { if l > int(gsoSize) {
l = int(gsoSize) l = int(gsoSize)
} }
if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0); err != nil { if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil {
return err return err
} }
p = p[l:] p = p[l:]

View file

@ -5,6 +5,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
@ -45,8 +46,8 @@ var _ = Describe("Connection (for sending packets)", func() {
pi := packetInfo{addr: netip.IPv6Loopback()} pi := packetInfo{addr: netip.IPv6Loopback()}
Expect(pi.OOB()).ToNot(BeEmpty()) Expect(pi.OOB()).ToNot(BeEmpty())
c := newSendConn(rawConn, remoteAddr, pi, utils.DefaultLogger) c := newSendConn(rawConn, remoteAddr, pi, utils.DefaultLogger)
rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0)) rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0), protocol.ECT1)
Expect(c.Write([]byte("foobar"), 0)).To(Succeed()) Expect(c.Write([]byte("foobar"), 0, protocol.ECT1)).To(Succeed())
}) })
} }
@ -55,8 +56,8 @@ var _ = Describe("Connection (for sending packets)", func() {
rawConn.EXPECT().LocalAddr() rawConn.EXPECT().LocalAddr()
rawConn.EXPECT().capabilities().AnyTimes() rawConn.EXPECT().capabilities().AnyTimes()
c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger)
rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3)) rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3), protocol.ECNCE)
Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) Expect(c.Write([]byte("foobar"), 3, protocol.ECNCE)).To(Succeed())
}) })
if platformSupportsGSO { if platformSupportsGSO {
@ -67,11 +68,11 @@ var _ = Describe("Connection (for sending packets)", func() {
c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger)
Expect(c.capabilities().GSO).To(BeTrue()) Expect(c.capabilities().GSO).To(BeTrue())
gomock.InOrder( gomock.InOrder(
rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4)).Return(0, errGSO), rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4), protocol.ECNCE).Return(0, errGSO),
rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0)).Return(4, nil), rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(4, nil),
rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0)).Return(2, nil), rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0), protocol.ECNCE).Return(2, nil),
) )
Expect(c.Write([]byte("foobar"), 4)).To(Succeed()) Expect(c.Write([]byte("foobar"), 4, protocol.ECNCE)).To(Succeed())
Expect(c.capabilities().GSO).To(BeFalse()) Expect(c.capabilities().GSO).To(BeFalse())
}) })
} }

View file

@ -1,8 +1,9 @@
package quic package quic
import "github.com/quic-go/quic-go/internal/protocol"
type sender interface { type sender interface {
// Send sends a packet. GSO is only used if gsoSize > 0. Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN)
Send(p *packetBuffer, gsoSize uint16)
Run() error Run() error
WouldBlock() bool WouldBlock() bool
Available() <-chan struct{} Available() <-chan struct{}
@ -12,6 +13,7 @@ type sender interface {
type queueEntry struct { type queueEntry struct {
buf *packetBuffer buf *packetBuffer
gsoSize uint16 gsoSize uint16
ecn protocol.ECN
} }
type sendQueue struct { type sendQueue struct {
@ -39,9 +41,9 @@ func newSendQueue(conn sendConn) sender {
// Send sends out a packet. It's guaranteed to not block. // Send sends out a packet. It's guaranteed to not block.
// Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock.
// Otherwise Send will panic. // Otherwise Send will panic.
func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16) { func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) {
select { select {
case h.queue <- queueEntry{buf: p, gsoSize: gsoSize}: case h.queue <- queueEntry{buf: p, gsoSize: gsoSize, ecn: ecn}:
// clear available channel if we've reached capacity // clear available channel if we've reached capacity
if len(h.queue) == sendQueueCapacity { if len(h.queue) == sendQueueCapacity {
select { select {
@ -76,7 +78,7 @@ func (h *sendQueue) Run() error {
// make sure that all queued packets are actually sent out // make sure that all queued packets are actually sent out
shouldClose = true shouldClose = true
case e := <-h.queue: case e := <-h.queue:
if err := h.conn.Write(e.buf.Data, e.gsoSize); err != nil { if err := h.conn.Write(e.buf.Data, e.gsoSize, e.ecn); err != nil {
// This additional check enables: // This additional check enables:
// 1. Checking for "datagram too large" message from the kernel, as such, // 1. Checking for "datagram too large" message from the kernel, as such,
// 2. Path MTU discovery,and // 2. Path MTU discovery,and

View file

@ -3,6 +3,8 @@ package quic
import ( import (
"errors" "errors"
"github.com/quic-go/quic-go/internal/protocol"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
@ -26,10 +28,10 @@ var _ = Describe("Send Queue", func() {
It("sends a packet", func() { It("sends a packet", func() {
p := getPacket([]byte("foobar")) p := getPacket([]byte("foobar"))
q.Send(p, 10) // make sure the packet size is passed through to the conn q.Send(p, 10, protocol.ECT1) // make sure the packet size is passed through to the conn
written := make(chan struct{}) written := make(chan struct{})
c.EXPECT().Write([]byte("foobar"), uint16(10)).Do(func([]byte, uint16) { close(written) }) c.EXPECT().Write([]byte("foobar"), uint16(10), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(written) })
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -45,19 +47,19 @@ var _ = Describe("Send Queue", func() {
It("panics when Send() is called although there's no space in the queue", func() { It("panics when Send() is called although there's no space in the queue", func() {
for i := 0; i < sendQueueCapacity; i++ { for i := 0; i < sendQueueCapacity; i++ {
Expect(q.WouldBlock()).To(BeFalse()) Expect(q.WouldBlock()).To(BeFalse())
q.Send(getPacket([]byte("foobar")), 6) q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
} }
Expect(q.WouldBlock()).To(BeTrue()) Expect(q.WouldBlock()).To(BeTrue())
Expect(func() { q.Send(getPacket([]byte("raboof")), 6) }).To(Panic()) Expect(func() { q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon) }).To(Panic())
}) })
It("signals when sending is possible again", func() { It("signals when sending is possible again", func() {
Expect(q.WouldBlock()).To(BeFalse()) Expect(q.WouldBlock()).To(BeFalse())
q.Send(getPacket([]byte("foobar1")), 6) q.Send(getPacket([]byte("foobar1")), 6, protocol.ECNNon)
Consistently(q.Available()).ShouldNot(Receive()) Consistently(q.Available()).ShouldNot(Receive())
// now start sending out packets. This should free up queue space. // now start sending out packets. This should free up queue space.
c.EXPECT().Write(gomock.Any(), gomock.Any()).MinTimes(1).MaxTimes(2) c.EXPECT().Write(gomock.Any(), gomock.Any(), protocol.ECNNon).MinTimes(1).MaxTimes(2)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -67,7 +69,7 @@ var _ = Describe("Send Queue", func() {
Eventually(q.Available()).Should(Receive()) Eventually(q.Available()).Should(Receive())
Expect(q.WouldBlock()).To(BeFalse()) Expect(q.WouldBlock()).To(BeFalse())
Expect(func() { q.Send(getPacket([]byte("foobar2")), 7) }).ToNot(Panic()) Expect(func() { q.Send(getPacket([]byte("foobar2")), 7, protocol.ECNNon) }).ToNot(Panic())
q.Close() q.Close()
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -77,7 +79,7 @@ var _ = Describe("Send Queue", func() {
write := make(chan struct{}, 1) write := make(chan struct{}, 1)
written := make(chan struct{}, 100) written := make(chan struct{}, 100)
// now start sending out packets. This should free up queue space. // now start sending out packets. This should free up queue space.
c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16) error { c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16, protocol.ECN) error {
written <- struct{}{} written <- struct{}{}
<-write <-write
return nil return nil
@ -92,19 +94,19 @@ var _ = Describe("Send Queue", func() {
close(done) close(done)
}() }()
q.Send(getPacket([]byte("foobar")), 6) q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
<-written <-written
// now fill up the send queue // now fill up the send queue
for i := 0; i < sendQueueCapacity; i++ { for i := 0; i < sendQueueCapacity; i++ {
Expect(q.WouldBlock()).To(BeFalse()) Expect(q.WouldBlock()).To(BeFalse())
q.Send(getPacket([]byte("foobar")), 6) q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
} }
// One more packet is queued when it's picked up by Run and written to the connection. // One more packet is queued when it's picked up by Run and written to the connection.
// In this test, it's blocked on write channel in the mocked Write call. // In this test, it's blocked on write channel in the mocked Write call.
<-written <-written
Eventually(q.WouldBlock()).Should(BeFalse()) Eventually(q.WouldBlock()).Should(BeFalse())
q.Send(getPacket([]byte("foobar")), 6) q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
Expect(q.WouldBlock()).To(BeTrue()) Expect(q.WouldBlock()).To(BeTrue())
Consistently(q.Available()).ShouldNot(Receive()) Consistently(q.Available()).ShouldNot(Receive())
@ -130,15 +132,15 @@ var _ = Describe("Send Queue", func() {
// the run loop exits if there is a write error // the run loop exits if there is a write error
testErr := errors.New("test error") testErr := errors.New("test error")
c.EXPECT().Write(gomock.Any(), gomock.Any()).Return(testErr) c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Return(testErr)
q.Send(getPacket([]byte("foobar")), 6) q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
sent := make(chan struct{}) sent := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
q.Send(getPacket([]byte("raboof")), 6) q.Send(getPacket([]byte("raboof")), 6, protocol.ECNNon)
q.Send(getPacket([]byte("quux")), 4) q.Send(getPacket([]byte("quux")), 4, protocol.ECNNon)
close(sent) close(sent)
}() }()
@ -147,7 +149,7 @@ var _ = Describe("Send Queue", func() {
It("blocks Close() until the packet has been sent out", func() { It("blocks Close() until the packet has been sent out", func() {
written := make(chan []byte) written := make(chan []byte)
c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16) { written <- p }) c.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16, _ protocol.ECN) { written <- p })
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -155,7 +157,7 @@ var _ = Describe("Send Queue", func() {
close(done) close(done)
}() }()
q.Send(getPacket([]byte("foobar")), 6) q.Send(getPacket([]byte("foobar")), 6, protocol.ECNNon)
closed := make(chan struct{}) closed := make(chan struct{})
go func() { go func() {

View file

@ -745,7 +745,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe
if s.tracer != nil { if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
} }
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0) _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon)
return err return err
} }
@ -844,7 +844,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
if s.tracer != nil { if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
} }
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0) _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon)
return err return err
} }
@ -882,7 +882,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
if s.tracer != nil { if s.tracer != nil {
s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
} }
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err) s.logger.Debugf("Error sending Version Negotiation: %s", err)
} }
} }

View file

@ -104,7 +104,7 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) {
}, nil }, nil
} }
func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16) (n int, err error) { func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, _ protocol.ECN) (n int, err error) {
if gsoSize != 0 { if gsoSize != 0 {
panic("cannot use GSO with a basicConn") panic("cannot use GSO with a basicConn")
} }

View file

@ -8,7 +8,6 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
) )
@ -53,6 +52,3 @@ func isRecvMsgSizeErr(err error) bool {
// https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2
return errors.Is(err, windows.WSAEMSGSIZE) return errors.Is(err, windows.WSAEMSGSIZE)
} }
func appendIPv4ECNMsg([]byte, protocol.ECN) []byte { return nil }
func appendIPv6ECNMsg([]byte, protocol.ECN) []byte { return nil }

View file

@ -5,8 +5,6 @@ package quic
import ( import (
"net" "net"
"net/netip" "net/netip"
"github.com/quic-go/quic-go/internal/protocol"
) )
func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) {
@ -16,9 +14,6 @@ func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) {
func inspectReadBuffer(any) (int, error) { return 0, nil } func inspectReadBuffer(any) (int, error) { return 0, nil }
func inspectWriteBuffer(any) (int, error) { return 0, nil } func inspectWriteBuffer(any) (int, error) { return 0, nil }
func appendIPv4ECNMsg([]byte, protocol.ECN) []byte { return nil }
func appendIPv6ECNMsg([]byte, protocol.ECN) []byte { return nil }
type packetInfo struct { type packetInfo struct {
addr netip.Addr addr netip.Addr
} }

View file

@ -229,7 +229,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
} }
// WritePacket writes a new packet. // WritePacket writes a new packet.
func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) { func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) {
oob := packetInfoOOB oob := packetInfoOOB
if gsoSize > 0 { if gsoSize > 0 {
if !c.capabilities().GSO { if !c.capabilities().GSO {
@ -237,6 +237,13 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso
} }
oob = appendUDPSegmentSizeMsg(oob, gsoSize) oob = appendUDPSegmentSizeMsg(oob, gsoSize)
} }
if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok {
if remoteUDPAddr.IP.To4() != nil {
oob = appendIPv4ECNMsg(oob, ecn)
} else {
oob = appendIPv6ECNMsg(oob, ecn)
}
}
n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
return n, err return n, err
} }

View file

@ -53,7 +53,7 @@ var _ = Describe("OOB Conn Test", func() {
return udpConn, packetChan return udpConn, packetChan
} }
Context("ECN conn", func() { Context("reading ECN-marked packets", func() {
sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr { sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr {
conn, err := net.DialUDP(network, nil, addr) conn, err := net.DialUDP(network, nil, addr)
ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
@ -289,6 +289,27 @@ var _ = Describe("OOB Conn Test", func() {
}) })
}) })
Context("sending ECN-marked packets", func() {
It("sets the ECN control message", func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
c := &oobRecordingConn{UDPConn: udpConn}
oobConn, err := newConn(c, true)
Expect(err).ToNot(HaveOccurred())
oob := make([]byte, 0, 123)
oobConn.WritePacket([]byte("foobar"), addr, oob, 0, protocol.ECNCE)
Expect(c.oobs).To(HaveLen(1))
oobMsg := c.oobs[0]
Expect(oobMsg).ToNot(BeEmpty())
Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob
expected := appendIPv4ECNMsg([]byte{}, protocol.ECNCE)
Expect(oobMsg).To(Equal(expected))
})
})
if platformSupportsGSO { if platformSupportsGSO {
Context("GSO", func() { Context("GSO", func() {
It("appends the GSO control message", func() { It("appends the GSO control message", func() {
@ -301,14 +322,15 @@ var _ = Describe("OOB Conn Test", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(oobConn.capabilities().GSO).To(BeTrue()) Expect(oobConn.capabilities().GSO).To(BeTrue())
oob := make([]byte, 0, 42) oob := make([]byte, 0, 123)
oobConn.WritePacket([]byte("foobar"), addr, oob, 3) oobConn.WritePacket([]byte("foobar"), addr, oob, 3, protocol.ECNCE)
Expect(c.oobs).To(HaveLen(1)) Expect(c.oobs).To(HaveLen(1))
oobMsg := c.oobs[0] oobMsg := c.oobs[0]
Expect(oobMsg).ToNot(BeEmpty()) Expect(oobMsg).ToNot(BeEmpty())
Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob
expected := appendUDPSegmentSizeMsg([]byte{}, 3) expected := appendUDPSegmentSizeMsg([]byte{}, 3)
Expect(oobMsg).To(Equal(expected)) // Check that the first control message is the OOB control message.
Expect(oobMsg[:len(expected)]).To(Equal(expected))
}) })
}) })
} }

View file

@ -228,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil { if err := t.init(false); err != nil {
return 0, err return 0, err
} }
return t.conn.WritePacket(b, addr, nil, 0) return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNNon)
} }
func (t *Transport) enqueueClosePacket(p closePacket) { func (t *Transport) enqueueClosePacket(p closePacket) {
@ -246,7 +246,7 @@ func (t *Transport) runSendQueue() {
case <-t.listening: case <-t.listening:
return return
case p := <-t.closeQueue: case p := <-t.closeQueue:
t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0) t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNNon)
case p := <-t.statelessResetQueue: case p := <-t.statelessResetQueue:
t.sendStatelessReset(p) t.sendStatelessReset(p)
} }
@ -414,7 +414,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) {
rand.Read(data) rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40 data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...) data = append(data, token[:]...)
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil {
t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
} }
} }