From 878e0b261a7237bda10aa6f77aa2cb282c0ff138 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 14 Apr 2021 16:45:42 +0700 Subject: [PATCH] pass a context to logging.Tracer.NewConnectionTracer This context has the same value attached to it as the context returned by Session.Context(). In the case of a dialed connection, this context is derived from the context used for dialing. --- client.go | 8 +++++-- client_test.go | 2 +- integrationtests/self/self_suite_test.go | 3 ++- integrationtests/self/tracer_test.go | 2 +- internal/mocks/logging/tracer.go | 9 ++++---- logging/interface.go | 3 ++- logging/mock_tracer_test.go | 9 ++++---- logging/multiplex.go | 5 ++-- logging/multiplex_test.go | 29 ++++++++++++++---------- qlog/qlog.go | 3 ++- qlog/qlog_test.go | 5 ++-- server.go | 9 ++++++-- server_test.go | 12 +++++----- 13 files changed, 60 insertions(+), 39 deletions(-) diff --git a/client.go b/client.go index 7331e215..267cdbc6 100644 --- a/client.go +++ b/client.go @@ -203,13 +203,17 @@ func dialContext( } c.packetHandlers = packetHandlers + c.tracingID = nextSessionTracingID() if c.config.Tracer != nil { - c.tracer = c.config.Tracer.TracerForConnection(protocol.PerspectiveClient, c.destConnID) + c.tracer = c.config.Tracer.TracerForConnection( + context.WithValue(ctx, SessionTracingKey, c.tracingID), + protocol.PerspectiveClient, + c.destConnID, + ) } if c.tracer != nil { c.tracer.StartedConnection(c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID) } - c.tracingID = nextSessionTracingID() if err := c.dial(ctx); err != nil { return nil, err } diff --git a/client_test.go b/client_test.go index 1efb7008..42031a47 100644 --- a/client_test.go +++ b/client_test.go @@ -54,7 +54,7 @@ var _ = Describe("Client", func() { originalClientSessConstructor = newClientSession tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tr := mocklogging.NewMockTracer(mockCtrl) - tr.EXPECT().TracerForConnection(protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) + tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}} Eventually(areSessionsRunning).Should(BeFalse()) // sess = NewMockQuicSession(mockCtrl) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index ea2e3384..e467e58c 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -3,6 +3,7 @@ package self_test import ( "bufio" "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -327,7 +328,7 @@ func newTracer(c func() logging.ConnectionTracer) logging.Tracer { return &tracer{createNewConnTracer: c} } -func (t *tracer) TracerForConnection(p logging.Perspective, odcid logging.ConnectionID) logging.ConnectionTracer { +func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { return t.createNewConnTracer() } func (t *tracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {} diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index 69c0b52c..431d4613 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -25,7 +25,7 @@ type customTracer struct{} var _ logging.Tracer = &customTracer{} -func (t *customTracer) TracerForConnection(p logging.Perspective, odcid logging.ConnectionID) logging.ConnectionTracer { +func (t *customTracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { return &customConnTracer{} } func (t *customTracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {} diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index 16a942b3..04c72623 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -5,6 +5,7 @@ package mocklogging import ( + context "context" net "net" reflect "reflect" @@ -62,15 +63,15 @@ func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) } // TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 protocol.Perspective, arg1 protocol.ConnectionID) logging.ConnectionTracer { +func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1) + ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) ret0, _ := ret[0].(logging.ConnectionTracer) return ret0 } // TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) } diff --git a/logging/interface.go b/logging/interface.go index 959b28b5..7a217dfa 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -3,6 +3,7 @@ package logging import ( + "context" "net" "time" @@ -95,7 +96,7 @@ type Tracer interface { // The ODCID is the original destination connection ID: // The destination connection ID that the client used on the first Initial packet it sent on this connection. // If nil is returned, tracing will be disabled for this connection. - TracerForConnection(p Perspective, odcid ConnectionID) ConnectionTracer + TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer SentPacket(net.Addr, *Header, ByteCount, []Frame) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) diff --git a/logging/mock_tracer_test.go b/logging/mock_tracer_test.go index f8dd2ad9..e970c09b 100644 --- a/logging/mock_tracer_test.go +++ b/logging/mock_tracer_test.go @@ -5,6 +5,7 @@ package logging import ( + context "context" net "net" reflect "reflect" @@ -61,15 +62,15 @@ func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) } // TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 protocol.Perspective, arg1 protocol.ConnectionID) ConnectionTracer { +func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1) + ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) ret0, _ := ret[0].(ConnectionTracer) return ret0 } // TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) } diff --git a/logging/multiplex.go b/logging/multiplex.go index fc3240fb..20072036 100644 --- a/logging/multiplex.go +++ b/logging/multiplex.go @@ -1,6 +1,7 @@ package logging import ( + "context" "net" "time" ) @@ -22,10 +23,10 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer { return &tracerMultiplexer{tracers} } -func (m *tracerMultiplexer) TracerForConnection(p Perspective, odcid ConnectionID) ConnectionTracer { +func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { var connTracers []ConnectionTracer for _, t := range m.tracers { - if ct := t.TracerForConnection(p, odcid); ct != nil { + if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { connTracers = append(connTracers, ct) } } diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 95693b65..657a6eeb 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -1,6 +1,7 @@ package logging import ( + "context" "net" "time" @@ -35,35 +36,39 @@ var _ = Describe("Tracing", func() { }) It("multiplexes the TracerForConnection call", func() { - tr1.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) - tracer.TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) + ctx := context.Background() + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) }) It("uses multiple connection tracers", func() { + ctx := context.Background() ctr1 := NewMockConnectionTracer(mockCtrl) ctr2 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2) - tr := tracer.TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) ctr1.EXPECT().LossTimerCanceled() ctr2.EXPECT().LossTimerCanceled() tr.LossTimerCanceled() }) It("handles tracers that return a nil ConnectionTracer", func() { + ctx := context.Background() ctr1 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}) - tr := tracer.TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) ctr1.EXPECT().LossTimerCanceled() tr.LossTimerCanceled() }) It("returns nil when all tracers return a nil ConnectionTracer", func() { - tr1.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3}) - Expect(tracer.TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil()) + ctx := context.Background() + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + Expect(tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil()) }) It("traces the PacketSent event", func() { diff --git a/qlog/qlog.go b/qlog/qlog.go index 5a553aa3..8b8cadf5 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -2,6 +2,7 @@ package qlog import ( "bytes" + "context" "fmt" "io" "log" @@ -59,7 +60,7 @@ func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io. return &tracer{getLogWriter: getLogWriter} } -func (t *tracer) TracerForConnection(p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { +func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { if w := t.getLogWriter(p, odcid.Bytes()); w != nil { return NewConnectionTracer(w, p, odcid) } diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index 7f40c25b..c74f6046 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -2,6 +2,7 @@ package qlog import ( "bytes" + "context" "encoding/json" "errors" "io" @@ -52,7 +53,7 @@ var _ = Describe("Tracing", func() { Context("tracer", func() { It("returns nil when there's no io.WriteCloser", func() { t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil }) - Expect(t.TracerForConnection(logging.PerspectiveClient, logging.ConnectionID{1, 2, 3, 4})).To(BeNil()) + Expect(t.TracerForConnection(context.Background(), logging.PerspectiveClient, logging.ConnectionID{1, 2, 3, 4})).To(BeNil()) }) }) @@ -83,7 +84,7 @@ var _ = Describe("Tracing", func() { BeforeEach(func() { buf = &bytes.Buffer{} t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) }) - tracer = t.TracerForConnection(logging.PerspectiveServer, logging.ConnectionID{0xde, 0xad, 0xbe, 0xef}) + tracer = t.TracerForConnection(context.Background(), logging.PerspectiveServer, logging.ConnectionID{0xde, 0xad, 0xbe, 0xef}) }) It("exports a trace that has the right metadata", func() { diff --git a/server.go b/server.go index bff996b6..c75324d2 100644 --- a/server.go +++ b/server.go @@ -451,6 +451,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } s.logger.Debugf("Changing connection ID to %s.", connID) var sess quicSession + tracingID := nextSessionTracingID() if added := s.sessionHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { var tracer logging.ConnectionTracer if s.config.Tracer != nil { @@ -459,7 +460,11 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro if origDestConnID.Len() > 0 { connID = origDestConnID } - tracer = s.config.Tracer.TracerForConnection(protocol.PerspectiveServer, connID) + tracer = s.config.Tracer.TracerForConnection( + context.WithValue(context.Background(), SessionTracingKey, tracingID), + protocol.PerspectiveServer, + connID, + ) } sess = s.newSession( newSendConn(s.conn, p.remoteAddr, p.info), @@ -475,7 +480,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro s.tokenGenerator, s.acceptEarlySessions, tracer, - nextSessionTracingID(), + tracingID, s.logger, hdr.Version, ) diff --git a/server_test.go b/server_test.go index 8343c68f..2fd21a85 100644 --- a/server_test.go +++ b/server_test.go @@ -322,7 +322,7 @@ var _ = Describe("Server", func() { fn() return true }) - tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) sess := NewMockQuicSession(mockCtrl) serv.newSession = func( _ sendConn, @@ -579,7 +579,7 @@ var _ = Describe("Server", func() { fn() return true }) - tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) sess := NewMockQuicSession(mockCtrl) serv.newSession = func( @@ -637,7 +637,7 @@ var _ = Describe("Server", func() { fn() return true }).AnyTimes() - tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()).AnyTimes() + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } acceptSession := make(chan struct{}) @@ -760,7 +760,7 @@ var _ = Describe("Server", func() { fn() return true }).Times(protocol.MaxAcceptQueueSize) - tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) var wg sync.WaitGroup wg.Add(protocol.MaxAcceptQueueSize) @@ -832,7 +832,7 @@ var _ = Describe("Server", func() { fn() return true }) - tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handlePacket(p) // make sure there are no Write calls on the packet conn @@ -940,7 +940,7 @@ var _ = Describe("Server", func() { fn() return true }) - tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}},