fix some race conditions in client tests

This commit is contained in:
Marten Seemann 2017-07-25 13:30:49 +07:00
parent 43279c37a8
commit 4c73b935f5

View file

@ -75,50 +75,53 @@ var _ = Describe("Client", func() {
}) })
It("dials non-forward-secure", func(done Done) { It("dials non-forward-secure", func(done Done) {
var dialedSess Session dialed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
var err error s, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
}() }()
Consistently(func() Session { return dialedSess }).Should(BeNil()) Consistently(dialed).ShouldNot(BeClosed())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(func() Session { return dialedSess }).ShouldNot(BeNil()) Eventually(dialed).Should(BeClosed())
close(done) close(done)
}) })
It("dials a non-forward-secure address", func(done Done) { It("dials a non-forward-secure address", func(done Done) {
var dialedSess Session dialed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
var err error s, err := DialAddrNonFWSecure("localhost:18901", nil, config)
dialedSess, err = DialAddrNonFWSecure("localhost:18901", nil, config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
}() }()
Consistently(func() Session { return dialedSess }).Should(BeNil()) Consistently(dialed).ShouldNot(BeClosed())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(func() Session { return dialedSess }).ShouldNot(BeNil()) Eventually(dialed).Should(BeClosed())
close(done) close(done)
}) })
It("Dial only returns after the handshake is complete", func(done Done) { It("Dial only returns after the handshake is complete", func(done Done) {
var dialedSess Session dialed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
var err error s, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
}() }()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Consistently(func() Session { return dialedSess }).Should(BeNil()) Consistently(dialed).ShouldNot(BeClosed())
close(sess.handshakeComplete) close(sess.handshakeComplete)
Eventually(func() Session { return dialedSess }).ShouldNot(BeNil()) Eventually(dialed).Should(BeClosed())
close(done) close(done)
}) })
It("resolves the address", func(done Done) { It("resolves the address", func(done Done) {
var cconn connection remoteAddrChan := make(chan string)
newClientSession = func( newClientSession = func(
conn connection, conn connection,
_ string, _ string,
@ -128,17 +131,16 @@ var _ = Describe("Client", func() {
_ *Config, _ *Config,
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) { ) (packetHandler, <-chan handshakeEvent, error) {
cconn = conn remoteAddrChan <- conn.RemoteAddr().String()
return sess, nil, nil return sess, nil, nil
} }
go DialAddr("localhost:17890", nil, &Config{}) go DialAddr("localhost:17890", nil, &Config{})
Eventually(func() connection { return cconn }).ShouldNot(BeNil()) Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
Expect(cconn.RemoteAddr().String()).To(Equal("127.0.0.1:17890"))
close(done) close(done)
}) })
It("uses the tls.Config.ServerName as the hostname, if present", func(done Done) { It("uses the tls.Config.ServerName as the hostname, if present", func(done Done) {
var hostname string hostnameChan := make(chan string)
newClientSession = func( newClientSession = func(
_ connection, _ connection,
h string, h string,
@ -148,11 +150,11 @@ var _ = Describe("Client", func() {
_ *Config, _ *Config,
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) { ) (packetHandler, <-chan handshakeEvent, error) {
hostname = h hostnameChan <- h
return sess, nil, nil return sess, nil, nil
} }
go DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil) go DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil)
Eventually(func() string { return hostname }).Should(Equal("foobar")) Eventually(hostnameChan).Should(Receive(Equal("foobar")))
close(done) close(done)
}) })
@ -161,10 +163,10 @@ var _ = Describe("Client", func() {
var dialErr error var dialErr error
go func() { go func() {
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(dialErr).To(MatchError(testErr))
close(done)
}() }()
sess.handshakeChan <- handshakeEvent{err: testErr} sess.handshakeChan <- handshakeEvent{err: testErr}
Eventually(func() error { return dialErr }).Should(MatchError(testErr))
close(done)
}) })
It("returns an error that occurs while waiting for the handshake to complete", func(done Done) { It("returns an error that occurs while waiting for the handshake to complete", func(done Done) {
@ -172,11 +174,11 @@ var _ = Describe("Client", func() {
var dialErr error var dialErr error
go func() { go func() {
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(dialErr).To(MatchError(testErr))
close(done)
}() }()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
sess.handshakeComplete <- testErr sess.handshakeComplete <- testErr
Eventually(func() error { return dialErr }).Should(MatchError(testErr))
close(done)
}) })
It("setups with the right values", func() { It("setups with the right values", func() {
@ -336,7 +338,7 @@ var _ = Describe("Client", func() {
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}() }()
<-c Eventually(c).Should(BeClosed())
Expect(cconn.(*conn).pconn).To(Equal(packetConn)) Expect(cconn.(*conn).pconn).To(Equal(packetConn))
Expect(hostname).To(Equal("quic.clemente.io")) Expect(hostname).To(Equal("quic.clemente.io"))
Expect(version).To(Equal(cl.version)) Expect(version).To(Equal(cl.version))
@ -357,16 +359,16 @@ var _ = Describe("Client", func() {
packetConn.dataToRead = b.Bytes() packetConn.dataToRead = b.Bytes()
Expect(sess.packetCount).To(BeZero()) Expect(sess.packetCount).To(BeZero())
var stoppedListening bool stoppedListening := make(chan struct{})
go func() { go func() {
cl.listen() cl.listen()
// it should continue listening when receiving valid packets // it should continue listening when receiving valid packets
stoppedListening = true close(stoppedListening)
}() }()
Eventually(func() int { return sess.packetCount }).Should(Equal(1)) Eventually(func() int { return sess.packetCount }).Should(Equal(1))
Expect(sess.closed).To(BeFalse()) Expect(sess.closed).To(BeFalse())
Consistently(func() bool { return stoppedListening }).Should(BeFalse()) Consistently(stoppedListening).ShouldNot(BeClosed())
}) })
It("closes the session when encountering an error while reading from the connection", func() { It("closes the session when encountering an error while reading from the connection", func() {