diff --git a/Changelog.md b/Changelog.md index 839aec33..04f98432 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,8 +1,9 @@ # Changelog -## v0.6.1 (unreleased) +## v0.7 (unreleased) - The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored. +- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`. ## v0.6.0 (2017-12-12) diff --git a/client.go b/client.go index bc47cd0c..21808536 100644 --- a/client.go +++ b/client.go @@ -60,33 +60,15 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) return Dial(udpConn, udpAddr, addr, tlsConf, config) } -// DialAddrNonFWSecure establishes a new QUIC connection to a server. -// The hostname for SNI is taken from the given address. -func DialAddrNonFWSecure( - addr string, - tlsConf *tls.Config, - config *Config, -) (NonFWSession, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config) -} - -// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn. +// Dial establishes a new QUIC connection to a server using a net.PacketConn. // The host parameter is used for SNI. -func DialNonFWSecure( +func Dial( pconn net.PacketConn, remoteAddr net.Addr, host string, tlsConf *tls.Config, config *Config, -) (NonFWSession, error) { +) (Session, error) { connID, err := generateConnectionID() if err != nil { return nil, err @@ -119,26 +101,7 @@ func DialNonFWSecure( if err := c.dial(); err != nil { return nil, err } - return c.session.(NonFWSession), nil -} - -// Dial establishes a new QUIC connection to a server using a net.PacketConn. -// The host parameter is used for SNI. -func Dial( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (Session, error) { - sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config) - if err != nil { - return nil, err - } - if err := sess.WaitUntilHandshakeComplete(); err != nil { - return nil, err - } - return sess, nil + return c.session, nil } // populateClientConfig populates fields in the quic.Config with their default values, if none are set @@ -268,14 +231,8 @@ func (c *client) establishSecureConnection() error { select { case <-errorChan: return runErr - case ev := <-c.session.handshakeStatus(): - if ev.err != nil { - return ev.err - } - if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure { - return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel) - } - return nil + case err := <-c.session.handshakeStatus(): + return err } } diff --git a/client_test.go b/client_test.go index fe3f8e09..a955b5ea 100644 --- a/client_test.go +++ b/client_test.go @@ -101,57 +101,7 @@ var _ = Describe("Client", func() { generateConnectionID = origGenerateConnectionID }) - It("dials non-forward-secure", func() { - packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config) - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - close(dialed) - }() - Consistently(dialed).ShouldNot(BeClosed()) - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} - Eventually(dialed).Should(BeClosed()) - }) - - It("dials a non-forward-secure address", func() { - serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") - Expect(err).ToNot(HaveOccurred()) - server, err := net.ListenUDP("udp", serverAddr) - Expect(err).ToNot(HaveOccurred()) - defer server.Close() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - for { - _, clientAddr, err := server.ReadFromUDP(make([]byte, 200)) - if err != nil { - return - } - _, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr) - Expect(err).ToNot(HaveOccurred()) - } - }() - - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config) - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - close(dialed) - }() - Consistently(dialed).ShouldNot(BeClosed()) - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} - Eventually(dialed).Should(BeClosed()) - server.Close() - Eventually(done).Should(BeClosed()) - }) - - It("Dial only returns after the handshake is complete", func() { + It("returns after the handshake is complete", func() { packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) dialed := make(chan struct{}) go func() { @@ -161,9 +111,7 @@ var _ = Describe("Client", func() { Expect(s).ToNot(BeNil()) close(dialed) }() - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} - Consistently(dialed).ShouldNot(BeClosed()) - close(sess.handshakeComplete) + close(sess.handshakeChan) Eventually(dialed).Should(BeClosed()) }) @@ -249,22 +197,7 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(testErr)) close(done) }() - sess.handshakeChan <- handshakeEvent{err: testErr} - Eventually(done).Should(BeClosed()) - }) - - It("returns an error that occurs while waiting for the handshake to complete", func() { - testErr := errors.New("late handshake error") - packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) - Expect(err).To(MatchError(testErr)) - close(done) - }() - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} - sess.handshakeComplete <- testErr + sess.handshakeChan <- testErr Eventually(done).Should(BeClosed()) }) @@ -309,7 +242,7 @@ var _ = Describe("Client", func() { ) (packetHandler, error) { return nil, testErr } - _, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(err).To(MatchError(testErr)) }) @@ -335,7 +268,7 @@ var _ = Describe("Client", func() { Expect(newVersion).ToNot(Equal(cl.version)) Expect(config.Versions).To(ContainElement(newVersion)) sessionChan := make(chan *mockSession) - handshakeChan := make(chan handshakeEvent) + handshakeChan := make(chan error) newClientSession = func( _ connection, _ string, @@ -386,7 +319,7 @@ var _ = Describe("Client", func() { Expect(negotiatedVersions).To(ContainElement(newVersion)) Expect(initialVersion).To(Equal(actualInitialVersion)) - handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} + close(handshakeChan) Eventually(established).Should(BeClosed()) }) diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 1e49b0b9..183bfff6 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -2,14 +2,12 @@ package self_test import ( "crypto/tls" - "fmt" "net" "time" quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/internal/testdata" @@ -111,18 +109,6 @@ var _ = Describe("Handshake RTT tests", func() { expectDurationInRTTs(4) }) - // 1 RTT for verifying the source address - // 1 RTT to become secure - // TODO (marten-seemann): enable this test (see #625) - PIt("is secure after 2 RTTs", func() { - utils.SetLogLevel(utils.LogLevelDebug) - runServerAndProxy() - _, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil) - fmt.Println("#### is non fw secure ###") - Expect(err).ToNot(HaveOccurred()) - expectDurationInRTTs(2) - }) - It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() { serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool { return true diff --git a/interface.go b/interface.go index 8025dd27..9a97f86d 100644 --- a/interface.go +++ b/interface.go @@ -130,13 +130,6 @@ type Session interface { Context() context.Context } -// A NonFWSession is a QUIC connection between two peers half-way through the handshake. -// The communication is encrypted, but not yet forward secure. -type NonFWSession interface { - Session - WaitUntilHandshakeComplete() error -} - // Config contains all configuration data needed for a QUIC server or client. type Config struct { // The QUIC versions that can be negotiated. diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index 2df6d6b8..11e43e83 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -51,8 +51,8 @@ type cryptoSetupClient struct { secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD - paramsChan chan<- TransportParameters - aeadChanged chan<- protocol.EncryptionLevel + paramsChan chan<- TransportParameters + handshakeEvent chan<- struct{} params *TransportParameters } @@ -74,7 +74,7 @@ func NewCryptoSetupClient( tlsConfig *tls.Config, params *TransportParameters, paramsChan chan<- TransportParameters, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, ) (CryptoSetup, error) { @@ -93,7 +93,7 @@ func NewCryptoSetupClient( keyExchange: getEphermalKEX, nullAEAD: nullAEAD, paramsChan: paramsChan, - aeadChanged: aeadChanged, + handshakeEvent: handshakeEvent, initialVersion: initialVersion, negotiatedVersions: negotiatedVersions, divNonceChan: make(chan []byte), @@ -159,8 +159,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { } // blocks until the session has received the parameters h.paramsChan <- *params - h.aeadChanged <- protocol.EncryptionForwardSecure - close(h.aeadChanged) + h.handshakeEvent <- struct{}{} + close(h.handshakeEvent) default: return qerr.InvalidCryptoMessageType } @@ -496,10 +496,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error { if err != nil { return err } - - h.aeadChanged <- protocol.EncryptionSecure + h.handshakeEvent <- struct{}{} } - return nil } diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index e19593a8..51d5c4f6 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -79,7 +79,7 @@ var _ = Describe("Client Crypto Setup", func() { stream *mockStream keyDerivationCalledWith *keyDerivationValues shloMap map[Tag][]byte - aeadChanged chan protocol.EncryptionLevel + handshakeEvent chan struct{} paramsChan chan TransportParameters ) @@ -108,7 +108,7 @@ var _ = Describe("Client Crypto Setup", func() { version := protocol.Version39 // use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking paramsChan = make(chan TransportParameters, 1) - aeadChanged = make(chan protocol.EncryptionLevel, 2) + handshakeEvent = make(chan struct{}, 2) csInt, err := NewCryptoSetupClient( stream, "hostname", @@ -117,7 +117,7 @@ var _ = Describe("Client Crypto Setup", func() { nil, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, paramsChan, - aeadChanged, + handshakeEvent, protocol.Version39, nil, ) @@ -385,22 +385,22 @@ var _ = Describe("Client Crypto Setup", func() { cs.receivedSecurePacket = false _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message"))) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("rejects SHLOs without a PUBS", func() { delete(shloMap, TagPUBS) _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS"))) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("rejects SHLOs without a version list", func() { delete(shloMap, TagVER) _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list"))) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("accepts a SHLO after a version negotiation", func() { @@ -435,7 +435,7 @@ var _ = Describe("Client Crypto Setup", func() { Expect(params.IdleTimeout).To(Equal(13 * time.Second)) }) - It("closes the aeadChanged when receiving an SHLO", func() { + It("closes the handshakeEvent chan when receiving an SHLO", func() { HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) done := make(chan struct{}) go func() { @@ -444,8 +444,8 @@ var _ = Describe("Client Crypto Setup", func() { Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() - Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure))) - Eventually(aeadChanged).Should(BeClosed()) + Eventually(handshakeEvent).Should(Receive()) + Eventually(handshakeEvent).Should(BeClosed()) // make the go routine return stream.close() Eventually(done).Should(BeClosed()) @@ -652,9 +652,9 @@ var _ = Describe("Client Crypto Setup", func() { Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert)) Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce)) Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient)) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("uses the server nonce, if the server sent one", func() { @@ -664,24 +664,24 @@ var _ = Describe("Client Crypto Setup", func() { Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).ToNot(BeNil()) Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...))) - Expect(aeadChanged).To(Receive()) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() { err := cs.maybeUpgradeCrypto() Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).To(BeNil()) - Expect(aeadChanged).ToNot(Receive()) + Expect(handshakeEvent).ToNot(Receive()) cs.serverVerified = true // make sure we really had all necessary values before, and only serverVerified was missing err = cs.maybeUpgradeCrypto() Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) }) It("tries to escalate before reading a handshake message", func() { @@ -694,10 +694,10 @@ var _ = Describe("Client Crypto Setup", func() { Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() - Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) + Eventually(handshakeEvent).Should(Receive()) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) // make the go routine return stream.close() Eventually(done).Should(BeClosed()) @@ -715,10 +715,10 @@ var _ = Describe("Client Crypto Setup", func() { cs.serverVerified = true Expect(cs.secureAEAD).To(BeNil()) cs.SetDiversificationNonce([]byte("div")) - Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) + Eventually(handshakeEvent).Should(Receive()) Expect(cs.secureAEAD).ToNot(BeNil()) - Expect(aeadChanged).ToNot(Receive()) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).ToNot(Receive()) + Expect(handshakeEvent).ToNot(BeClosed()) // make the go routine return stream.close() Eventually(done).Should(BeClosed()) diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index 6ff11ab6..4bec2d4a 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -42,7 +42,7 @@ type cryptoSetupServer struct { receivedParams bool paramsChan chan<- TransportParameters - aeadChanged chan<- protocol.EncryptionLevel + handshakeEvent chan<- struct{} keyDerivation QuicCryptoKeyDerivationFunction keyExchange KeyExchangeFunction @@ -76,7 +76,7 @@ func NewCryptoSetup( supportedVersions []protocol.VersionNumber, acceptSTK func(net.Addr, *Cookie) bool, paramsChan chan<- TransportParameters, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, ) (CryptoSetup, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) if err != nil { @@ -96,7 +96,7 @@ func NewCryptoSetup( acceptSTKCallback: acceptSTK, sentSHLO: make(chan struct{}), paramsChan: paramsChan, - aeadChanged: aeadChanged, + handshakeEvent: handshakeEvent, }, nil } @@ -182,7 +182,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if _, err := h.cryptoStream.Write(reply); err != nil { return false, err } - h.aeadChanged <- protocol.EncryptionForwardSecure + h.handshakeEvent <- struct{}{} close(h.sentSHLO) return true, nil } @@ -206,9 +206,9 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu if err == nil { if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client h.receivedForwardSecurePacket = true - // wait until protocol.EncryptionForwardSecure was sent on the aeadChan + // wait for the send on the handshakeEvent chan <-h.sentSHLO - close(h.aeadChanged) + close(h.handshakeEvent) } return res, protocol.EncryptionForwardSecure, nil } @@ -396,8 +396,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T if err != nil { return nil, err } - - h.aeadChanged <- protocol.EncryptionSecure + h.handshakeEvent <- struct{}{} // Generate a new curve instance to derive the forward secure key var fsNonce bytes.Buffer diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index 344c2383..55b21032 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -124,7 +124,7 @@ var _ = Describe("Server Crypto Setup", func() { cs *cryptoSetupServer stream *mockStream paramsChan chan TransportParameters - aeadChanged chan protocol.EncryptionLevel + handshakeEvent chan struct{} nonce32 []byte versionTag []byte validSTK []byte @@ -146,7 +146,7 @@ var _ = Describe("Server Crypto Setup", func() { // use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking paramsChan = make(chan TransportParameters, 1) - aeadChanged = make(chan protocol.EncryptionLevel, 2) + handshakeEvent = make(chan struct{}, 2) stream = newMockStream() kex = &mockKEX{} signer = &mockSigner{} @@ -170,7 +170,7 @@ var _ = Describe("Server Crypto Setup", func() { supportedVersions, nil, paramsChan, - aeadChanged, + handshakeEvent, ) Expect(err).NotTo(HaveOccurred()) cs = csInt.(*cryptoSetupServer) @@ -343,10 +343,10 @@ var _ = Describe("Server Crypto Setup", func() { err := cs.HandleCryptoStream() Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + Expect(handshakeEvent).To(Receive()) // for the switch to secure Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO")) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) // for the switch to forward secure + Expect(handshakeEvent).ToNot(BeClosed()) }) It("rejects client nonces that have the wrong length", func() { @@ -377,9 +377,9 @@ var _ = Describe("Server Crypto Setup", func() { Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) - Expect(aeadChanged).ToNot(BeClosed()) + Expect(handshakeEvent).To(Receive()) // for the switch to secure + Expect(handshakeEvent).To(Receive()) // for the switch to forward secure + Expect(handshakeEvent).ToNot(BeClosed()) }) It("recognizes inchoate CHLOs missing SCID", func() { @@ -535,7 +535,7 @@ var _ = Describe("Server Crypto Setup", func() { TagKEXS: kexs, }) Expect(err).ToNot(HaveOccurred()) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + Expect(handshakeEvent).To(Receive()) // for the switch to secure close(cs.sentSHLO) } @@ -657,7 +657,7 @@ var _ = Describe("Server Crypto Setup", func() { cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{}) _, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{}) Expect(err).ToNot(HaveOccurred()) - Expect(aeadChanged).To(BeClosed()) + Expect(handshakeEvent).To(BeClosed()) }) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 041c0b42..f25bacad 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -26,9 +26,9 @@ type cryptoSetupTLS struct { nullAEAD crypto.AEAD aead crypto.AEAD - tls MintTLS - cryptoStream *CryptoStreamConn - aeadChanged chan<- protocol.EncryptionLevel + tls MintTLS + cryptoStream *CryptoStreamConn + handshakeEvent chan<- struct{} } // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server @@ -36,16 +36,16 @@ func NewCryptoSetupTLSServer( tls MintTLS, cryptoStream *CryptoStreamConn, nullAEAD crypto.AEAD, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, version protocol.VersionNumber, ) CryptoSetup { return &cryptoSetupTLS{ - tls: tls, - cryptoStream: cryptoStream, - nullAEAD: nullAEAD, - perspective: protocol.PerspectiveServer, - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, + tls: tls, + cryptoStream: cryptoStream, + nullAEAD: nullAEAD, + perspective: protocol.PerspectiveServer, + keyDerivation: crypto.DeriveAESKeys, + handshakeEvent: handshakeEvent, } } @@ -54,7 +54,7 @@ func NewCryptoSetupTLSClient( cryptoStream io.ReadWriter, connID protocol.ConnectionID, hostname string, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, tls MintTLS, version protocol.VersionNumber, ) (CryptoSetup, error) { @@ -64,11 +64,11 @@ func NewCryptoSetupTLSClient( } return &cryptoSetupTLS{ - perspective: protocol.PerspectiveClient, - tls: tls, - nullAEAD: nullAEAD, - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, + perspective: protocol.PerspectiveClient, + tls: tls, + nullAEAD: nullAEAD, + keyDerivation: crypto.DeriveAESKeys, + handshakeEvent: handshakeEvent, }, nil } @@ -102,9 +102,8 @@ handshakeLoop: h.aead = aead h.mutex.Unlock() - // signal to the outside world that the handshake completed - h.aeadChanged <- protocol.EncryptionForwardSecure - close(h.aeadChanged) + h.handshakeEvent <- struct{}{} + close(h.handshakeEvent) return nil } diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 03b486e6..4b8a2725 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -20,17 +20,17 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e var _ = Describe("TLS Crypto Setup", func() { var ( - cs *cryptoSetupTLS - aeadChanged chan protocol.EncryptionLevel + cs *cryptoSetupTLS + handshakeEvent chan struct{} ) BeforeEach(func() { - aeadChanged = make(chan protocol.EncryptionLevel, 2) + handshakeEvent = make(chan struct{}, 2) cs = NewCryptoSetupTLSServer( nil, NewCryptoStreamConn(nil), nil, // AEAD - aeadChanged, + handshakeEvent, protocol.VersionTLS, ).(*cryptoSetupTLS) cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl) @@ -51,8 +51,8 @@ var _ = Describe("TLS Crypto Setup", func() { cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) - Expect(aeadChanged).To(BeClosed()) + Expect(handshakeEvent).To(Receive()) + Expect(handshakeEvent).To(BeClosed()) }) It("handshakes until it is connected", func() { @@ -63,7 +63,7 @@ var _ = Describe("TLS Crypto Setup", func() { cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) - Expect(aeadChanged).To(Receive()) + Expect(handshakeEvent).To(Receive()) }) Context("escalating crypto", func() { @@ -180,17 +180,17 @@ var _ = Describe("TLS Crypto Setup", func() { var _ = Describe("TLS Crypto Setup, for the client", func() { var ( - cs *cryptoSetupTLS - aeadChanged chan protocol.EncryptionLevel + cs *cryptoSetupTLS + handshakeEvent chan struct{} ) BeforeEach(func() { - aeadChanged = make(chan protocol.EncryptionLevel, 2) + handshakeEvent = make(chan struct{}) csInt, err := NewCryptoSetupTLSClient( nil, 0, "quic.clemente.io", - aeadChanged, + handshakeEvent, nil, // mintTLS protocol.VersionTLS, ) diff --git a/server.go b/server.go index e7ddae8f..33d58831 100644 --- a/server.go +++ b/server.go @@ -20,7 +20,7 @@ import ( type packetHandler interface { Session getCryptoStream() cryptoStreamI - handshakeStatus() <-chan handshakeEvent + handshakeStatus() <-chan error handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber run() error @@ -391,14 +391,8 @@ func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.C }() go func() { - for { - ev := <-session.handshakeStatus() - if ev.err != nil { - return - } - if ev.encLevel == protocol.EncryptionForwardSecure { - break - } + if err := <-session.handshakeStatus(); err != nil { + return } s.sessionQueue <- session }() diff --git a/server_test.go b/server_test.go index 1ca7e652..31ce2fbc 100644 --- a/server_test.go +++ b/server_test.go @@ -22,14 +22,13 @@ import ( ) type mockSession struct { - connectionID protocol.ConnectionID - packetCount int - closed bool - closeReason error - closedRemote bool - stopRunLoop chan struct{} // run returns as soon as this channel receives a value - handshakeChan chan handshakeEvent - handshakeComplete chan error // for WaitUntilHandshakeComplete + connectionID protocol.ConnectionID + packetCount int + closed bool + closeReason error + closedRemote bool + stopRunLoop chan struct{} // run returns as soon as this channel receives a value + handshakeChan chan error } func (s *mockSession) handlePacket(*receivedPacket) { @@ -40,9 +39,6 @@ func (s *mockSession) run() error { <-s.stopRunLoop return s.closeReason } -func (s *mockSession) WaitUntilHandshakeComplete() error { - return <-s.handshakeComplete -} func (s *mockSession) Close(e error) error { if s.closed { return nil @@ -61,17 +57,16 @@ func (s *mockSession) closeRemote(e error) { func (s *mockSession) OpenStream() (Stream, error) { return &stream{}, nil } -func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } -func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } -func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } -func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } -func (*mockSession) Context() context.Context { panic("not implemented") } -func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } -func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } -func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") } +func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } +func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } +func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } +func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } +func (*mockSession) Context() context.Context { panic("not implemented") } +func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } +func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan } +func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") } var _ Session = &mockSession{} -var _ NonFWSession = &mockSession{} func newMockSession( _ connection, @@ -82,10 +77,9 @@ func newMockSession( _ *Config, ) (packetHandler, error) { s := mockSession{ - connectionID: connectionID, - handshakeChan: make(chan handshakeEvent), - handshakeComplete: make(chan error), - stopRunLoop: make(chan struct{}), + connectionID: connectionID, + handshakeChan: make(chan error), + stopRunLoop: make(chan struct{}), } return &s, nil } @@ -155,9 +149,8 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) sess := serv.sessions[connID].(*mockSession) - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} Consistently(func() Session { return acceptedSess }).Should(BeNil()) - sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionForwardSecure} + close(sess.handshakeChan) Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) close(done) }, 0.5) @@ -173,7 +166,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) sess := serv.sessions[connID].(*mockSession) - sess.handshakeChan <- handshakeEvent{err: errors.New("handshake failed")} + sess.handshakeChan <- errors.New("handshake failed") Consistently(func() bool { return accepted }).Should(BeFalse()) close(done) }) diff --git a/session.go b/session.go index 9b269c78..64a05887 100644 --- a/session.go +++ b/session.go @@ -39,11 +39,6 @@ var ( newCryptoSetupClient = handshake.NewCryptoSetupClient ) -type handshakeEvent struct { - encLevel protocol.EncryptionLevel - err error -} - type closeError struct { err error remote bool @@ -90,17 +85,14 @@ type session struct { // this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them paramsChan <-chan handshake.TransportParameters - // this channel is passed to the CryptoSetup and receives the current encryption level - // it is closed as soon as the handshake is complete - aeadChanged <-chan protocol.EncryptionLevel + // the handshakeEvent channel is passed to the CryptoSetup. + // It receives when it makes sense to try decrypting undecryptable packets. + handshakeEvent <-chan struct{} + // handshakeChan is returned by handshakeStatus. + // It receives any error that might occur during the handshake. + // It is closed when the handshake is complete. + handshakeChan chan error handshakeComplete bool - // will be closed as soon as the handshake completes, and receive any error that might occur until then - // it is used to block WaitUntilHandshakeComplete() - handshakeCompleteChan chan error - // handshakeChan receives handshake events and is closed as soon the handshake completes - // the receiving end of this channel is passed to the creator of the session - // it receives at most 3 handshake events: 2 when the encryption level changes, and one error - handshakeChan chan handshakeEvent lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire @@ -131,15 +123,15 @@ func newSession( config *Config, ) (packetHandler, error) { paramsChan := make(chan handshake.TransportParameters) - aeadChanged := make(chan protocol.EncryptionLevel, 2) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - connectionID: connectionID, - perspective: protocol.PerspectiveServer, - version: v, - config: config, - aeadChanged: aeadChanged, - paramsChan: paramsChan, + conn: conn, + connectionID: connectionID, + perspective: protocol.PerspectiveServer, + version: v, + config: config, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, } s.preSetup() transportParams := &handshake.TransportParameters{ @@ -158,7 +150,7 @@ func newSession( s.config.Versions, s.config.AcceptCookie, paramsChan, - aeadChanged, + handshakeEvent, ) if err != nil { return nil, err @@ -179,15 +171,15 @@ var newClientSession = func( negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton ) (packetHandler, error) { paramsChan := make(chan handshake.TransportParameters) - aeadChanged := make(chan protocol.EncryptionLevel, 2) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - connectionID: connectionID, - perspective: protocol.PerspectiveClient, - version: v, - config: config, - aeadChanged: aeadChanged, - paramsChan: paramsChan, + conn: conn, + connectionID: connectionID, + perspective: protocol.PerspectiveClient, + version: v, + config: config, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, } s.preSetup() transportParams := &handshake.TransportParameters{ @@ -205,7 +197,7 @@ var newClientSession = func( tlsConf, transportParams, paramsChan, - aeadChanged, + handshakeEvent, initialVersion, negotiatedVersions, ) @@ -227,21 +219,21 @@ func newTLSServerSession( peerParams *handshake.TransportParameters, v protocol.VersionNumber, ) (packetHandler, error) { - aeadChanged := make(chan protocol.EncryptionLevel, 2) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - config: config, - connectionID: connectionID, - perspective: protocol.PerspectiveServer, - version: v, - aeadChanged: aeadChanged, + conn: conn, + config: config, + connectionID: connectionID, + perspective: protocol.PerspectiveServer, + version: v, + handshakeEvent: handshakeEvent, } s.preSetup() s.cryptoSetup = handshake.NewCryptoSetupTLSServer( tls, cryptoStreamConn, nullAEAD, - aeadChanged, + handshakeEvent, v, ) if err := s.postSetup(initialPacketNumber); err != nil { @@ -264,15 +256,15 @@ var newTLSClientSession = func( paramsChan <-chan handshake.TransportParameters, initialPacketNumber protocol.PacketNumber, ) (packetHandler, error) { - aeadChanged := make(chan protocol.EncryptionLevel, 2) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - config: config, - connectionID: connectionID, - perspective: protocol.PerspectiveClient, - version: v, - aeadChanged: aeadChanged, - paramsChan: paramsChan, + conn: conn, + config: config, + connectionID: connectionID, + perspective: protocol.PerspectiveClient, + version: v, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, } s.preSetup() tls.SetCryptoStream(s.cryptoStream) @@ -280,7 +272,7 @@ var newTLSClientSession = func( s.cryptoStream, s.connectionID, hostname, - aeadChanged, + handshakeEvent, tls, v, ) @@ -302,8 +294,7 @@ func (s *session) preSetup() { } func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { - s.handshakeChan = make(chan handshakeEvent, 3) - s.handshakeCompleteChan = make(chan error, 1) + s.handshakeChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) @@ -343,7 +334,7 @@ func (s *session) run() error { }() var closeErr closeError - aeadChanged := s.aeadChanged + handshakeEvent := s.handshakeEvent runLoop: for { @@ -381,16 +372,14 @@ runLoop: putPacketBuffer(p.header.Raw) case p := <-s.paramsChan: s.processTransportParameters(&p) - case l, ok := <-aeadChanged: + case _, ok := <-handshakeEvent: if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. s.handshakeComplete = true - aeadChanged = nil // prevent this case from ever being selected again + handshakeEvent = nil // prevent this case from ever being selected again s.sentPacketHandler.SetHandshakeComplete() close(s.handshakeChan) - close(s.handshakeCompleteChan) } else { s.tryDecryptingQueuedPackets() - s.handshakeChan <- handshakeEvent{encLevel: l} } } @@ -428,8 +417,7 @@ runLoop: // only send the error the handshakeChan when the handshake is not completed yet // otherwise this chan will already be closed if !s.handshakeComplete { - s.handshakeCompleteChan <- closeErr.err - s.handshakeChan <- handshakeEvent{err: closeErr.err} + s.handshakeChan <- closeErr.err } s.handleCloseError(closeErr) return closeErr.err @@ -878,10 +866,6 @@ func (s *session) OpenStreamSync() (Stream, error) { return s.streamsMap.OpenStreamSync() } -func (s *session) WaitUntilHandshakeComplete() error { - return <-s.handshakeCompleteChan -} - func (s *session) newStream(id protocol.StreamID) streamI { var initialSendWindow protocol.ByteCount if s.peerParams != nil { @@ -970,7 +954,7 @@ func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *session) handshakeStatus() <-chan handshakeEvent { +func (s *session) handshakeStatus() <-chan error { return s.handshakeChan } diff --git a/session_test.go b/session_test.go index 58bcb23f..250fdad5 100644 --- a/session_test.go +++ b/session_test.go @@ -78,11 +78,11 @@ func areSessionsRunning() bool { var _ = Describe("Session", func() { var ( - sess *session - scfg *handshake.ServerConfig - mconn *mockConnection - cryptoSetup *mockCryptoSetup - aeadChanged chan<- protocol.EncryptionLevel + sess *session + scfg *handshake.ServerConfig + mconn *mockConnection + cryptoSetup *mockCryptoSetup + handshakeChan chan<- struct{} ) BeforeEach(func() { @@ -99,9 +99,9 @@ var _ = Describe("Session", func() { _ []protocol.VersionNumber, _ func(net.Addr, *Cookie) bool, _ chan<- handshake.TransportParameters, - aeadChangedP chan<- protocol.EncryptionLevel, + handshakeChanP chan<- struct{}, ) (handshake.CryptoSetup, error) { - aeadChanged = aeadChangedP + handshakeChan = handshakeChanP return cryptoSetup, nil } @@ -149,7 +149,7 @@ var _ = Describe("Session", func() { _ []protocol.VersionNumber, cookieFunc func(net.Addr, *Cookie) bool, _ chan<- handshake.TransportParameters, - _ chan<- protocol.EncryptionLevel, + _ chan<- struct{}, ) (handshake.CryptoSetup, error) { cookieVerify = cookieFunc return cryptoSetup, nil @@ -516,61 +516,6 @@ var _ = Describe("Session", func() { Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242))) }) - Context("waiting until the handshake completes", func() { - It("waits until the handshake is complete", func() { - go func() { - defer GinkgoRecover() - sess.run() - }() - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - err := sess.WaitUntilHandshakeComplete() - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - aeadChanged <- protocol.EncryptionForwardSecure - Consistently(done).ShouldNot(BeClosed()) - close(aeadChanged) - Eventually(done).Should(BeClosed()) - Expect(sess.Close(nil)).To(Succeed()) - }) - - It("errors if the handshake fails", func(done Done) { - testErr := errors.New("crypto error") - sess.cryptoSetup = &mockCryptoSetup{handleErr: testErr} - go sess.run() - err := sess.WaitUntilHandshakeComplete() - Expect(err).To(MatchError(testErr)) - close(done) - }, 0.5) - - It("returns when Close is called", func(done Done) { - testErr := errors.New("close error") - go sess.run() - var waitReturned bool - go func() { - defer GinkgoRecover() - err := sess.WaitUntilHandshakeComplete() - Expect(err).To(MatchError(testErr)) - waitReturned = true - }() - sess.Close(testErr) - Eventually(func() bool { return waitReturned }).Should(BeTrue()) - close(done) - }) - - It("doesn't wait if the handshake is already completed", func(done Done) { - go sess.run() - close(aeadChanged) - err := sess.WaitUntilHandshakeComplete() - Expect(err).ToNot(HaveOccurred()) - Expect(sess.Close(nil)).To(Succeed()) - close(done) - }) - }) - Context("accepting streams", func() { BeforeEach(func() { // don't use the mock here @@ -1362,46 +1307,48 @@ var _ = Describe("Session", func() { }) }) - It("send a handshake event on the handshakeChan when the AEAD changes to secure", func(done Done) { - go sess.run() - aeadChanged <- protocol.EncryptionSecure - Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionSecure})) + It("doesn't do anything when the crypto setup says to decrypt undecryptable packets", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + handshakeChan <- struct{}{} + Consistently(sess.handshakeStatus()).ShouldNot(Receive()) + // make sure the go routine returns Expect(sess.Close(nil)).To(Succeed()) - close(done) + Eventually(done).Should(BeClosed()) }) - It("send a handshake event on the handshakeChan when the AEAD changes to forward-secure", func(done Done) { - go sess.run() - aeadChanged <- protocol.EncryptionForwardSecure - Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionForwardSecure})) - Expect(sess.Close(nil)).To(Succeed()) - close(done) - }) - - It("closes the handshakeChan when the handshake completes", func(done Done) { - go sess.run() - close(aeadChanged) + It("closes the handshakeChan when the handshake completes", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + close(handshakeChan) Eventually(sess.handshakeStatus()).Should(BeClosed()) + // make sure the go routine returns Expect(sess.Close(nil)).To(Succeed()) - close(done) + Eventually(done).Should(BeClosed()) }) - It("passes errors to the handshakeChan", func(done Done) { + It("passes errors to the handshakeChan", func() { testErr := errors.New("handshake error") - go sess.run() - Expect(sess.Close(nil)).To(Succeed()) - Expect(sess.handshakeStatus()).To(Receive(&handshakeEvent{err: testErr})) - close(done) - }) - - It("does not block if an error occurs", func(done Done) { - // this test basically tests that the handshakeChan has a capacity of 3 - // The session needs to run (and close) properly, even if no one is receiving from the handshakeChan - go sess.run() - aeadChanged <- protocol.EncryptionSecure - aeadChanged <- protocol.EncryptionForwardSecure - Expect(sess.Close(nil)).To(Succeed()) - close(done) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := sess.run() + Expect(err).To(MatchError(testErr)) + close(done) + }() + sess.Close(testErr) + Expect(sess.handshakeStatus()).To(Receive(Equal(testErr))) + Eventually(done).Should(BeClosed()) }) It("process transport parameters received from the peer", func() { @@ -1503,7 +1450,7 @@ var _ = Describe("Session", func() { It("closes the session due to the idle timeout after handshake", func() { sess.config.IdleTimeout = 0 - close(aeadChanged) + close(handshakeChan) errChan := make(chan error) go func() { defer GinkgoRecover() @@ -1619,9 +1566,9 @@ var _ = Describe("Session", func() { var _ = Describe("Client Session", func() { var ( - sess *session - mconn *mockConnection - aeadChanged chan<- protocol.EncryptionLevel + sess *session + mconn *mockConnection + handshakeChan chan<- struct{} cryptoSetup *mockCryptoSetup ) @@ -1638,11 +1585,11 @@ var _ = Describe("Client Session", func() { _ *tls.Config, _ *handshake.TransportParameters, _ chan<- handshake.TransportParameters, - aeadChangedP chan<- protocol.EncryptionLevel, + handshakeChanP chan<- struct{}, _ protocol.VersionNumber, _ []protocol.VersionNumber, ) (handshake.CryptoSetup, error) { - aeadChanged = aeadChangedP + handshakeChan = handshakeChanP return cryptoSetup, nil } @@ -1674,10 +1621,13 @@ var _ = Describe("Client Session", func() { sess.unpacker = &mockUnpacker{} }) - It("passes the diversification nonce to the cryptoSetup", func() { + It("passes the diversification nonce to the crypto setup", func() { + done := make(chan struct{}) go func() { defer GinkgoRecover() - sess.run() + err := sess.run() + Expect(err).ToNot(HaveOccurred()) + close(done) }() hdr.PacketNumber = 5 hdr.DiversificationNonce = []byte("foobar") @@ -1685,16 +1635,7 @@ var _ = Describe("Client Session", func() { Expect(err).ToNot(HaveOccurred()) Eventually(func() []byte { return cryptoSetup.divNonce }).Should(Equal(hdr.DiversificationNonce)) Expect(sess.Close(nil)).To(Succeed()) + Eventually(done).Should(BeClosed()) }) }) - - It("does not block if an error occurs", func(done Done) { - // this test basically tests that the handshakeChan has a capacity of 3 - // The session needs to run (and close) properly, even if no one is receiving from the handshakeChan - go sess.run() - aeadChanged <- protocol.EncryptionSecure - aeadChanged <- protocol.EncryptionForwardSecure - Expect(sess.Close(nil)).To(Succeed()) - close(done) - }) })