remove Tracer from Config, put ConnectionTracer constructor there

This commit is contained in:
Marten Seemann 2023-04-28 12:49:11 +02:00
parent 5544f0f9a1
commit 07ad2cbee2
31 changed files with 202 additions and 331 deletions

View file

@ -146,11 +146,7 @@ func dial(
c.tracingID = nextConnTracingID()
if c.config.Tracer != nil {
c.tracer = c.config.Tracer.TracerForConnection(
context.WithValue(ctx, ConnectionTracingKey, c.tracingID),
protocol.PerspectiveClient,
c.destConnID,
)
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
}
if c.tracer != nil {
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)

View file

@ -56,10 +56,12 @@ var _ = Describe("Client", func() {
connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37})
originalClientConnConstructor = newClientConnection
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
tr := mocklogging.NewMockTracer(mockCtrl)
tr.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
config = &Config{
Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer {
return tracer
},
Versions: []protocol.VersionNumber{protocol.Version1},
}
Eventually(areConnsRunning).Should(BeFalse())
packetConn = NewMockSendConn(mockCtrl)
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()

View file

@ -1,14 +1,15 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"reflect"
"time"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -46,7 +47,7 @@ var _ = Describe("Config", func() {
}
switch fn := typ.Field(i).Name; fn {
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease":
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer":
// Can't compare functions.
case "Versions":
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
@ -88,8 +89,6 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf(true))
case "Allow0RTT":
f.Set(reflect.ValueOf(true))
case "Tracer":
f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl)))
default:
Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
}
@ -109,11 +108,15 @@ var _ = Describe("Config", func() {
Context("cloning", func() {
It("clones function fields", func() {
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool
c1 := &Config{
GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer {
calledTracer = true
return nil
},
}
c2 := c1.Clone()
c2.RequireAddressValidation(&net.UDPAddr{})
@ -122,6 +125,8 @@ var _ = Describe("Config", func() {
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
_, err := c2.GetConfigForClient(&ClientHelloInfo{})
Expect(err).To(MatchError("nope"))
c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{})
Expect(calledTracer).To(BeTrue())
})
It("clones non-function fields", func() {

View file

@ -3,6 +3,7 @@ package main
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"flag"
@ -57,15 +58,15 @@ func main() {
var qconf quic.Config
if *enableQlog {
qconf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
filename := fmt.Sprintf("client_%x.qlog", connID)
f, err := os.Create(filename)
if err != nil {
log.Fatal(err)
}
log.Printf("Creating qlog file %s.\n", filename)
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
})
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
}
}
roundTripper := &http3.RoundTripper{
TLSClientConfig: &tls.Config{

View file

@ -2,6 +2,7 @@ package main
import (
"bufio"
"context"
"crypto/md5"
"errors"
"flag"
@ -162,15 +163,15 @@ func main() {
handler := setupHandler(*www)
quicConf := &quic.Config{}
if *enableQlog {
quicConf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
filename := fmt.Sprintf("server_%x.qlog", connID)
f, err := os.Create(filename)
if err != nil {
log.Fatal(err)
}
log.Printf("Creating qlog file %s.\n", filename)
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
})
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
}
}
var wg sync.WaitGroup

View file

@ -75,7 +75,9 @@ var _ = Describe("Key Update tests", func() {
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return &keyUpdateConnTracer{} })}),
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return &keyUpdateConnTracer{}
}}),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())

View file

@ -27,7 +27,7 @@ var _ = Describe("Packetization", func() {
getTLSConfig(),
getQuicConfig(&quic.Config{
DisablePathMTUDiscovery: true,
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
Tracer: newTracer(serverTracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -50,7 +50,7 @@ var _ = Describe("Packetization", func() {
getTLSClientConfig(),
getQuicConfig(&quic.Config{
DisablePathMTUDiscovery: true,
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
Tracer: newTracer(clientTracer),
}),
)
Expect(err).ToNot(HaveOccurred())

View file

@ -87,7 +87,7 @@ var (
logBuf *syncedBuffer
versionParam string
qlogTracer logging.Tracer
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
enableQlog bool
version quic.VersionNumber
@ -175,7 +175,13 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
if conf.Tracer == nil {
conf.Tracer = qlogTracer
} else if qlogTracer != nil {
conf.Tracer = logging.NewMultiplexedTracer(qlogTracer, conf.Tracer)
origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID),
origTracer(ctx, p, connID),
)
}
}
}
return conf
@ -232,19 +238,8 @@ func scaleDuration(d time.Duration) time.Duration {
return time.Duration(scaleFactor) * d
}
type tracer struct {
logging.NullTracer
createNewConnTracer func() logging.ConnectionTracer
}
var _ logging.Tracer = &tracer{}
func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
return &tracer{createNewConnTracer: c}
}
func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
return t.createNewConnTracer()
func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer }
}
type packet struct {

View file

@ -201,7 +201,7 @@ var _ = Describe("Timeout tests", func() {
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: idleTimeout,
Tracer: newTracer(func() logging.ConnectionTracer { return tr }),
Tracer: newTracer(tr),
DisablePathMTUDiscovery: true,
}),
)

