simplify the Tracer interface by combining the TracerFor... methods

This commit is contained in:
Marten Seemann 2020-07-11 10:46:35 +07:00
parent ece3592544
commit ee24d3899e
32 changed files with 139 additions and 194 deletions

View file

@ -176,7 +176,7 @@ func dialContext(
c.packetHandlers = packetHandlers c.packetHandlers = packetHandlers
if c.config.Tracer != nil { if c.config.Tracer != nil {
c.tracer = c.config.Tracer.TracerForClient(c.destConnID) c.tracer = c.config.Tracer.TracerForConnection(protocol.PerspectiveClient, c.destConnID)
} }
if err := c.dial(ctx); err != nil { if err := c.dial(ctx); err != nil {
return nil, err return nil, err

View file

@ -68,7 +68,7 @@ var _ = Describe("Client", func() {
originalClientSessConstructor = newClientSession originalClientSessConstructor = newClientSession
tracer = mocks.NewMockConnectionTracer(mockCtrl) tracer = mocks.NewMockConnectionTracer(mockCtrl)
tr := mocks.NewMockTracer(mockCtrl) tr := mocks.NewMockTracer(mockCtrl)
tr.EXPECT().TracerForClient(gomock.Any()).Return(tracer).MaxTimes(1) tr.EXPECT().TracerForConnection(protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
config = &Config{Tracer: tr} config = &Config{Tracer: tr}
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
// sess = NewMockQuicSession(mockCtrl) // sess = NewMockQuicSession(mockCtrl)

View file

@ -17,6 +17,7 @@ import (
"github.com/lucas-clemente/quic-go/http3" "github.com/lucas-clemente/quic-go/http3"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/qlog" "github.com/lucas-clemente/quic-go/qlog"
) )
@ -56,7 +57,7 @@ func main() {
var qconf quic.Config var qconf quic.Config
if *enableQlog { if *enableQlog {
qconf.Tracer = qlog.NewTracer(func(connID []byte) io.WriteCloser { qconf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
filename := fmt.Sprintf("client_%x.qlog", connID) filename := fmt.Sprintf("client_%x.qlog", connID)
f, err := os.Create(filename) f, err := os.Create(filename)
if err != nil { if err != nil {

View file

@ -22,6 +22,7 @@ import (
"github.com/lucas-clemente/quic-go/http3" "github.com/lucas-clemente/quic-go/http3"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/qlog" "github.com/lucas-clemente/quic-go/qlog"
"github.com/lucas-clemente/quic-go/quictrace" "github.com/lucas-clemente/quic-go/quictrace"
) )
@ -210,7 +211,7 @@ func main() {
quicConf.QuicTracer = tracer quicConf.QuicTracer = tracer
} }
if *enableQlog { if *enableQlog {
quicConf.Tracer = qlog.NewTracer(func(connID []byte) io.WriteCloser { quicConf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
filename := fmt.Sprintf("server_%x.qlog", connID) filename := fmt.Sprintf("server_%x.qlog", connID)
f, err := os.Create(filename) f, err := os.Create(filename)
if err != nil { if err != nil {

View file

@ -28,7 +28,7 @@ var _ = Describe("Stream Cancelations", func() {
runServer := func() <-chan int32 { runServer := func() <-chan int32 {
numCanceledStreamsChan := make(chan int32) numCanceledStreamsChan := make(chan int32)
var err error var err error
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfigForServer(nil)) server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var canceledCounter int32 var canceledCounter int32
@ -71,7 +71,7 @@ var _ = Describe("Stream Cancelations", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -115,7 +115,7 @@ var _ = Describe("Stream Cancelations", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -163,7 +163,7 @@ var _ = Describe("Stream Cancelations", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -302,7 +302,7 @@ var _ = Describe("Stream Cancelations", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -381,7 +381,7 @@ var _ = Describe("Stream Cancelations", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -456,7 +456,7 @@ var _ = Describe("Stream Cancelations", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIncomingUniStreams: numStreams / 3}), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 3}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -539,7 +539,7 @@ var _ = Describe("Stream Cancelations", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIncomingUniStreams: 5}), getQuicConfig(&quic.Config{MaxIncomingUniStreams: 5}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -60,11 +60,11 @@ var _ = Describe("Connection ID lengths tests", func() {
} }
It("downloads a file using a 0-byte connection ID for the client", func() { It("downloads a file using a 0-byte connection ID for the client", func() {
serverConf := getQuicConfigForServer(&quic.Config{ serverConf := getQuicConfig(&quic.Config{
ConnectionIDLength: randomConnIDLen(), ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS}, Versions: []protocol.VersionNumber{protocol.VersionTLS},
}) })
clientConf := getQuicConfigForClient(&quic.Config{ clientConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS}, Versions: []protocol.VersionNumber{protocol.VersionTLS},
}) })
@ -74,11 +74,11 @@ var _ = Describe("Connection ID lengths tests", func() {
}) })
It("downloads a file when both client and server use a random connection ID length", func() { It("downloads a file when both client and server use a random connection ID length", func() {
serverConf := getQuicConfigForServer(&quic.Config{ serverConf := getQuicConfig(&quic.Config{
ConnectionIDLength: randomConnIDLen(), ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS}, Versions: []protocol.VersionNumber{protocol.VersionTLS},
}) })
clientConf := getQuicConfigForClient(&quic.Config{ clientConf := getQuicConfig(&quic.Config{
ConnectionIDLength: randomConnIDLen(), ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS}, Versions: []protocol.VersionNumber{protocol.VersionTLS},
}) })

View file

@ -31,7 +31,7 @@ var _ = Describe("Drop Tests", func() {
ln, err = quic.ListenAddr( ln, err = quic.ListenAddr(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port serverPort := ln.Addr().(*net.UDPAddr).Port
@ -104,7 +104,7 @@ var _ = Describe("Drop Tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer sess.CloseWithError(0, "") defer sess.CloseWithError(0, "")

View file

@ -25,7 +25,7 @@ var _ = Describe("early data", func() {
ln, err := quic.ListenAddrEarly( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
@ -56,7 +56,7 @@ var _ = Describe("early data", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptUniStream(context.Background()) str, err := sess.AcceptUniStream(context.Background())

View file

@ -34,7 +34,7 @@ var _ = Describe("Handshake drop tests", func() {
const timeout = 10 * time.Minute const timeout = 10 * time.Minute
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) { startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) {
conf := getQuicConfigForServer(&quic.Config{ conf := getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout, MaxIdleTimeout: timeout,
HandshakeTimeout: timeout, HandshakeTimeout: timeout,
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
@ -86,7 +86,7 @@ var _ = Describe("Handshake drop tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout, MaxIdleTimeout: timeout,
HandshakeTimeout: timeout, HandshakeTimeout: timeout,
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
@ -122,7 +122,7 @@ var _ = Describe("Handshake drop tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout, MaxIdleTimeout: timeout,
HandshakeTimeout: timeout, HandshakeTimeout: timeout,
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
@ -156,7 +156,7 @@ var _ = Describe("Handshake drop tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout, MaxIdleTimeout: timeout,
HandshakeTimeout: timeout, HandshakeTimeout: timeout,
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},

View file

@ -29,7 +29,7 @@ var _ = Describe("Handshake RTT tests", func() {
BeforeEach(func() { BeforeEach(func() {
acceptStopped = make(chan struct{}) acceptStopped = make(chan struct{})
serverConfig = getQuicConfigForServer(nil) serverConfig = getQuicConfig(nil)
serverTLSConfig = getTLSConfig() serverTLSConfig = getTLSConfig()
}) })
@ -79,7 +79,7 @@ var _ = Describe("Handshake RTT tests", func() {
} }
serverConfig.Versions = protocol.SupportedVersions[:1] serverConfig.Versions = protocol.SupportedVersions[:1]
runServerAndProxy() runServerAndProxy()
clientConfig := getQuicConfigForClient(&quic.Config{Versions: protocol.SupportedVersions[1:2]}) clientConfig := getQuicConfig(&quic.Config{Versions: protocol.SupportedVersions[1:2]})
_, err := quic.DialAddr( _, err := quic.DialAddr(
proxy.LocalAddr().String(), proxy.LocalAddr().String(),
getTLSClientConfig(), getTLSClientConfig(),
@ -94,7 +94,7 @@ var _ = Describe("Handshake RTT tests", func() {
BeforeEach(func() { BeforeEach(func() {
serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS} serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS}
clientConfig = getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}) clientConfig = getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}})
clientConfig := getTLSClientConfig() clientConfig := getTLSClientConfig()
clientConfig.InsecureSkipVerify = true clientConfig.InsecureSkipVerify = true
}) })

View file

@ -56,7 +56,7 @@ var _ = Describe("Handshake tests", func() {
BeforeEach(func() { BeforeEach(func() {
server = nil server = nil
acceptStopped = make(chan struct{}) acceptStopped = make(chan struct{})
serverConfig = getQuicConfigForServer(nil) serverConfig = getQuicConfig(nil)
}) })
AfterEach(func() { AfterEach(func() {
@ -121,7 +121,7 @@ var _ = Describe("Handshake tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}, Versions: []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10},
}), }),
) )
@ -185,7 +185,7 @@ var _ = Describe("Handshake tests", func() {
BeforeEach(func() { BeforeEach(func() {
serverConfig.Versions = []protocol.VersionNumber{version} serverConfig.Versions = []protocol.VersionNumber{version}
clientConfig = getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}) clientConfig = getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}})
}) })
It("accepts the certificate", func() { It("accepts the certificate", func() {
@ -203,7 +203,7 @@ var _ = Describe("Handshake tests", func() {
_, err := quic.DialAddr( _, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -428,7 +428,7 @@ var _ = Describe("Handshake tests", func() {
gets := make(chan string, 100) gets := make(chan string, 100)
puts := make(chan string, 100) puts := make(chan string, 100)
tokenStore := newTokenStore(gets, puts) tokenStore := newTokenStore(gets, puts)
quicConf := getQuicConfigForClient(&quic.Config{TokenStore: tokenStore}) quicConf := getQuicConfig(&quic.Config{TokenStore: tokenStore})
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),

View file

@ -77,7 +77,7 @@ var _ = Describe("HTTP tests", func() {
Handler: mux, Handler: mux,
TLSConfig: testdata.GetTLSConfig(), TLSConfig: testdata.GetTLSConfig(),
}, },
QuicConfig: getQuicConfigForServer(&quic.Config{Versions: versions}), QuicConfig: getQuicConfig(&quic.Config{Versions: versions}),
} }
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0") addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
@ -111,7 +111,7 @@ var _ = Describe("HTTP tests", func() {
RootCAs: testdata.GetRootCA(), RootCAs: testdata.GetRootCA(),
}, },
DisableCompression: true, DisableCompression: true,
QuicConfig: getQuicConfigForClient(&quic.Config{ QuicConfig: getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
MaxIdleTimeout: 10 * time.Second, MaxIdleTimeout: 10 * time.Second,
}), }),

View file

@ -63,7 +63,7 @@ var _ = Describe("MITM test", func() {
} }
BeforeEach(func() { BeforeEach(func() {
serverConfig = getQuicConfigForServer(&quic.Config{ serverConfig = getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
}) })
@ -128,7 +128,7 @@ var _ = Describe("MITM test", func() {
raddr, raddr,
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
}), }),
@ -174,7 +174,7 @@ var _ = Describe("MITM test", func() {
raddr, raddr,
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
}), }),
@ -335,7 +335,7 @@ var _ = Describe("MITM test", func() {
raddr, raddr,
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
HandshakeTimeout: 2 * time.Second, HandshakeTimeout: 2 * time.Second,

View file

@ -46,7 +46,7 @@ var _ = Describe("Multiplexing", func() {
addr, addr,
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer sess.CloseWithError(0, "") defer sess.CloseWithError(0, "")
@ -62,7 +62,7 @@ var _ = Describe("Multiplexing", func() {
ln, err := quic.ListenAddr( ln, err := quic.ListenAddr(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
return ln return ln
@ -145,7 +145,7 @@ var _ = Describe("Multiplexing", func() {
server, err := quic.Listen( server, err := quic.Listen(
conn, conn,
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runServer(server) runServer(server)
@ -181,7 +181,7 @@ var _ = Describe("Multiplexing", func() {
server1, err := quic.Listen( server1, err := quic.Listen(
conn1, conn1,
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runServer(server1) runServer(server1)
@ -190,7 +190,7 @@ var _ = Describe("Multiplexing", func() {
server2, err := quic.Listen( server2, err := quic.Listen(
conn2, conn2,
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runServer(server2) runServer(server2)

View file

@ -29,7 +29,7 @@ var _ = Describe("Packetization", func() {
server, err = quic.ListenAddr( server, err = quic.ListenAddr(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{AcceptToken: func(net.Addr, *quic.Token) bool { return true }}), getQuicConfig(&quic.Config{AcceptToken: func(net.Addr, *quic.Token) bool { return true }}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
@ -63,7 +63,7 @@ var _ = Describe("Packetization", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(nil), getQuicConfig(nil),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -23,7 +23,7 @@ var _ = Describe("non-zero RTT", func() {
ln, err := quic.ListenAddr( ln, err := quic.ListenAddr(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go func() { go func() {
@ -43,7 +43,7 @@ var _ = Describe("non-zero RTT", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", port), fmt.Sprintf("localhost:%d", port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream(context.Background()) str, err := sess.AcceptStream(context.Background())
@ -79,7 +79,7 @@ var _ = Describe("non-zero RTT", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream(context.Background()) str, err := sess.AcceptStream(context.Background())

View file

@ -19,6 +19,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qlog" "github.com/lucas-clemente/quic-go/qlog"
@ -243,15 +245,7 @@ func getTLSClientConfig() *tls.Config {
return tlsClientConfig.Clone() return tlsClientConfig.Clone()
} }
func getQuicConfigForClient(conf *quic.Config) *quic.Config { func getQuicConfig(conf *quic.Config) *quic.Config {
return getQuicConfigForRole("client", conf)
}
func getQuicConfigForServer(conf *quic.Config) *quic.Config {
return getQuicConfigForRole("server", conf)
}
func getQuicConfigForRole(role string, conf *quic.Config) *quic.Config {
if conf == nil { if conf == nil {
conf = &quic.Config{} conf = &quic.Config{}
} else { } else {
@ -260,7 +254,11 @@ func getQuicConfigForRole(role string, conf *quic.Config) *quic.Config {
if !enableQlog { if !enableQlog {
return conf return conf
} }
conf.Tracer = qlog.NewTracer(func(connectionID []byte) io.WriteCloser { conf.Tracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role) filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename) fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename)
f, err := os.Create(filename) f, err := os.Create(filename)

View file

@ -24,7 +24,7 @@ var _ = Describe("Stateless Resets", func() {
It(fmt.Sprintf("sends and recognizes stateless resets, for %d byte connection IDs", connIDLen), func() { It(fmt.Sprintf("sends and recognizes stateless resets, for %d byte connection IDs", connIDLen), func() {
statelessResetKey := make([]byte, 32) statelessResetKey := make([]byte, 32)
rand.Read(statelessResetKey) rand.Read(statelessResetKey)
serverConfig := getQuicConfigForServer(&quic.Config{StatelessResetKey: statelessResetKey}) serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: statelessResetKey})
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -58,7 +58,7 @@ var _ = Describe("Stateless Resets", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
ConnectionIDLength: connIDLen, ConnectionIDLength: connIDLen,
MaxIdleTimeout: 2 * time.Second, MaxIdleTimeout: 2 * time.Second,
}), }),

View file

@ -33,7 +33,7 @@ var _ = Describe("Bidirectional streams", func() {
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
MaxIncomingStreams: 0, MaxIncomingStreams: 0,
} }
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfigForServer(qconf)) server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(qconf))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
}) })
@ -100,7 +100,7 @@ var _ = Describe("Bidirectional streams", func() {
client, err := quic.DialAddr( client, err := quic.DialAddr(
serverAddr, serverAddr,
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(qconf), getQuicConfig(qconf),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runSendingPeer(client) runSendingPeer(client)
@ -118,7 +118,7 @@ var _ = Describe("Bidirectional streams", func() {
client, err := quic.DialAddr( client, err := quic.DialAddr(
serverAddr, serverAddr,
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(qconf), getQuicConfig(qconf),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client) runReceivingPeer(client)
@ -145,7 +145,7 @@ var _ = Describe("Bidirectional streams", func() {
client, err := quic.DialAddr( client, err := quic.DialAddr(
serverAddr, serverAddr,
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(qconf), getQuicConfig(qconf),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{}) done2 := make(chan struct{})

View file

@ -51,7 +51,7 @@ var _ = Describe("Timeout tests", func() {
_, err := quic.DialAddr( _, err := quic.DialAddr(
"localhost:12345", "localhost:12345",
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{HandshakeTimeout: 10 * time.Millisecond}), getQuicConfig(&quic.Config{HandshakeTimeout: 10 * time.Millisecond}),
) )
errChan <- err errChan <- err
}() }()
@ -69,7 +69,7 @@ var _ = Describe("Timeout tests", func() {
ctx, ctx,
"localhost:12345", "localhost:12345",
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(nil), getQuicConfig(nil),
) )
errChan <- err errChan <- err
}() }()
@ -85,7 +85,7 @@ var _ = Describe("Timeout tests", func() {
server, err := quic.ListenAddr( server, err := quic.ListenAddr(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(nil), getQuicConfig(nil),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer server.Close() defer server.Close()
@ -114,7 +114,7 @@ var _ = Describe("Timeout tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIdleTimeout: idleTimeout}), getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
strIn, err := sess.AcceptStream(context.Background()) strIn, err := sess.AcceptStream(context.Background())
@ -164,7 +164,7 @@ var _ = Describe("Timeout tests", func() {
server, err := quic.ListenAddr( server, err := quic.ListenAddr(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(nil), getQuicConfig(nil),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer server.Close() defer server.Close()
@ -181,7 +181,7 @@ var _ = Describe("Timeout tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIdleTimeout: idleTimeout}), getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
startTime := time.Now() startTime := time.Now()
@ -209,7 +209,7 @@ var _ = Describe("Timeout tests", func() {
server, err := quic.ListenAddr( server, err := quic.ListenAddr(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(nil), getQuicConfig(nil),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer server.Close() defer server.Close()
@ -226,7 +226,7 @@ var _ = Describe("Timeout tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{MaxIdleTimeout: idleTimeout}), getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -266,7 +266,7 @@ var _ = Describe("Timeout tests", func() {
server, err := quic.ListenAddr( server, err := quic.ListenAddr(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(nil), getQuicConfig(nil),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer server.Close() defer server.Close()
@ -294,7 +294,7 @@ var _ = Describe("Timeout tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
MaxIdleTimeout: idleTimeout, MaxIdleTimeout: idleTimeout,
KeepAlive: true, KeepAlive: true,
}), }),
@ -357,7 +357,7 @@ var _ = Describe("Timeout tests", func() {
ln, err := quic.Listen( ln, err := quic.Listen(
&faultyConn{PacketConn: conn, Timeout: time.Now().Add(timeout)}, &faultyConn{PacketConn: conn, Timeout: time.Now().Add(timeout)},
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(nil), getQuicConfig(nil),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -380,7 +380,7 @@ var _ = Describe("Timeout tests", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{ getQuicConfig(&quic.Config{
HandshakeTimeout: time.Second, HandshakeTimeout: time.Second,
MaxIdleTimeout: time.Second, MaxIdleTimeout: time.Second,
}), }),
@ -409,7 +409,7 @@ var _ = Describe("Timeout tests", func() {
ln, err := quic.ListenAddr( ln, err := quic.ListenAddr(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{ getQuicConfig(&quic.Config{
HandshakeTimeout: time.Second, HandshakeTimeout: time.Second,
MaxIdleTimeout: time.Second, MaxIdleTimeout: time.Second,
KeepAlive: true, KeepAlive: true,
@ -450,7 +450,7 @@ var _ = Describe("Timeout tests", func() {
proxy.LocalAddr(), proxy.LocalAddr(),
"localhost", "localhost",
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(nil), getQuicConfig(nil),
) )
if err != nil { if err != nil {
clientErrChan <- err clientErrChan <- err

View file

@ -26,7 +26,7 @@ var _ = Describe("Unidirectional Streams", func() {
BeforeEach(func() { BeforeEach(func() {
var err error var err error
qconf = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} qconf = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfigForServer(qconf)) server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(qconf))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
}) })
@ -81,7 +81,7 @@ var _ = Describe("Unidirectional Streams", func() {
client, err := quic.DialAddr( client, err := quic.DialAddr(
serverAddr, serverAddr,
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(qconf), getQuicConfig(qconf),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runSendingPeer(client) runSendingPeer(client)
@ -99,7 +99,7 @@ var _ = Describe("Unidirectional Streams", func() {
client, err := quic.DialAddr( client, err := quic.DialAddr(
serverAddr, serverAddr,
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(qconf), getQuicConfig(qconf),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client) runReceivingPeer(client)
@ -125,7 +125,7 @@ var _ = Describe("Unidirectional Streams", func() {
client, err := quic.DialAddr( client, err := quic.DialAddr(
serverAddr, serverAddr,
getTLSClientConfig(), getTLSClientConfig(),
getQuicConfigForClient(qconf), getQuicConfig(qconf),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{}) done2 := make(chan struct{})

View file

@ -59,7 +59,7 @@ var _ = Describe("0-RTT", func() {
sess, err := quic.DialAddr( sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxyPort), fmt.Sprintf("localhost:%d", proxyPort),
clientConf, clientConf,
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(puts).Should(Receive()) Eventually(puts).Should(Receive())
@ -93,7 +93,7 @@ var _ = Describe("0-RTT", func() {
sess, err := quic.DialAddrEarly( sess, err := quic.DialAddrEarly(
fmt.Sprintf("localhost:%d", proxyPort), fmt.Sprintf("localhost:%d", proxyPort),
clientConf, clientConf,
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenUniStream() str, err := sess.OpenUniStream()
@ -109,7 +109,7 @@ var _ = Describe("0-RTT", func() {
ln, err := quic.ListenAddrEarly( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
}), }),
@ -171,7 +171,7 @@ var _ = Describe("0-RTT", func() {
sess, err := quic.DialAddrEarly( sess, err := quic.DialAddrEarly(
fmt.Sprintf("localhost:%d", proxy.LocalPort()), fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf, clientConf,
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sent0RTT := make(chan struct{}) sent0RTT := make(chan struct{})
@ -210,7 +210,7 @@ var _ = Describe("0-RTT", func() {
ln, err := quic.ListenAddrEarly( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
}), }),
@ -264,7 +264,7 @@ var _ = Describe("0-RTT", func() {
ln, err := quic.ListenAddrEarly( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfigForServer(&quic.Config{Versions: []protocol.VersionNumber{version}}), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer ln.Close() defer ln.Close()
@ -330,7 +330,7 @@ var _ = Describe("0-RTT", func() {
ln, err := quic.ListenAddrEarly( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfigForServer(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
MaxIncomingStreams: maxStreams, MaxIncomingStreams: maxStreams,
@ -345,7 +345,7 @@ var _ = Describe("0-RTT", func() {
ln, err = quic.ListenAddrEarly( ln, err = quic.ListenAddrEarly(
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfigForServer(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
MaxIncomingStreams: maxStreams + 1, MaxIncomingStreams: maxStreams + 1,
@ -367,7 +367,7 @@ var _ = Describe("0-RTT", func() {
ln, err := quic.ListenAddrEarly( ln, err := quic.ListenAddrEarly(
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfigForServer(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
}), }),
@ -383,7 +383,7 @@ var _ = Describe("0-RTT", func() {
ln, err = quic.ListenAddrEarly( ln, err = quic.ListenAddrEarly(
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfigForServer(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true },
}), }),

View file

@ -35,30 +35,16 @@ func (m *MockTracer) EXPECT() *MockTracerMockRecorder {
return m.recorder return m.recorder
} }
// TracerForClient mocks base method // TracerForConnection mocks base method
func (m *MockTracer) TracerForClient(arg0 protocol.ConnectionID) logging.ConnectionTracer { func (m *MockTracer) TracerForConnection(arg0 protocol.Perspective, arg1 protocol.ConnectionID) logging.ConnectionTracer {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForClient", arg0) ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1)
ret0, _ := ret[0].(logging.ConnectionTracer) ret0, _ := ret[0].(logging.ConnectionTracer)
return ret0 return ret0
} }
// TracerForClient indicates an expected call of TracerForClient // TracerForConnection indicates an expected call of TracerForConnection
func (mr *MockTracerMockRecorder) TracerForClient(arg0 interface{}) *gomock.Call { func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForClient", reflect.TypeOf((*MockTracer)(nil).TracerForClient), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1)
}
// TracerForServer mocks base method
func (m *MockTracer) TracerForServer(arg0 protocol.ConnectionID) logging.ConnectionTracer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForServer", arg0)
ret0, _ := ret[0].(logging.ConnectionTracer)
return ret0
}
// TracerForServer indicates an expected call of TracerForServer
func (mr *MockTracerMockRecorder) TracerForServer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForServer", reflect.TypeOf((*MockTracer)(nil).TracerForServer), arg0)
} }

View file

@ -8,6 +8,8 @@ import (
"os" "os"
"strings" "strings"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
@ -25,7 +27,7 @@ func GetSSLKeyLog() (io.WriteCloser, error) {
} }
// GetQLOGWriter creates the QLOGDIR and returns the GetLogWriter callback // GetQLOGWriter creates the QLOGDIR and returns the GetLogWriter callback
func GetQLOGWriter() (func(connID []byte) io.WriteCloser, error) { func GetQLOGWriter() (func(perspective logging.Perspective, connID []byte) io.WriteCloser, error) {
qlogDir := os.Getenv("QLOGDIR") qlogDir := os.Getenv("QLOGDIR")
if len(qlogDir) == 0 { if len(qlogDir) == 0 {
return nil, nil return nil, nil
@ -35,7 +37,7 @@ func GetQLOGWriter() (func(connID []byte) io.WriteCloser, error) {
return nil, fmt.Errorf("failed to create qlog dir %s: %s", qlogDir, err.Error()) return nil, fmt.Errorf("failed to create qlog dir %s: %s", qlogDir, err.Error())
} }
} }
return func(connID []byte) io.WriteCloser { return func(_ logging.Perspective, connID []byte) io.WriteCloser {
path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID) path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID)
f, err := os.Create(path) f, err := os.Create(path)
if err != nil { if err != nil {

View file

@ -79,12 +79,11 @@ const (
// A Tracer traces events. // A Tracer traces events.
type Tracer interface { type Tracer interface {
// TracerForServer requests a new tracer for a connection that was accepted by the server. // ConnectionTracer requests a new tracer for a connection.
// The ODCID is the original destination connection ID:
// The destination connection ID that the client used on the first Initial packet it sent on this connection.
// If nil is returned, tracing will be disabled for this connection. // If nil is returned, tracing will be disabled for this connection.
TracerForServer(odcid ConnectionID) ConnectionTracer TracerForConnection(p Perspective, odcid ConnectionID) ConnectionTracer
// TracerForServer requests a new tracer for a connection that was dialed by the client.
// If nil is returned, tracing will be disabled for this connection.
TracerForClient(odcid ConnectionID) ConnectionTracer
} }
// A ConnectionTracer records events. // A ConnectionTracer records events.

View file

@ -33,30 +33,16 @@ func (m *MockTracer) EXPECT() *MockTracerMockRecorder {
return m.recorder return m.recorder
} }
// TracerForClient mocks base method // TracerForConnection mocks base method
func (m *MockTracer) TracerForClient(arg0 protocol.ConnectionID) ConnectionTracer { func (m *MockTracer) TracerForConnection(arg0 protocol.Perspective, arg1 protocol.ConnectionID) ConnectionTracer {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForClient", arg0) ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1)
ret0, _ := ret[0].(ConnectionTracer) ret0, _ := ret[0].(ConnectionTracer)
return ret0 return ret0
} }
// TracerForClient indicates an expected call of TracerForClient // TracerForConnection indicates an expected call of TracerForConnection
func (mr *MockTracerMockRecorder) TracerForClient(arg0 interface{}) *gomock.Call { func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForClient", reflect.TypeOf((*MockTracer)(nil).TracerForClient), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1)
}
// TracerForServer mocks base method
func (m *MockTracer) TracerForServer(arg0 protocol.ConnectionID) ConnectionTracer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForServer", arg0)
ret0, _ := ret[0].(ConnectionTracer)
return ret0
}
// TracerForServer indicates an expected call of TracerForServer
func (mr *MockTracerMockRecorder) TracerForServer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForServer", reflect.TypeOf((*MockTracer)(nil).TracerForServer), arg0)
} }

View file

@ -22,20 +22,10 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer {
return &tracerMultiplexer{tracers} return &tracerMultiplexer{tracers}
} }
func (m *tracerMultiplexer) TracerForServer(odcid ConnectionID) ConnectionTracer { func (m *tracerMultiplexer) TracerForConnection(p Perspective, odcid ConnectionID) ConnectionTracer {
var connTracers []ConnectionTracer var connTracers []ConnectionTracer
for _, t := range m.tracers { for _, t := range m.tracers {
if ct := t.TracerForServer(odcid); ct != nil { if ct := t.TracerForConnection(p, odcid); ct != nil {
connTracers = append(connTracers, ct)
}
}
return newConnectionMultiplexer(connTracers...)
}
func (m *tracerMultiplexer) TracerForClient(odcid ConnectionID) ConnectionTracer {
var connTracers []ConnectionTracer
for _, t := range m.tracers {
if ct := t.TracerForClient(odcid); ct != nil {
connTracers = append(connTracers, ct) connTracers = append(connTracers, ct)
} }
} }

View file

@ -23,24 +23,18 @@ var _ = Describe("Tracing", func() {
tracer = NewMultiplexedTracer(tr1, tr2) tracer = NewMultiplexedTracer(tr1, tr2)
}) })
It("multiplexes the TracerForServer call", func() { It("multiplexes the TracerForConnection call", func() {
tr1.EXPECT().TracerForServer(ConnectionID{1, 2, 3}) tr1.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
tr2.EXPECT().TracerForServer(ConnectionID{1, 2, 3}) tr2.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
tracer.TracerForServer(ConnectionID{1, 2, 3}) tracer.TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
})
It("multiplexes the TracerForClient call", func() {
tr1.EXPECT().TracerForClient(ConnectionID{1, 2, 3})
tr2.EXPECT().TracerForClient(ConnectionID{1, 2, 3})
tracer.TracerForClient(ConnectionID{1, 2, 3})
}) })
It("uses multiple connection tracers", func() { It("uses multiple connection tracers", func() {
ctr1 := NewMockConnectionTracer(mockCtrl) ctr1 := NewMockConnectionTracer(mockCtrl)
ctr2 := NewMockConnectionTracer(mockCtrl) ctr2 := NewMockConnectionTracer(mockCtrl)
tr1.EXPECT().TracerForClient(ConnectionID{1, 2, 3}).Return(ctr1) tr1.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1)
tr2.EXPECT().TracerForClient(ConnectionID{1, 2, 3}).Return(ctr2) tr2.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2)
tr := tracer.TracerForClient(ConnectionID{1, 2, 3}) tr := tracer.TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3})
ctr1.EXPECT().LossTimerCanceled() ctr1.EXPECT().LossTimerCanceled()
ctr2.EXPECT().LossTimerCanceled() ctr2.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled() tr.LossTimerCanceled()
@ -48,17 +42,17 @@ var _ = Describe("Tracing", func() {
It("handles tracers that return a nil ConnectionTracer", func() { It("handles tracers that return a nil ConnectionTracer", func() {
ctr1 := NewMockConnectionTracer(mockCtrl) ctr1 := NewMockConnectionTracer(mockCtrl)
tr1.EXPECT().TracerForClient(ConnectionID{1, 2, 3}).Return(ctr1) tr1.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1)
tr2.EXPECT().TracerForClient(ConnectionID{1, 2, 3}) tr2.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3})
tr := tracer.TracerForClient(ConnectionID{1, 2, 3}) tr := tracer.TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3})
ctr1.EXPECT().LossTimerCanceled() ctr1.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled() tr.LossTimerCanceled()
}) })
It("returns nil when all tracers return a nil ConnectionTracer", func() { It("returns nil when all tracers return a nil ConnectionTracer", func() {
tr1.EXPECT().TracerForClient(ConnectionID{1, 2, 3}) tr1.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
tr2.EXPECT().TracerForClient(ConnectionID{1, 2, 3}) tr2.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
Expect(tracer.TracerForClient(ConnectionID{1, 2, 3})).To(BeNil()) Expect(tracer.TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil())
}) })
}) })

View file

@ -44,12 +44,8 @@ var _ logging.Tracer = &tracer{}
// NewTracer creates a new metrics tracer. // NewTracer creates a new metrics tracer.
func NewTracer() logging.Tracer { return &tracer{} } func NewTracer() logging.Tracer { return &tracer{} }
func (t *tracer) TracerForServer(logging.ConnectionID) logging.ConnectionTracer { func (t *tracer) TracerForConnection(p logging.Perspective, _ logging.ConnectionID) logging.ConnectionTracer {
return newConnTracer(t, logging.PerspectiveServer) return newConnTracer(t, p)
}
func (t *tracer) TracerForClient(logging.ConnectionID) logging.ConnectionTracer {
return newConnTracer(t, logging.PerspectiveClient)
} }
type connTracer struct { type connTracer struct {

View file

@ -20,26 +20,19 @@ import (
const eventChanSize = 50 const eventChanSize = 50
type tracer struct { type tracer struct {
getLogWriter func(connectionID []byte) io.WriteCloser getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser
} }
var _ logging.Tracer = &tracer{} var _ logging.Tracer = &tracer{}
// NewTracer creates a new qlog tracer. // NewTracer creates a new qlog tracer.
func NewTracer(getLogWriter func(connectionID []byte) io.WriteCloser) logging.Tracer { func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser) logging.Tracer {
return &tracer{getLogWriter: getLogWriter} return &tracer{getLogWriter: getLogWriter}
} }
func (t *tracer) TracerForServer(odcid protocol.ConnectionID) logging.ConnectionTracer { func (t *tracer) TracerForConnection(p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer {
if w := t.getLogWriter(odcid.Bytes()); w != nil { if w := t.getLogWriter(p, odcid.Bytes()); w != nil {
return newConnectionTracer(w, protocol.PerspectiveServer, odcid) return newConnectionTracer(w, p, odcid)
}
return nil
}
func (t *tracer) TracerForClient(odcid protocol.ConnectionID) logging.ConnectionTracer {
if w := t.getLogWriter(odcid.Bytes()); w != nil {
return newConnectionTracer(w, protocol.PerspectiveClient, odcid)
} }
return nil return nil
} }

View file

@ -51,9 +51,8 @@ type entry struct {
var _ = Describe("Tracing", func() { var _ = Describe("Tracing", func() {
Context("tracer", func() { Context("tracer", func() {
It("returns nil when there's no io.WriteCloser", func() { It("returns nil when there's no io.WriteCloser", func() {
t := NewTracer(func([]byte) io.WriteCloser { return nil }) t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil })
Expect(t.TracerForClient(logging.ConnectionID{1, 2, 3, 4})).To(BeNil()) Expect(t.TracerForConnection(logging.PerspectiveClient, logging.ConnectionID{1, 2, 3, 4})).To(BeNil())
Expect(t.TracerForServer(logging.ConnectionID{1, 2, 3, 4})).To(BeNil())
}) })
}) })
@ -65,8 +64,8 @@ var _ = Describe("Tracing", func() {
BeforeEach(func() { BeforeEach(func() {
buf = &bytes.Buffer{} buf = &bytes.Buffer{}
t := NewTracer(func([]byte) io.WriteCloser { return nopWriteCloser(buf) }) t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) })
tracer = t.TracerForServer(logging.ConnectionID{0xde, 0xad, 0xbe, 0xef}) tracer = t.TracerForConnection(logging.PerspectiveServer, logging.ConnectionID{0xde, 0xad, 0xbe, 0xef})
}) })
It("exports a trace that has the right metadata", func() { It("exports a trace that has the right metadata", func() {

View file

@ -453,7 +453,7 @@ func (s *baseServer) createNewSession(
if origDestConnID.Len() > 0 { if origDestConnID.Len() > 0 {
connID = origDestConnID connID = origDestConnID
} }
tracer = s.config.Tracer.TracerForServer(connID) tracer = s.config.Tracer.TracerForConnection(protocol.PerspectiveServer, connID)
} }
sess = s.newSession( sess = s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr}, &conn{pconn: s.conn, currentAddr: remoteAddr},