improve client tests

Use a mock newClientSession. That way, it’s a lot easier to test dialing
new connections.
This commit is contained in:
Marten Seemann 2017-05-12 18:40:53 +08:00
parent 8ba1bd817f
commit 9fad63ff50
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
4 changed files with 263 additions and 144 deletions

View file

@ -4,8 +4,6 @@ import (
"bytes"
"errors"
"net"
"reflect"
"unsafe"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
@ -21,27 +19,34 @@ var _ = Describe("Client", func() {
sess *mockSession
packetConn *mockPacketConn
addr net.Addr
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
)
BeforeEach(func() {
originalClientSessConstructor = newClientSession
Eventually(areSessionsRunning).Should(BeFalse())
msess, _, _ := newMockSession(nil, 0, 0, nil, nil)
sess = msess.(*mockSession)
packetConn = &mockPacketConn{}
config = &Config{
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
}
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
sess = &mockSession{connectionID: 0x1337}
cl = &client{
config: config,
connectionID: 0x1337,
session: sess,
version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr},
errorChan: make(chan struct{}),
handshakeChan: make(chan handshakeEvent),
config: config,
connectionID: 0x1337,
session: sess,
version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr},
errorChan: make(chan struct{}),
}
})
AfterEach(func() {
newClientSession = originalClientSessConstructor
})
AfterEach(func() {
if s, ok := cl.session.(*session); ok {
s.Close(nil)
@ -50,13 +55,88 @@ var _ = Describe("Client", func() {
})
Context("Dialing", func() {
PIt("creates a new client", func() {
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io"))
sess.Close(nil)
BeforeEach(func() {
newClientSession = func(
_ connection,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
return sess, sess.handshakeChan, nil
}
})
It("dials non-forward-secure", func(done Done) {
var dialedSess Session
go func() {
defer GinkgoRecover()
var err error
dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred())
}()
Consistently(func() Session { return dialedSess }).Should(BeNil())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(func() Session { return dialedSess }).ShouldNot(BeNil())
close(done)
})
It("Dial only returns after the handshake is complete", func(done Done) {
var dialedSess Session
go func() {
defer GinkgoRecover()
var err error
dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred())
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Consistently(func() Session { return dialedSess }).Should(BeNil())
close(sess.handshakeComplete)
Eventually(func() Session { return dialedSess }).ShouldNot(BeNil())
close(done)
})
It("resolves the address", func(done Done) {
var cconn connection
newClientSession = func(
conn connection,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
cconn = conn
return sess, nil, nil
}
go DialAddr("localhost:17890", &Config{})
Eventually(func() connection { return cconn }).ShouldNot(BeNil())
Expect(cconn.RemoteAddr().String()).To(Equal("127.0.0.1:17890"))
close(done)
})
It("returns an error that occurs while waiting for the connection to become secure", func(done Done) {
testErr := errors.New("early handshake error")
var dialErr error
go func() {
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config)
}()
sess.handshakeChan <- handshakeEvent{err: testErr}
Eventually(func() error { return dialErr }).Should(MatchError(testErr))
close(done)
})
It("returns an error that occurs while waiting for the handshake to complete", func(done Done) {
testErr := errors.New("late handshake error")
var dialErr error
go func() {
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config)
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
sess.handshakeComplete <- testErr
Eventually(func() error { return dialErr }).Should(MatchError(testErr))
close(done)
})
It("uses all supported versions, if none are specified in the quic.Config", func() {
@ -64,18 +144,121 @@ var _ = Describe("Client", func() {
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
})
It("errors when receiving an invalid first packet from the server", func() {
It("errors when receiving an invalid first packet from the server", func(done Done) {
packetConn.dataToRead = []byte{0xff}
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).To(HaveOccurred())
Expect(sess).To(BeNil())
close(done)
})
It("errors when receiving an error from the connection", func() {
It("errors when receiving an error from the connection", func(done Done) {
testErr := errors.New("connection error")
packetConn.readErr = testErr
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).To(MatchError(testErr))
close(done)
})
It("errors if it can't create a session", func() {
testErr := errors.New("error creating session")
newClientSession = func(
_ connection,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
return nil, nil, testErr
}
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).To(MatchError(testErr))
})
Context("version negotiation", func() {
It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() {
ph := PublicHeader{
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
ConnectionID: 0x1337,
}
b := &bytes.Buffer{}
err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
err = cl.handlePacket(nil, b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(cl.versionNegotiated).To(BeTrue())
})
It("changes the version after receiving a version negotiation packet", func() {
var negotiatedVersions []protocol.VersionNumber
newClientSession = func(
_ connection,
_ string,
_ protocol.VersionNumber,
connectionID protocol.ConnectionID,
_ *Config,
negotiatedVersionsP []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
negotiatedVersions = negotiatedVersionsP
return &mockSession{
connectionID: connectionID,
}, nil, nil
}
newVersion := protocol.VersionNumber(77)
Expect(config.Versions).To(ContainElement(newVersion))
Expect(newVersion).ToNot(Equal(cl.version))
Expect(sess.packetCount).To(BeZero())
cl.connectionID = 0x1337
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(newVersion))
Expect(cl.versionNegotiated).To(BeTrue())
// it swapped the sessions
// Expect(cl.session).ToNot(Equal(sess))
Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
Expect(err).ToNot(HaveOccurred())
// it didn't pass the version negoation packet to the old session (since it has no payload)
Expect(sess.packetCount).To(BeZero())
Expect(negotiatedVersions).To(Equal([]protocol.VersionNumber{newVersion}))
})
It("errors if no matching version is found", func() {
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1}))
Expect(err).To(MatchError(qerr.InvalidVersion))
})
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
v := protocol.SupportedVersions[1]
Expect(v).ToNot(Equal(cl.version))
Expect(config.Versions).ToNot(ContainElement(v))
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{v}))
Expect(err).To(MatchError(qerr.InvalidVersion))
})
It("changes to the version preferred by the quic.Config", func() {
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(config.Versions[1]))
})
It("ignores delayed version negotiation packets", func() {
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
cl.versionNegotiated = true
Expect(sess.packetCount).To(BeZero())
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.versionNegotiated).To(BeTrue())
Expect(sess.packetCount).To(BeZero())
})
It("drops version negotiation packets that contain the offered version", func() {
ver := cl.version
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{ver}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(ver))
})
})
})
@ -84,39 +267,37 @@ var _ = Describe("Client", func() {
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
})
// this test requires a real session
// and a real UDP conn (because it unblocks and errors when it is closed)
PIt("properly closes", func(done Done) {
Eventually(areSessionsRunning).Should(BeFalse())
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
Expect(err).ToNot(HaveOccurred())
cl.conn = &conn{pconn: udpConn, currentAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}}
err = cl.createNewSession(nil)
Expect(err).ToNot(HaveOccurred())
Eventually(areSessionsRunning).Should(BeTrue())
var stoppedListening bool
It("creates new sessions with the right parameters", func(done Done) {
c := make(chan struct{})
var cconn connection
var hostname string
var version protocol.VersionNumber
var conf *Config
newClientSession = func(
connP connection,
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
configP *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
cconn = connP
hostname = hostnameP
version = versionP
conf = configP
close(c)
return sess, nil, nil
}
go func() {
cl.listen()
stoppedListening = true
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred())
}()
testErr := errors.New("test error")
err = cl.session.Close(testErr)
Expect(err).ToNot(HaveOccurred())
Eventually(func() bool { return stoppedListening }).Should(BeTrue())
Eventually(areSessionsRunning).Should(BeFalse())
<-c
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
Expect(hostname).To(Equal("quic.clemente.io"))
Expect(version).To(Equal(cl.version))
Expect(conf).To(Equal(config))
close(done)
}, 10)
It("creates new sessions with the right parameters", func() {
cl.session = nil
cl.hostname = "hostname"
err := cl.createNewSession(nil)
Expect(err).ToNot(HaveOccurred())
Expect(cl.session).ToNot(BeNil())
Expect(cl.session.(*session).connectionID).To(Equal(cl.connectionID))
Expect(cl.session.(*session).version).To(Equal(cl.version))
})
Context("handling packets", func() {
@ -160,77 +341,4 @@ var _ = Describe("Client", func() {
Expect(sess.closeReason).To(MatchError(testErr))
})
})
Context("version negotiation", func() {
It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() {
ph := PublicHeader{
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
ConnectionID: 0x1337,
}
b := &bytes.Buffer{}
err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
err = cl.handlePacket(nil, b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(cl.versionNegotiated).To(BeTrue())
})
It("changes the version after receiving a version negotiation packet", func() {
newVersion := protocol.VersionNumber(77)
Expect(config.Versions).To(ContainElement(newVersion))
Expect(newVersion).ToNot(Equal(cl.version))
Expect(sess.packetCount).To(BeZero())
cl.connectionID = 0x1337
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(newVersion))
Expect(cl.versionNegotiated).To(BeTrue())
// it swapped the sessions
Expect(cl.session).ToNot(Equal(sess))
Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
Expect(err).ToNot(HaveOccurred())
// it didn't pass the version negoation packet to the old session (since it has no payload)
Expect(sess.packetCount).To(BeZero())
// if the version negotiation packet was passed to the new session, it would end up as an undecryptable packet there
Expect(cl.session.(*session).undecryptablePackets).To(BeEmpty())
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{newVersion}))
})
It("errors if no matching version is found", func() {
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1}))
Expect(err).To(MatchError(qerr.InvalidVersion))
})
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
v := protocol.SupportedVersions[1]
Expect(v).ToNot(Equal(cl.version))
Expect(config.Versions).ToNot(ContainElement(v))
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{v}))
Expect(err).To(MatchError(qerr.InvalidVersion))
})
It("changes to the version preferred by the quic.Config", func() {
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(config.Versions[1]))
})
It("ignores delayed version negotiation packets", func() {
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
cl.versionNegotiated = true
Expect(sess.packetCount).To(BeZero())
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.versionNegotiated).To(BeTrue())
Expect(sess.packetCount).To(BeZero())
})
It("drops version negotiation packets that contain the offered version", func() {
ver := cl.version
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{ver}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(ver))
})
})
})

View file

@ -18,12 +18,13 @@ import (
)
type mockSession struct {
connectionID protocol.ConnectionID
packetCount int
closed bool
closeReason error
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan handshakeEvent
connectionID protocol.ConnectionID
packetCount int
closed bool
closeReason error
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan handshakeEvent
handshakeComplete chan error // for WaitUntilHandshakeComplete
}
func (s *mockSession) handlePacket(*receivedPacket) {
@ -34,9 +35,16 @@ func (s *mockSession) run() error {
<-s.stopRunLoop
return s.closeReason
}
func (s *mockSession) WaitUntilHandshakeComplete() error {
return <-s.handshakeComplete
}
func (s *mockSession) Close(e error) error {
if s.closed {
return nil
}
s.closeReason = e
s.closed = true
close(s.stopRunLoop)
return nil
}
func (s *mockSession) AcceptStream() (Stream, error) {
@ -56,6 +64,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
}
var _ Session = &mockSession{}
var _ NonFWSession = &mockSession{}
func newMockSession(
_ connection,
@ -65,9 +74,10 @@ func newMockSession(
_ *Config,
) (packetHandler, <-chan handshakeEvent, error) {
s := mockSession{
connectionID: connectionID,
handshakeChan: make(chan handshakeEvent),
stopRunLoop: make(chan struct{}),
connectionID: connectionID,
handshakeChan: make(chan handshakeEvent),
handshakeComplete: make(chan error),
stopRunLoop: make(chan struct{}),
}
return &s, s.handshakeChan, nil
}
@ -211,11 +221,11 @@ var _ = Describe("Server", func() {
})
It("closes sessions and the connection when Close is called", func() {
session := &mockSession{}
session, _, _ := newMockSession(nil, 0, 0, nil, nil)
serv.sessions[1] = session
err := serv.Close()
Expect(err).NotTo(HaveOccurred())
Expect(session.closed).To(BeTrue())
Expect(session.(*mockSession).closed).To(BeTrue())
Expect(conn.closed).To(BeTrue())
})
@ -254,14 +264,14 @@ var _ = Describe("Server", func() {
}, 0.5)
It("closes all sessions when encountering a connection error", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveKey(connID))
Expect(serv.sessions[connID].(*mockSession).closed).To(BeFalse())
session, _, _ := newMockSession(nil, 0, 0, nil, nil)
serv.sessions[0x12345] = session
Expect(serv.sessions[0x12345].(*mockSession).closed).To(BeFalse())
testErr := errors.New("connection error")
conn.readErr = testErr
go serv.serve()
Eventually(func() bool { return serv.sessions[connID].(*mockSession).closed }).Should(BeTrue())
Eventually(func() Session { return serv.sessions[connID] }).Should(BeNil())
Eventually(func() bool { return session.(*mockSession).closed }).Should(BeTrue())
Expect(serv.Close()).To(Succeed())
})

View file

@ -154,14 +154,15 @@ func newSession(
return s, handshakeChan, err
}
func newClientSession(
// declare this as a variable, such that we can it mock it in the tests
var newClientSession = func(
conn connection,
hostname string,
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
config *Config,
negotiatedVersions []protocol.VersionNumber,
) (*session, <-chan handshakeEvent, error) {
) (packetHandler, <-chan handshakeEvent, error) {
s := &session{
conn: conn,
connectionID: connectionID,

View file

@ -1520,8 +1520,7 @@ var _ = Describe("Client Session", func() {
mconn = &mockConnection{
remoteAddr: &net.UDPAddr{},
}
var err error
sess, _, err = newClientSession(
sessP, _, err := newClientSession(
mconn,
"hostname",
protocol.Version35,
@ -1529,6 +1528,7 @@ var _ = Describe("Client Session", func() {
populateClientConfig(&Config{}),
nil,
)
sess = sessP.(*session)
Expect(err).ToNot(HaveOccurred())
Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream
// we need an aeadChanged chan that we can write to