View file

@ -19,14 +19,6 @@ import (
. "github.com/onsi/gomega"
)
type customTracer struct{ logging.NullTracer }
func (t *customTracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
return &customConnTracer{}
}
type customConnTracer struct{ logging.NullConnectionTracer }
var _ = Describe("Handshake tests", func() {
addTracers := func(pers protocol.Perspective, conf *quic.Config) *quic.Config {
enableQlog := mrand.Int()%3 != 0
@ -34,22 +26,32 @@ var _ = Describe("Handshake tests", func() {
fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer)
var tracers []logging.Tracer
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
if enableQlog {
tracers = append(tracers, qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connectionID)
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID)
return nil
}
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connectionID)
return utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil))
}))
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connID)
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)), p, connID)
})
}
if enableCustomTracer {
tracers = append(tracers, &customTracer{})
tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return logging.NullConnectionTracer{}
})
}
c := conf.Clone()
c.Tracer = logging.NewMultiplexedTracer(tracers...)
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors))
for _, c := range tracerConstructors {
if tr := c(ctx, p, connID); tr != nil {
tracers = append(tracers, tr)
}
}
return logging.NewMultiplexedConnectionTracer(tracers...)
}
return c
}

View file

@ -223,7 +223,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -277,7 +277,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -359,7 +359,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -435,7 +435,7 @@ var _ = Describe("0-RTT", func() {
getQuicConfig(&quic.Config{
RequireAddressValidation: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -496,7 +496,7 @@ var _ = Describe("0-RTT", func() {
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams + 1,
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -541,7 +541,7 @@ var _ = Describe("0-RTT", func() {
getQuicConfig(&quic.Config{
MaxIncomingStreams: maxStreams - 1,
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -569,7 +569,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -596,7 +596,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: false, // application rejects 0-RTT
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -622,7 +622,7 @@ var _ = Describe("0-RTT", func() {
secondConf := getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
})
addFlowControlLimit(secondConf, 100)
ln, err := quic.ListenAddrEarly(
@ -699,7 +699,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: 1,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -775,7 +775,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())

View file

@ -2,23 +2,25 @@ package tools
import (
"bufio"
"context"
"fmt"
"io"
"log"
"os"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
)
func NewQlogger(logger io.Writer) logging.Tracer {
return qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
filename := fmt.Sprintf("log_%x_%s.qlog", connID.Bytes(), role)
fmt.Fprintf(logger, "Creating %s.\n", filename)
f, err := os.Create(filename)
if err != nil {
@ -26,6 +28,6 @@ func NewQlogger(logger io.Writer) logging.Tracer {
return nil
}
bw := bufio.NewWriter(f)
return utils.NewBufferedWriteCloser(bw, f)
})
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bw, f), p, connID)
}
}

View file

@ -85,7 +85,9 @@ var _ = Describe("Handshake tests", func() {
serverConfig := &quic.Config{}
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
serverTracer := &versionNegotiationTracer{}
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return serverTracer
}
server, cl := startServer(getTLSConfig(), serverConfig)
defer cl()
clientTracer := &versionNegotiationTracer{}
@ -93,7 +95,9 @@ var _ = Describe("Handshake tests", func() {
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQlogTracer(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}),
maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer {
return clientTracer
}}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
@ -111,10 +115,12 @@ var _ = Describe("Handshake tests", func() {
expectedVersion := protocol.SupportedVersions[0]
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverTracer := &versionNegotiationTracer{}
serverConfig := &quic.Config{}
serverConfig.Versions = supportedVersions
serverTracer := &versionNegotiationTracer{}
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return serverTracer
}
server, cl := startServer(getTLSConfig(), serverConfig)
defer cl()
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
@ -123,9 +129,11 @@ var _ = Describe("Handshake tests", func() {
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQlogTracer(&quic.Config{
maybeAddQLOGTracer(&quic.Config{
Versions: clientVersions,
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return clientTracer
},
}),
)
Expect(err).ToNot(HaveOccurred())

View file

@ -47,7 +47,7 @@ var _ = Describe("Handshake RTT tests", func() {
context.Background(),
proxy.LocalAddr().String(),
getTLSClientConfig(),
maybeAddQlogTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}),
maybeAddQLOGTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}),
)
Expect(err).To(HaveOccurred())
expectDurationInRTTs(startTime, 1)

View file

@ -58,7 +58,7 @@ func TestQuicVersionNegotiation(t *testing.T) {
RunSpecs(t, "Version Negotiation Suite")
}
func maybeAddQlogTracer(c *quic.Config) *quic.Config {
func maybeAddQLOGTracer(c *quic.Config) *quic.Config {
if c == nil {
c = &quic.Config{}
}
@ -69,22 +69,13 @@ func maybeAddQlogTracer(c *quic.Config) *quic.Config {
if c.Tracer == nil {
c.Tracer = qlogger
} else if qlogger != nil {
c.Tracer = logging.NewMultiplexedTracer(qlogger, c.Tracer)
origTracer := c.Tracer
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogger(ctx, p, connID),
origTracer(ctx, p, connID),
)
}
}
return c
}
type tracer struct {
logging.NullTracer
createNewConnTracer func() logging.ConnectionTracer
}
var _ logging.Tracer = &tracer{}
func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
return &tracer{createNewConnTracer: c}
}
func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
return t.createNewConnTracer()
}

View file

@ -322,7 +322,7 @@ type Config struct {
Allow0RTT bool
// Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool
Tracer logging.Tracer
Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer
}
type ClientHelloInfo struct {

View file

@ -5,7 +5,6 @@
package mocklogging
import (
context "context"
net "net"
reflect "reflect"
@ -73,17 +72,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2,
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3)
}
// TracerForConnection mocks base method.
func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
ret0, _ := ret[0].(logging.ConnectionTracer)
return ret0
}
// TracerForConnection indicates an expected call of TracerForConnection.
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
}

View file

@ -21,7 +21,6 @@ import (
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/interop/http09"
"github.com/quic-go/quic-go/interop/utils"
"github.com/quic-go/quic-go/qlog"
)
var errUnsupported = errors.New("unsupported test case")
@ -65,11 +64,7 @@ func runTestcase(testcase string) error {
flag.Parse()
urls := flag.Args()
getLogWriter, err := utils.GetQLOGWriter()
if err != nil {
return err
}
quicConf := &quic.Config{Tracer: qlog.NewTracer(getLogWriter)}
quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer}
if testcase == "http3" {
r := &http3.RoundTripper{

View file

@ -13,7 +13,6 @@ import (
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/interop/http09"
"github.com/quic-go/quic-go/interop/utils"
"github.com/quic-go/quic-go/qlog"
)
var tlsConf *tls.Config
@ -38,16 +37,10 @@ func main() {
testcase := os.Getenv("TESTCASE")
getLogWriter, err := utils.GetQLOGWriter()
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
// a quic.Config that doesn't do a Retry
quicConf := &quic.Config{
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" },
Allow0RTT: testcase == "zerortt",
Tracer: qlog.NewTracer(getLogWriter),
Tracer: utils.NewQLOGConnectionTracer,
}
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
if err != nil {

View file

@ -2,14 +2,17 @@ package utils
import (
"bufio"
"context"
"fmt"
"io"
"log"
"os"
"strings"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
)
// GetSSLKeyLog creates a file for the TLS key log
@ -25,18 +28,17 @@ func GetSSLKeyLog() (io.WriteCloser, error) {
return f, nil
}
// GetQLOGWriter creates the QLOGDIR and returns the GetLogWriter callback
func GetQLOGWriter() (func(perspective logging.Perspective, connID []byte) io.WriteCloser, error) {
// NewQLOGConnectionTracer create a qlog file in QLOGDIR
func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
qlogDir := os.Getenv("QLOGDIR")
if len(qlogDir) == 0 {
return nil, nil
return nil
}
if _, err := os.Stat(qlogDir); os.IsNotExist(err) {
if err := os.MkdirAll(qlogDir, 0o666); err != nil {
return nil, fmt.Errorf("failed to create qlog dir %s: %s", qlogDir, err.Error())
log.Fatalf("failed to create qlog dir %s: %v", qlogDir, err)
}
}
return func(_ logging.Perspective, connID []byte) io.WriteCloser {
path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID)
f, err := os.Create(path)
if err != nil {
@ -44,6 +46,5 @@ func GetQLOGWriter() (func(perspective logging.Perspective, connID []byte) io.Wr
return nil
}
log.Printf("Created qlog file: %s\n", path)
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
}, nil
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
}

View file

@ -3,7 +3,6 @@
package logging
import (
"context"
"net"
"time"
@ -101,12 +100,6 @@ type ShortHeader struct {
// A Tracer traces events.
type Tracer interface {
// TracerForConnection requests a new tracer for a connection.
// The ODCID is the original destination connection ID:
// The destination connection ID that the client used on the first Initial packet it sent on this connection.
// If nil is returned, tracing will be disabled for this connection.
TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer
SentPacket(net.Addr, *Header, ByteCount, []Frame)
SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason)

View file

@ -5,7 +5,6 @@
package logging
import (
context "context"
net "net"
reflect "reflect"
@ -72,17 +71,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2,
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3)
}
// TracerForConnection mocks base method.
func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
ret0, _ := ret[0].(ConnectionTracer)
return ret0
}
// TracerForConnection indicates an expected call of TracerForConnection.
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
}

View file

@ -1,7 +1,6 @@
package logging
import (
"context"
"net"
"time"
)
@ -23,16 +22,6 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer {
return &tracerMultiplexer{tracers}
}
func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer {
var connTracers []ConnectionTracer
for _, t := range m.tracers {
if ct := t.TracerForConnection(ctx, p, odcid); ct != nil {
connTracers = append(connTracers, ct)
}
}
return NewMultiplexedConnectionTracer(connTracers...)
}
func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) {
for _, t := range m.tracers {
t.SentPacket(remote, hdr, size, frames)

View file

@ -1,7 +1,6 @@
package logging
import (
"context"
"errors"
"net"
"time"
@ -37,46 +36,6 @@ var _ = Describe("Tracing", func() {
tracer = NewMultiplexedTracer(tr1, tr2)
})
It("multiplexes the TracerForConnection call", func() {
ctx := context.Background()
connID := protocol.ParseConnectionID([]byte{1, 2, 3})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
tracer.TracerForConnection(ctx, PerspectiveClient, connID)
})
It("uses multiple connection tracers", func() {
ctx := context.Background()
ctr1 := NewMockConnectionTracer(mockCtrl)
ctr2 := NewMockConnectionTracer(mockCtrl)
connID := protocol.ParseConnectionID([]byte{1, 2, 3})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr2)
tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID)
ctr1.EXPECT().LossTimerCanceled()
ctr2.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled()
})
It("handles tracers that return a nil ConnectionTracer", func() {
ctx := context.Background()
ctr1 := NewMockConnectionTracer(mockCtrl)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID)
tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID)
ctr1.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled()
})
It("returns nil when all tracers return a nil ConnectionTracer", func() {
ctx := context.Background()
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
Expect(tracer.TracerForConnection(ctx, PerspectiveClient, connID)).To(BeNil())
})
It("traces the PacketSent event", func() {
remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)}
hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}

View file

@ -1,7 +1,6 @@
package logging
import (
"context"
"net"
"time"
)
@ -12,9 +11,6 @@ type NullTracer struct{}
var _ Tracer = &NullTracer{}
func (n NullTracer) TracerForConnection(context.Context, Perspective, ConnectionID) ConnectionTracer {
return NullConnectionTracer{}
}
func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {}
func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) {
}

View file

@ -2,7 +2,6 @@ package qlog
import (
"bytes"
"context"
"fmt"
"io"
"log"
@ -49,26 +48,6 @@ func init() {
const eventChanSize = 50
type tracer struct {
logging.NullTracer
getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser
}
var _ logging.Tracer = &tracer{}
// NewTracer creates a new qlog tracer.
func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser) logging.Tracer {
return &tracer{getLogWriter: getLogWriter}
}
func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer {
if w := t.getLogWriter(p, odcid.Bytes()); w != nil {
return NewConnectionTracer(w, p, odcid)
}
return nil
}
type connectionTracer struct {
mutex sync.Mutex

View file

@ -2,7 +2,6 @@ package qlog
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
@ -51,17 +50,6 @@ type entry struct {
}
var _ = Describe("Tracing", func() {
Context("tracer", func() {
It("returns nil when there's no io.WriteCloser", func() {
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil })
Expect(t.TracerForConnection(
context.Background(),
logging.PerspectiveClient,
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
)).To(BeNil())
})
})
It("stops writing when encountering an error", func() {
buf := &bytes.Buffer{}
t := NewConnectionTracer(
@ -88,9 +76,8 @@ var _ = Describe("Tracing", func() {
BeforeEach(func() {
buf = &bytes.Buffer{}
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) })
tracer = t.TracerForConnection(
context.Background(),
tracer = NewConnectionTracer(
nopWriteCloser(buf),
logging.PerspectiveServer,
protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
)

108
server.go
View file

@ -106,6 +106,8 @@ type baseServer struct {
connQueue chan quicConn
connQueueLen int32 // to be used as an atomic
tracer logging.Tracer
logger utils.Logger
}
@ -212,7 +214,16 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear
return tr.ListenEarly(tlsConf, config)
}
func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) {
func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
tlsConf *tls.Config,
config *Config,
tracer logging.Tracer,
onClose func(),
acceptEarly bool,
) (*baseServer, error) {
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
if err != nil {
return nil, err
@ -229,6 +240,7 @@ func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator C
running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
onClose: onClose,
@ -318,8 +330,8 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
case s.receivedPackets <- p:
default:
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
}
}
}
@ -331,8 +343,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
if wire.IsVersionNegotiationPacket(p.data) {
s.logger.Debugf("Dropping Version Negotiation packet.")
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@ -345,16 +357,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) {
if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize {
s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size())
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
if err != nil { // should never happen
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@ -366,8 +378,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
if wire.Is0RTTPacket(p.data) {
if !s.acceptEarlyConns {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@ -378,16 +390,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
// The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Error parsing packet: %s", err)
return false
}
if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@ -397,8 +409,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
// There's little point in sending a Stateless Reset, since the client
// might not have received the token yet.
s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@ -416,8 +428,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
}
return false
}
@ -430,8 +442,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
@ -440,8 +452,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
}
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
@ -468,8 +480,8 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
continue
}
for _, p := range q.packets {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
p.buffer.Release()
}
@ -504,8 +516,8 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release()
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
}
return errors.New("too short connection ID")
}
@ -585,19 +597,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
var conn quicConn
tracingID := nextConnTracingID()
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
var tracer logging.ConnectionTracer
if s.config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := hdr.DestConnectionID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = s.config.Tracer.TracerForConnection(
context.WithValue(context.Background(), ConnectionTracingKey, tracingID),
protocol.PerspectiveServer,
connID,
)
}
config := s.config
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
@ -607,6 +606,15 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
config = populateConfig(conf)
}
var tracer logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := hdr.DestConnectionID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
}
conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info),
s.connHandler,
@ -715,8 +723,8 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
// append the Retry integrity tag
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
buf.Data = append(buf.Data, tag[:]...)
if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
}
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB())
return err
@ -729,8 +737,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
if err != nil {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
}
// don't return the error here. Just drop the packet.
return nil
@ -738,8 +746,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
hdrLen := extHdr.ParsedLen()
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
// don't return the error here. Just drop the packet.
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
}
return nil
}
@ -792,8 +800,8 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
replyHdr.Log(s.logger)
wire.LogFrame(s.logger, ccf, true)
if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
}
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB())
return err
@ -803,8 +811,8 @@ func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest pro
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
if s.config.Tracer != nil {
s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions)
if s.tracer != nil {
s.tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions)
}
if _, err := s.conn.WritePacket(data, remote, oob); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)

