From 30e01b9524a4a9fbbb285d2280a31ac4c427d0bd Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 3 Feb 2024 12:35:19 +0700 Subject: [PATCH] use the transport tracer in integration tests --- integrationtests/self/conn_id_test.go | 2 + integrationtests/self/handshake_rtt_test.go | 1 + integrationtests/self/handshake_test.go | 17 ++++-- integrationtests/self/mitm_test.go | 2 + integrationtests/self/multiplex_test.go | 7 +++ integrationtests/self/self_suite_test.go | 58 ++++++++++++------- integrationtests/self/zero_rtt_test.go | 2 + integrationtests/tools/qlog.go | 21 +++++-- .../versionnegotiation_suite_test.go | 2 +- 9 files changed, 78 insertions(+), 34 deletions(-) diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index 0d8c4b44..24047bce 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -50,6 +50,7 @@ var _ = Describe("Connection ID lengths tests", func() { ConnectionIDLength: connIDLen, ConnectionIDGenerator: connIDGenerator, } + addTracer(tr) ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) go func() { @@ -92,6 +93,7 @@ var _ = Describe("Connection ID lengths tests", func() { ConnectionIDLength: connIDLen, ConnectionIDGenerator: connIDGenerator, } + addTracer(tr) defer tr.Close() cl, err := tr.Dial( context.Background(), diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 40e541ab..3e2a210f 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -64,6 +64,7 @@ var _ = Describe("Handshake RTT tests", func() { Conn: udpConn, MaxUnvalidatedHandshakes: -1, } + addTracer(tr) defer tr.Close() ln, err := tr.Listen(serverTLSConfig, serverConfig) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 5d7f5868..7ffdd57d 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -328,7 +328,10 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) pconn, err = net.ListenUDP("udp", laddr) Expect(err).ToNot(HaveOccurred()) - dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4} + dialer = &quic.Transport{ + Conn: pconn, + ConnectionIDLength: 4, + } }) AfterEach(func() { @@ -431,9 +434,8 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) udpConn, err := net.ListenUDP("udp", laddr) Expect(err).ToNot(HaveOccurred()) - tr := quic.Transport{ - Conn: udpConn, - } + tr := &quic.Transport{Conn: udpConn} + addTracer(tr) defer tr.Close() tlsConf := &tls.Config{} done := make(chan struct{}) @@ -476,10 +478,11 @@ var _ = Describe("Handshake tests", func() { It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() { const limit = 3 - tr := quic.Transport{ + tr := &quic.Transport{ Conn: conn, MaxUnvalidatedHandshakes: limit, } + addTracer(tr) defer tr.Close() // Block all handshakes. @@ -541,10 +544,11 @@ var _ = Describe("Handshake tests", func() { It("rejects connections when the number of handshakes reaches MaxHandshakes", func() { const limit = 3 - tr := quic.Transport{ + tr := &quic.Transport{ Conn: conn, MaxHandshakes: limit, } + addTracer(tr) defer tr.Close() // Block all handshakes. @@ -717,6 +721,7 @@ var _ = Describe("Handshake tests", func() { Conn: udpConn, MaxUnvalidatedHandshakes: -1, } + addTracer(tr) defer tr.Close() server, err := tr.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index b0b247ba..be35d7da 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -41,6 +41,7 @@ var _ = Describe("MITM test", func() { Conn: c, ConnectionIDLength: connIDLen, } + addTracer(serverTransport) if forceAddressValidation { serverTransport.MaxUnvalidatedHandshakes = -1 } @@ -86,6 +87,7 @@ var _ = Describe("MITM test", func() { Conn: clientUDPConn, ConnectionIDLength: connIDLen, } + addTracer(clientTransport) }) Context("unsuccessful attacks", func() { diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index dace9a8c..7b202c71 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -74,6 +74,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn.Close() tr := &quic.Transport{Conn: conn} + addTracer(tr) done1 := make(chan struct{}) done2 := make(chan struct{}) @@ -109,6 +110,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn.Close() tr := &quic.Transport{Conn: conn} + addTracer(tr) done1 := make(chan struct{}) done2 := make(chan struct{}) @@ -139,6 +141,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn.Close() tr := &quic.Transport{Conn: conn} + addTracer(tr) server, err := tr.Listen( getTLSConfig(), getQuicConfig(nil), @@ -167,6 +170,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn1.Close() tr1 := &quic.Transport{Conn: conn1} + addTracer(tr1) addr2, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) @@ -174,6 +178,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn2.Close() tr2 := &quic.Transport{Conn: conn2} + addTracer(tr2) server1, err := tr1.Listen( getTLSConfig(), @@ -220,6 +225,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn1.Close() tr1 := &quic.Transport{Conn: conn1} + addTracer(tr1) addr2, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) @@ -227,6 +233,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn2.Close() tr2 := &quic.Transport{Conn: conn2} + addTracer(tr2) server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 276adf32..42c8d944 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -86,7 +86,6 @@ var ( logBuf *syncedBuffer versionParam string - qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer enableQlog bool version quic.Version @@ -138,9 +137,6 @@ func init() { } var _ = BeforeSuite(func() { - if enableQlog { - qlogTracer = tools.NewQlogger(GinkgoWriter) - } switch versionParam { case "1": version = quic.Version1 @@ -175,28 +171,48 @@ func getQuicConfig(conf *quic.Config) *quic.Config { } else { conf = conf.Clone() } - if enableQlog { - if conf.Tracer == nil { - conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { - return logging.NewMultiplexedConnectionTracer( - qlogTracer(ctx, p, connID), - // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere - &logging.ConnectionTracer{}, - ) - } - } else if qlogTracer != nil { - origTracer := conf.Tracer - conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { - return logging.NewMultiplexedConnectionTracer( - qlogTracer(ctx, p, connID), - origTracer(ctx, p, connID), - ) - } + if !enableQlog { + return conf + } + if conf.Tracer == nil { + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID), + // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere + &logging.ConnectionTracer{}, + ) } + return conf + } + origTracer := conf.Tracer + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID), + origTracer(ctx, p, connID), + ) } return conf } +func addTracer(tr *quic.Transport) { + if !enableQlog { + return + } + if tr.Tracer == nil { + tr.Tracer = logging.NewMultiplexedTracer( + tools.QlogTracer(GinkgoWriter), + // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere + &logging.Tracer{}, + ) + return + } + origTracer := tr.Tracer + tr.Tracer = logging.NewMultiplexedTracer( + tools.QlogTracer(GinkgoWriter), + origTracer, + ) +} + var _ = BeforeEach(func() { log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 55307679..d25b0482 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -175,6 +175,7 @@ var _ = Describe("0-RTT", func() { Conn: udpConn, ConnectionIDLength: connIDLen, } + addTracer(tr) defer tr.Close() conn, err = tr.DialEarly( context.Background(), @@ -463,6 +464,7 @@ var _ = Describe("0-RTT", func() { Conn: udpConn, MaxUnvalidatedHandshakes: -1, } + addTracer(tr) defer tr.Close() ln, err := tr.ListenEarly( tlsConf, diff --git a/integrationtests/tools/qlog.go b/integrationtests/tools/qlog.go index 049432cc..ae19ef4b 100644 --- a/integrationtests/tools/qlog.go +++ b/integrationtests/tools/qlog.go @@ -7,6 +7,7 @@ import ( "io" "log" "os" + "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/utils" @@ -14,13 +15,21 @@ import ( "github.com/quic-go/quic-go/qlog" ) -func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { +func QlogTracer(logger io.Writer) *logging.Tracer { + filename := fmt.Sprintf("log_%s_transport.qlog", time.Now().Format("2006-01-02T15:04:05")) + fmt.Fprintf(logger, "Creating %s.\n", filename) + f, err := os.Create(filename) + if err != nil { + log.Fatalf("failed to create qlog file: %s", err) + return nil + } + bw := bufio.NewWriter(f) + return qlog.NewTracer(utils.NewBufferedWriteCloser(bw, f)) +} + +func NewQlogConnectionTracer(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { - role := "server" - if p == logging.PerspectiveClient { - role = "client" - } - filename := fmt.Sprintf("log_%s_%s.qlog", connID, role) + filename := fmt.Sprintf("log_%s_%s.qlog", connID, p.String()) fmt.Fprintf(logger, "Creating %s.\n", filename) f, err := os.Create(filename) if err != nil { diff --git a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go index 150181f2..2cb3f865 100644 --- a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go +++ b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go @@ -65,7 +65,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config { if !enableQlog { return c } - qlogger := tools.NewQlogger(GinkgoWriter) + qlogger := tools.NewQlogConnectionTracer(GinkgoWriter) if c.Tracer == nil { c.Tracer = qlogger } else if qlogger != nil {