mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
Merge pull request #3153 from lucas-clemente/trace-version-selection
trace and qlog version selection / negotiation
This commit is contained in:
commit
0413afd615
12 changed files with 170 additions and 4 deletions
|
@ -8,10 +8,11 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/israce"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/logging"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -46,6 +47,29 @@ func (c *tokenStore) Pop(key string) *quic.ClientToken {
|
|||
return c.store.Pop(key)
|
||||
}
|
||||
|
||||
type versionNegotiationTracer struct {
|
||||
connTracer
|
||||
|
||||
loggedVersions bool
|
||||
receivedVersionNegotiation bool
|
||||
chosen logging.VersionNumber
|
||||
clientVersions, serverVersions []logging.VersionNumber
|
||||
}
|
||||
|
||||
func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
|
||||
if t.loggedVersions {
|
||||
Fail("only expected one call to NegotiatedVersions")
|
||||
}
|
||||
t.loggedVersions = true
|
||||
t.chosen = chosen
|
||||
t.clientVersions = clientVersions
|
||||
t.serverVersions = serverVersions
|
||||
}
|
||||
|
||||
func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(*logging.Header, []logging.VersionNumber) {
|
||||
t.receivedVersionNegotiation = true
|
||||
}
|
||||
|
||||
var _ = Describe("Handshake tests", func() {
|
||||
var (
|
||||
server quic.Listener
|
||||
|
@ -97,37 +121,61 @@ var _ = Describe("Handshake tests", func() {
|
|||
})
|
||||
|
||||
It("when the server supports more versions than the client", func() {
|
||||
expectedVersion := protocol.SupportedVersions[0]
|
||||
// the server doesn't support the highest supported version, which is the first one the client will try
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
|
||||
runServer(getTLSConfig())
|
||||
defer server.Close()
|
||||
clientTracer := &versionNegotiationTracer{}
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
nil,
|
||||
getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0]))
|
||||
Expect(sess.(versioner).GetVersion()).To(Equal(expectedVersion))
|
||||
Expect(sess.CloseWithError(0, "")).To(Succeed())
|
||||
Expect(clientTracer.chosen).To(Equal(expectedVersion))
|
||||
Expect(clientTracer.receivedVersionNegotiation).To(BeFalse())
|
||||
Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(clientTracer.serverVersions).To(BeEmpty())
|
||||
Expect(serverTracer.chosen).To(Equal(expectedVersion))
|
||||
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
|
||||
Expect(serverTracer.clientVersions).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("when the client supports more versions than the server supports", func() {
|
||||
expectedVersion := protocol.SupportedVersions[0]
|
||||
// the server doesn't support the highest supported version, which is the first one the client will try
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverConfig.Versions = supportedVersions
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
|
||||
runServer(getTLSConfig())
|
||||
defer server.Close()
|
||||
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
|
||||
clientTracer := &versionNegotiationTracer{}
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
Versions: []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10},
|
||||
Versions: clientVersions,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0]))
|
||||
Expect(sess.CloseWithError(0, "")).To(Succeed())
|
||||
Expect(clientTracer.chosen).To(Equal(expectedVersion))
|
||||
Expect(clientTracer.receivedVersionNegotiation).To(BeTrue())
|
||||
Expect(clientTracer.clientVersions).To(Equal(clientVersions))
|
||||
Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions
|
||||
Expect(serverTracer.chosen).To(Equal(expectedVersion))
|
||||
Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
|
||||
Expect(serverTracer.clientVersions).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
@ -341,6 +341,9 @@ var _ logging.ConnectionTracer = &connTracer{}
|
|||
|
||||
func (t *connTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) {
|
||||
}
|
||||
|
||||
func (t *connTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
|
||||
}
|
||||
func (t *connTracer) ClosedConnection(logging.CloseReason) {}
|
||||
func (t *connTracer) SentTransportParameters(*logging.TransportParameters) {}
|
||||
func (t *connTracer) ReceivedTransportParameters(*logging.TransportParameters) {}
|
||||
|
|
|
@ -38,6 +38,9 @@ var _ logging.ConnectionTracer = &customConnTracer{}
|
|||
|
||||
func (t *customConnTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) {
|
||||
}
|
||||
|
||||
func (t *customConnTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
|
||||
}
|
||||
func (t *customConnTracer) ClosedConnection(logging.CloseReason) {}
|
||||
func (t *customConnTracer) SentTransportParameters(*logging.TransportParameters) {}
|
||||
func (t *customConnTracer) ReceivedTransportParameters(*logging.TransportParameters) {}
|
||||
|
|
|
@ -171,6 +171,18 @@ func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interfac
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// NegotiatedVersion mocks base method.
|
||||
func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// NegotiatedVersion indicates an expected call of NegotiatedVersion.
|
||||
func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// ReceivedPacket mocks base method.
|
||||
func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -105,6 +105,7 @@ type Tracer interface {
|
|||
// A ConnectionTracer records events.
|
||||
type ConnectionTracer interface {
|
||||
StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID)
|
||||
NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber)
|
||||
ClosedConnection(CloseReason)
|
||||
SentTransportParameters(*TransportParameters)
|
||||
ReceivedTransportParameters(*TransportParameters)
|
||||
|
|
|
@ -170,6 +170,18 @@ func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interfac
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// NegotiatedVersion mocks base method.
|
||||
func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// NegotiatedVersion indicates an expected call of NegotiatedVersion.
|
||||
func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// ReceivedPacket mocks base method.
|
||||
func (m *MockConnectionTracer) ReceivedPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 []Frame) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -68,6 +68,12 @@ func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcCon
|
|||
}
|
||||
}
|
||||
|
||||
func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) {
|
||||
for _, t := range m.tracers {
|
||||
t.NegotiatedVersion(chosen, clientVersions, serverVersions)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connTracerMultiplexer) ClosedConnection(reason CloseReason) {
|
||||
for _, t := range m.tracers {
|
||||
t.ClosedConnection(reason)
|
||||
|
|
|
@ -83,6 +83,25 @@ func (e eventConnectionStarted) MarshalJSONObject(enc *gojay.Encoder) {
|
|||
enc.StringKey("dst_cid", connectionID(e.DestConnectionID).String())
|
||||
}
|
||||
|
||||
type eventVersionNegotiated struct {
|
||||
clientVersions, serverVersions []versionNumber
|
||||
chosenVersion versionNumber
|
||||
}
|
||||
|
||||
func (e eventVersionNegotiated) Category() category { return categoryTransport }
|
||||
func (e eventVersionNegotiated) Name() string { return "version_information" }
|
||||
func (e eventVersionNegotiated) IsNil() bool { return false }
|
||||
|
||||
func (e eventVersionNegotiated) MarshalJSONObject(enc *gojay.Encoder) {
|
||||
if len(e.clientVersions) > 0 {
|
||||
enc.ArrayKey("client_versions", versions(e.clientVersions))
|
||||
}
|
||||
if len(e.serverVersions) > 0 {
|
||||
enc.ArrayKey("server_versions", versions(e.serverVersions))
|
||||
}
|
||||
enc.StringKey("chosen_version", e.chosenVersion.String())
|
||||
}
|
||||
|
||||
type eventConnectionClosed struct {
|
||||
Reason logging.CloseReason
|
||||
}
|
||||
|
|
23
qlog/qlog.go
23
qlog/qlog.go
|
@ -183,6 +183,29 @@ func (t *connectionTracer) StartedConnection(local, remote net.Addr, srcConnID,
|
|||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (t *connectionTracer) NegotiatedVersion(chosen logging.VersionNumber, client, server []logging.VersionNumber) {
|
||||
var clientVersions, serverVersions []versionNumber
|
||||
if len(client) > 0 {
|
||||
clientVersions = make([]versionNumber, len(client))
|
||||
for i, v := range client {
|
||||
clientVersions[i] = versionNumber(v)
|
||||
}
|
||||
}
|
||||
if len(server) > 0 {
|
||||
serverVersions = make([]versionNumber, len(server))
|
||||
for i, v := range server {
|
||||
serverVersions[i] = versionNumber(v)
|
||||
}
|
||||
}
|
||||
t.mutex.Lock()
|
||||
t.recordEvent(time.Now(), &eventVersionNegotiated{
|
||||
clientVersions: clientVersions,
|
||||
serverVersions: serverVersions,
|
||||
chosenVersion: versionNumber(chosen),
|
||||
})
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (t *connectionTracer) ClosedConnection(r logging.CloseReason) {
|
||||
t.mutex.Lock()
|
||||
t.recordEvent(time.Now(), &eventConnectionClosed{Reason: r})
|
||||
|
|
|
@ -170,6 +170,30 @@ var _ = Describe("Tracing", func() {
|
|||
Expect(ev).To(HaveKeyWithValue("dst_cid", "05060708"))
|
||||
})
|
||||
|
||||
It("records the version, if no version negotiation happened", func() {
|
||||
tracer.NegotiatedVersion(0x1337, nil, nil)
|
||||
entry := exportAndParseSingle()
|
||||
Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond)))
|
||||
Expect(entry.Name).To(Equal("transport:version_information"))
|
||||
ev := entry.Event
|
||||
Expect(ev).To(HaveLen(1))
|
||||
Expect(ev).To(HaveKeyWithValue("chosen_version", "1337"))
|
||||
})
|
||||
|
||||
It("records the version, if version negotiation happened", func() {
|
||||
tracer.NegotiatedVersion(0x1337, []logging.VersionNumber{1, 2, 3}, []logging.VersionNumber{4, 5, 6})
|
||||
entry := exportAndParseSingle()
|
||||
Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond)))
|
||||
Expect(entry.Name).To(Equal("transport:version_information"))
|
||||
ev := entry.Event
|
||||
Expect(ev).To(HaveLen(3))
|
||||
Expect(ev).To(HaveKeyWithValue("chosen_version", "1337"))
|
||||
Expect(ev).To(HaveKey("client_versions"))
|
||||
Expect(ev["client_versions"].([]interface{})).To(Equal([]interface{}{"1", "2", "3"}))
|
||||
Expect(ev).To(HaveKey("server_versions"))
|
||||
Expect(ev["server_versions"].([]interface{})).To(Equal([]interface{}{"4", "5", "6"}))
|
||||
})
|
||||
|
||||
It("records idle timeouts", func() {
|
||||
tracer.ClosedConnection(logging.NewTimeoutCloseReason(logging.TimeoutReasonIdle))
|
||||
entry := exportAndParseSingle()
|
||||
|
|
13
session.go
13
session.go
|
@ -1100,6 +1100,9 @@ func (s *session) handleVersionNegotiationPacket(p *receivedPacket) {
|
|||
s.logger.Infof("No compatible QUIC version found.")
|
||||
return
|
||||
}
|
||||
if s.tracer != nil {
|
||||
s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions)
|
||||
}
|
||||
|
||||
s.logger.Infof("Switching to QUIC version %s.", newVersion)
|
||||
nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial)
|
||||
|
@ -1121,6 +1124,16 @@ func (s *session) handleUnpackedPacket(
|
|||
|
||||
if !s.receivedFirstPacket {
|
||||
s.receivedFirstPacket = true
|
||||
if !s.versionNegotiated && s.tracer != nil {
|
||||
var clientVersions, serverVersions []protocol.VersionNumber
|
||||
switch s.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
clientVersions = s.config.Versions
|
||||
case protocol.PerspectiveServer:
|
||||
serverVersions = s.config.Versions
|
||||
}
|
||||
s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions)
|
||||
}
|
||||
// The server can change the source connection ID with the first Handshake packet.
|
||||
if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
|
||||
cid := packet.hdr.SrcConnectionID
|
||||
|
|
|
@ -90,6 +90,7 @@ var _ = Describe("Session", func() {
|
|||
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
tracer.EXPECT().SentTransportParameters(gomock.Any())
|
||||
tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
tracer.EXPECT().UpdatedCongestionState(gomock.Any())
|
||||
|
@ -2465,6 +2466,7 @@ var _ = Describe("Client Session", func() {
|
|||
}
|
||||
sessionRunner = NewMockSessionRunner(mockCtrl)
|
||||
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
tracer.EXPECT().SentTransportParameters(gomock.Any())
|
||||
tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
tracer.EXPECT().UpdatedCongestionState(gomock.Any())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue