mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
remove Tracer from Config, put ConnectionTracer constructor there
This commit is contained in:
parent
5544f0f9a1
commit
07ad2cbee2
31 changed files with 202 additions and 331 deletions
|
@ -146,11 +146,7 @@ func dial(
|
|||
|
||||
c.tracingID = nextConnTracingID()
|
||||
if c.config.Tracer != nil {
|
||||
c.tracer = c.config.Tracer.TracerForConnection(
|
||||
context.WithValue(ctx, ConnectionTracingKey, c.tracingID),
|
||||
protocol.PerspectiveClient,
|
||||
c.destConnID,
|
||||
)
|
||||
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
|
||||
}
|
||||
if c.tracer != nil {
|
||||
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
|
||||
|
|
|
@ -56,10 +56,12 @@ var _ = Describe("Client", func() {
|
|||
connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37})
|
||||
originalClientConnConstructor = newClientConnection
|
||||
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||
tr := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
|
||||
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||
config = &Config{
|
||||
Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer {
|
||||
return tracer
|
||||
},
|
||||
Versions: []protocol.VersionNumber{protocol.Version1},
|
||||
}
|
||||
Eventually(areConnsRunning).Should(BeFalse())
|
||||
packetConn = NewMockSendConn(mockCtrl)
|
||||
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -46,7 +47,7 @@ var _ = Describe("Config", func() {
|
|||
}
|
||||
|
||||
switch fn := typ.Field(i).Name; fn {
|
||||
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease":
|
||||
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer":
|
||||
// Can't compare functions.
|
||||
case "Versions":
|
||||
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
|
||||
|
@ -88,8 +89,6 @@ var _ = Describe("Config", func() {
|
|||
f.Set(reflect.ValueOf(true))
|
||||
case "Allow0RTT":
|
||||
f.Set(reflect.ValueOf(true))
|
||||
case "Tracer":
|
||||
f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl)))
|
||||
default:
|
||||
Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
|
||||
}
|
||||
|
@ -109,11 +108,15 @@ var _ = Describe("Config", func() {
|
|||
|
||||
Context("cloning", func() {
|
||||
It("clones function fields", func() {
|
||||
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
|
||||
var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool
|
||||
c1 := &Config{
|
||||
GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
|
||||
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
|
||||
Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer {
|
||||
calledTracer = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
c2 := c1.Clone()
|
||||
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||
|
@ -122,6 +125,8 @@ var _ = Describe("Config", func() {
|
|||
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
|
||||
_, err := c2.GetConfigForClient(&ClientHelloInfo{})
|
||||
Expect(err).To(MatchError("nope"))
|
||||
c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{})
|
||||
Expect(calledTracer).To(BeTrue())
|
||||
})
|
||||
|
||||
It("clones non-function fields", func() {
|
||||
|
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"flag"
|
||||
|
@ -57,15 +58,15 @@ func main() {
|
|||
|
||||
var qconf quic.Config
|
||||
if *enableQlog {
|
||||
qconf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
|
||||
qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
filename := fmt.Sprintf("client_%x.qlog", connID)
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf("Creating qlog file %s.\n", filename)
|
||||
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
|
||||
})
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
|
||||
}
|
||||
}
|
||||
roundTripper := &http3.RoundTripper{
|
||||
TLSClientConfig: &tls.Config{
|
||||
|
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"flag"
|
||||
|
@ -162,15 +163,15 @@ func main() {
|
|||
handler := setupHandler(*www)
|
||||
quicConf := &quic.Config{}
|
||||
if *enableQlog {
|
||||
quicConf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
|
||||
quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
filename := fmt.Sprintf("server_%x.qlog", connID)
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf("Creating qlog file %s.\n", filename)
|
||||
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
|
||||
})
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
|
|
@ -75,7 +75,9 @@ var _ = Describe("Key Update tests", func() {
|
|||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return &keyUpdateConnTracer{} })}),
|
||||
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return &keyUpdateConnTracer{}
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
|
|
|
@ -27,7 +27,7 @@ var _ = Describe("Packetization", func() {
|
|||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
DisablePathMTUDiscovery: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
|
||||
Tracer: newTracer(serverTracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -50,7 +50,7 @@ var _ = Describe("Packetization", func() {
|
|||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
DisablePathMTUDiscovery: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
|
||||
Tracer: newTracer(clientTracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -87,7 +87,7 @@ var (
|
|||
logBuf *syncedBuffer
|
||||
versionParam string
|
||||
|
||||
qlogTracer logging.Tracer
|
||||
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
|
||||
enableQlog bool
|
||||
|
||||
version quic.VersionNumber
|
||||
|
@ -175,7 +175,13 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
|
|||
if conf.Tracer == nil {
|
||||
conf.Tracer = qlogTracer
|
||||
} else if qlogTracer != nil {
|
||||
conf.Tracer = logging.NewMultiplexedTracer(qlogTracer, conf.Tracer)
|
||||
origTracer := conf.Tracer
|
||||
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
return logging.NewMultiplexedConnectionTracer(
|
||||
qlogTracer(ctx, p, connID),
|
||||
origTracer(ctx, p, connID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return conf
|
||||
|
@ -232,19 +238,8 @@ func scaleDuration(d time.Duration) time.Duration {
|
|||
return time.Duration(scaleFactor) * d
|
||||
}
|
||||
|
||||
type tracer struct {
|
||||
logging.NullTracer
|
||||
createNewConnTracer func() logging.ConnectionTracer
|
||||
}
|
||||
|
||||
var _ logging.Tracer = &tracer{}
|
||||
|
||||
func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
|
||||
return &tracer{createNewConnTracer: c}
|
||||
}
|
||||
|
||||
func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
|
||||
return t.createNewConnTracer()
|
||||
func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer }
|
||||
}
|
||||
|
||||
type packet struct {
|
||||
|
|
|
@ -201,7 +201,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tr }),
|
||||
Tracer: newTracer(tr),
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
|
|
|
@ -19,14 +19,6 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type customTracer struct{ logging.NullTracer }
|
||||
|
||||
func (t *customTracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
|
||||
return &customConnTracer{}
|
||||
}
|
||||
|
||||
type customConnTracer struct{ logging.NullConnectionTracer }
|
||||
|
||||
var _ = Describe("Handshake tests", func() {
|
||||
addTracers := func(pers protocol.Perspective, conf *quic.Config) *quic.Config {
|
||||
enableQlog := mrand.Int()%3 != 0
|
||||
|
@ -34,22 +26,32 @@ var _ = Describe("Handshake tests", func() {
|
|||
|
||||
fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer)
|
||||
|
||||
var tracers []logging.Tracer
|
||||
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
|
||||
if enableQlog {
|
||||
tracers = append(tracers, qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
|
||||
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connectionID)
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID)
|
||||
return nil
|
||||
}
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connectionID)
|
||||
return utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil))
|
||||
}))
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connID)
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)), p, connID)
|
||||
})
|
||||
}
|
||||
if enableCustomTracer {
|
||||
tracers = append(tracers, &customTracer{})
|
||||
tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return logging.NullConnectionTracer{}
|
||||
})
|
||||
}
|
||||
c := conf.Clone()
|
||||
c.Tracer = logging.NewMultiplexedTracer(tracers...)
|
||||
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors))
|
||||
for _, c := range tracerConstructors {
|
||||
if tr := c(ctx, p, connID); tr != nil {
|
||||
tracers = append(tracers, tr)
|
||||
}
|
||||
}
|
||||
return logging.NewMultiplexedConnectionTracer(tracers...)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
|
|
|
@ -223,7 +223,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -277,7 +277,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -359,7 +359,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -435,7 +435,7 @@ var _ = Describe("0-RTT", func() {
|
|||
getQuicConfig(&quic.Config{
|
||||
RequireAddressValidation: func(net.Addr) bool { return true },
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -496,7 +496,7 @@ var _ = Describe("0-RTT", func() {
|
|||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: maxStreams + 1,
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -541,7 +541,7 @@ var _ = Describe("0-RTT", func() {
|
|||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingStreams: maxStreams - 1,
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -569,7 +569,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -596,7 +596,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: false, // application rejects 0-RTT
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -622,7 +622,7 @@ var _ = Describe("0-RTT", func() {
|
|||
|
||||
secondConf := getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
})
|
||||
addFlowControlLimit(secondConf, 100)
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
|
@ -699,7 +699,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: 1,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -775,7 +775,7 @@ var _ = Describe("0-RTT", func() {
|
|||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -2,23 +2,25 @@ package tools
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
func NewQlogger(logger io.Writer) logging.Tracer {
|
||||
return qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
|
||||
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
role := "server"
|
||||
if p == logging.PerspectiveClient {
|
||||
role = "client"
|
||||
}
|
||||
filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
|
||||
filename := fmt.Sprintf("log_%x_%s.qlog", connID.Bytes(), role)
|
||||
fmt.Fprintf(logger, "Creating %s.\n", filename)
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
|
@ -26,6 +28,6 @@ func NewQlogger(logger io.Writer) logging.Tracer {
|
|||
return nil
|
||||
}
|
||||
bw := bufio.NewWriter(f)
|
||||
return utils.NewBufferedWriteCloser(bw, f)
|
||||
})
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bw, f), p, connID)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,7 +85,9 @@ var _ = Describe("Handshake tests", func() {
|
|||
serverConfig := &quic.Config{}
|
||||
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return serverTracer
|
||||
}
|
||||
server, cl := startServer(getTLSConfig(), serverConfig)
|
||||
defer cl()
|
||||
clientTracer := &versionNegotiationTracer{}
|
||||
|
@ -93,7 +95,9 @@ var _ = Describe("Handshake tests", func() {
|
|||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
maybeAddQlogTracer(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}),
|
||||
maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer {
|
||||
return clientTracer
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
|
||||
|
@ -111,10 +115,12 @@ var _ = Describe("Handshake tests", func() {
|
|||
expectedVersion := protocol.SupportedVersions[0]
|
||||
// the server doesn't support the highest supported version, which is the first one the client will try
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverConfig := &quic.Config{}
|
||||
serverConfig.Versions = supportedVersions
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return serverTracer
|
||||
}
|
||||
server, cl := startServer(getTLSConfig(), serverConfig)
|
||||
defer cl()
|
||||
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
|
||||
|
@ -123,9 +129,11 @@ var _ = Describe("Handshake tests", func() {
|
|||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
maybeAddQlogTracer(&quic.Config{
|
||||
maybeAddQLOGTracer(&quic.Config{
|
||||
Versions: clientVersions,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return clientTracer
|
||||
},
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -47,7 +47,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
context.Background(),
|
||||
proxy.LocalAddr().String(),
|
||||
getTLSClientConfig(),
|
||||
maybeAddQlogTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}),
|
||||
maybeAddQLOGTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}),
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
expectDurationInRTTs(startTime, 1)
|
||||
|
|
|
@ -58,7 +58,7 @@ func TestQuicVersionNegotiation(t *testing.T) {
|
|||
RunSpecs(t, "Version Negotiation Suite")
|
||||
}
|
||||
|
||||
func maybeAddQlogTracer(c *quic.Config) *quic.Config {
|
||||
func maybeAddQLOGTracer(c *quic.Config) *quic.Config {
|
||||
if c == nil {
|
||||
c = &quic.Config{}
|
||||
}
|
||||
|
@ -69,22 +69,13 @@ func maybeAddQlogTracer(c *quic.Config) *quic.Config {
|
|||
if c.Tracer == nil {
|
||||
c.Tracer = qlogger
|
||||
} else if qlogger != nil {
|
||||
c.Tracer = logging.NewMultiplexedTracer(qlogger, c.Tracer)
|
||||
origTracer := c.Tracer
|
||||
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
return logging.NewMultiplexedConnectionTracer(
|
||||
qlogger(ctx, p, connID),
|
||||
origTracer(ctx, p, connID),
|
||||
)
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type tracer struct {
|
||||
logging.NullTracer
|
||||
createNewConnTracer func() logging.ConnectionTracer
|
||||
}
|
||||
|
||||
var _ logging.Tracer = &tracer{}
|
||||
|
||||
func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
|
||||
return &tracer{createNewConnTracer: c}
|
||||
}
|
||||
|
||||
func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
|
||||
return t.createNewConnTracer()
|
||||
}
|
||||
|
|
|
@ -322,7 +322,7 @@ type Config struct {
|
|||
Allow0RTT bool
|
||||
// Enable QUIC datagram support (RFC 9221).
|
||||
EnableDatagrams bool
|
||||
Tracer logging.Tracer
|
||||
Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer
|
||||
}
|
||||
|
||||
type ClientHelloInfo struct {
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
package mocklogging
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
|
@ -73,17 +72,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2,
|
|||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// TracerForConnection mocks base method.
|
||||
func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(logging.ConnectionTracer)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// TracerForConnection indicates an expected call of TracerForConnection.
|
||||
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/interop/http09"
|
||||
"github.com/quic-go/quic-go/interop/utils"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
var errUnsupported = errors.New("unsupported test case")
|
||||
|
@ -65,11 +64,7 @@ func runTestcase(testcase string) error {
|
|||
flag.Parse()
|
||||
urls := flag.Args()
|
||||
|
||||
getLogWriter, err := utils.GetQLOGWriter()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
quicConf := &quic.Config{Tracer: qlog.NewTracer(getLogWriter)}
|
||||
quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer}
|
||||
|
||||
if testcase == "http3" {
|
||||
r := &http3.RoundTripper{
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/interop/http09"
|
||||
"github.com/quic-go/quic-go/interop/utils"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
var tlsConf *tls.Config
|
||||
|
@ -38,16 +37,10 @@ func main() {
|
|||
|
||||
testcase := os.Getenv("TESTCASE")
|
||||
|
||||
getLogWriter, err := utils.GetQLOGWriter()
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
// a quic.Config that doesn't do a Retry
|
||||
quicConf := &quic.Config{
|
||||
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" },
|
||||
Allow0RTT: testcase == "zerortt",
|
||||
Tracer: qlog.NewTracer(getLogWriter),
|
||||
Tracer: utils.NewQLOGConnectionTracer,
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
|
||||
if err != nil {
|
||||
|
|
|
@ -2,14 +2,17 @@ package utils
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
// GetSSLKeyLog creates a file for the TLS key log
|
||||
|
@ -25,18 +28,17 @@ func GetSSLKeyLog() (io.WriteCloser, error) {
|
|||
return f, nil
|
||||
}
|
||||
|
||||
// GetQLOGWriter creates the QLOGDIR and returns the GetLogWriter callback
|
||||
func GetQLOGWriter() (func(perspective logging.Perspective, connID []byte) io.WriteCloser, error) {
|
||||
// NewQLOGConnectionTracer create a qlog file in QLOGDIR
|
||||
func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
qlogDir := os.Getenv("QLOGDIR")
|
||||
if len(qlogDir) == 0 {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
if _, err := os.Stat(qlogDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(qlogDir, 0o666); err != nil {
|
||||
return nil, fmt.Errorf("failed to create qlog dir %s: %s", qlogDir, err.Error())
|
||||
log.Fatalf("failed to create qlog dir %s: %v", qlogDir, err)
|
||||
}
|
||||
}
|
||||
return func(_ logging.Perspective, connID []byte) io.WriteCloser {
|
||||
path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID)
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
|
@ -44,6 +46,5 @@ func GetQLOGWriter() (func(perspective logging.Perspective, connID []byte) io.Wr
|
|||
return nil
|
||||
}
|
||||
log.Printf("Created qlog file: %s\n", path)
|
||||
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
|
||||
}, nil
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
|
@ -101,12 +100,6 @@ type ShortHeader struct {
|
|||
|
||||
// A Tracer traces events.
|
||||
type Tracer interface {
|
||||
// TracerForConnection requests a new tracer for a connection.
|
||||
// The ODCID is the original destination connection ID:
|
||||
// The destination connection ID that the client used on the first Initial packet it sent on this connection.
|
||||
// If nil is returned, tracing will be disabled for this connection.
|
||||
TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer
|
||||
|
||||
SentPacket(net.Addr, *Header, ByteCount, []Frame)
|
||||
SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
|
||||
DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
|
@ -72,17 +71,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2,
|
|||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// TracerForConnection mocks base method.
|
||||
func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(ConnectionTracer)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// TracerForConnection indicates an expected call of TracerForConnection.
|
||||
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
@ -23,16 +22,6 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer {
|
|||
return &tracerMultiplexer{tracers}
|
||||
}
|
||||
|
||||
func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer {
|
||||
var connTracers []ConnectionTracer
|
||||
for _, t := range m.tracers {
|
||||
if ct := t.TracerForConnection(ctx, p, odcid); ct != nil {
|
||||
connTracers = append(connTracers, ct)
|
||||
}
|
||||
}
|
||||
return NewMultiplexedConnectionTracer(connTracers...)
|
||||
}
|
||||
|
||||
func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) {
|
||||
for _, t := range m.tracers {
|
||||
t.SentPacket(remote, hdr, size, frames)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
@ -37,46 +36,6 @@ var _ = Describe("Tracing", func() {
|
|||
tracer = NewMultiplexedTracer(tr1, tr2)
|
||||
})
|
||||
|
||||
It("multiplexes the TracerForConnection call", func() {
|
||||
ctx := context.Background()
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3})
|
||||
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
tracer.TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
})
|
||||
|
||||
It("uses multiple connection tracers", func() {
|
||||
ctx := context.Background()
|
||||
ctr1 := NewMockConnectionTracer(mockCtrl)
|
||||
ctr2 := NewMockConnectionTracer(mockCtrl)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3})
|
||||
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1)
|
||||
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr2)
|
||||
tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID)
|
||||
ctr1.EXPECT().LossTimerCanceled()
|
||||
ctr2.EXPECT().LossTimerCanceled()
|
||||
tr.LossTimerCanceled()
|
||||
})
|
||||
|
||||
It("handles tracers that return a nil ConnectionTracer", func() {
|
||||
ctx := context.Background()
|
||||
ctr1 := NewMockConnectionTracer(mockCtrl)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1)
|
||||
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID)
|
||||
tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID)
|
||||
ctr1.EXPECT().LossTimerCanceled()
|
||||
tr.LossTimerCanceled()
|
||||
})
|
||||
|
||||
It("returns nil when all tracers return a nil ConnectionTracer", func() {
|
||||
ctx := context.Background()
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
|
||||
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
Expect(tracer.TracerForConnection(ctx, PerspectiveClient, connID)).To(BeNil())
|
||||
})
|
||||
|
||||
It("traces the PacketSent event", func() {
|
||||
remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)}
|
||||
hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
@ -12,9 +11,6 @@ type NullTracer struct{}
|
|||
|
||||
var _ Tracer = &NullTracer{}
|
||||
|
||||
func (n NullTracer) TracerForConnection(context.Context, Perspective, ConnectionID) ConnectionTracer {
|
||||
return NullConnectionTracer{}
|
||||
}
|
||||
func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {}
|
||||
func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) {
|
||||
}
|
||||
|
|
21
qlog/qlog.go
21
qlog/qlog.go
|
@ -2,7 +2,6 @@ package qlog
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
@ -49,26 +48,6 @@ func init() {
|
|||
|
||||
const eventChanSize = 50
|
||||
|
||||
type tracer struct {
|
||||
logging.NullTracer
|
||||
|
||||
getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser
|
||||
}
|
||||
|
||||
var _ logging.Tracer = &tracer{}
|
||||
|
||||
// NewTracer creates a new qlog tracer.
|
||||
func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser) logging.Tracer {
|
||||
return &tracer{getLogWriter: getLogWriter}
|
||||
}
|
||||
|
||||
func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer {
|
||||
if w := t.getLogWriter(p, odcid.Bytes()); w != nil {
|
||||
return NewConnectionTracer(w, p, odcid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type connectionTracer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ package qlog
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
|
@ -51,17 +50,6 @@ type entry struct {
|
|||
}
|
||||
|
||||
var _ = Describe("Tracing", func() {
|
||||
Context("tracer", func() {
|
||||
It("returns nil when there's no io.WriteCloser", func() {
|
||||
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil })
|
||||
Expect(t.TracerForConnection(
|
||||
context.Background(),
|
||||
logging.PerspectiveClient,
|
||||
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
||||
)).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
It("stops writing when encountering an error", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
t := NewConnectionTracer(
|
||||
|
@ -88,9 +76,8 @@ var _ = Describe("Tracing", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
buf = &bytes.Buffer{}
|
||||
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) })
|
||||
tracer = t.TracerForConnection(
|
||||
context.Background(),
|
||||
tracer = NewConnectionTracer(
|
||||
nopWriteCloser(buf),
|
||||
logging.PerspectiveServer,
|
||||
protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
|
||||
)
|
||||
|
|
108
server.go
108
server.go
|
@ -106,6 +106,8 @@ type baseServer struct {
|
|||
connQueue chan quicConn
|
||||
connQueueLen int32 // to be used as an atomic
|
||||
|
||||
tracer logging.Tracer
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
|
@ -212,7 +214,16 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear
|
|||
return tr.ListenEarly(tlsConf, config)
|
||||
}
|
||||
|
||||
func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) {
|
||||
func newServer(
|
||||
conn rawConn,
|
||||
connHandler packetHandlerManager,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
tracer logging.Tracer,
|
||||
onClose func(),
|
||||
acceptEarly bool,
|
||||
) (*baseServer, error) {
|
||||
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -229,6 +240,7 @@ func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator C
|
|||
running: make(chan struct{}),
|
||||
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||
newConn: newConnection,
|
||||
tracer: tracer,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
acceptEarlyConns: acceptEarly,
|
||||
onClose: onClose,
|
||||
|
@ -318,8 +330,8 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
|
|||
case s.receivedPackets <- p:
|
||||
default:
|
||||
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -331,8 +343,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
|||
|
||||
if wire.IsVersionNegotiationPacket(p.data) {
|
||||
s.logger.Debugf("Dropping Version Negotiation packet.")
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -345,16 +357,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
|||
if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) {
|
||||
if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize {
|
||||
s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size())
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
|
||||
if err != nil { // should never happen
|
||||
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -366,8 +378,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
|||
|
||||
if wire.Is0RTTPacket(p.data) {
|
||||
if !s.acceptEarlyConns {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -378,16 +390,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
|||
// The header will then be parsed again.
|
||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
||||
if err != nil {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
s.logger.Debugf("Error parsing packet: %s", err)
|
||||
return false
|
||||
}
|
||||
if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
|
||||
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -397,8 +409,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
|||
// There's little point in sending a Stateless Reset, since the client
|
||||
// might not have received the token yet.
|
||||
s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -416,8 +428,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
|||
func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
if err != nil {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -430,8 +442,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
|
|||
|
||||
if q, ok := s.zeroRTTQueues[connID]; ok {
|
||||
if len(q.packets) >= protocol.Max0RTTQueueLen {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -440,8 +452,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
|
|||
}
|
||||
|
||||
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -468,8 +480,8 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
|
|||
continue
|
||||
}
|
||||
for _, p := range q.packets {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
p.buffer.Release()
|
||||
}
|
||||
|
@ -504,8 +516,8 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
|
|||
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error {
|
||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
p.buffer.Release()
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return errors.New("too short connection ID")
|
||||
}
|
||||
|
@ -585,19 +597,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
var conn quicConn
|
||||
tracingID := nextConnTracingID()
|
||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
|
||||
var tracer logging.ConnectionTracer
|
||||
if s.config.Tracer != nil {
|
||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||
connID := hdr.DestConnectionID
|
||||
if origDestConnID.Len() > 0 {
|
||||
connID = origDestConnID
|
||||
}
|
||||
tracer = s.config.Tracer.TracerForConnection(
|
||||
context.WithValue(context.Background(), ConnectionTracingKey, tracingID),
|
||||
protocol.PerspectiveServer,
|
||||
connID,
|
||||
)
|
||||
}
|
||||
config := s.config
|
||||
if s.config.GetConfigForClient != nil {
|
||||
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
|
||||
|
@ -607,6 +606,15 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
}
|
||||
config = populateConfig(conf)
|
||||
}
|
||||
var tracer logging.ConnectionTracer
|
||||
if config.Tracer != nil {
|
||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||
connID := hdr.DestConnectionID
|
||||
if origDestConnID.Len() > 0 {
|
||||
connID = origDestConnID
|
||||
}
|
||||
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
|
||||
}
|
||||
conn = s.newConn(
|
||||
newSendConn(s.conn, p.remoteAddr, p.info),
|
||||
s.connHandler,
|
||||
|
@ -715,8 +723,8 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
|
|||
// append the Retry integrity tag
|
||||
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
|
||||
buf.Data = append(buf.Data, tag[:]...)
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
|
||||
if s.tracer != nil {
|
||||
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
|
||||
}
|
||||
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB())
|
||||
return err
|
||||
|
@ -729,8 +737,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
|
|||
data := p.data[:hdr.ParsedLen()+hdr.Length]
|
||||
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
|
||||
if err != nil {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
// don't return the error here. Just drop the packet.
|
||||
return nil
|
||||
|
@ -738,8 +746,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
|
|||
hdrLen := extHdr.ParsedLen()
|
||||
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
|
||||
// don't return the error here. Just drop the packet.
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -792,8 +800,8 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
|
|||
|
||||
replyHdr.Log(s.logger)
|
||||
wire.LogFrame(s.logger, ccf, true)
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
|
||||
if s.tracer != nil {
|
||||
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
|
||||
}
|
||||
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB())
|
||||
return err
|
||||
|
@ -803,8 +811,8 @@ func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest pro
|
|||
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
|
||||
|
||||
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions)
|
||||
if s.tracer != nil {
|
||||
s.tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions)
|
||||
}
|
||||
if _, err := s.conn.WritePacket(data, remote, oob); err != nil {
|
||||
s.logger.Debugf("Error sending Version Negotiation: %s", err)
|
||||
|
|
|
@ -170,7 +170,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
Context("server accepting connections that completed the handshake", func() {
|
||||
var (
|
||||
ln *Listener
|
||||
tr *Transport
|
||||
serv *baseServer
|
||||
phm *MockPacketHandlerManager
|
||||
tracer *mocklogging.MockTracer
|
||||
|
@ -178,8 +178,8 @@ var _ = Describe("Server", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||
var err error
|
||||
ln, err = Listen(conn, tlsConf, &Config{Tracer: tracer})
|
||||
tr = &Transport{Conn: conn, Tracer: tracer}
|
||||
ln, err := tr.Listen(tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv = ln.baseServer
|
||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||
|
@ -187,7 +187,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
AfterEach(func() {
|
||||
ln.Close()
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
|
@ -276,7 +276,6 @@ var _ = Describe("Server", func() {
|
|||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
|
@ -478,7 +477,6 @@ var _ = Describe("Server", func() {
|
|||
return ok
|
||||
}),
|
||||
)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
|
||||
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
|
@ -537,7 +535,6 @@ var _ = Describe("Server", func() {
|
|||
_, ok := fn()
|
||||
return ok
|
||||
}).AnyTimes()
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
|
||||
|
||||
acceptConn := make(chan struct{})
|
||||
var counter uint32 // to be used as an atomic, so we query it in Eventually
|
||||
|
@ -665,7 +662,6 @@ var _ = Describe("Server", func() {
|
|||
_, ok := fn()
|
||||
return ok
|
||||
}).Times(protocol.MaxAcceptQueueSize)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(protocol.MaxAcceptQueueSize)
|
||||
|
@ -737,7 +733,6 @@ var _ = Describe("Server", func() {
|
|||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
|
||||
|
||||
serv.handlePacket(p)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
|
@ -1062,6 +1057,7 @@ var _ = Describe("Server", func() {
|
|||
return ok
|
||||
})
|
||||
done := make(chan struct{})
|
||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer close(done)
|
||||
rejectHdr := parseHeader(b)
|
||||
|
@ -1119,7 +1115,6 @@ var _ = Describe("Server", func() {
|
|||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
|
||||
serv.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
||||
|
@ -1326,6 +1321,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
Context("0-RTT", func() {
|
||||
var (
|
||||
tr *Transport
|
||||
serv *baseServer
|
||||
phm *MockPacketHandlerManager
|
||||
tracer *mocklogging.MockTracer
|
||||
|
@ -1333,7 +1329,8 @@ var _ = Describe("Server", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||
ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer})
|
||||
tr = &Transport{Conn: conn, Tracer: tracer}
|
||||
ln, err := tr.ListenEarly(tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||
serv = ln.baseServer
|
||||
|
@ -1342,7 +1339,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
AfterEach(func() {
|
||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
||||
serv.Close()
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("passes packets to existing connections", func() {
|
||||
|
@ -1425,7 +1422,6 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
|
|
15
transport.go
15
transport.go
|
@ -95,10 +95,10 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error)
|
|||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(conf, true); err != nil {
|
||||
if err := t.init(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, false)
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -124,10 +124,10 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
|
|||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(conf, true); err != nil {
|
||||
if err := t.init(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, true)
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
|
|||
return nil, err
|
||||
}
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(conf, false); err != nil {
|
||||
if err := t.init(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
|
@ -157,7 +157,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
|
|||
return nil, err
|
||||
}
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(conf, false); err != nil {
|
||||
if err := t.init(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
|
@ -200,7 +200,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
|||
// only print warnings about the UDP receive buffer size once
|
||||
var receiveBufferWarningOnce sync.Once
|
||||
|
||||
func (t *Transport) init(conf *Config, isServer bool) error {
|
||||
func (t *Transport) init(isServer bool) error {
|
||||
t.initOnce.Do(func() {
|
||||
getMultiplexer().AddConn(t.Conn)
|
||||
|
||||
|
@ -210,7 +210,6 @@ func (t *Transport) init(conf *Config, isServer bool) error {
|
|||
return
|
||||
}
|
||||
|
||||
t.Tracer = conf.Tracer
|
||||
t.logger = utils.DefaultLogger // TODO: make this configurable
|
||||
t.conn = conn
|
||||
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
|
||||
|
|
|
@ -61,7 +61,7 @@ var _ = Describe("Transport", func() {
|
|||
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(&Config{}, true)
|
||||
tr.init(true)
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
|
@ -128,8 +128,9 @@ var _ = Describe("Transport", func() {
|
|||
tr := &Transport{
|
||||
Conn: newMockPacketConn(packetChan),
|
||||
ConnectionIDLength: 10,
|
||||
Tracer: tracer,
|
||||
}
|
||||
tr.init(&Config{Tracer: tracer}, true)
|
||||
tr.init(true)
|
||||
dropped := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
|
||||
packetChan <- packetToRead{
|
||||
|
@ -148,7 +149,7 @@ var _ = Describe("Transport", func() {
|
|||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{}, true)
|
||||
tr.init(true)
|
||||
tr.handlerMap = phm
|
||||
|
||||
done := make(chan struct{})
|
||||
|
@ -166,7 +167,7 @@ var _ = Describe("Transport", func() {
|
|||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(&Config{}, true)
|
||||
tr.init(true)
|
||||
tr.handlerMap = phm
|
||||
|
||||
tempErr := deadlineError{}
|
||||
|
@ -188,7 +189,7 @@ var _ = Describe("Transport", func() {
|
|||
Conn: newMockPacketConn(packetChan),
|
||||
ConnectionIDLength: connID.Len(),
|
||||
}
|
||||
tr.init(&Config{}, true)
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
|
@ -221,7 +222,7 @@ var _ = Describe("Transport", func() {
|
|||
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(&Config{}, true)
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
|
@ -257,7 +258,7 @@ var _ = Describe("Transport", func() {
|
|||
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
|
||||
ConnectionIDLength: connID.Len(),
|
||||
}
|
||||
tr.init(&Config{}, true)
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue