pass a context to logging.Tracer.NewConnectionTracer

This context has the same value attached to it as the context returned
by Session.Context().
In the case of a dialed connection, this context is derived from the
context used for dialing.
This commit is contained in:
Marten Seemann 2021-04-14 16:45:42 +07:00
parent 4917760726
commit 878e0b261a
13 changed files with 60 additions and 39 deletions

View file

@ -203,13 +203,17 @@ func dialContext(
} }
c.packetHandlers = packetHandlers c.packetHandlers = packetHandlers
c.tracingID = nextSessionTracingID()
if c.config.Tracer != nil { if c.config.Tracer != nil {
c.tracer = c.config.Tracer.TracerForConnection(protocol.PerspectiveClient, c.destConnID) c.tracer = c.config.Tracer.TracerForConnection(
context.WithValue(ctx, SessionTracingKey, c.tracingID),
protocol.PerspectiveClient,
c.destConnID,
)
} }
if c.tracer != nil { if c.tracer != nil {
c.tracer.StartedConnection(c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID) c.tracer.StartedConnection(c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID)
} }
c.tracingID = nextSessionTracingID()
if err := c.dial(ctx); err != nil { if err := c.dial(ctx); err != nil {
return nil, err return nil, err
} }

View file

@ -54,7 +54,7 @@ var _ = Describe("Client", func() {
originalClientSessConstructor = newClientSession originalClientSessConstructor = newClientSession
tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
tr := mocklogging.NewMockTracer(mockCtrl) tr := mocklogging.NewMockTracer(mockCtrl)
tr.EXPECT().TracerForConnection(protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}} config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}}
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
// sess = NewMockQuicSession(mockCtrl) // sess = NewMockQuicSession(mockCtrl)

View file

@ -3,6 +3,7 @@ package self_test
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/tls" "crypto/tls"
@ -327,7 +328,7 @@ func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
return &tracer{createNewConnTracer: c} return &tracer{createNewConnTracer: c}
} }
func (t *tracer) TracerForConnection(p logging.Perspective, odcid logging.ConnectionID) logging.ConnectionTracer { func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
return t.createNewConnTracer() return t.createNewConnTracer()
} }
func (t *tracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {} func (t *tracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {}

View file

@ -25,7 +25,7 @@ type customTracer struct{}
var _ logging.Tracer = &customTracer{} var _ logging.Tracer = &customTracer{}
func (t *customTracer) TracerForConnection(p logging.Perspective, odcid logging.ConnectionID) logging.ConnectionTracer { func (t *customTracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
return &customConnTracer{} return &customConnTracer{}
} }
func (t *customTracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {} func (t *customTracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {}

View file

@ -5,6 +5,7 @@
package mocklogging package mocklogging
import ( import (
context "context"
net "net" net "net"
reflect "reflect" reflect "reflect"
@ -62,15 +63,15 @@ func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{})
} }
// TracerForConnection mocks base method. // TracerForConnection mocks base method.
func (m *MockTracer) TracerForConnection(arg0 protocol.Perspective, arg1 protocol.ConnectionID) logging.ConnectionTracer { func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1) ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
ret0, _ := ret[0].(logging.ConnectionTracer) ret0, _ := ret[0].(logging.ConnectionTracer)
return ret0 return ret0
} }
// TracerForConnection indicates an expected call of TracerForConnection. // TracerForConnection indicates an expected call of TracerForConnection.
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1 interface{}) *gomock.Call { func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
} }

View file

@ -3,6 +3,7 @@
package logging package logging
import ( import (
"context"
"net" "net"
"time" "time"
@ -95,7 +96,7 @@ type Tracer interface {
// The ODCID is the original destination connection ID: // The ODCID is the original destination connection ID:
// The destination connection ID that the client used on the first Initial packet it sent on this connection. // The destination connection ID that the client used on the first Initial packet it sent on this connection.
// If nil is returned, tracing will be disabled for this connection. // If nil is returned, tracing will be disabled for this connection.
TracerForConnection(p Perspective, odcid ConnectionID) ConnectionTracer TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer
SentPacket(net.Addr, *Header, ByteCount, []Frame) SentPacket(net.Addr, *Header, ByteCount, []Frame)
DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason)

View file

@ -5,6 +5,7 @@
package logging package logging
import ( import (
context "context"
net "net" net "net"
reflect "reflect" reflect "reflect"
@ -61,15 +62,15 @@ func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{})
} }
// TracerForConnection mocks base method. // TracerForConnection mocks base method.
func (m *MockTracer) TracerForConnection(arg0 protocol.Perspective, arg1 protocol.ConnectionID) ConnectionTracer { func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1) ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
ret0, _ := ret[0].(ConnectionTracer) ret0, _ := ret[0].(ConnectionTracer)
return ret0 return ret0
} }
// TracerForConnection indicates an expected call of TracerForConnection. // TracerForConnection indicates an expected call of TracerForConnection.
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1 interface{}) *gomock.Call { func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
} }

View file

@ -1,6 +1,7 @@
package logging package logging
import ( import (
"context"
"net" "net"
"time" "time"
) )
@ -22,10 +23,10 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer {
return &tracerMultiplexer{tracers} return &tracerMultiplexer{tracers}
} }
func (m *tracerMultiplexer) TracerForConnection(p Perspective, odcid ConnectionID) ConnectionTracer { func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer {
var connTracers []ConnectionTracer var connTracers []ConnectionTracer
for _, t := range m.tracers { for _, t := range m.tracers {
if ct := t.TracerForConnection(p, odcid); ct != nil { if ct := t.TracerForConnection(ctx, p, odcid); ct != nil {
connTracers = append(connTracers, ct) connTracers = append(connTracers, ct)
} }
} }

View file

@ -1,6 +1,7 @@
package logging package logging
import ( import (
"context"
"net" "net"
"time" "time"
@ -35,35 +36,39 @@ var _ = Describe("Tracing", func() {
}) })
It("multiplexes the TracerForConnection call", func() { It("multiplexes the TracerForConnection call", func() {
tr1.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) ctx := context.Background()
tr2.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
tracer.TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
}) })
It("uses multiple connection tracers", func() { It("uses multiple connection tracers", func() {
ctx := context.Background()
ctr1 := NewMockConnectionTracer(mockCtrl) ctr1 := NewMockConnectionTracer(mockCtrl)
ctr2 := NewMockConnectionTracer(mockCtrl) ctr2 := NewMockConnectionTracer(mockCtrl)
tr1.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1)
tr2.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2) tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2)
tr := tracer.TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}) tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3})
ctr1.EXPECT().LossTimerCanceled() ctr1.EXPECT().LossTimerCanceled()
ctr2.EXPECT().LossTimerCanceled() ctr2.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled() tr.LossTimerCanceled()
}) })
It("handles tracers that return a nil ConnectionTracer", func() { It("handles tracers that return a nil ConnectionTracer", func() {
ctx := context.Background()
ctr1 := NewMockConnectionTracer(mockCtrl) ctr1 := NewMockConnectionTracer(mockCtrl)
tr1.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1)
tr2.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}) tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3})
tr := tracer.TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}) tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3})
ctr1.EXPECT().LossTimerCanceled() ctr1.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled() tr.LossTimerCanceled()
}) })
It("returns nil when all tracers return a nil ConnectionTracer", func() { It("returns nil when all tracers return a nil ConnectionTracer", func() {
tr1.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) ctx := context.Background()
tr2.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
Expect(tracer.TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil()) tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
Expect(tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil())
}) })
It("traces the PacketSent event", func() { It("traces the PacketSent event", func() {

View file

@ -2,6 +2,7 @@ package qlog
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -59,7 +60,7 @@ func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.
return &tracer{getLogWriter: getLogWriter} return &tracer{getLogWriter: getLogWriter}
} }
func (t *tracer) TracerForConnection(p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer {
if w := t.getLogWriter(p, odcid.Bytes()); w != nil { if w := t.getLogWriter(p, odcid.Bytes()); w != nil {
return NewConnectionTracer(w, p, odcid) return NewConnectionTracer(w, p, odcid)
} }

View file

@ -2,6 +2,7 @@ package qlog
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
@ -52,7 +53,7 @@ var _ = Describe("Tracing", func() {
Context("tracer", func() { Context("tracer", func() {
It("returns nil when there's no io.WriteCloser", func() { It("returns nil when there's no io.WriteCloser", func() {
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil }) t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil })
Expect(t.TracerForConnection(logging.PerspectiveClient, logging.ConnectionID{1, 2, 3, 4})).To(BeNil()) Expect(t.TracerForConnection(context.Background(), logging.PerspectiveClient, logging.ConnectionID{1, 2, 3, 4})).To(BeNil())
}) })
}) })
@ -83,7 +84,7 @@ var _ = Describe("Tracing", func() {
BeforeEach(func() { BeforeEach(func() {
buf = &bytes.Buffer{} buf = &bytes.Buffer{}
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) }) t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) })
tracer = t.TracerForConnection(logging.PerspectiveServer, logging.ConnectionID{0xde, 0xad, 0xbe, 0xef}) tracer = t.TracerForConnection(context.Background(), logging.PerspectiveServer, logging.ConnectionID{0xde, 0xad, 0xbe, 0xef})
}) })
It("exports a trace that has the right metadata", func() { It("exports a trace that has the right metadata", func() {

View file

@ -451,6 +451,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
} }
s.logger.Debugf("Changing connection ID to %s.", connID) s.logger.Debugf("Changing connection ID to %s.", connID)
var sess quicSession var sess quicSession
tracingID := nextSessionTracingID()
if added := s.sessionHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { if added := s.sessionHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
var tracer logging.ConnectionTracer var tracer logging.ConnectionTracer
if s.config.Tracer != nil { if s.config.Tracer != nil {
@ -459,7 +460,11 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
if origDestConnID.Len() > 0 { if origDestConnID.Len() > 0 {
connID = origDestConnID connID = origDestConnID
} }
tracer = s.config.Tracer.TracerForConnection(protocol.PerspectiveServer, connID) tracer = s.config.Tracer.TracerForConnection(
context.WithValue(context.Background(), SessionTracingKey, tracingID),
protocol.PerspectiveServer,
connID,
)
} }
sess = s.newSession( sess = s.newSession(
newSendConn(s.conn, p.remoteAddr, p.info), newSendConn(s.conn, p.remoteAddr, p.info),
@ -475,7 +480,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
s.tokenGenerator, s.tokenGenerator,
s.acceptEarlySessions, s.acceptEarlySessions,
tracer, tracer,
nextSessionTracingID(), tracingID,
s.logger, s.logger,
hdr.Version, hdr.Version,
) )

View file

@ -322,7 +322,7 @@ var _ = Describe("Server", func() {
fn() fn()
return true return true
}) })
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
serv.newSession = func( serv.newSession = func(
_ sendConn, _ sendConn,
@ -579,7 +579,7 @@ var _ = Describe("Server", func() {
fn() fn()
return true return true
}) })
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
serv.newSession = func( serv.newSession = func(
@ -637,7 +637,7 @@ var _ = Describe("Server", func() {
fn() fn()
return true return true
}).AnyTimes() }).AnyTimes()
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()).AnyTimes() tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
acceptSession := make(chan struct{}) acceptSession := make(chan struct{})
@ -760,7 +760,7 @@ var _ = Describe("Server", func() {
fn() fn()
return true return true
}).Times(protocol.MaxAcceptQueueSize) }).Times(protocol.MaxAcceptQueueSize)
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize) wg.Add(protocol.MaxAcceptQueueSize)
@ -832,7 +832,7 @@ var _ = Describe("Server", func() {
fn() fn()
return true return true
}) })
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
serv.handlePacket(p) serv.handlePacket(p)
// make sure there are no Write calls on the packet conn // make sure there are no Write calls on the packet conn
@ -940,7 +940,7 @@ var _ = Describe("Server", func() {
fn() fn()
return true return true
}) })
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
serv.handleInitialImpl( serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()}, &receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}},