From e7c4e756ad2d5db63f747041e5c100fc574c2fd4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 19 Apr 2021 11:16:30 +0700 Subject: [PATCH] trace and qlog version selection / negotiation --- integrationtests/self/handshake_test.go | 56 +++++++++++++++++++-- integrationtests/self/self_suite_test.go | 3 ++ integrationtests/self/tracer_test.go | 3 ++ internal/mocks/logging/connection_tracer.go | 12 +++++ logging/interface.go | 1 + logging/mock_connection_tracer_test.go | 12 +++++ logging/multiplex.go | 6 +++ qlog/event.go | 19 +++++++ qlog/qlog.go | 23 +++++++++ qlog/qlog_test.go | 24 +++++++++ session.go | 13 +++++ session_test.go | 2 + 12 files changed, 170 insertions(+), 4 deletions(-) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index ace1504e..cac5f12e 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -8,10 +8,11 @@ import ( "net" "time" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/integrationtests/tools/israce" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/logging" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -46,6 +47,29 @@ func (c *tokenStore) Pop(key string) *quic.ClientToken { return c.store.Pop(key) } +type versionNegotiationTracer struct { + connTracer + + loggedVersions bool + receivedVersionNegotiation bool + chosen logging.VersionNumber + clientVersions, serverVersions []logging.VersionNumber +} + +func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + if t.loggedVersions { + Fail("only expected one call to NegotiatedVersions") + } + t.loggedVersions = true + t.chosen = chosen + t.clientVersions = clientVersions + t.serverVersions = serverVersions +} + +func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(*logging.Header, []logging.VersionNumber) { + t.receivedVersionNegotiation = true +} + var _ = Describe("Handshake tests", func() { var ( server quic.Listener @@ -97,37 +121,61 @@ var _ = Describe("Handshake tests", func() { }) It("when the server supports more versions than the client", func() { + expectedVersion := protocol.SupportedVersions[0] // the server doesn't support the highest supported version, which is the first one the client will try // but it supports a bunch of versions that the client doesn't speak serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} + serverTracer := &versionNegotiationTracer{} + serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) runServer(getTLSConfig()) defer server.Close() + clientTracer := &versionNegotiationTracer{} sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - nil, + getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}), ) Expect(err).ToNot(HaveOccurred()) - Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) + Expect(sess.(versioner).GetVersion()).To(Equal(expectedVersion)) Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(clientTracer.chosen).To(Equal(expectedVersion)) + Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) + Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions)) + Expect(clientTracer.serverVersions).To(BeEmpty()) + Expect(serverTracer.chosen).To(Equal(expectedVersion)) + Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverTracer.clientVersions).To(BeEmpty()) }) It("when the client supports more versions than the server supports", func() { + expectedVersion := protocol.SupportedVersions[0] // the server doesn't support the highest supported version, which is the first one the client will try // but it supports a bunch of versions that the client doesn't speak serverConfig.Versions = supportedVersions + serverTracer := &versionNegotiationTracer{} + serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) runServer(getTLSConfig()) defer server.Close() + clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} + clientTracer := &versionNegotiationTracer{} sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}, + Versions: clientVersions, + Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }), }), ) Expect(err).ToNot(HaveOccurred()) Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(clientTracer.chosen).To(Equal(expectedVersion)) + Expect(clientTracer.receivedVersionNegotiation).To(BeTrue()) + Expect(clientTracer.clientVersions).To(Equal(clientVersions)) + Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions + Expect(serverTracer.chosen).To(Equal(expectedVersion)) + Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverTracer.clientVersions).To(BeEmpty()) }) }) } diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index ea2e3384..1f1870e4 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -340,6 +340,9 @@ var _ logging.ConnectionTracer = &connTracer{} func (t *connTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) { } + +func (t *connTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { +} func (t *connTracer) ClosedConnection(logging.CloseReason) {} func (t *connTracer) SentTransportParameters(*logging.TransportParameters) {} func (t *connTracer) ReceivedTransportParameters(*logging.TransportParameters) {} diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index 69c0b52c..8aaad395 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -38,6 +38,9 @@ var _ logging.ConnectionTracer = &customConnTracer{} func (t *customConnTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) { } + +func (t *customConnTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { +} func (t *customConnTracer) ClosedConnection(logging.CloseReason) {} func (t *customConnTracer) SentTransportParameters(*logging.TransportParameters) {} func (t *customConnTracer) ReceivedTransportParameters(*logging.TransportParameters) {} diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index 84eef243..372e3f2a 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -171,6 +171,18 @@ func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) } +// NegotiatedVersion mocks base method. +func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) +} + +// NegotiatedVersion indicates an expected call of NegotiatedVersion. +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) +} + // ReceivedPacket mocks base method. func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { m.ctrl.T.Helper() diff --git a/logging/interface.go b/logging/interface.go index 959b28b5..2edd2223 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -104,6 +104,7 @@ type Tracer interface { // A ConnectionTracer records events. type ConnectionTracer interface { StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) + NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) ClosedConnection(CloseReason) SentTransportParameters(*TransportParameters) ReceivedTransportParameters(*TransportParameters) diff --git a/logging/mock_connection_tracer_test.go b/logging/mock_connection_tracer_test.go index f939b356..56439e66 100644 --- a/logging/mock_connection_tracer_test.go +++ b/logging/mock_connection_tracer_test.go @@ -170,6 +170,18 @@ func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) } +// NegotiatedVersion mocks base method. +func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) +} + +// NegotiatedVersion indicates an expected call of NegotiatedVersion. +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) +} + // ReceivedPacket mocks base method. func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []Frame) { m.ctrl.T.Helper() diff --git a/logging/multiplex.go b/logging/multiplex.go index fc3240fb..2bac6b60 100644 --- a/logging/multiplex.go +++ b/logging/multiplex.go @@ -67,6 +67,12 @@ func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcCon } } +func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { + for _, t := range m.tracers { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } +} + func (m *connTracerMultiplexer) ClosedConnection(reason CloseReason) { for _, t := range m.tracers { t.ClosedConnection(reason) diff --git a/qlog/event.go b/qlog/event.go index aa00a10a..24361d09 100644 --- a/qlog/event.go +++ b/qlog/event.go @@ -83,6 +83,25 @@ func (e eventConnectionStarted) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("dst_cid", connectionID(e.DestConnectionID).String()) } +type eventVersionNegotiated struct { + clientVersions, serverVersions []versionNumber + chosenVersion versionNumber +} + +func (e eventVersionNegotiated) Category() category { return categoryTransport } +func (e eventVersionNegotiated) Name() string { return "version_information" } +func (e eventVersionNegotiated) IsNil() bool { return false } + +func (e eventVersionNegotiated) MarshalJSONObject(enc *gojay.Encoder) { + if len(e.clientVersions) > 0 { + enc.ArrayKey("client_versions", versions(e.clientVersions)) + } + if len(e.serverVersions) > 0 { + enc.ArrayKey("server_versions", versions(e.serverVersions)) + } + enc.StringKey("chosen_version", e.chosenVersion.String()) +} + type eventConnectionClosed struct { Reason logging.CloseReason } diff --git a/qlog/qlog.go b/qlog/qlog.go index 5a553aa3..7a2c7306 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -182,6 +182,29 @@ func (t *connectionTracer) StartedConnection(local, remote net.Addr, srcConnID, t.mutex.Unlock() } +func (t *connectionTracer) NegotiatedVersion(chosen logging.VersionNumber, client, server []logging.VersionNumber) { + var clientVersions, serverVersions []versionNumber + if len(client) > 0 { + clientVersions = make([]versionNumber, len(client)) + for i, v := range client { + clientVersions[i] = versionNumber(v) + } + } + if len(server) > 0 { + serverVersions = make([]versionNumber, len(server)) + for i, v := range server { + serverVersions[i] = versionNumber(v) + } + } + t.mutex.Lock() + t.recordEvent(time.Now(), &eventVersionNegotiated{ + clientVersions: clientVersions, + serverVersions: serverVersions, + chosenVersion: versionNumber(chosen), + }) + t.mutex.Unlock() +} + func (t *connectionTracer) ClosedConnection(r logging.CloseReason) { t.mutex.Lock() t.recordEvent(time.Now(), &eventConnectionClosed{Reason: r}) diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index 7f40c25b..f9d51402 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -169,6 +169,30 @@ var _ = Describe("Tracing", func() { Expect(ev).To(HaveKeyWithValue("dst_cid", "05060708")) }) + It("records the version, if no version negotiation happened", func() { + tracer.NegotiatedVersion(0x1337, nil, nil) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:version_information")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("chosen_version", "1337")) + }) + + It("records the version, if version negotiation happened", func() { + tracer.NegotiatedVersion(0x1337, []logging.VersionNumber{1, 2, 3}, []logging.VersionNumber{4, 5, 6}) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:version_information")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("chosen_version", "1337")) + Expect(ev).To(HaveKey("client_versions")) + Expect(ev["client_versions"].([]interface{})).To(Equal([]interface{}{"1", "2", "3"})) + Expect(ev).To(HaveKey("server_versions")) + Expect(ev["server_versions"].([]interface{})).To(Equal([]interface{}{"4", "5", "6"})) + }) + It("records idle timeouts", func() { tracer.ClosedConnection(logging.NewTimeoutCloseReason(logging.TimeoutReasonIdle)) entry := exportAndParseSingle() diff --git a/session.go b/session.go index 0dd9f828..bf264dc3 100644 --- a/session.go +++ b/session.go @@ -1093,6 +1093,9 @@ func (s *session) handleVersionNegotiationPacket(p *receivedPacket) { s.logger.Infof("No compatible QUIC version found.") return } + if s.tracer != nil { + s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions) + } s.logger.Infof("Switching to QUIC version %s.", newVersion) nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial) @@ -1114,6 +1117,16 @@ func (s *session) handleUnpackedPacket( if !s.receivedFirstPacket { s.receivedFirstPacket = true + if !s.versionNegotiated && s.tracer != nil { + var clientVersions, serverVersions []protocol.VersionNumber + switch s.perspective { + case protocol.PerspectiveClient: + clientVersions = s.config.Versions + case protocol.PerspectiveServer: + serverVersions = s.config.Versions + } + s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions) + } // The server can change the source connection ID with the first Handshake packet. if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { cid := packet.hdr.SrcConnectionID diff --git a/session_test.go b/session_test.go index fdfc80a5..c5f90eb2 100644 --- a/session_test.go +++ b/session_test.go @@ -90,6 +90,7 @@ var _ = Describe("Session", func() { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) Expect(err).ToNot(HaveOccurred()) tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() tracer.EXPECT().UpdatedCongestionState(gomock.Any()) @@ -2464,6 +2465,7 @@ var _ = Describe("Client Session", func() { } sessionRunner = NewMockSessionRunner(mockCtrl) tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() tracer.EXPECT().UpdatedCongestionState(gomock.Any())