mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
use a mock qlogger in the client tests
This commit is contained in:
parent
6ba147119d
commit
80418be227
2 changed files with 44 additions and 13 deletions
|
@ -56,6 +56,8 @@ var (
|
||||||
// make it possible to mock connection ID generation in the tests
|
// make it possible to mock connection ID generation in the tests
|
||||||
generateConnectionID = protocol.GenerateConnectionID
|
generateConnectionID = protocol.GenerateConnectionID
|
||||||
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||||
|
// make it possible to the qlogger
|
||||||
|
newQlogger = qlog.NewTracer
|
||||||
)
|
)
|
||||||
|
|
||||||
// DialAddr establishes a new QUIC connection to a server.
|
// DialAddr establishes a new QUIC connection to a server.
|
||||||
|
@ -181,7 +183,7 @@ func dialContext(
|
||||||
|
|
||||||
if c.config.GetLogWriter != nil {
|
if c.config.GetLogWriter != nil {
|
||||||
if w := c.config.GetLogWriter(c.destConnID); w != nil {
|
if w := c.config.GetLogWriter(c.destConnID); w != nil {
|
||||||
c.qlogger = qlog.NewTracer(w, protocol.PerspectiveClient, c.destConnID)
|
c.qlogger = newQlogger(w, protocol.PerspectiveClient, c.destConnID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := c.dial(ctx); err != nil {
|
if err := c.dial(ctx); err != nil {
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
@ -12,6 +15,7 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/qlog"
|
"github.com/lucas-clemente/quic-go/qlog"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
@ -30,6 +34,8 @@ var _ = Describe("Client", func() {
|
||||||
mockMultiplexer *MockMultiplexer
|
mockMultiplexer *MockMultiplexer
|
||||||
origMultiplexer multiplexer
|
origMultiplexer multiplexer
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
|
qlogger *mocks.MockTracer
|
||||||
|
config *Config
|
||||||
|
|
||||||
originalClientSessConstructor func(
|
originalClientSessConstructor func(
|
||||||
conn connection,
|
conn connection,
|
||||||
|
@ -45,6 +51,7 @@ var _ = Describe("Client", func() {
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
v protocol.VersionNumber,
|
v protocol.VersionNumber,
|
||||||
) quicSession
|
) quicSession
|
||||||
|
originalQlogConstructor func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) qlog.Tracer
|
||||||
)
|
)
|
||||||
|
|
||||||
// generate a packet sent by the server that accepts the QUIC version suggested by the client
|
// generate a packet sent by the server that accepts the QUIC version suggested by the client
|
||||||
|
@ -72,6 +79,21 @@ var _ = Describe("Client", func() {
|
||||||
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
|
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
|
||||||
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
||||||
originalClientSessConstructor = newClientSession
|
originalClientSessConstructor = newClientSession
|
||||||
|
originalQlogConstructor = newQlogger
|
||||||
|
qlogger = mocks.NewMockTracer(mockCtrl)
|
||||||
|
newQlogger = func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) qlog.Tracer {
|
||||||
|
return qlogger
|
||||||
|
}
|
||||||
|
config = &Config{
|
||||||
|
GetLogWriter: func([]byte) io.WriteCloser {
|
||||||
|
// Since we're mocking the qlogger, it doesn't matter what we return here,
|
||||||
|
// as long as it's not nil.
|
||||||
|
return utils.NewBufferedWriteCloser(
|
||||||
|
bufio.NewWriter(&bytes.Buffer{}),
|
||||||
|
ioutil.NopCloser(&bytes.Buffer{}),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
Eventually(areSessionsRunning).Should(BeFalse())
|
Eventually(areSessionsRunning).Should(BeFalse())
|
||||||
// sess = NewMockQuicSession(mockCtrl)
|
// sess = NewMockQuicSession(mockCtrl)
|
||||||
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||||
|
@ -83,6 +105,7 @@ var _ = Describe("Client", func() {
|
||||||
destConnID: connID,
|
destConnID: connID,
|
||||||
version: protocol.SupportedVersions[0],
|
version: protocol.SupportedVersions[0],
|
||||||
conn: &conn{pconn: packetConn, currentAddr: addr},
|
conn: &conn{pconn: packetConn, currentAddr: addr},
|
||||||
|
qlogger: qlogger,
|
||||||
logger: utils.DefaultLogger,
|
logger: utils.DefaultLogger,
|
||||||
}
|
}
|
||||||
getMultiplexer() // make the sync.Once execute
|
getMultiplexer() // make the sync.Once execute
|
||||||
|
@ -95,6 +118,7 @@ var _ = Describe("Client", func() {
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
connMuxer = origMultiplexer
|
connMuxer = origMultiplexer
|
||||||
newClientSession = originalClientSessConstructor
|
newClientSession = originalClientSessConstructor
|
||||||
|
newQlogger = originalQlogConstructor
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
|
@ -219,12 +243,13 @@ var _ = Describe("Client", func() {
|
||||||
sess.EXPECT().run()
|
sess.EXPECT().run()
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
|
||||||
_, err := Dial(
|
_, err := Dial(
|
||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"test.com",
|
"test.com",
|
||||||
tlsConf,
|
tlsConf,
|
||||||
&Config{},
|
config,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Eventually(hostnameChan).Should(Receive(Equal("test.com")))
|
Eventually(hostnameChan).Should(Receive(Equal("test.com")))
|
||||||
|
@ -258,12 +283,13 @@ var _ = Describe("Client", func() {
|
||||||
sess.EXPECT().HandshakeComplete().Return(ctx)
|
sess.EXPECT().HandshakeComplete().Return(ctx)
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
|
||||||
s, err := Dial(
|
s, err := Dial(
|
||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
tlsConf,
|
tlsConf,
|
||||||
&Config{},
|
config,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(s).ToNot(BeNil())
|
Expect(s).ToNot(BeNil())
|
||||||
|
@ -302,12 +328,13 @@ var _ = Describe("Client", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
|
||||||
s, err := DialEarly(
|
s, err := DialEarly(
|
||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
tlsConf,
|
tlsConf,
|
||||||
&Config{},
|
config,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(s).ToNot(BeNil())
|
Expect(s).ToNot(BeNil())
|
||||||
|
@ -343,12 +370,13 @@ var _ = Describe("Client", func() {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
|
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
|
||||||
|
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
|
||||||
_, err := Dial(
|
_, err := Dial(
|
||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
tlsConf,
|
tlsConf,
|
||||||
&Config{},
|
config,
|
||||||
)
|
)
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
@ -385,13 +413,14 @@ var _ = Describe("Client", func() {
|
||||||
dialed := make(chan struct{})
|
dialed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
|
||||||
_, err := DialContext(
|
_, err := DialContext(
|
||||||
ctx,
|
ctx,
|
||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
tlsConf,
|
tlsConf,
|
||||||
&Config{},
|
config,
|
||||||
)
|
)
|
||||||
Expect(err).To(MatchError(context.Canceled))
|
Expect(err).To(MatchError(context.Canceled))
|
||||||
close(dialed)
|
close(dialed)
|
||||||
|
@ -513,8 +542,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("uses 0-byte connection IDs when dialing an address", func() {
|
It("uses 0-byte connection IDs when dialing an address", func() {
|
||||||
config := &Config{}
|
c := populateClientConfig(&Config{}, true)
|
||||||
c := populateClientConfig(config, true)
|
|
||||||
Expect(c.ConnectionIDLength).To(BeZero())
|
Expect(c.ConnectionIDLength).To(BeZero())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -606,12 +634,13 @@ var _ = Describe("Client", func() {
|
||||||
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
|
||||||
_, err := Dial(
|
_, err := Dial(
|
||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
tlsConf,
|
tlsConf,
|
||||||
&Config{},
|
config,
|
||||||
)
|
)
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
@ -620,7 +649,7 @@ var _ = Describe("Client", func() {
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
sess := NewMockQuicSession(mockCtrl)
|
||||||
sess.EXPECT().handlePacket(gomock.Any())
|
sess.EXPECT().handlePacket(gomock.Any())
|
||||||
cl.session = sess
|
cl.session = sess
|
||||||
cl.config = &Config{}
|
cl.config = config
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
Expect((&wire.ExtendedHeader{
|
Expect((&wire.ExtendedHeader{
|
||||||
Header: wire.Header{
|
Header: wire.Header{
|
||||||
|
@ -686,7 +715,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops version negotiation packets that contain the offered version", func() {
|
It("drops version negotiation packets that contain the offered version", func() {
|
||||||
cl.config = &Config{}
|
cl.config = config
|
||||||
ver := cl.version
|
ver := cl.version
|
||||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver}))
|
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver}))
|
||||||
Expect(cl.version).To(Equal(ver))
|
Expect(cl.version).To(Equal(ver))
|
||||||
|
@ -709,7 +738,7 @@ var _ = Describe("Client", func() {
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
sess := NewMockQuicSession(mockCtrl)
|
||||||
sess.EXPECT().handlePacket(gomock.Any())
|
sess.EXPECT().handlePacket(gomock.Any())
|
||||||
cl.session = sess
|
cl.session = sess
|
||||||
cl.config = &Config{}
|
cl.config = config
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
Expect((&wire.ExtendedHeader{
|
Expect((&wire.ExtendedHeader{
|
||||||
Header: wire.Header{
|
Header: wire.Header{
|
||||||
|
@ -722,7 +751,7 @@ var _ = Describe("Client", func() {
|
||||||
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
|
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
|
||||||
|
|
||||||
// Version negotiation is now ignored
|
// Version negotiation is now ignored
|
||||||
cl.config = &Config{}
|
cl.config = config
|
||||||
ver := cl.version
|
ver := cl.version
|
||||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234}))
|
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234}))
|
||||||
Expect(cl.version).To(Equal(ver))
|
Expect(cl.version).To(Equal(ver))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue