From 5b5ffa942bc7d90eaaa7979d3f3c3a442d505188 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 30 Apr 2023 17:27:50 +0200 Subject: [PATCH] pack packets into large buffers when GSO is available --- buffer_pool.go | 5 +- connection.go | 121 ++++++++++++++---- connection_test.go | 269 +++++++++++++++++++++++++---------------- mock_send_conn_test.go | 9 +- mock_sender_test.go | 9 +- packet_handler_map.go | 4 +- send_conn.go | 12 +- send_conn_test.go | 2 +- send_queue.go | 23 ++-- send_queue_test.go | 36 +++--- server.go | 6 +- sys_conn.go | 6 +- sys_conn_df_linux.go | 4 +- sys_conn_no_gso.go | 4 +- sys_conn_oob.go | 11 +- transport.go | 4 +- 16 files changed, 339 insertions(+), 186 deletions(-) diff --git a/buffer_pool.go b/buffer_pool.go index 7d676c84..48589e12 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -51,9 +51,8 @@ func (b *packetBuffer) Release() { } // Len returns the length of Data -func (b *packetBuffer) Len() protocol.ByteCount { - return protocol.ByteCount(len(b.Data)) -} +func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) } +func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) } func (b *packetBuffer) putBack() { if cap(b.Data) == protocol.MaxPacketBufferSize { diff --git a/connection.go b/connection.go index ddc60183..ec538a24 100644 --- a/connection.go +++ b/connection.go @@ -1792,15 +1792,19 @@ func (s *connection) triggerSending() error { func (s *connection) sendPackets() error { now := time.Now() + // Path MTU Discovery + // Can't use GSO, since we need to send a single packet that's larger than our current maximum size. + // Performance-wise, this doesn't matter, since we only send a very small (<10) number of + // MTU probe packets per connection. if s.handshakeConfirmed && s.mtuDiscoverer != nil && s.mtuDiscoverer.ShouldSendProbe(now) { ping, size := s.mtuDiscoverer.GetPing() - p, buffer, err := s.packer.PackMTUProbePacket(ping, size, now, s.version) + p, buf, err := s.packer.PackMTUProbePacket(ping, size, now, s.version) if err != nil { return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) s.registerPackedShortHeaderPacket(p.Packet, now) - s.sendQueue.Send(buffer) + s.sendQueue.Send(buf, buf.Len()) // This is kind of a hack. We need to trigger sending again somehow. s.pacingDeadline = deadlineSendImmediately return nil @@ -1827,21 +1831,28 @@ func (s *connection) sendPackets() error { return nil } + if s.conn.capabilities().GSO { + return s.sendPacketsWithGSO(now) + } + return s.sendPacketsWithoutGSO(now) +} + +func (s *connection) sendPacketsWithoutGSO(now time.Time) error { for { buf := getPacketBuffer() - sent, err := s.appendPacket(buf, now) - if err != nil || !sent { + if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil { + if err == errNothingToPack { + buf.Release() + return nil + } return err } + + s.sendQueue.Send(buf, buf.Len()) + if s.sendQueue.WouldBlock() { return nil } - // Prioritize receiving of packets over sending out more packets. - if len(s.receivedPackets) > 0 { - s.pacingDeadline = deadlineSendImmediately - return nil - } - sendMode := s.sentPacketHandler.SendMode() if sendMode == ackhandler.SendPacingLimited { s.resetPacingDeadline() @@ -1850,6 +1861,66 @@ func (s *connection) sendPackets() error { if sendMode != ackhandler.SendAny { return nil } + // Prioritize receiving of packets over sending out more packets. + if len(s.receivedPackets) > 0 { + s.pacingDeadline = deadlineSendImmediately + return nil + } + } +} + +func (s *connection) sendPacketsWithGSO(now time.Time) error { + buf := getLargePacketBuffer() + maxSize := s.mtuDiscoverer.CurrentSize() + + for { + var dontSendMore bool + size, err := s.appendPacket(buf, maxSize, now) + if err != nil { + if err != errNothingToPack { + return err + } + if buf.Len() == 0 { + buf.Release() + return nil + } + dontSendMore = true + } + + if !dontSendMore { + sendMode := s.sentPacketHandler.SendMode() + if sendMode == ackhandler.SendPacingLimited { + s.resetPacingDeadline() + } + if sendMode != ackhandler.SendAny { + dontSendMore = true + } + } + + // Append another packet if + // 1. The congestion controller and pacer allow sending more + // 2. The last packet appended was a full-size packet + // 3. We still have enough space for another full-size packet in the buffer + if !dontSendMore && size == maxSize && buf.Len()+maxSize <= buf.Cap() { + continue + } + + s.sendQueue.Send(buf, maxSize) + + if dontSendMore { + return nil + } + if s.sendQueue.WouldBlock() { + return nil + } + + // Prioritize receiving of packets over sending out more packets. + if len(s.receivedPackets) > 0 { + s.pacingDeadline = deadlineSendImmediately + return nil + } + + buf = getLargePacketBuffer() } } @@ -1875,16 +1946,16 @@ func (s *connection) maybeSendAckOnlyPacket() error { } now := time.Now() - p, buffer, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) + p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { if err == errNothingToPack { return nil } return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) s.registerPackedShortHeaderPacket(p.Packet, now) - s.sendQueue.Send(buffer) + s.sendQueue.Send(buf, buf.Len()) return nil } @@ -1930,18 +2001,18 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { return nil } -func (s *connection) appendPacket(buf *packetBuffer, now time.Time) (bool, error) { - p, err := s.packer.AppendPacket(buf, s.mtuDiscoverer.CurrentSize(), s.version) +// appendPacket appends a new packet to the given packetBuffer. +// If there was nothing to pack, the returned size is 0. +func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) { + startLen := buf.Len() + p, err := s.packer.AppendPacket(buf, maxSize, s.version) if err != nil { - if err == errNothingToPack { - return false, nil - } - return false, err + return 0, err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) + size := buf.Len() - startLen + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false) s.registerPackedShortHeaderPacket(p.Packet, now) - s.sendQueue.Send(buf) - return true, nil + return size, nil } func (s *connection) registerPackedShortHeaderPacket(p *ackhandler.Packet, now time.Time) { @@ -1968,7 +2039,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time s.sentPacketHandler.SentPacket(p.Packet) } s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer) + s.sendQueue.Send(packet.buffer, packet.buffer.Len()) } func (s *connection) sendConnectionClose(e error) ([]byte, error) { @@ -1990,7 +2061,7 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { return nil, err } s.logCoalescedPacket(packet) - return packet.buffer.Data, s.conn.Write(packet.buffer.Data) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data, packet.buffer.Len()) } func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { diff --git a/connection_test.go b/connection_test.go index 1b7a7bbe..c259086b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -46,6 +46,7 @@ var _ = Describe("Connection", func() { packer *MockPacker cryptoSetup *mocks.MockCryptoSetup tracer *mocklogging.MockConnectionTracer + capabilities connCapabilities ) remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331} @@ -53,12 +54,6 @@ var _ = Describe("Connection", func() { destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - getShortHeaderPacket := func(pn protocol.PacketNumber) shortHeaderPacket { - buffer := getPacketBuffer() - buffer.Data = append(buffer.Data, []byte("foobar")...) - return shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: pn}} - } - getCoalescedPacket := func(pn protocol.PacketNumber, isLongHeader bool) *coalescedPacket { buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) @@ -91,11 +86,21 @@ var _ = Describe("Connection", func() { }) } + expectAppendPacket := func(packer *MockPacker, p shortHeaderPacket, b []byte) *gomock.Call { + return packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), Version1).DoAndReturn(func(buf *packetBuffer, _ protocol.ByteCount, _ protocol.VersionNumber) (shortHeaderPacket, error) { + buf.Data = append(buf.Data, b...) + return p, nil + }) + } + + enableGSO := func() { capabilities = connCapabilities{GSO: true} } + BeforeEach(func() { Eventually(areConnsRunning).Should(BeFalse()) connRunner = NewMockConnRunner(mockCtrl) mconn = NewMockSendConn(mockCtrl) + mconn.EXPECT().capabilities().DoAndReturn(func() connCapabilities { return capabilities }).AnyTimes() mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) @@ -136,6 +141,7 @@ var _ = Describe("Connection", func() { AfterEach(func() { Eventually(areConnsRunning).Should(BeFalse()) + capabilities = connCapabilities{} }) Context("frame handling", func() { @@ -447,7 +453,7 @@ var _ = Describe("Connection", func() { Expect(e.ErrorMessage).To(BeEmpty()) return &coalescedPacket{buffer: buffer}, nil }) - mconn.EXPECT().Write([]byte("connection close")) + mconn.EXPECT().Write([]byte("connection close"), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { var appErr *ApplicationError @@ -468,7 +474,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -487,7 +493,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), @@ -508,7 +514,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), @@ -557,7 +563,7 @@ var _ = Describe("Connection", func() { close(returned) }() Consistently(returned).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -599,7 +605,8 @@ var _ = Describe("Connection", func() { It("closes when the sendQueue encounters an error", func() { conn.handshakeConfirmed = true sconn := NewMockSendConn(mockCtrl) - sconn.EXPECT().Write(gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() + sconn.EXPECT().capabilities().AnyTimes() + sconn.EXPECT().Write(gomock.Any(), gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() conn.sendQueue = newSendQueue(sconn) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() @@ -613,8 +620,7 @@ var _ = Describe("Connection", func() { connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() conn.sentPacketHandler = sph - p := getShortHeaderPacket(1) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 1}}, []byte("foobar")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() runConn() conn.queueControlFrame(&wire.PingFrame{}) @@ -827,7 +833,7 @@ var _ = Describe("Connection", func() { // make the go routine return tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -861,7 +867,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -896,7 +902,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("close")) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -917,7 +923,7 @@ var _ = Describe("Connection", func() { close(done) }() expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) packet := getShortHeaderPacket(srcConnID, 0x42, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -944,7 +950,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) conn.shutdown() Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -965,7 +971,7 @@ var _ = Describe("Connection", func() { close(done) }() expectReplaceWithClosed() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) @@ -1181,7 +1187,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() @@ -1208,12 +1214,17 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) conn.sentPacketHandler = sph runConn() - p := getShortHeaderPacket(1) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + p := shortHeaderPacket{ + DestConnID: protocol.ParseConnectionID([]byte{1, 2, 3}), + PacketNumberLen: protocol.PacketNumberLen3, + Packet: &ackhandler.Packet{PacketNumber: 1337}, + KeyPhase: protocol.KeyPhaseOne, + } + expectAppendPacket(packer, p, []byte("foobar")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{ DestConnectionID: p.DestConnID, PacketNumber: p.PacketNumber, @@ -1256,13 +1267,12 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) - p := getShortHeaderPacket(1) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 13}}, []byte("foobar")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() conn.connFlowController = fc runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) @@ -1318,7 +1328,7 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1343,7 +1353,7 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1383,7 +1393,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() @@ -1396,12 +1406,63 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - p := getShortHeaderPacket(10) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) - p = getShortHeaderPacket(11) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 10}}, []byte("packet10")) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 11}}, []byte("packet11")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).Times(2) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ protocol.ByteCount) { + Expect(b.Data).To(Equal([]byte("packet10"))) + }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ protocol.ByteCount) { + Expect(b.Data).To(Equal([]byte("packet11"))) + }) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure that only 2 packets are sent + }) + + It("sends multiple packets one by one immediately, with GSO", func() { + enableGSO() + sph.EXPECT().SentPacket(gomock.Any()).Times(2) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) + payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload1) + payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload2) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 10}}, payload1) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 11}}, payload2) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any(), conn.mtuDiscoverer.CurrentSize()).Do(func(b *packetBuffer, l protocol.ByteCount) { + Expect(b.Data).To(Equal(append(payload1, payload2...))) + }) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + time.Sleep(50 * time.Millisecond) // make sure that only 2 packets are sent + }) + + It("stops appending packets when a smaller packet is packed, with GSO", func() { + enableGSO() + sph.EXPECT().SentPacket(gomock.Any()).Times(2) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) + sph.EXPECT().SendMode().Return(ackhandler.SendNone) + payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize()) + rand.Read(payload1) + payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize()-1) + rand.Read(payload2) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 10}}, payload1) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 11}}, payload2) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any(), conn.mtuDiscoverer.CurrentSize()).Do(func(b *packetBuffer, l protocol.ByteCount) { + Expect(b.Data).To(Equal(append(payload1, payload2...))) + }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1414,11 +1475,10 @@ var _ = Describe("Connection", func() { It("sends multiple packets, when the pacer allows immediate sending", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) - p := getShortHeaderPacket(10) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 10}}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1432,11 +1492,10 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited) - p := getShortHeaderPacket(10) - packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(p, getPacketBuffer(), nil) + packer.EXPECT().PackAckOnlyPacket(gomock.Any(), conn.version).Return(shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 123}}, getPacketBuffer(), nil) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1452,10 +1511,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().SendMode().Return(ackhandler.SendAck) - p := getShortHeaderPacket(100) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 100}}, []byte("packet100")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()) + sender.EXPECT().Send(gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1467,23 +1525,21 @@ var _ = Describe("Connection", func() { It("paces packets", func() { pacingDelay := scaleDuration(100 * time.Millisecond) - p1 := getShortHeaderPacket(100) - p2 := getShortHeaderPacket(101) gomock.InOrder( sph.EXPECT().SendMode().Return(ackhandler.SendAny), - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p1, nil), + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 100}}, []byte("packet100")), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().SendMode().Return(ackhandler.SendAny), - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p2, nil), + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 101}}, []byte("packet101")), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), ) written := make(chan struct{}, 2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(2) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(2) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1501,12 +1557,11 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendPacingLimited) sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) for pn := protocol.PacketNumber(1000); pn < 1003; pn++ { - p := getShortHeaderPacket(pn) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: pn}}, []byte("packet")) } written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(3) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(3) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1516,29 +1571,34 @@ var _ = Describe("Connection", func() { Eventually(written).Should(HaveLen(3)) }) - It("doesn't try to send if the send queue is full", func() { - available := make(chan struct{}, 1) - sender.EXPECT().WouldBlock().Return(true) - sender.EXPECT().Available().Return(available) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - conn.run() - }() - conn.scheduleSending() - time.Sleep(scaleDuration(50 * time.Millisecond)) + for _, withGSO := range []bool{false, true} { + withGSO := withGSO + It(fmt.Sprintf("doesn't try to send if the send queue is full: %t", withGSO), func() { + if withGSO { + enableGSO() + } + available := make(chan struct{}, 1) + sender.EXPECT().WouldBlock().Return(true) + sender.EXPECT().Available().Return(available) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + time.Sleep(scaleDuration(50 * time.Millisecond)) - written := make(chan struct{}) - sender.EXPECT().WouldBlock().AnyTimes() - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - p := getShortHeaderPacket(1000) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) - available <- struct{}{} - Eventually(written).Should(BeClosed()) - }) + written := make(chan struct{}) + sender.EXPECT().WouldBlock().AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 1000}}, []byte("packet1000")) + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { close(written) }) + available <- struct{}{} + Eventually(written).Should(BeClosed()) + }) + } It("stops sending when there are new packets to receive", func() { sender.EXPECT().WouldBlock().AnyTimes() @@ -1555,10 +1615,9 @@ var _ = Describe("Connection", func() { conn.handlePacket(&receivedPacket{buffer: getPacketBuffer()}) }) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - p := getShortHeaderPacket(1000) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 10}}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { close(written) }) conn.scheduleSending() time.Sleep(scaleDuration(50 * time.Millisecond)) @@ -1569,12 +1628,11 @@ var _ = Describe("Connection", func() { It("stops sending when the send queue is full", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SendMode().Return(ackhandler.SendAny) - p := getShortHeaderPacket(1000) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 1000}}, []byte("packet1000")) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1590,10 +1648,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() - p = getShortHeaderPacket(1001) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 1001}}, []byte("packet1001")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1625,12 +1682,11 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendNone) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) - p := getShortHeaderPacket(1) - packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), gomock.Any(), conn.version).Return(p, getPacketBuffer(), nil) + packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), gomock.Any(), conn.version).Return(shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 1}}, getPacketBuffer(), nil) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1659,7 +1715,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) sender.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -1675,8 +1731,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) conn.sentPacketHandler = sph - p := getShortHeaderPacket(1) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 1}}, []byte("packet1")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) go func() { @@ -1688,15 +1743,14 @@ var _ = Describe("Connection", func() { time.Sleep(50 * time.Millisecond) // only EXPECT calls after scheduleSending is called written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() conn.scheduleSending() Eventually(written).Should(BeClosed()) }) It("sets the timer to the ack timer", func() { - p := getShortHeaderPacket(1234) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(p, nil) + expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 1234}}, []byte("packet1234")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() @@ -1712,7 +1766,7 @@ var _ = Describe("Connection", func() { conn.receivedPacketHandler = rph written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() @@ -1776,7 +1830,7 @@ var _ = Describe("Connection", func() { ) sent := make(chan struct{}) - mconn.EXPECT().Write([]byte("foobar")).Do(func([]byte) { close(sent) }) + mconn.EXPECT().Write([]byte("foobar"), protocol.ByteCount(6)).Do(func([]byte, protocol.ByteCount) { close(sent) }) go func() { defer GinkgoRecover() @@ -1792,7 +1846,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -1827,7 +1881,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -1872,7 +1926,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -1894,7 +1948,7 @@ var _ = Describe("Connection", func() { }() handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) conn.closeLocal(errors.New("handshake error")) Consistently(handshakeCtx).ShouldNot(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed()) @@ -1907,7 +1961,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SetHandshakeConfirmed() sph.EXPECT().SentPacket(gomock.Any()) - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph done := make(chan struct{}) @@ -1925,7 +1979,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) close(conn.handshakeCompleteChan) conn.run() }() @@ -1953,7 +2007,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -1977,7 +2031,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed()) @@ -2034,7 +2088,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2170,7 +2224,7 @@ var _ = Describe("Connection", func() { // make the go routine return expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) conn.shutdown() Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2247,7 +2301,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2366,6 +2420,7 @@ var _ = Describe("Client Connection", func() { Eventually(areConnsRunning).Should(BeFalse()) mconn = NewMockSendConn(mockCtrl) + mconn.EXPECT().capabilities().AnyTimes() mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() mconn.EXPECT().capabilities().AnyTimes() @@ -2433,7 +2488,7 @@ var _ = Describe("Client Connection", func() { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() @@ -2694,7 +2749,7 @@ var _ = Describe("Client Connection", func() { packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()), tracer.EXPECT().Close(), diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index ac6c3d15..62c01a23 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + protocol "github.com/quic-go/quic-go/internal/protocol" ) // MockSendConn is a mock of SendConn interface. @@ -77,17 +78,17 @@ func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call { } // Write mocks base method. -func (m *MockSendConn) Write(arg0 []byte) error { +func (m *MockSendConn) Write(arg0 []byte, arg1 protocol.ByteCount) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0) + ret := m.ctrl.Call(m, "Write", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // Write indicates an expected call of Write. -func (mr *MockSendConnMockRecorder) Write(arg0 interface{}) *gomock.Call { +func (mr *MockSendConnMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1) } // capabilities mocks base method. diff --git a/mock_sender_test.go b/mock_sender_test.go index 3ec60235..feafdf4e 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -8,6 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + protocol "github.com/quic-go/quic-go/internal/protocol" ) // MockSender is a mock of Sender interface. @@ -74,15 +75,15 @@ func (mr *MockSenderMockRecorder) Run() *gomock.Call { } // Send mocks base method. -func (m *MockSender) Send(arg0 *packetBuffer) { +func (m *MockSender) Send(arg0 *packetBuffer, arg1 protocol.ByteCount) { m.ctrl.T.Helper() - m.ctrl.Call(m, "Send", arg0) + m.ctrl.Call(m, "Send", arg0, arg1) } // Send indicates an expected call of Send. -func (mr *MockSenderMockRecorder) Send(arg0 interface{}) *gomock.Call { +func (mr *MockSenderMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1) } // WouldBlock mocks base method. diff --git a/packet_handler_map.go b/packet_handler_map.go index 4c309138..823c6836 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -26,7 +26,9 @@ type connCapabilities struct { // rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { ReadPacket() (*receivedPacket, error) - WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + // The size parameter is used for GSO. + // If GSO is not support, len(b) must be equal to size. + WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer diff --git a/send_conn.go b/send_conn.go index 1ec26cf9..fab36747 100644 --- a/send_conn.go +++ b/send_conn.go @@ -1,12 +1,15 @@ package quic import ( + "math" "net" + + "github.com/quic-go/quic-go/internal/protocol" ) // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { - Write([]byte) error + Write(b []byte, size protocol.ByteCount) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr @@ -40,8 +43,11 @@ func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn { } } -func (c *sconn) Write(p []byte) error { - _, err := c.WritePacket(p, c.remoteAddr, c.oob) +func (c *sconn) Write(p []byte, size protocol.ByteCount) error { + if size > math.MaxUint16 { + panic("size overflow") + } + _, err := c.WritePacket(p, uint16(size), c.remoteAddr, c.oob) return err } diff --git a/send_conn_test.go b/send_conn_test.go index 2da3e3ab..0b5cc621 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -24,7 +24,7 @@ var _ = Describe("Connection (for sending packets)", func() { It("writes", func() { packetConn.EXPECT().WriteTo([]byte("foobar"), addr) - Expect(c.Write([]byte("foobar"))).To(Succeed()) + Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) }) It("gets the remote address", func() { diff --git a/send_queue.go b/send_queue.go index 9eafcd37..ab7a45ca 100644 --- a/send_queue.go +++ b/send_queue.go @@ -1,15 +1,22 @@ package quic +import "github.com/quic-go/quic-go/internal/protocol" + type sender interface { - Send(p *packetBuffer) + Send(p *packetBuffer, packetSize protocol.ByteCount) Run() error WouldBlock() bool Available() <-chan struct{} Close() } +type queueEntry struct { + buf *packetBuffer + size protocol.ByteCount +} + type sendQueue struct { - queue chan *packetBuffer + queue chan queueEntry closeCalled chan struct{} // runStopped when Close() is called runStopped chan struct{} // runStopped when the run loop returns available chan struct{} @@ -26,16 +33,16 @@ func newSendQueue(conn sendConn) sender { runStopped: make(chan struct{}), closeCalled: make(chan struct{}), available: make(chan struct{}, 1), - queue: make(chan *packetBuffer, sendQueueCapacity), + queue: make(chan queueEntry, sendQueueCapacity), } } // Send sends out a packet. It's guaranteed to not block. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Otherwise Send will panic. -func (h *sendQueue) Send(p *packetBuffer) { +func (h *sendQueue) Send(p *packetBuffer, size protocol.ByteCount) { select { - case h.queue <- p: + case h.queue <- queueEntry{buf: p, size: size}: // clear available channel if we've reached capacity if len(h.queue) == sendQueueCapacity { select { @@ -69,8 +76,8 @@ func (h *sendQueue) Run() error { h.closeCalled = nil // prevent this case from being selected again // make sure that all queued packets are actually sent out shouldClose = true - case p := <-h.queue: - if err := h.conn.Write(p.Data); err != nil { + case e := <-h.queue: + if err := h.conn.Write(e.buf.Data, e.size); err != nil { // This additional check enables: // 1. Checking for "datagram too large" message from the kernel, as such, // 2. Path MTU discovery,and @@ -79,7 +86,7 @@ func (h *sendQueue) Run() error { return err } } - p.Release() + e.buf.Release() select { case h.available <- struct{}{}: default: diff --git a/send_queue_test.go b/send_queue_test.go index 7fae24a4..5a9e6598 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -3,6 +3,8 @@ package quic import ( "errors" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -26,10 +28,10 @@ var _ = Describe("Send Queue", func() { It("sends a packet", func() { p := getPacket([]byte("foobar")) - q.Send(p) + q.Send(p, 10) // make sure the packet size is passed through to the conn written := make(chan struct{}) - c.EXPECT().Write([]byte("foobar")).Do(func([]byte) { close(written) }) + c.EXPECT().Write([]byte("foobar"), protocol.ByteCount(10)).Do(func([]byte, protocol.ByteCount) { close(written) }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -45,19 +47,19 @@ var _ = Describe("Send Queue", func() { It("panics when Send() is called although there's no space in the queue", func() { for i := 0; i < sendQueueCapacity; i++ { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar"))) + q.Send(getPacket([]byte("foobar")), 6) } Expect(q.WouldBlock()).To(BeTrue()) - Expect(func() { q.Send(getPacket([]byte("raboof"))) }).To(Panic()) + Expect(func() { q.Send(getPacket([]byte("raboof")), 6) }).To(Panic()) }) It("signals when sending is possible again", func() { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar1"))) + q.Send(getPacket([]byte("foobar1")), 6) Consistently(q.Available()).ShouldNot(Receive()) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any()).MinTimes(1).MaxTimes(2) + c.EXPECT().Write(gomock.Any(), gomock.Any()).MinTimes(1).MaxTimes(2) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -67,7 +69,7 @@ var _ = Describe("Send Queue", func() { Eventually(q.Available()).Should(Receive()) Expect(q.WouldBlock()).To(BeFalse()) - Expect(func() { q.Send(getPacket([]byte("foobar2"))) }).ToNot(Panic()) + Expect(func() { q.Send(getPacket([]byte("foobar2")), 7) }).ToNot(Panic()) q.Close() Eventually(done).Should(BeClosed()) @@ -77,7 +79,7 @@ var _ = Describe("Send Queue", func() { write := make(chan struct{}, 1) written := make(chan struct{}, 100) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) error { + c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, protocol.ByteCount) error { written <- struct{}{} <-write return nil @@ -92,19 +94,19 @@ var _ = Describe("Send Queue", func() { close(done) }() - q.Send(getPacket([]byte("foobar"))) + q.Send(getPacket([]byte("foobar")), 6) <-written // now fill up the send queue for i := 0; i < sendQueueCapacity; i++ { Expect(q.WouldBlock()).To(BeFalse()) - q.Send(getPacket([]byte("foobar"))) + q.Send(getPacket([]byte("foobar")), 6) } // One more packet is queued when it's picked up by Run and written to the connection. // In this test, it's blocked on write channel in the mocked Write call. <-written Eventually(q.WouldBlock()).Should(BeFalse()) - q.Send(getPacket([]byte("foobar"))) + q.Send(getPacket([]byte("foobar")), 6) Expect(q.WouldBlock()).To(BeTrue()) Consistently(q.Available()).ShouldNot(Receive()) @@ -130,15 +132,15 @@ var _ = Describe("Send Queue", func() { // the run loop exits if there is a write error testErr := errors.New("test error") - c.EXPECT().Write(gomock.Any()).Return(testErr) - q.Send(getPacket([]byte("foobar"))) + c.EXPECT().Write(gomock.Any(), gomock.Any()).Return(testErr) + q.Send(getPacket([]byte("foobar")), 6) Eventually(done).Should(BeClosed()) sent := make(chan struct{}) go func() { defer GinkgoRecover() - q.Send(getPacket([]byte("raboof"))) - q.Send(getPacket([]byte("quux"))) + q.Send(getPacket([]byte("raboof")), 6) + q.Send(getPacket([]byte("quux")), 4) close(sent) }() @@ -147,7 +149,7 @@ var _ = Describe("Send Queue", func() { It("blocks Close() until the packet has been sent out", func() { written := make(chan []byte) - c.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }) + c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ protocol.ByteCount) { written <- p }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -155,7 +157,7 @@ var _ = Describe("Send Queue", func() { close(done) }() - q.Send(getPacket([]byte("foobar"))) + q.Send(getPacket([]byte("foobar")), 6) closed := make(chan struct{}) go func() { diff --git a/server.go b/server.go index 600cd801..352a83d3 100644 --- a/server.go +++ b/server.go @@ -742,7 +742,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, uint16(len(buf.Data)), remoteAddr, info.OOB()) return err } @@ -839,7 +839,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, uint16(len(b.Data)), remoteAddr, info.OOB()) return err } @@ -877,7 +877,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p *receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/sys_conn.go b/sys_conn.go index 29c098a0..f7feabae 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -1,6 +1,7 @@ package quic import ( + "fmt" "net" "syscall" "time" @@ -95,7 +96,10 @@ func (c *basicConn) ReadPacket() (*receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, _ []byte) (n int, err error) { + if uint16(len(b)) != packetSize { + panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) + } return c.PacketConn.WriteTo(b, addr) } diff --git a/sys_conn_df_linux.go b/sys_conn_df_linux.go index 8de9c78f..ea8d585b 100644 --- a/sys_conn_df_linux.go +++ b/sys_conn_df_linux.go @@ -60,7 +60,7 @@ func isMsgSizeErr(err error) bool { return errors.Is(err, unix.EMSGSIZE) } -func appendUDPSegmentSizeMsg(b []byte, size int) []byte { +func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte { startLen := len(b) const dataLen = 2 // payload is a uint16 b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) @@ -71,6 +71,6 @@ func appendUDPSegmentSizeMsg(b []byte, size int) []byte { // UnixRights uses the private `data` method, but I *think* this achieves the same goal. offset := startLen + unix.CmsgSpace(0) - *(*uint16)(unsafe.Pointer(&b[offset])) = uint16(size) + *(*uint16)(unsafe.Pointer(&b[offset])) = size return b } diff --git a/sys_conn_no_gso.go b/sys_conn_no_gso.go index aa09f6bf..6f6a8c91 100644 --- a/sys_conn_no_gso.go +++ b/sys_conn_no_gso.go @@ -4,5 +4,5 @@ package quic import "syscall" -func maybeSetGSO(_ syscall.RawConn) bool { return false } -func appendUDPSegmentSizeMsg(_ []byte, _ int) []byte { return nil } +func maybeSetGSO(_ syscall.RawConn) bool { return false } +func appendUDPSegmentSizeMsg(_ []byte, _ uint16) []byte { return nil } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index f044d7bf..136f726f 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -5,6 +5,7 @@ package quic import ( "encoding/binary" "errors" + "fmt" "net" "syscall" "time" @@ -241,15 +242,19 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { // This is needed for users who call OptimizeConn to be able to send (non-QUIC) packets on the underlying connection. // With GSO enabled, this would otherwise not be needed, as the kernel requires the UDP_SEGMENT message to be set. func (c *oobConn) WriteTo(p []byte, addr net.Addr) (int, error) { - return c.WritePacket(p, addr, nil) + return c.WritePacket(p, uint16(len(p)), addr, nil) } // WritePacket writes a new packet. // If the connection supports GSO (and we activated GSO support before), // it appends the UDP_SEGMENT size message to oob. -func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) { +// Callers are advised to make sure that oob has a sufficient capacity, +// such that appending the UDP_SEGMENT size message doesn't cause an allocation. +func (c *oobConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, oob []byte) (n int, err error) { if c.cap.GSO { - oob = appendUDPSegmentSizeMsg(oob, len(b)) + oob = appendUDPSegmentSizeMsg(oob, packetSize) + } else if uint16(len(b)) != packetSize { + panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) } n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err diff --git a/transport.go b/transport.go index fcd63e62..b6d402b4 100644 --- a/transport.go +++ b/transport.go @@ -232,7 +232,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) + t.conn.WritePacket(p.payload, uint16(len(p.payload)), p.addr, p.info.OOB()) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -406,7 +406,7 @@ func (t *Transport) sendStatelessReset(p *receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := t.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } }