use the transport tracer in integration tests

This commit is contained in:
Marten Seemann 2024-02-03 12:35:19 +07:00
parent 55c05aceed
commit 30e01b9524
9 changed files with 78 additions and 34 deletions

View file

@ -50,6 +50,7 @@ var _ = Describe("Connection ID lengths tests", func() {
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
addTracer(tr)
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
go func() {
@ -92,6 +93,7 @@ var _ = Describe("Connection ID lengths tests", func() {
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
addTracer(tr)
defer tr.Close()
cl, err := tr.Dial(
context.Background(),

View file

@ -64,6 +64,7 @@ var _ = Describe("Handshake RTT tests", func() {
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
addTracer(tr)
defer tr.Close()
ln, err := tr.Listen(serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())

View file

@ -328,7 +328,10 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
pconn, err = net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4}
dialer = &quic.Transport{
Conn: pconn,
ConnectionIDLength: 4,
}
})
AfterEach(func() {
@ -431,9 +434,8 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
tr := quic.Transport{
Conn: udpConn,
}
tr := &quic.Transport{Conn: udpConn}
addTracer(tr)
defer tr.Close()
tlsConf := &tls.Config{}
done := make(chan struct{})
@ -476,10 +478,11 @@ var _ = Describe("Handshake tests", func() {
It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
const limit = 3
tr := quic.Transport{
tr := &quic.Transport{
Conn: conn,
MaxUnvalidatedHandshakes: limit,
}
addTracer(tr)
defer tr.Close()
// Block all handshakes.
@ -541,10 +544,11 @@ var _ = Describe("Handshake tests", func() {
It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
const limit = 3
tr := quic.Transport{
tr := &quic.Transport{
Conn: conn,
MaxHandshakes: limit,
}
addTracer(tr)
defer tr.Close()
// Block all handshakes.
@ -717,6 +721,7 @@ var _ = Describe("Handshake tests", func() {
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
addTracer(tr)
defer tr.Close()
server, err := tr.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())

View file

@ -41,6 +41,7 @@ var _ = Describe("MITM test", func() {
Conn: c,
ConnectionIDLength: connIDLen,
}
addTracer(serverTransport)
if forceAddressValidation {
serverTransport.MaxUnvalidatedHandshakes = -1
}
@ -86,6 +87,7 @@ var _ = Describe("MITM test", func() {
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
addTracer(clientTransport)
})
Context("unsuccessful attacks", func() {

View file

@ -74,6 +74,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
done1 := make(chan struct{})
done2 := make(chan struct{})
@ -109,6 +110,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
done1 := make(chan struct{})
done2 := make(chan struct{})
@ -139,6 +141,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
server, err := tr.Listen(
getTLSConfig(),
getQuicConfig(nil),
@ -167,6 +170,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addTracer(tr1)
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
@ -174,6 +178,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
addTracer(tr2)
server1, err := tr1.Listen(
getTLSConfig(),
@ -220,6 +225,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addTracer(tr1)
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
@ -227,6 +233,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
addTracer(tr2)
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())

View file

@ -86,7 +86,6 @@ var (
logBuf *syncedBuffer
versionParam string
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
enableQlog bool
version quic.Version
@ -138,9 +137,6 @@ func init() {
}
var _ = BeforeSuite(func() {
if enableQlog {
qlogTracer = tools.NewQlogger(GinkgoWriter)
}
switch versionParam {
case "1":
version = quic.Version1
@ -175,28 +171,48 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
} else {
conf = conf.Clone()
}
if enableQlog {
if conf.Tracer == nil {
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.ConnectionTracer{},
)
}
} else if qlogTracer != nil {
origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID),
origTracer(ctx, p, connID),
)
}
if !enableQlog {
return conf
}
if conf.Tracer == nil {
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.ConnectionTracer{},
)
}
return conf
}
origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
origTracer(ctx, p, connID),
)
}
return conf
}
func addTracer(tr *quic.Transport) {
if !enableQlog {
return
}
if tr.Tracer == nil {
tr.Tracer = logging.NewMultiplexedTracer(
tools.QlogTracer(GinkgoWriter),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.Tracer{},
)
return
}
origTracer := tr.Tracer
tr.Tracer = logging.NewMultiplexedTracer(
tools.QlogTracer(GinkgoWriter),
origTracer,
)
}
var _ = BeforeEach(func() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)

View file

@ -175,6 +175,7 @@ var _ = Describe("0-RTT", func() {
Conn: udpConn,
ConnectionIDLength: connIDLen,
}
addTracer(tr)
defer tr.Close()
conn, err = tr.DialEarly(
context.Background(),
@ -463,6 +464,7 @@ var _ = Describe("0-RTT", func() {
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
addTracer(tr)
defer tr.Close()
ln, err := tr.ListenEarly(
tlsConf,

View file

@ -7,6 +7,7 @@ import (
"io"
"log"
"os"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
@ -14,13 +15,21 @@ import (
"github.com/quic-go/quic-go/qlog"
)
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
func QlogTracer(logger io.Writer) *logging.Tracer {
filename := fmt.Sprintf("log_%s_transport.qlog", time.Now().Format("2006-01-02T15:04:05"))
fmt.Fprintf(logger, "Creating %s.\n", filename)
f, err := os.Create(filename)
if err != nil {
log.Fatalf("failed to create qlog file: %s", err)
return nil
}
bw := bufio.NewWriter(f)
return qlog.NewTracer(utils.NewBufferedWriteCloser(bw, f))
}
func NewQlogConnectionTracer(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("log_%s_%s.qlog", connID, role)
filename := fmt.Sprintf("log_%s_%s.qlog", connID, p.String())
fmt.Fprintf(logger, "Creating %s.\n", filename)
f, err := os.Create(filename)
if err != nil {

View file

@ -65,7 +65,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config {
if !enableQlog {
return c
}
qlogger := tools.NewQlogger(GinkgoWriter)
qlogger := tools.NewQlogConnectionTracer(GinkgoWriter)
if c.Tracer == nil {
c.Tracer = qlogger
} else if qlogger != nil {