make the logging.Tracer and logging.ConnectionTracer a struct (#4082)

This commit is contained in:
Marten Seemann 2023-09-16 18:58:51 +07:00 committed by GitHub
parent d8cc4cb3ef
commit 9b82196578
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 1388 additions and 1158 deletions

View file

@ -38,18 +38,13 @@ func countKeyPhases() (sent, received int) {
return
}
type keyUpdateConnTracer struct {
logging.NullConnectionTracer
}
var _ logging.ConnectionTracer = &keyUpdateConnTracer{}
func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) {
sentHeaders = append(sentHeaders, hdr)
}
func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) {
receivedHeaders = append(receivedHeaders, hdr)
var keyUpdateConnTracer = &logging.ConnectionTracer{
SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) {
sentHeaders = append(sentHeaders, hdr)
},
ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) {
receivedHeaders = append(receivedHeaders, hdr)
},
}
var _ = Describe("Key Update tests", func() {
@ -77,8 +72,8 @@ var _ = Describe("Key Update tests", func() {
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return &keyUpdateConnTracer{}
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return keyUpdateConnTracer
}}),
)
Expect(err).ToNot(HaveOccurred())

View file

@ -21,7 +21,7 @@ var _ = Describe("Packetization", func() {
It("bundles ACKs", func() {
const numMsg = 100
serverTracer := newPacketTracer()
serverCounter, serverTracer := newPacketTracer()
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
@ -43,7 +43,7 @@ var _ = Describe("Packetization", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
clientTracer := newPacketTracer()
clientCounter, clientTracer := newPacketTracer()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
@ -104,8 +104,8 @@ var _ = Describe("Packetization", func() {
return
}
numBundledIncoming := countBundledPackets(clientTracer.getRcvdShortHeaderPackets())
numBundledOutgoing := countBundledPackets(serverTracer.getRcvdShortHeaderPackets())
numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets())
numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets())
fmt.Fprintf(GinkgoWriter, "bundled incoming packets: %d / %d\n", numBundledIncoming, numMsg)
fmt.Fprintf(GinkgoWriter, "bundled outgoing packets: %d / %d\n", numBundledOutgoing, numMsg)
Expect(numBundledIncoming).To(And(

View file

@ -86,7 +86,7 @@ var (
logBuf *syncedBuffer
versionParam string
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
enableQlog bool
version quic.VersionNumber
@ -177,10 +177,16 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
}
if enableQlog {
if conf.Tracer == nil {
conf.Tracer = qlogTracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.ConnectionTracer{},
)
}
} else if qlogTracer != nil {
origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID),
origTracer(ctx, p, connID),
@ -242,8 +248,8 @@ func scaleDuration(d time.Duration) time.Duration {
return time.Duration(scaleFactor) * d
}
func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer }
func newTracer(tracer *logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return tracer }
}
type packet struct {
@ -258,51 +264,46 @@ type shortHeaderPacket struct {
frames []logging.Frame
}
type packetTracer struct {
logging.NullConnectionTracer
type packetCounter struct {
closed chan struct{}
sentShortHdr, rcvdShortHdr []shortHeaderPacket
rcvdLongHdr []packet
}
var _ logging.ConnectionTracer = &packetTracer{}
func newPacketTracer() *packetTracer {
return &packetTracer{closed: make(chan struct{})}
}
func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
}
func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
}
func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) {
if ack != nil {
frames = append(frames, ack)
}
t.sentShortHdr = append(t.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
}
func (t *packetTracer) Close() { close(t.closed) }
func (t *packetTracer) getSentShortHeaderPackets() []shortHeaderPacket {
func (t *packetCounter) getSentShortHeaderPackets() []shortHeaderPacket {
<-t.closed
return t.sentShortHdr
}
func (t *packetTracer) getRcvdLongHeaderPackets() []packet {
func (t *packetCounter) getRcvdLongHeaderPackets() []packet {
<-t.closed
return t.rcvdLongHdr
}
func (t *packetTracer) getRcvdShortHeaderPackets() []shortHeaderPacket {
func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket {
<-t.closed
return t.rcvdShortHdr
}
func newPacketTracer() (*packetCounter, *logging.ConnectionTracer) {
c := &packetCounter{closed: make(chan struct{})}
return c, &logging.ConnectionTracer{
ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
c.rcvdLongHdr = append(c.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
},
ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
c.rcvdShortHdr = append(c.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
},
SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) {
if ack != nil {
frames = append(frames, ack)
}
c.sentShortHdr = append(c.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
},
Close: func() { close(c.closed) },
}
}
func TestSelf(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Self integration tests")

View file

@ -194,7 +194,7 @@ var _ = Describe("Timeout tests", func() {
close(serverConnClosed)
}()
tr := newPacketTracer()
counter, tr := newPacketTracer()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
@ -215,7 +215,7 @@ var _ = Describe("Timeout tests", func() {
}()
Eventually(done, 2*idleTimeout).Should(BeClosed())
var lastAckElicitingPacketSentAt time.Time
for _, p := range tr.getSentShortHeaderPackets() {
for _, p := range counter.getSentShortHeaderPackets() {
var hasAckElicitingFrame bool
for _, f := range p.frames {
if _, ok := f.(*logging.AckFrame); ok {
@ -228,7 +228,7 @@ var _ = Describe("Timeout tests", func() {
lastAckElicitingPacketSentAt = p.time
}
}
rcvdPackets := tr.getRcvdShortHeaderPackets()
rcvdPackets := counter.getRcvdShortHeaderPackets()
lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time
// We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout.
// This is ok since we're dealing with a lossless connection here,

View file

@ -26,9 +26,9 @@ var _ = Describe("Handshake tests", func() {
fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer)
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
if enableQlog {
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID)
return nil
@ -38,13 +38,13 @@ var _ = Describe("Handshake tests", func() {
})
}
if enableCustomTracer {
tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return logging.NullConnectionTracer{}
tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{}
})
}
c := conf.Clone()
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors))
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
tracers := make([]*logging.ConnectionTracer, 0, len(tracerConstructors))
for _, c := range tracerConstructors {
if tr := c(ctx, p, connID); tr != nil {
tracers = append(tracers, tr)

View file

@ -202,7 +202,7 @@ var _ = Describe("0-RTT", func() {
Eventually(conn.Context().Done()).Should(BeClosed())
}
// can be used to extract 0-RTT from a packetTracer
// can be used to extract 0-RTT from a packetCounter
get0RTTPackets := func(packets []packet) []protocol.PacketNumber {
var zeroRTTPackets []protocol.PacketNumber
for _, p := range packets {
@ -219,7 +219,7 @@ var _ = Describe("0-RTT", func() {
It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() {
tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil)
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -244,7 +244,7 @@ var _ = Describe("0-RTT", func() {
)
var numNewConnIDs int
for _, p := range tracer.getRcvdLongHeaderPackets() {
for _, p := range counter.getRcvdLongHeaderPackets() {
for _, f := range p.frames {
if _, ok := f.(*logging.NewConnectionIDFrame); ok {
numNewConnIDs++
@ -260,7 +260,7 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
})
@ -273,7 +273,7 @@ var _ = Describe("0-RTT", func() {
zeroRTTData := GeneratePRData(5 << 10)
oneRTTData := PRData
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -330,7 +330,7 @@ var _ = Describe("0-RTT", func() {
// check that 0-RTT packets only contain STREAM frames for the first stream
var num0RTT int
for _, p := range tracer.getRcvdLongHeaderPackets() {
for _, p := range counter.getRcvdLongHeaderPackets() {
if p.hdr.Header.Type != protocol.PacketType0RTT {
continue
}
@ -355,7 +355,7 @@ var _ = Describe("0-RTT", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -406,7 +406,7 @@ var _ = Describe("0-RTT", func() {
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
Expect(numDropped).ToNot(BeZero())
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
})
It("retransmits all 0-RTT data when the server performs a Retry", func() {
@ -430,7 +430,7 @@ var _ = Describe("0-RTT", func() {
return
}
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -480,7 +480,7 @@ var _ = Describe("0-RTT", func() {
defer mutex.Unlock()
Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra
Expect(secondCounter).To(BeNumerically("~", firstCounter, 20))
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5))
Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5)))
})
@ -491,14 +491,12 @@ var _ = Describe("0-RTT", func() {
MaxIncomingUniStreams: maxStreams,
}))
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams + 1,
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -536,7 +534,7 @@ var _ = Describe("0-RTT", func() {
MaxIncomingStreams: maxStreams,
}))
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -556,7 +554,7 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
It("rejects 0-RTT when the ALPN changed", func() {
@ -565,7 +563,7 @@ var _ = Describe("0-RTT", func() {
// now close the listener and dial new connection with a different ALPN
clientConf.NextProtos = []string{"new-alpn"}
tlsConf.NextProtos = []string{"new-alpn"}
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -585,14 +583,14 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
It("rejects 0-RTT when the application doesn't allow it", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
// now close the listener and dial new connection with a different ALPN
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -612,12 +610,12 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
DescribeTable("flow control limits",
func(addFlowControlLimit func(*quic.Config, uint64)) {
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
addFlowControlLimit(firstConf, 3)
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
@ -669,7 +667,7 @@ var _ = Describe("0-RTT", func() {
Eventually(conn.Context().Done()).Should(BeClosed())
var processedFirst bool
for _, p := range tracer.getRcvdLongHeaderPackets() {
for _, p := range counter.getRcvdLongHeaderPackets() {
for _, f := range p.frames {
if sf, ok := f.(*logging.StreamFrame); ok {
if !processedFirst {
@ -695,7 +693,7 @@ var _ = Describe("0-RTT", func() {
It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
// now dial new connection with different transport parameters
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -764,14 +762,14 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
}
It("queues 0-RTT packets, if the Initial is delayed", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -796,8 +794,8 @@ var _ = Describe("0-RTT", func() {
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
})
@ -807,7 +805,7 @@ var _ = Describe("0-RTT", func() {
EnableDatagrams: true,
}))
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -856,7 +854,7 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
Expect(zeroRTTPackets).To(HaveLen(1))
})
@ -865,7 +863,7 @@ var _ = Describe("0-RTT", func() {
EnableDatagrams: true,
}))
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -911,6 +909,6 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
})

View file

@ -232,7 +232,7 @@ var _ = Describe("0-RTT", func() {
Eventually(conn.Context().Done()).Should(BeClosed())
}
// can be used to extract 0-RTT from a packetTracer
// can be used to extract 0-RTT from a packetCounter
get0RTTPackets := func(packets []packet) []protocol.PacketNumber {
var zeroRTTPackets []protocol.PacketNumber
for _, p := range packets {
@ -251,7 +251,7 @@ var _ = Describe("0-RTT", func() {
clientTLSConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -276,7 +276,7 @@ var _ = Describe("0-RTT", func() {
)
var numNewConnIDs int
for _, p := range tracer.getRcvdLongHeaderPackets() {
for _, p := range counter.getRcvdLongHeaderPackets() {
for _, f := range p.frames {
if _, ok := f.(*logging.NewConnectionIDFrame); ok {
numNewConnIDs++
@ -292,7 +292,7 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
})
@ -307,7 +307,7 @@ var _ = Describe("0-RTT", func() {
zeroRTTData := GeneratePRData(5 << 10)
oneRTTData := PRData
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -364,7 +364,7 @@ var _ = Describe("0-RTT", func() {
// check that 0-RTT packets only contain STREAM frames for the first stream
var num0RTT int
for _, p := range tracer.getRcvdLongHeaderPackets() {
for _, p := range counter.getRcvdLongHeaderPackets() {
if p.hdr.Header.Type != protocol.PacketType0RTT {
continue
}
@ -391,7 +391,7 @@ var _ = Describe("0-RTT", func() {
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -442,7 +442,7 @@ var _ = Describe("0-RTT", func() {
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
Expect(numDropped).ToNot(BeZero())
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
})
It("retransmits all 0-RTT data when the server performs a Retry", func() {
@ -468,7 +468,7 @@ var _ = Describe("0-RTT", func() {
return
}
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -518,7 +518,7 @@ var _ = Describe("0-RTT", func() {
defer mutex.Unlock()
Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra
Expect(secondCounter).To(BeNumerically("~", firstCounter, 20))
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5))
Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5)))
})
@ -531,14 +531,12 @@ var _ = Describe("0-RTT", func() {
MaxIncomingUniStreams: maxStreams,
}), clientConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams + 1,
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -578,7 +576,7 @@ var _ = Describe("0-RTT", func() {
MaxIncomingStreams: maxStreams,
}), clientConf)
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -599,7 +597,7 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
It("rejects 0-RTT when the ALPN changed", func() {
@ -612,7 +610,7 @@ var _ = Describe("0-RTT", func() {
// Append to the client's ALPN.
// crypto/tls will attempt to resume with the ALPN from the original connection
clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn")
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -632,7 +630,7 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
It("rejects 0-RTT when the application doesn't allow it", func() {
@ -641,7 +639,7 @@ var _ = Describe("0-RTT", func() {
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
// now close the listener and dial new connection with a different ALPN
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -661,12 +659,12 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
DescribeTable("flow control limits",
func(addFlowControlLimit func(*quic.Config, uint64)) {
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
addFlowControlLimit(firstConf, 3)
tlsConf := getTLSConfig()
@ -720,7 +718,7 @@ var _ = Describe("0-RTT", func() {
Eventually(conn.Context().Done()).Should(BeClosed())
var processedFirst bool
for _, p := range tracer.getRcvdLongHeaderPackets() {
for _, p := range counter.getRcvdLongHeaderPackets() {
for _, f := range p.frames {
if sf, ok := f.(*logging.StreamFrame); ok {
if !processedFirst {
@ -748,7 +746,7 @@ var _ = Describe("0-RTT", func() {
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
// now dial new connection with different transport parameters
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -817,7 +815,7 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
}
@ -826,7 +824,7 @@ var _ = Describe("0-RTT", func() {
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -851,8 +849,8 @@ var _ = Describe("0-RTT", func() {
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
})
@ -878,14 +876,10 @@ var _ = Describe("0-RTT", func() {
clientTLSConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
getQuicConfig(&quic.Config{Allow0RTT: true}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
@ -916,14 +910,10 @@ var _ = Describe("0-RTT", func() {
}
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
getQuicConfig(&quic.Config{Allow0RTT: true}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
@ -946,7 +936,7 @@ var _ = Describe("0-RTT", func() {
EnableDatagrams: true,
}), clientTLSConf)
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -994,7 +984,7 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
Expect(zeroRTTPackets).To(HaveLen(1))
})
@ -1005,7 +995,7 @@ var _ = Describe("0-RTT", func() {
EnableDatagrams: true,
}), clientTLSConf)
tracer := newPacketTracer()
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -1051,6 +1041,6 @@ var _ = Describe("0-RTT", func() {
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
})

View file

@ -14,8 +14,8 @@ import (
"github.com/quic-go/quic-go/qlog"
)
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
role := "server"
if p == logging.PerspectiveClient {
role = "client"

View file

@ -21,29 +21,29 @@ type versioner interface {
GetVersion() protocol.VersionNumber
}
type versionNegotiationTracer struct {
logging.NullConnectionTracer
type result struct {
loggedVersions bool
receivedVersionNegotiation bool
chosen logging.VersionNumber
clientVersions, serverVersions []logging.VersionNumber
}
var _ logging.ConnectionTracer = &versionNegotiationTracer{}
func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
if t.loggedVersions {
Fail("only expected one call to NegotiatedVersions")
func newVersionNegotiationTracer() (*result, *logging.ConnectionTracer) {
r := &result{}
return r, &logging.ConnectionTracer{
NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
if r.loggedVersions {
Fail("only expected one call to NegotiatedVersions")
}
r.loggedVersions = true
r.chosen = chosen
r.clientVersions = clientVersions
r.serverVersions = serverVersions
},
ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
r.receivedVersionNegotiation = true
},
}
t.loggedVersions = true
t.chosen = chosen
t.clientVersions = clientVersions
t.serverVersions = serverVersions
}
func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
t.receivedVersionNegotiation = true
}
var _ = Describe("Handshake tests", func() {
@ -86,54 +86,54 @@ var _ = Describe("Handshake tests", func() {
// but it supports a bunch of versions that the client doesn't speak
serverConfig := &quic.Config{}
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
serverTracer := &versionNegotiationTracer{}
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
serverResult, serverTracer := newVersionNegotiationTracer()
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return serverTracer
}
server, cl := startServer(getTLSConfig(), serverConfig)
defer cl()
clientTracer := &versionNegotiationTracer{}
clientResult, clientTracer := newVersionNegotiationTracer()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer {
maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) *logging.ConnectionTracer {
return clientTracer
}}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
Expect(conn.CloseWithError(0, "")).To(Succeed())
Expect(clientTracer.chosen).To(Equal(expectedVersion))
Expect(clientTracer.receivedVersionNegotiation).To(BeFalse())
Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions))
Expect(clientTracer.serverVersions).To(BeEmpty())
Expect(serverTracer.chosen).To(Equal(expectedVersion))
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
Expect(serverTracer.clientVersions).To(BeEmpty())
Expect(clientResult.chosen).To(Equal(expectedVersion))
Expect(clientResult.receivedVersionNegotiation).To(BeFalse())
Expect(clientResult.clientVersions).To(Equal(protocol.SupportedVersions))
Expect(clientResult.serverVersions).To(BeEmpty())
Expect(serverResult.chosen).To(Equal(expectedVersion))
Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions))
Expect(serverResult.clientVersions).To(BeEmpty())
})
It("when the client supports more versions than the server supports", func() {
expectedVersion := protocol.SupportedVersions[0]
// The server doesn't support the highest supported version, which is the first one the client will try,
// but it supports a bunch of versions that the client doesn't speak
serverTracer := &versionNegotiationTracer{}
serverResult, serverTracer := newVersionNegotiationTracer()
serverConfig := &quic.Config{}
serverConfig.Versions = supportedVersions
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return serverTracer
}
server, cl := startServer(getTLSConfig(), serverConfig)
defer cl()
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
clientTracer := &versionNegotiationTracer{}
clientResult, clientTracer := newVersionNegotiationTracer()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQLOGTracer(&quic.Config{
Versions: clientVersions,
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return clientTracer
},
}),
@ -141,22 +141,22 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0]))
Expect(conn.CloseWithError(0, "")).To(Succeed())
Expect(clientTracer.chosen).To(Equal(expectedVersion))
Expect(clientTracer.receivedVersionNegotiation).To(BeTrue())
Expect(clientTracer.clientVersions).To(Equal(clientVersions))
Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions
Expect(serverTracer.chosen).To(Equal(expectedVersion))
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
Expect(serverTracer.clientVersions).To(BeEmpty())
Expect(clientResult.chosen).To(Equal(expectedVersion))
Expect(clientResult.receivedVersionNegotiation).To(BeTrue())
Expect(clientResult.clientVersions).To(Equal(clientVersions))
Expect(clientResult.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions
Expect(serverResult.chosen).To(Equal(expectedVersion))
Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions))
Expect(serverResult.clientVersions).To(BeEmpty())
})
It("fails if the server disables version negotiation", func() {
// The server doesn't support the highest supported version, which is the first one the client will try,
// but it supports a bunch of versions that the client doesn't speak
serverTracer := &versionNegotiationTracer{}
_, serverTracer := newVersionNegotiationTracer()
serverConfig := &quic.Config{}
serverConfig.Versions = supportedVersions
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return serverTracer
}
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
@ -170,14 +170,14 @@ var _ = Describe("Handshake tests", func() {
defer ln.Close()
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
clientTracer := &versionNegotiationTracer{}
clientResult, clientTracer := newVersionNegotiationTracer()
_, err = quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQLOGTracer(&quic.Config{
Versions: clientVersions,
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return clientTracer
},
HandshakeIdleTimeout: 100 * time.Millisecond,
@ -187,7 +187,7 @@ var _ = Describe("Handshake tests", func() {
var nerr net.Error
Expect(errors.As(err, &nerr)).To(BeTrue())
Expect(nerr.Timeout()).To(BeTrue())
Expect(clientTracer.receivedVersionNegotiation).To(BeFalse())
Expect(clientResult.receivedVersionNegotiation).To(BeFalse())
})
}
})

View file

@ -70,7 +70,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config {
c.Tracer = qlogger
} else if qlogger != nil {
origTracer := c.Tracer
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogger(ctx, p, connID),
origTracer(ctx, p, connID),