View file

@ -170,7 +170,7 @@ var _ = Describe("Server", func() {
Context("server accepting connections that completed the handshake", func() {
var (
ln *Listener
tr *Transport
serv *baseServer
phm *MockPacketHandlerManager
tracer *mocklogging.MockTracer
@ -178,8 +178,8 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
tracer = mocklogging.NewMockTracer(mockCtrl)
var err error
ln, err = Listen(conn, tlsConf, &Config{Tracer: tracer})
tr = &Transport{Conn: conn, Tracer: tracer}
ln, err := tr.Listen(tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
serv = ln.baseServer
phm = NewMockPacketHandlerManager(mockCtrl)
@ -187,7 +187,7 @@ var _ = Describe("Server", func() {
})
AfterEach(func() {
ln.Close()
tr.Close()
})
Context("handling packets", func() {
@ -276,7 +276,6 @@ var _ = Describe("Server", func() {
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
@ -478,7 +477,6 @@ var _ = Describe("Server", func() {
return ok
}),
)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
@ -537,7 +535,6 @@ var _ = Describe("Server", func() {
_, ok := fn()
return ok
}).AnyTimes()
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
acceptConn := make(chan struct{})
var counter uint32 // to be used as an atomic, so we query it in Eventually
@ -665,7 +662,6 @@ var _ = Describe("Server", func() {
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize)
@ -737,7 +733,6 @@ var _ = Describe("Server", func() {
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
@ -1062,6 +1057,7 @@ var _ = Describe("Server", func() {
return ok
})
done := make(chan struct{})
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
rejectHdr := parseHeader(b)
@ -1119,7 +1115,6 @@ var _ = Describe("Server", func() {
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
@ -1326,6 +1321,7 @@ var _ = Describe("Server", func() {
Context("0-RTT", func() {
var (
tr *Transport
serv *baseServer
phm *MockPacketHandlerManager
tracer *mocklogging.MockTracer
@ -1333,7 +1329,8 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
tracer = mocklogging.NewMockTracer(mockCtrl)
ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer})
tr = &Transport{Conn: conn, Tracer: tracer}
ln, err := tr.ListenEarly(tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
phm = NewMockPacketHandlerManager(mockCtrl)
serv = ln.baseServer
@ -1342,7 +1339,7 @@ var _ = Describe("Server", func() {
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
serv.Close()
tr.Close()
})
It("passes packets to existing connections", func() {
@ -1425,7 +1422,6 @@ var _ = Describe("Server", func() {
return conn
}
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())

View file

@ -95,10 +95,10 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error)
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(conf, true); err != nil {
if err := t.init(true); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, false)
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
if err != nil {
return nil, err
}
@ -124,10 +124,10 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(conf, true); err != nil {
if err := t.init(true); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, true)
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true)
if err != nil {
return nil, err
}
@ -141,7 +141,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
return nil, err
}
conf = populateConfig(conf)
if err := t.init(conf, false); err != nil {
if err := t.init(false); err != nil {
return nil, err
}
var onClose func()
@ -157,7 +157,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
return nil, err
}
conf = populateConfig(conf)
if err := t.init(conf, false); err != nil {
if err := t.init(false); err != nil {
return nil, err
}
var onClose func()
@ -200,7 +200,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
// only print warnings about the UDP receive buffer size once
var receiveBufferWarningOnce sync.Once
func (t *Transport) init(conf *Config, isServer bool) error {
func (t *Transport) init(isServer bool) error {
t.initOnce.Do(func() {
getMultiplexer().AddConn(t.Conn)
@ -210,7 +210,6 @@ func (t *Transport) init(conf *Config, isServer bool) error {
return
}
t.Tracer = conf.Tracer
t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)

View file

@ -61,7 +61,7 @@ var _ = Describe("Transport", func() {
It("handles packets for different packet handlers on the same packet conn", func() {
packetChan := make(chan packetToRead)
tr := &Transport{Conn: newMockPacketConn(packetChan)}
tr.init(&Config{}, true)
tr.init(true)
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
@ -128,8 +128,9 @@ var _ = Describe("Transport", func() {
tr := &Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: 10,
Tracer: tracer,
}
tr.init(&Config{Tracer: tracer}, true)
tr.init(true)
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
packetChan <- packetToRead{
@ -148,7 +149,7 @@ var _ = Describe("Transport", func() {
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{}, true)
tr.init(true)
tr.handlerMap = phm
done := make(chan struct{})
@ -166,7 +167,7 @@ var _ = Describe("Transport", func() {
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(&Config{}, true)
tr.init(true)
tr.handlerMap = phm
tempErr := deadlineError{}
@ -188,7 +189,7 @@ var _ = Describe("Transport", func() {
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: connID.Len(),
}
tr.init(&Config{}, true)
tr.init(true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
@ -221,7 +222,7 @@ var _ = Describe("Transport", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
tr.init(&Config{}, true)
tr.init(true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
@ -257,7 +258,7 @@ var _ = Describe("Transport", func() {
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
ConnectionIDLength: connID.Len(),
}
tr.init(&Config{}, true)
tr.init(true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm