mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
make the logging.Tracer and logging.ConnectionTracer a struct (#4082)
This commit is contained in:
parent
d8cc4cb3ef
commit
9b82196578
46 changed files with 1388 additions and 1158 deletions
|
@ -38,18 +38,13 @@ func countKeyPhases() (sent, received int) {
|
|||
return
|
||||
}
|
||||
|
||||
type keyUpdateConnTracer struct {
|
||||
logging.NullConnectionTracer
|
||||
}
|
||||
|
||||
var _ logging.ConnectionTracer = &keyUpdateConnTracer{}
|
||||
|
||||
func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) {
|
||||
sentHeaders = append(sentHeaders, hdr)
|
||||
}
|
||||
|
||||
func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) {
|
||||
receivedHeaders = append(receivedHeaders, hdr)
|
||||
var keyUpdateConnTracer = &logging.ConnectionTracer{
|
||||
SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) {
|
||||
sentHeaders = append(sentHeaders, hdr)
|
||||
},
|
||||
ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) {
|
||||
receivedHeaders = append(receivedHeaders, hdr)
|
||||
},
|
||||
}
|
||||
|
||||
var _ = Describe("Key Update tests", func() {
|
||||
|
@ -77,8 +72,8 @@ var _ = Describe("Key Update tests", func() {
|
|||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return &keyUpdateConnTracer{}
|
||||
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return keyUpdateConnTracer
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -21,7 +21,7 @@ var _ = Describe("Packetization", func() {
|
|||
It("bundles ACKs", func() {
|
||||
const numMsg = 100
|
||||
|
||||
serverTracer := newPacketTracer()
|
||||
serverCounter, serverTracer := newPacketTracer()
|
||||
server, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
getTLSConfig(),
|
||||
|
@ -43,7 +43,7 @@ var _ = Describe("Packetization", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
clientTracer := newPacketTracer()
|
||||
clientCounter, clientTracer := newPacketTracer()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
|
@ -104,8 +104,8 @@ var _ = Describe("Packetization", func() {
|
|||
return
|
||||
}
|
||||
|
||||
numBundledIncoming := countBundledPackets(clientTracer.getRcvdShortHeaderPackets())
|
||||
numBundledOutgoing := countBundledPackets(serverTracer.getRcvdShortHeaderPackets())
|
||||
numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets())
|
||||
numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets())
|
||||
fmt.Fprintf(GinkgoWriter, "bundled incoming packets: %d / %d\n", numBundledIncoming, numMsg)
|
||||
fmt.Fprintf(GinkgoWriter, "bundled outgoing packets: %d / %d\n", numBundledOutgoing, numMsg)
|
||||
Expect(numBundledIncoming).To(And(
|
||||
|
|
|
@ -86,7 +86,7 @@ var (
|
|||
logBuf *syncedBuffer
|
||||
versionParam string
|
||||
|
||||
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
|
||||
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
|
||||
enableQlog bool
|
||||
|
||||
version quic.VersionNumber
|
||||
|
@ -177,10 +177,16 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
|
|||
}
|
||||
if enableQlog {
|
||||
if conf.Tracer == nil {
|
||||
conf.Tracer = qlogTracer
|
||||
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return logging.NewMultiplexedConnectionTracer(
|
||||
qlogTracer(ctx, p, connID),
|
||||
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
|
||||
&logging.ConnectionTracer{},
|
||||
)
|
||||
}
|
||||
} else if qlogTracer != nil {
|
||||
origTracer := conf.Tracer
|
||||
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return logging.NewMultiplexedConnectionTracer(
|
||||
qlogTracer(ctx, p, connID),
|
||||
origTracer(ctx, p, connID),
|
||||
|
@ -242,8 +248,8 @@ func scaleDuration(d time.Duration) time.Duration {
|
|||
return time.Duration(scaleFactor) * d
|
||||
}
|
||||
|
||||
func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer }
|
||||
func newTracer(tracer *logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return tracer }
|
||||
}
|
||||
|
||||
type packet struct {
|
||||
|
@ -258,51 +264,46 @@ type shortHeaderPacket struct {
|
|||
frames []logging.Frame
|
||||
}
|
||||
|
||||
type packetTracer struct {
|
||||
logging.NullConnectionTracer
|
||||
type packetCounter struct {
|
||||
closed chan struct{}
|
||||
sentShortHdr, rcvdShortHdr []shortHeaderPacket
|
||||
rcvdLongHdr []packet
|
||||
}
|
||||
|
||||
var _ logging.ConnectionTracer = &packetTracer{}
|
||||
|
||||
func newPacketTracer() *packetTracer {
|
||||
return &packetTracer{closed: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
|
||||
t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
|
||||
}
|
||||
|
||||
func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
|
||||
t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
|
||||
}
|
||||
|
||||
func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) {
|
||||
if ack != nil {
|
||||
frames = append(frames, ack)
|
||||
}
|
||||
t.sentShortHdr = append(t.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
|
||||
}
|
||||
|
||||
func (t *packetTracer) Close() { close(t.closed) }
|
||||
|
||||
func (t *packetTracer) getSentShortHeaderPackets() []shortHeaderPacket {
|
||||
func (t *packetCounter) getSentShortHeaderPackets() []shortHeaderPacket {
|
||||
<-t.closed
|
||||
return t.sentShortHdr
|
||||
}
|
||||
|
||||
func (t *packetTracer) getRcvdLongHeaderPackets() []packet {
|
||||
func (t *packetCounter) getRcvdLongHeaderPackets() []packet {
|
||||
<-t.closed
|
||||
return t.rcvdLongHdr
|
||||
}
|
||||
|
||||
func (t *packetTracer) getRcvdShortHeaderPackets() []shortHeaderPacket {
|
||||
func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket {
|
||||
<-t.closed
|
||||
return t.rcvdShortHdr
|
||||
}
|
||||
|
||||
func newPacketTracer() (*packetCounter, *logging.ConnectionTracer) {
|
||||
c := &packetCounter{closed: make(chan struct{})}
|
||||
return c, &logging.ConnectionTracer{
|
||||
ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
|
||||
c.rcvdLongHdr = append(c.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
|
||||
},
|
||||
ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
|
||||
c.rcvdShortHdr = append(c.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
|
||||
},
|
||||
SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) {
|
||||
if ack != nil {
|
||||
frames = append(frames, ack)
|
||||
}
|
||||
c.sentShortHdr = append(c.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
|
||||
},
|
||||
Close: func() { close(c.closed) },
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelf(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Self integration tests")
|
||||
|
|
|
@ -194,7 +194,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
close(serverConnClosed)
|
||||
}()
|
||||
|
||||
tr := newPacketTracer()
|
||||
counter, tr := newPacketTracer()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
|
@ -215,7 +215,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
}()
|
||||
Eventually(done, 2*idleTimeout).Should(BeClosed())
|
||||
var lastAckElicitingPacketSentAt time.Time
|
||||
for _, p := range tr.getSentShortHeaderPackets() {
|
||||
for _, p := range counter.getSentShortHeaderPackets() {
|
||||
var hasAckElicitingFrame bool
|
||||
for _, f := range p.frames {
|
||||
if _, ok := f.(*logging.AckFrame); ok {
|
||||
|
@ -228,7 +228,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
lastAckElicitingPacketSentAt = p.time
|
||||
}
|
||||
}
|
||||
rcvdPackets := tr.getRcvdShortHeaderPackets()
|
||||
rcvdPackets := counter.getRcvdShortHeaderPackets()
|
||||
lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time
|
||||
// We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout.
|
||||
// This is ok since we're dealing with a lossless connection here,
|
||||
|
|
|
@ -26,9 +26,9 @@ var _ = Describe("Handshake tests", func() {
|
|||
|
||||
fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer)
|
||||
|
||||
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
|
||||
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
|
||||
if enableQlog {
|
||||
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID)
|
||||
return nil
|
||||
|
@ -38,13 +38,13 @@ var _ = Describe("Handshake tests", func() {
|
|||
})
|
||||
}
|
||||
if enableCustomTracer {
|
||||
tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return logging.NullConnectionTracer{}
|
||||
tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{}
|
||||
})
|
||||
}
|
||||
c := conf.Clone()
|
||||
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors))
|
||||
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
tracers := make([]*logging.ConnectionTracer, 0, len(tracerConstructors))
|
||||
for _, c := range tracerConstructors {
|
||||
if tr := c(ctx, p, connID); tr != nil {
|
||||
tracers = append(tracers, tr)
|
||||
|
|
|
@ -202,7 +202,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
}
|
||||
|
||||
// can be used to extract 0-RTT from a packetTracer
|
||||
// can be used to extract 0-RTT from a packetCounter
|
||||
get0RTTPackets := func(packets []packet) []protocol.PacketNumber {
|
||||
var zeroRTTPackets []protocol.PacketNumber
|
||||
for _, p := range packets {
|
||||
|
@ -219,7 +219,7 @@ var _ = Describe("0-RTT", func() {
|
|||
It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() {
|
||||
tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -244,7 +244,7 @@ var _ = Describe("0-RTT", func() {
|
|||
)
|
||||
|
||||
var numNewConnIDs int
|
||||
for _, p := range tracer.getRcvdLongHeaderPackets() {
|
||||
for _, p := range counter.getRcvdLongHeaderPackets() {
|
||||
for _, f := range p.frames {
|
||||
if _, ok := f.(*logging.NewConnectionIDFrame); ok {
|
||||
numNewConnIDs++
|
||||
|
@ -260,7 +260,7 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
|
||||
Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
|
||||
})
|
||||
|
@ -273,7 +273,7 @@ var _ = Describe("0-RTT", func() {
|
|||
zeroRTTData := GeneratePRData(5 << 10)
|
||||
oneRTTData := PRData
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -330,7 +330,7 @@ var _ = Describe("0-RTT", func() {
|
|||
|
||||
// check that 0-RTT packets only contain STREAM frames for the first stream
|
||||
var num0RTT int
|
||||
for _, p := range tracer.getRcvdLongHeaderPackets() {
|
||||
for _, p := range counter.getRcvdLongHeaderPackets() {
|
||||
if p.hdr.Header.Type != protocol.PacketType0RTT {
|
||||
continue
|
||||
}
|
||||
|
@ -355,7 +355,7 @@ var _ = Describe("0-RTT", func() {
|
|||
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -406,7 +406,7 @@ var _ = Describe("0-RTT", func() {
|
|||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
|
||||
Expect(numDropped).ToNot(BeZero())
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("retransmits all 0-RTT data when the server performs a Retry", func() {
|
||||
|
@ -430,7 +430,7 @@ var _ = Describe("0-RTT", func() {
|
|||
return
|
||||
}
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -480,7 +480,7 @@ var _ = Describe("0-RTT", func() {
|
|||
defer mutex.Unlock()
|
||||
Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra
|
||||
Expect(secondCounter).To(BeNumerically("~", firstCounter, 20))
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5))
|
||||
Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5)))
|
||||
})
|
||||
|
@ -491,14 +491,12 @@ var _ = Describe("0-RTT", func() {
|
|||
MaxIncomingUniStreams: maxStreams,
|
||||
}))
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: maxStreams + 1,
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -536,7 +534,7 @@ var _ = Describe("0-RTT", func() {
|
|||
MaxIncomingStreams: maxStreams,
|
||||
}))
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -556,7 +554,7 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("rejects 0-RTT when the ALPN changed", func() {
|
||||
|
@ -565,7 +563,7 @@ var _ = Describe("0-RTT", func() {
|
|||
// now close the listener and dial new connection with a different ALPN
|
||||
clientConf.NextProtos = []string{"new-alpn"}
|
||||
tlsConf.NextProtos = []string{"new-alpn"}
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -585,14 +583,14 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("rejects 0-RTT when the application doesn't allow it", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
// now close the listener and dial new connection with a different ALPN
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -612,12 +610,12 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
|
||||
DescribeTable("flow control limits",
|
||||
func(addFlowControlLimit func(*quic.Config, uint64)) {
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
|
||||
addFlowControlLimit(firstConf, 3)
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
|
||||
|
@ -669,7 +667,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
|
||||
var processedFirst bool
|
||||
for _, p := range tracer.getRcvdLongHeaderPackets() {
|
||||
for _, p := range counter.getRcvdLongHeaderPackets() {
|
||||
for _, f := range p.frames {
|
||||
if sf, ok := f.(*logging.StreamFrame); ok {
|
||||
if !processedFirst {
|
||||
|
@ -695,7 +693,7 @@ var _ = Describe("0-RTT", func() {
|
|||
It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
// now dial new connection with different transport parameters
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -764,14 +762,14 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
}
|
||||
|
||||
It("queues 0-RTT packets, if the Initial is delayed", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -796,8 +794,8 @@ var _ = Describe("0-RTT", func() {
|
|||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
|
||||
|
||||
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
|
||||
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
|
||||
})
|
||||
|
@ -807,7 +805,7 @@ var _ = Describe("0-RTT", func() {
|
|||
EnableDatagrams: true,
|
||||
}))
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -856,7 +854,7 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
|
||||
Expect(zeroRTTPackets).To(HaveLen(1))
|
||||
})
|
||||
|
||||
|
@ -865,7 +863,7 @@ var _ = Describe("0-RTT", func() {
|
|||
EnableDatagrams: true,
|
||||
}))
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -911,6 +909,6 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -232,7 +232,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
}
|
||||
|
||||
// can be used to extract 0-RTT from a packetTracer
|
||||
// can be used to extract 0-RTT from a packetCounter
|
||||
get0RTTPackets := func(packets []packet) []protocol.PacketNumber {
|
||||
var zeroRTTPackets []protocol.PacketNumber
|
||||
for _, p := range packets {
|
||||
|
@ -251,7 +251,7 @@ var _ = Describe("0-RTT", func() {
|
|||
clientTLSConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -276,7 +276,7 @@ var _ = Describe("0-RTT", func() {
|
|||
)
|
||||
|
||||
var numNewConnIDs int
|
||||
for _, p := range tracer.getRcvdLongHeaderPackets() {
|
||||
for _, p := range counter.getRcvdLongHeaderPackets() {
|
||||
for _, f := range p.frames {
|
||||
if _, ok := f.(*logging.NewConnectionIDFrame); ok {
|
||||
numNewConnIDs++
|
||||
|
@ -292,7 +292,7 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
|
||||
Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
|
||||
})
|
||||
|
@ -307,7 +307,7 @@ var _ = Describe("0-RTT", func() {
|
|||
zeroRTTData := GeneratePRData(5 << 10)
|
||||
oneRTTData := PRData
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -364,7 +364,7 @@ var _ = Describe("0-RTT", func() {
|
|||
|
||||
// check that 0-RTT packets only contain STREAM frames for the first stream
|
||||
var num0RTT int
|
||||
for _, p := range tracer.getRcvdLongHeaderPackets() {
|
||||
for _, p := range counter.getRcvdLongHeaderPackets() {
|
||||
if p.hdr.Header.Type != protocol.PacketType0RTT {
|
||||
continue
|
||||
}
|
||||
|
@ -391,7 +391,7 @@ var _ = Describe("0-RTT", func() {
|
|||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -442,7 +442,7 @@ var _ = Describe("0-RTT", func() {
|
|||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
|
||||
Expect(numDropped).ToNot(BeZero())
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("retransmits all 0-RTT data when the server performs a Retry", func() {
|
||||
|
@ -468,7 +468,7 @@ var _ = Describe("0-RTT", func() {
|
|||
return
|
||||
}
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -518,7 +518,7 @@ var _ = Describe("0-RTT", func() {
|
|||
defer mutex.Unlock()
|
||||
Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra
|
||||
Expect(secondCounter).To(BeNumerically("~", firstCounter, 20))
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5))
|
||||
Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5)))
|
||||
})
|
||||
|
@ -531,14 +531,12 @@ var _ = Describe("0-RTT", func() {
|
|||
MaxIncomingUniStreams: maxStreams,
|
||||
}), clientConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: maxStreams + 1,
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -578,7 +576,7 @@ var _ = Describe("0-RTT", func() {
|
|||
MaxIncomingStreams: maxStreams,
|
||||
}), clientConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -599,7 +597,7 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("rejects 0-RTT when the ALPN changed", func() {
|
||||
|
@ -612,7 +610,7 @@ var _ = Describe("0-RTT", func() {
|
|||
// Append to the client's ALPN.
|
||||
// crypto/tls will attempt to resume with the ALPN from the original connection
|
||||
clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn")
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -632,7 +630,7 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("rejects 0-RTT when the application doesn't allow it", func() {
|
||||
|
@ -641,7 +639,7 @@ var _ = Describe("0-RTT", func() {
|
|||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
|
||||
// now close the listener and dial new connection with a different ALPN
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -661,12 +659,12 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
|
||||
DescribeTable("flow control limits",
|
||||
func(addFlowControlLimit func(*quic.Config, uint64)) {
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
|
||||
addFlowControlLimit(firstConf, 3)
|
||||
tlsConf := getTLSConfig()
|
||||
|
@ -720,7 +718,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
|
||||
var processedFirst bool
|
||||
for _, p := range tracer.getRcvdLongHeaderPackets() {
|
||||
for _, p := range counter.getRcvdLongHeaderPackets() {
|
||||
for _, f := range p.frames {
|
||||
if sf, ok := f.(*logging.StreamFrame); ok {
|
||||
if !processedFirst {
|
||||
|
@ -748,7 +746,7 @@ var _ = Describe("0-RTT", func() {
|
|||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
// now dial new connection with different transport parameters
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -817,7 +815,7 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -826,7 +824,7 @@ var _ = Describe("0-RTT", func() {
|
|||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -851,8 +849,8 @@ var _ = Describe("0-RTT", func() {
|
|||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
|
||||
|
||||
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
|
||||
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
|
||||
})
|
||||
|
@ -878,14 +876,10 @@ var _ = Describe("0-RTT", func() {
|
|||
clientTLSConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
getQuicConfig(&quic.Config{Allow0RTT: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
@ -916,14 +910,10 @@ var _ = Describe("0-RTT", func() {
|
|||
}
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
getQuicConfig(&quic.Config{Allow0RTT: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
@ -946,7 +936,7 @@ var _ = Describe("0-RTT", func() {
|
|||
EnableDatagrams: true,
|
||||
}), clientTLSConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -994,7 +984,7 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets())
|
||||
Expect(zeroRTTPackets).To(HaveLen(1))
|
||||
})
|
||||
|
||||
|
@ -1005,7 +995,7 @@ var _ = Describe("0-RTT", func() {
|
|||
EnableDatagrams: true,
|
||||
}), clientTLSConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
counter, tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
|
@ -1051,6 +1041,6 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -14,8 +14,8 @@ import (
|
|||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
role := "server"
|
||||
if p == logging.PerspectiveClient {
|
||||
role = "client"
|
||||
|
|
|
@ -21,29 +21,29 @@ type versioner interface {
|
|||
GetVersion() protocol.VersionNumber
|
||||
}
|
||||
|
||||
type versionNegotiationTracer struct {
|
||||
logging.NullConnectionTracer
|
||||
|
||||
type result struct {
|
||||
loggedVersions bool
|
||||
receivedVersionNegotiation bool
|
||||
chosen logging.VersionNumber
|
||||
clientVersions, serverVersions []logging.VersionNumber
|
||||
}
|
||||
|
||||
var _ logging.ConnectionTracer = &versionNegotiationTracer{}
|
||||
|
||||
func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
|
||||
if t.loggedVersions {
|
||||
Fail("only expected one call to NegotiatedVersions")
|
||||
func newVersionNegotiationTracer() (*result, *logging.ConnectionTracer) {
|
||||
r := &result{}
|
||||
return r, &logging.ConnectionTracer{
|
||||
NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
|
||||
if r.loggedVersions {
|
||||
Fail("only expected one call to NegotiatedVersions")
|
||||
}
|
||||
r.loggedVersions = true
|
||||
r.chosen = chosen
|
||||
r.clientVersions = clientVersions
|
||||
r.serverVersions = serverVersions
|
||||
},
|
||||
ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
|
||||
r.receivedVersionNegotiation = true
|
||||
},
|
||||
}
|
||||
t.loggedVersions = true
|
||||
t.chosen = chosen
|
||||
t.clientVersions = clientVersions
|
||||
t.serverVersions = serverVersions
|
||||
}
|
||||
|
||||
func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
|
||||
t.receivedVersionNegotiation = true
|
||||
}
|
||||
|
||||
var _ = Describe("Handshake tests", func() {
|
||||
|
@ -86,54 +86,54 @@ var _ = Describe("Handshake tests", func() {
|
|||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverConfig := &quic.Config{}
|
||||
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
serverResult, serverTracer := newVersionNegotiationTracer()
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return serverTracer
|
||||
}
|
||||
server, cl := startServer(getTLSConfig(), serverConfig)
|
||||
defer cl()
|
||||
clientTracer := &versionNegotiationTracer{}
|
||||
clientResult, clientTracer := newVersionNegotiationTracer()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer {
|
||||
maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return clientTracer
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
Expect(clientTracer.chosen).To(Equal(expectedVersion))
|
||||
Expect(clientTracer.receivedVersionNegotiation).To(BeFalse())
|
||||
Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(clientTracer.serverVersions).To(BeEmpty())
|
||||
Expect(serverTracer.chosen).To(Equal(expectedVersion))
|
||||
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
|
||||
Expect(serverTracer.clientVersions).To(BeEmpty())
|
||||
Expect(clientResult.chosen).To(Equal(expectedVersion))
|
||||
Expect(clientResult.receivedVersionNegotiation).To(BeFalse())
|
||||
Expect(clientResult.clientVersions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(clientResult.serverVersions).To(BeEmpty())
|
||||
Expect(serverResult.chosen).To(Equal(expectedVersion))
|
||||
Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions))
|
||||
Expect(serverResult.clientVersions).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("when the client supports more versions than the server supports", func() {
|
||||
expectedVersion := protocol.SupportedVersions[0]
|
||||
// The server doesn't support the highest supported version, which is the first one the client will try,
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverResult, serverTracer := newVersionNegotiationTracer()
|
||||
serverConfig := &quic.Config{}
|
||||
serverConfig.Versions = supportedVersions
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return serverTracer
|
||||
}
|
||||
server, cl := startServer(getTLSConfig(), serverConfig)
|
||||
defer cl()
|
||||
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
|
||||
clientTracer := &versionNegotiationTracer{}
|
||||
clientResult, clientTracer := newVersionNegotiationTracer()
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
maybeAddQLOGTracer(&quic.Config{
|
||||
Versions: clientVersions,
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return clientTracer
|
||||
},
|
||||
}),
|
||||
|
@ -141,22 +141,22 @@ var _ = Describe("Handshake tests", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0]))
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
Expect(clientTracer.chosen).To(Equal(expectedVersion))
|
||||
Expect(clientTracer.receivedVersionNegotiation).To(BeTrue())
|
||||
Expect(clientTracer.clientVersions).To(Equal(clientVersions))
|
||||
Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions
|
||||
Expect(serverTracer.chosen).To(Equal(expectedVersion))
|
||||
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
|
||||
Expect(serverTracer.clientVersions).To(BeEmpty())
|
||||
Expect(clientResult.chosen).To(Equal(expectedVersion))
|
||||
Expect(clientResult.receivedVersionNegotiation).To(BeTrue())
|
||||
Expect(clientResult.clientVersions).To(Equal(clientVersions))
|
||||
Expect(clientResult.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions
|
||||
Expect(serverResult.chosen).To(Equal(expectedVersion))
|
||||
Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions))
|
||||
Expect(serverResult.clientVersions).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("fails if the server disables version negotiation", func() {
|
||||
// The server doesn't support the highest supported version, which is the first one the client will try,
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
_, serverTracer := newVersionNegotiationTracer()
|
||||
serverConfig := &quic.Config{}
|
||||
serverConfig.Versions = supportedVersions
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return serverTracer
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
|
||||
|
@ -170,14 +170,14 @@ var _ = Describe("Handshake tests", func() {
|
|||
defer ln.Close()
|
||||
|
||||
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
|
||||
clientTracer := &versionNegotiationTracer{}
|
||||
clientResult, clientTracer := newVersionNegotiationTracer()
|
||||
_, err = quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
maybeAddQLOGTracer(&quic.Config{
|
||||
Versions: clientVersions,
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return clientTracer
|
||||
},
|
||||
HandshakeIdleTimeout: 100 * time.Millisecond,
|
||||
|
@ -187,7 +187,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
var nerr net.Error
|
||||
Expect(errors.As(err, &nerr)).To(BeTrue())
|
||||
Expect(nerr.Timeout()).To(BeTrue())
|
||||
Expect(clientTracer.receivedVersionNegotiation).To(BeFalse())
|
||||
Expect(clientResult.receivedVersionNegotiation).To(BeFalse())
|
||||
})
|
||||
}
|
||||
})
|
||||
|
|
|
@ -70,7 +70,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config {
|
|||
c.Tracer = qlogger
|
||||
} else if qlogger != nil {
|
||||
origTracer := c.Tracer
|
||||
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return logging.NewMultiplexedConnectionTracer(
|
||||
qlogger(ctx, p, connID),
|
||||
origTracer(ctx, p, connID),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue