use a mock qlogger in the client tests

This commit is contained in:
Marten Seemann 2020-04-14 17:15:36 +07:00
parent 6ba147119d
commit 80418be227
2 changed files with 44 additions and 13 deletions

View file

@ -1,10 +1,13 @@
package quic
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"io"
"io/ioutil"
"net"
"os"
"time"
@ -12,6 +15,7 @@ import (
"github.com/lucas-clemente/quic-go/qlog"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
@ -30,6 +34,8 @@ var _ = Describe("Client", func() {
mockMultiplexer *MockMultiplexer
origMultiplexer multiplexer
tlsConf *tls.Config
qlogger *mocks.MockTracer
config *Config
originalClientSessConstructor func(
conn connection,
@ -45,6 +51,7 @@ var _ = Describe("Client", func() {
logger utils.Logger,
v protocol.VersionNumber,
) quicSession
originalQlogConstructor func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) qlog.Tracer
)
// generate a packet sent by the server that accepts the QUIC version suggested by the client
@ -72,6 +79,21 @@ var _ = Describe("Client", func() {
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
originalClientSessConstructor = newClientSession
originalQlogConstructor = newQlogger
qlogger = mocks.NewMockTracer(mockCtrl)
newQlogger = func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) qlog.Tracer {
return qlogger
}
config = &Config{
GetLogWriter: func([]byte) io.WriteCloser {
// Since we're mocking the qlogger, it doesn't matter what we return here,
// as long as it's not nil.
return utils.NewBufferedWriteCloser(
bufio.NewWriter(&bytes.Buffer{}),
ioutil.NopCloser(&bytes.Buffer{}),
)
},
}
Eventually(areSessionsRunning).Should(BeFalse())
// sess = NewMockQuicSession(mockCtrl)
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
@ -83,6 +105,7 @@ var _ = Describe("Client", func() {
destConnID: connID,
version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr},
qlogger: qlogger,
logger: utils.DefaultLogger,
}
getMultiplexer() // make the sync.Once execute
@ -95,6 +118,7 @@ var _ = Describe("Client", func() {
AfterEach(func() {
connMuxer = origMultiplexer
newClientSession = originalClientSessConstructor
newQlogger = originalQlogConstructor
})
AfterEach(func() {
@ -219,12 +243,13 @@ var _ = Describe("Client", func() {
sess.EXPECT().run()
return sess
}
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial(
packetConn,
addr,
"test.com",
tlsConf,
&Config{},
config,
)
Expect(err).ToNot(HaveOccurred())
Eventually(hostnameChan).Should(Receive(Equal("test.com")))
@ -258,12 +283,13 @@ var _ = Describe("Client", func() {
sess.EXPECT().HandshakeComplete().Return(ctx)
return sess
}
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
s, err := Dial(
packetConn,
addr,
"localhost:1337",
tlsConf,
&Config{},
config,
)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
@ -302,12 +328,13 @@ var _ = Describe("Client", func() {
go func() {
defer GinkgoRecover()
defer close(done)
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
s, err := DialEarly(
packetConn,
addr,
"localhost:1337",
tlsConf,
&Config{},
config,
)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
@ -343,12 +370,13 @@ var _ = Describe("Client", func() {
return sess
}
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial(
packetConn,
addr,
"localhost:1337",
tlsConf,
&Config{},
config,
)
Expect(err).To(MatchError(testErr))
})
@ -385,13 +413,14 @@ var _ = Describe("Client", func() {
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := DialContext(
ctx,
packetConn,
addr,
"localhost:1337",
tlsConf,
&Config{},
config,
)
Expect(err).To(MatchError(context.Canceled))
close(dialed)
@ -513,8 +542,7 @@ var _ = Describe("Client", func() {
})
It("uses 0-byte connection IDs when dialing an address", func() {
config := &Config{}
c := populateClientConfig(config, true)
c := populateClientConfig(&Config{}, true)
Expect(c.ConnectionIDLength).To(BeZero())
})
@ -606,12 +634,13 @@ var _ = Describe("Client", func() {
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess
}
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial(
packetConn,
addr,
"localhost:1337",
tlsConf,
&Config{},
config,
)
Expect(err).To(MatchError(testErr))
})
@ -620,7 +649,7 @@ var _ = Describe("Client", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess
cl.config = &Config{}
cl.config = config
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
Header: wire.Header{
@ -686,7 +715,7 @@ var _ = Describe("Client", func() {
})
It("drops version negotiation packets that contain the offered version", func() {
cl.config = &Config{}
cl.config = config
ver := cl.version
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver}))
Expect(cl.version).To(Equal(ver))
@ -709,7 +738,7 @@ var _ = Describe("Client", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess
cl.config = &Config{}
cl.config = config
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
Header: wire.Header{
@ -722,7 +751,7 @@ var _ = Describe("Client", func() {
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
// Version negotiation is now ignored
cl.config = &Config{}
cl.config = config
ver := cl.version
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234}))
Expect(cl.version).To(Equal(ver))