add a method to retrieve non-QUIC packets from the Transport (#3992)

This commit is contained in:
Marten Seemann 2023-08-19 15:19:17 +07:00 committed by GitHub
parent 6880f88089
commit fe3c4f271d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 195 additions and 4 deletions

View file

@ -2,9 +2,11 @@ package self_test
import (
"context"
"crypto/rand"
"io"
"net"
"runtime"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
@ -210,4 +212,67 @@ var _ = Describe("Multiplexing", func() {
})
}
})
It("sends and receives non-QUIC packets", func() {
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn1, err := net.ListenUDP("udp", addr1)
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
runServer(server)
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var sentPackets, rcvdPackets atomic.Int64
const packetLen = 128
// send a non-QUIC packet every 100µs
go func() {
defer GinkgoRecover()
ticker := time.NewTicker(time.Millisecond / 10)
defer ticker.Stop()
for {
select {
case <-ticker.C:
case <-ctx.Done():
return
}
b := make([]byte, packetLen)
rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet
_, err := tr1.WriteTo(b, tr2.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
sentPackets.Add(1)
}
}()
// receive and count non-QUIC packets
go func() {
defer GinkgoRecover()
for {
b := make([]byte, 1024)
n, addr, err := tr2.ReadNonQUICPacket(ctx, b)
if err != nil {
Expect(err).To(MatchError(context.Canceled))
return
}
Expect(addr).To(Equal(tr1.Conn.LocalAddr()))
Expect(n).To(Equal(packetLen))
rcvdPackets.Add(1)
}
}()
dial(tr2, server.Addr())
Eventually(func() int64 { return sentPackets.Load() }).Should(BeNumerically(">", 10))
Eventually(func() int64 { return rcvdPackets.Load() }).Should(BeNumerically(">=", sentPackets.Load()*4/5))
})
})