use a mock qlogger in the client tests

This commit is contained in:
Marten Seemann 2020-04-14 17:15:36 +07:00
parent 6ba147119d
commit 80418be227
2 changed files with 44 additions and 13 deletions

View file

@ -56,6 +56,8 @@ var (
// make it possible to mock connection ID generation in the tests // make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// make it possible to the qlogger
newQlogger = qlog.NewTracer
) )
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
@ -181,7 +183,7 @@ func dialContext(
if c.config.GetLogWriter != nil { if c.config.GetLogWriter != nil {
if w := c.config.GetLogWriter(c.destConnID); w != nil { if w := c.config.GetLogWriter(c.destConnID); w != nil {
c.qlogger = qlog.NewTracer(w, protocol.PerspectiveClient, c.destConnID) c.qlogger = newQlogger(w, protocol.PerspectiveClient, c.destConnID)
} }
} }
if err := c.dial(ctx); err != nil { if err := c.dial(ctx); err != nil {

View file

@ -1,10 +1,13 @@
package quic package quic
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"io"
"io/ioutil"
"net" "net"
"os" "os"
"time" "time"
@ -12,6 +15,7 @@ import (
"github.com/lucas-clemente/quic-go/qlog" "github.com/lucas-clemente/quic-go/qlog"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -30,6 +34,8 @@ var _ = Describe("Client", func() {
mockMultiplexer *MockMultiplexer mockMultiplexer *MockMultiplexer
origMultiplexer multiplexer origMultiplexer multiplexer
tlsConf *tls.Config tlsConf *tls.Config
qlogger *mocks.MockTracer
config *Config
originalClientSessConstructor func( originalClientSessConstructor func(
conn connection, conn connection,
@ -45,6 +51,7 @@ var _ = Describe("Client", func() {
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.VersionNumber,
) quicSession ) quicSession
originalQlogConstructor func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) qlog.Tracer
) )
// generate a packet sent by the server that accepts the QUIC version suggested by the client // generate a packet sent by the server that accepts the QUIC version suggested by the client
@ -72,6 +79,21 @@ var _ = Describe("Client", func() {
tlsConf = &tls.Config{NextProtos: []string{"proto1"}} tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
originalClientSessConstructor = newClientSession originalClientSessConstructor = newClientSession
originalQlogConstructor = newQlogger
qlogger = mocks.NewMockTracer(mockCtrl)
newQlogger = func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) qlog.Tracer {
return qlogger
}
config = &Config{
GetLogWriter: func([]byte) io.WriteCloser {
// Since we're mocking the qlogger, it doesn't matter what we return here,
// as long as it's not nil.
return utils.NewBufferedWriteCloser(
bufio.NewWriter(&bytes.Buffer{}),
ioutil.NopCloser(&bytes.Buffer{}),
)
},
}
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
// sess = NewMockQuicSession(mockCtrl) // sess = NewMockQuicSession(mockCtrl)
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
@ -83,6 +105,7 @@ var _ = Describe("Client", func() {
destConnID: connID, destConnID: connID,
version: protocol.SupportedVersions[0], version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr}, conn: &conn{pconn: packetConn, currentAddr: addr},
qlogger: qlogger,
logger: utils.DefaultLogger, logger: utils.DefaultLogger,
} }
getMultiplexer() // make the sync.Once execute getMultiplexer() // make the sync.Once execute
@ -95,6 +118,7 @@ var _ = Describe("Client", func() {
AfterEach(func() { AfterEach(func() {
connMuxer = origMultiplexer connMuxer = origMultiplexer
newClientSession = originalClientSessConstructor newClientSession = originalClientSessConstructor
newQlogger = originalQlogConstructor
}) })
AfterEach(func() { AfterEach(func() {
@ -219,12 +243,13 @@ var _ = Describe("Client", func() {
sess.EXPECT().run() sess.EXPECT().run()
return sess return sess
} }
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial( _, err := Dial(
packetConn, packetConn,
addr, addr,
"test.com", "test.com",
tlsConf, tlsConf,
&Config{}, config,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(hostnameChan).Should(Receive(Equal("test.com"))) Eventually(hostnameChan).Should(Receive(Equal("test.com")))
@ -258,12 +283,13 @@ var _ = Describe("Client", func() {
sess.EXPECT().HandshakeComplete().Return(ctx) sess.EXPECT().HandshakeComplete().Return(ctx)
return sess return sess
} }
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
s, err := Dial( s, err := Dial(
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
tlsConf, tlsConf,
&Config{}, config,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil()) Expect(s).ToNot(BeNil())
@ -302,12 +328,13 @@ var _ = Describe("Client", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer close(done) defer close(done)
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
s, err := DialEarly( s, err := DialEarly(
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
tlsConf, tlsConf,
&Config{}, config,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil()) Expect(s).ToNot(BeNil())
@ -343,12 +370,13 @@ var _ = Describe("Client", func() {
return sess return sess
} }
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial( _, err := Dial(
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
tlsConf, tlsConf,
&Config{}, config,
) )
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -385,13 +413,14 @@ var _ = Describe("Client", func() {
dialed := make(chan struct{}) dialed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := DialContext( _, err := DialContext(
ctx, ctx,
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
tlsConf, tlsConf,
&Config{}, config,
) )
Expect(err).To(MatchError(context.Canceled)) Expect(err).To(MatchError(context.Canceled))
close(dialed) close(dialed)
@ -513,8 +542,7 @@ var _ = Describe("Client", func() {
}) })
It("uses 0-byte connection IDs when dialing an address", func() { It("uses 0-byte connection IDs when dialing an address", func() {
config := &Config{} c := populateClientConfig(&Config{}, true)
c := populateClientConfig(config, true)
Expect(c.ConnectionIDLength).To(BeZero()) Expect(c.ConnectionIDLength).To(BeZero())
}) })
@ -606,12 +634,13 @@ var _ = Describe("Client", func() {
sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess return sess
} }
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial( _, err := Dial(
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
tlsConf, tlsConf,
&Config{}, config,
) )
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -620,7 +649,7 @@ var _ = Describe("Client", func() {
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()) sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess cl.session = sess
cl.config = &Config{} cl.config = config
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{ Expect((&wire.ExtendedHeader{
Header: wire.Header{ Header: wire.Header{
@ -686,7 +715,7 @@ var _ = Describe("Client", func() {
}) })
It("drops version negotiation packets that contain the offered version", func() { It("drops version negotiation packets that contain the offered version", func() {
cl.config = &Config{} cl.config = config
ver := cl.version ver := cl.version
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})) cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver}))
Expect(cl.version).To(Equal(ver)) Expect(cl.version).To(Equal(ver))
@ -709,7 +738,7 @@ var _ = Describe("Client", func() {
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()) sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess cl.session = sess
cl.config = &Config{} cl.config = config
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{ Expect((&wire.ExtendedHeader{
Header: wire.Header{ Header: wire.Header{
@ -722,7 +751,7 @@ var _ = Describe("Client", func() {
cl.handlePacket(&receivedPacket{data: buf.Bytes()}) cl.handlePacket(&receivedPacket{data: buf.Bytes()})
// Version negotiation is now ignored // Version negotiation is now ignored
cl.config = &Config{} cl.config = config
ver := cl.version ver := cl.version
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234})) cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234}))
Expect(cl.version).To(Equal(ver)) Expect(cl.version).To(Equal(ver))