diff --git a/packet_packer.go b/packet_packer.go index 858fcd71..6599a277 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -156,6 +156,7 @@ type packetPacker struct { pnManager packetNumberManager framer frameSource acks ackFrameSource + datagramQueue *datagramQueue retransmissionQueue *retransmissionQueue maxPacketSize protocol.ByteCount @@ -175,6 +176,7 @@ func newPacketPacker( cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, + datagramQueue *datagramQueue, perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { @@ -185,6 +187,7 @@ func newPacketPacker( initialStream: initialStream, handshakeStream: handshakeStream, retransmissionQueue: retransmissionQueue, + datagramQueue: datagramQueue, perspective: perspective, version: version, framer: framer, @@ -576,10 +579,25 @@ func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload { payload := &payload{} + + var hasDatagram bool + if p.datagramQueue != nil { + if datagram := p.datagramQueue.Get(); datagram != nil { + payload.frames = append(payload.frames, ackhandler.Frame{ + Frame: datagram, + // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. + OnLost: func(wire.Frame) {}, + }) + payload.length += datagram.Length(p.version) + hasDatagram = true + } + } + var ack *wire.AckFrame hasData := p.framer.HasData() hasRetransmission := p.retransmissionQueue.HasAppData() - if ackAllowed { + // TODO: make sure ACKs are sent when a lot of DATAGRAMs are queued + if !hasDatagram && ackAllowed { ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData) if ack != nil { payload.ack = ack diff --git a/packet_packer_test.go b/packet_packer_test.go index 49c9f23b..92b883d3 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -7,17 +7,16 @@ import ( "net" "time" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/ackhandler" - - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -30,6 +29,7 @@ var _ = Describe("Packet packer", func() { var ( packer *packetPacker retransmissionQueue *retransmissionQueue + datagramQueue *datagramQueue framer *MockFrameSource ackFramer *MockAckFrameSource initialStream *MockCryptoStream @@ -90,6 +90,7 @@ var _ = Describe("Packet packer", func() { ackFramer = NewMockAckFrameSource(mockCtrl) sealingManager = NewMockSealingManager(mockCtrl) pnManager = mockackhandler.NewMockSentPacketHandler(mockCtrl) + datagramQueue = newDatagramQueue(func() {}) packer = newPacketPacker( protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, @@ -102,6 +103,7 @@ var _ = Describe("Packet packer", func() { sealingManager, framer, ackFramer, + datagramQueue, protocol.PerspectiveServer, version, ) @@ -537,6 +539,33 @@ var _ = Describe("Packet packer", func() { Expect(p.buffer.Len()).ToNot(BeZero()) }) + It("packs DATAGRAM frames", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + f := &wire.DatagramFrame{ + DataLenPresent: true, + Data: []byte("foobar"), + } + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + datagramQueue.AddAndWait(f) + }() + // make sure the DATAGRAM has actually been queued + time.Sleep(scaleDuration(20 * time.Millisecond)) + + framer.EXPECT().HasData() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + Expect(p.frames[0].Frame).To(Equal(f)) + Expect(p.buffer.Data).ToNot(BeEmpty()) + Eventually(done).Should(BeClosed()) + }) + It("accounts for the space consumed by control frames", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) diff --git a/session.go b/session.go index 7fc28c0c..24b165a1 100644 --- a/session.go +++ b/session.go @@ -204,6 +204,8 @@ type session struct { keepAlivePingSent bool keepAliveInterval time.Duration + datagramQueue *datagramQueue + logID string tracer logging.ConnectionTracer logger utils.Logger @@ -334,6 +336,7 @@ var newSession = func( cs, s.framer, s.receivedPacketHandler, + s.datagramQueue, s.perspective, s.version, ) @@ -454,6 +457,7 @@ var newClientSession = func( cs, s.framer, s.receivedPacketHandler, + s.datagramQueue, s.perspective, s.version, ) @@ -1307,6 +1311,9 @@ func (s *session) handleCloseError(closeErr closeError) { s.streamsMap.CloseWithError(quicErr) s.connIDManager.Close() + if s.datagramQueue != nil { + s.datagramQueue.CloseWithError(quicErr) + } if s.tracer != nil { // timeout errors are logged as soon as they occur (to distinguish between handshake and idle timeouts)