use the transport tracer in integration tests

This commit is contained in:
Marten Seemann 2024-02-03 12:35:19 +07:00
parent 55c05aceed
commit 30e01b9524
9 changed files with 78 additions and 34 deletions

View file

@ -50,6 +50,7 @@ var _ = Describe("Connection ID lengths tests", func() {
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator, ConnectionIDGenerator: connIDGenerator,
} }
addTracer(tr)
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go func() { go func() {
@ -92,6 +93,7 @@ var _ = Describe("Connection ID lengths tests", func() {
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator, ConnectionIDGenerator: connIDGenerator,
} }
addTracer(tr)
defer tr.Close() defer tr.Close()
cl, err := tr.Dial( cl, err := tr.Dial(
context.Background(), context.Background(),

View file

@ -64,6 +64,7 @@ var _ = Describe("Handshake RTT tests", func() {
Conn: udpConn, Conn: udpConn,
MaxUnvalidatedHandshakes: -1, MaxUnvalidatedHandshakes: -1,
} }
addTracer(tr)
defer tr.Close() defer tr.Close()
ln, err := tr.Listen(serverTLSConfig, serverConfig) ln, err := tr.Listen(serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -328,7 +328,10 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
pconn, err = net.ListenUDP("udp", laddr) pconn, err = net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4} dialer = &quic.Transport{
Conn: pconn,
ConnectionIDLength: 4,
}
}) })
AfterEach(func() { AfterEach(func() {
@ -431,9 +434,8 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr) udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
tr := quic.Transport{ tr := &quic.Transport{Conn: udpConn}
Conn: udpConn, addTracer(tr)
}
defer tr.Close() defer tr.Close()
tlsConf := &tls.Config{} tlsConf := &tls.Config{}
done := make(chan struct{}) done := make(chan struct{})
@ -476,10 +478,11 @@ var _ = Describe("Handshake tests", func() {
It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() { It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
const limit = 3 const limit = 3
tr := quic.Transport{ tr := &quic.Transport{
Conn: conn, Conn: conn,
MaxUnvalidatedHandshakes: limit, MaxUnvalidatedHandshakes: limit,
} }
addTracer(tr)
defer tr.Close() defer tr.Close()
// Block all handshakes. // Block all handshakes.
@ -541,10 +544,11 @@ var _ = Describe("Handshake tests", func() {
It("rejects connections when the number of handshakes reaches MaxHandshakes", func() { It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
const limit = 3 const limit = 3
tr := quic.Transport{ tr := &quic.Transport{
Conn: conn, Conn: conn,
MaxHandshakes: limit, MaxHandshakes: limit,
} }
addTracer(tr)
defer tr.Close() defer tr.Close()
// Block all handshakes. // Block all handshakes.
@ -717,6 +721,7 @@ var _ = Describe("Handshake tests", func() {
Conn: udpConn, Conn: udpConn,
MaxUnvalidatedHandshakes: -1, MaxUnvalidatedHandshakes: -1,
} }
addTracer(tr)
defer tr.Close() defer tr.Close()
server, err := tr.Listen(getTLSConfig(), serverConfig) server, err := tr.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -41,6 +41,7 @@ var _ = Describe("MITM test", func() {
Conn: c, Conn: c,
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
} }
addTracer(serverTransport)
if forceAddressValidation { if forceAddressValidation {
serverTransport.MaxUnvalidatedHandshakes = -1 serverTransport.MaxUnvalidatedHandshakes = -1
} }
@ -86,6 +87,7 @@ var _ = Describe("MITM test", func() {
Conn: clientUDPConn, Conn: clientUDPConn,
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
} }
addTracer(clientTransport)
}) })
Context("unsuccessful attacks", func() { Context("unsuccessful attacks", func() {

View file

@ -74,6 +74,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn.Close() defer conn.Close()
tr := &quic.Transport{Conn: conn} tr := &quic.Transport{Conn: conn}
addTracer(tr)
done1 := make(chan struct{}) done1 := make(chan struct{})
done2 := make(chan struct{}) done2 := make(chan struct{})
@ -109,6 +110,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn.Close() defer conn.Close()
tr := &quic.Transport{Conn: conn} tr := &quic.Transport{Conn: conn}
addTracer(tr)
done1 := make(chan struct{}) done1 := make(chan struct{})
done2 := make(chan struct{}) done2 := make(chan struct{})
@ -139,6 +141,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn.Close() defer conn.Close()
tr := &quic.Transport{Conn: conn} tr := &quic.Transport{Conn: conn}
addTracer(tr)
server, err := tr.Listen( server, err := tr.Listen(
getTLSConfig(), getTLSConfig(),
getQuicConfig(nil), getQuicConfig(nil),
@ -167,6 +170,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn1.Close() defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1} tr1 := &quic.Transport{Conn: conn1}
addTracer(tr1)
addr2, err := net.ResolveUDPAddr("udp", "localhost:0") addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -174,6 +178,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn2.Close() defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2} tr2 := &quic.Transport{Conn: conn2}
addTracer(tr2)
server1, err := tr1.Listen( server1, err := tr1.Listen(
getTLSConfig(), getTLSConfig(),
@ -220,6 +225,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn1.Close() defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1} tr1 := &quic.Transport{Conn: conn1}
addTracer(tr1)
addr2, err := net.ResolveUDPAddr("udp", "localhost:0") addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -227,6 +233,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn2.Close() defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2} tr2 := &quic.Transport{Conn: conn2}
addTracer(tr2)
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -86,7 +86,6 @@ var (
logBuf *syncedBuffer logBuf *syncedBuffer
versionParam string versionParam string
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
enableQlog bool enableQlog bool
version quic.Version version quic.Version
@ -138,9 +137,6 @@ func init() {
} }
var _ = BeforeSuite(func() { var _ = BeforeSuite(func() {
if enableQlog {
qlogTracer = tools.NewQlogger(GinkgoWriter)
}
switch versionParam { switch versionParam {
case "1": case "1":
version = quic.Version1 version = quic.Version1
@ -175,28 +171,48 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
} else { } else {
conf = conf.Clone() conf = conf.Clone()
} }
if enableQlog { if !enableQlog {
return conf
}
if conf.Tracer == nil { if conf.Tracer == nil {
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer( return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID), tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.ConnectionTracer{}, &logging.ConnectionTracer{},
) )
} }
} else if qlogTracer != nil { return conf
}
origTracer := conf.Tracer origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer( return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID), tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
origTracer(ctx, p, connID), origTracer(ctx, p, connID),
) )
} }
}
}
return conf return conf
} }
func addTracer(tr *quic.Transport) {
if !enableQlog {
return
}
if tr.Tracer == nil {
tr.Tracer = logging.NewMultiplexedTracer(
tools.QlogTracer(GinkgoWriter),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.Tracer{},
)
return
}
origTracer := tr.Tracer
tr.Tracer = logging.NewMultiplexedTracer(
tools.QlogTracer(GinkgoWriter),
origTracer,
)
}
var _ = BeforeEach(func() { var _ = BeforeEach(func() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds) log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)

View file

@ -175,6 +175,7 @@ var _ = Describe("0-RTT", func() {
Conn: udpConn, Conn: udpConn,
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
} }
addTracer(tr)
defer tr.Close() defer tr.Close()
conn, err = tr.DialEarly( conn, err = tr.DialEarly(
context.Background(), context.Background(),
@ -463,6 +464,7 @@ var _ = Describe("0-RTT", func() {
Conn: udpConn, Conn: udpConn,
MaxUnvalidatedHandshakes: -1, MaxUnvalidatedHandshakes: -1,
} }
addTracer(tr)
defer tr.Close() defer tr.Close()
ln, err := tr.ListenEarly( ln, err := tr.ListenEarly(
tlsConf, tlsConf,

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"log" "log"
"os" "os"
"time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
@ -14,13 +15,21 @@ import (
"github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlog"
) )
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { func QlogTracer(logger io.Writer) *logging.Tracer {
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { filename := fmt.Sprintf("log_%s_transport.qlog", time.Now().Format("2006-01-02T15:04:05"))
role := "server" fmt.Fprintf(logger, "Creating %s.\n", filename)
if p == logging.PerspectiveClient { f, err := os.Create(filename)
role = "client" if err != nil {
log.Fatalf("failed to create qlog file: %s", err)
return nil
} }
filename := fmt.Sprintf("log_%s_%s.qlog", connID, role) bw := bufio.NewWriter(f)
return qlog.NewTracer(utils.NewBufferedWriteCloser(bw, f))
}
func NewQlogConnectionTracer(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
filename := fmt.Sprintf("log_%s_%s.qlog", connID, p.String())
fmt.Fprintf(logger, "Creating %s.\n", filename) fmt.Fprintf(logger, "Creating %s.\n", filename)
f, err := os.Create(filename) f, err := os.Create(filename)
if err != nil { if err != nil {

View file

@ -65,7 +65,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config {
if !enableQlog { if !enableQlog {
return c return c
} }
qlogger := tools.NewQlogger(GinkgoWriter) qlogger := tools.NewQlogConnectionTracer(GinkgoWriter)
if c.Tracer == nil { if c.Tracer == nil {
c.Tracer = qlogger c.Tracer = qlogger
} else if qlogger != nil { } else if qlogger != nil {