remove non forward-secure dialing

This was broken for a long time, and won't be available when using the
TLS 1.3 handshake.
This commit is contained in:
Marten Seemann 2017-12-26 17:50:08 +07:00
parent bd60e996dc
commit 99a2853e7d
15 changed files with 219 additions and 441 deletions

View file

@ -1,8 +1,9 @@
# Changelog # Changelog
## v0.6.1 (unreleased) ## v0.7 (unreleased)
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored. - The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
## v0.6.0 (2017-12-12) ## v0.6.0 (2017-12-12)

View file

@ -60,33 +60,15 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
return Dial(udpConn, udpAddr, addr, tlsConf, config) return Dial(udpConn, udpAddr, addr, tlsConf, config)
} }
// DialAddrNonFWSecure establishes a new QUIC connection to a server. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func DialNonFWSecure( func Dial(
pconn net.PacketConn, pconn net.PacketConn,
remoteAddr net.Addr, remoteAddr net.Addr,
host string, host string,
tlsConf *tls.Config, tlsConf *tls.Config,
config *Config, config *Config,
) (NonFWSession, error) { ) (Session, error) {
connID, err := generateConnectionID() connID, err := generateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
@ -119,26 +101,7 @@ func DialNonFWSecure(
if err := c.dial(); err != nil { if err := c.dial(); err != nil {
return nil, err return nil, err
} }
return c.session.(NonFWSession), nil return c.session, nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
if err != nil {
return nil, err
}
if err := sess.WaitUntilHandshakeComplete(); err != nil {
return nil, err
}
return sess, nil
} }
// populateClientConfig populates fields in the quic.Config with their default values, if none are set // populateClientConfig populates fields in the quic.Config with their default values, if none are set
@ -268,14 +231,8 @@ func (c *client) establishSecureConnection() error {
select { select {
case <-errorChan: case <-errorChan:
return runErr return runErr
case ev := <-c.session.handshakeStatus(): case err := <-c.session.handshakeStatus():
if ev.err != nil { return err
return ev.err
}
if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure {
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
}
return nil
} }
} }

View file

@ -101,57 +101,7 @@ var _ = Describe("Client", func() {
generateConnectionID = origGenerateConnectionID generateConnectionID = origGenerateConnectionID
}) })
It("dials non-forward-secure", func() { It("returns after the handshake is complete", func() {
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(dialed).Should(BeClosed())
})
It("dials a non-forward-secure address", func() {
serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
server, err := net.ListenUDP("udp", serverAddr)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
for {
_, clientAddr, err := server.ReadFromUDP(make([]byte, 200))
if err != nil {
return
}
_, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr)
Expect(err).ToNot(HaveOccurred())
}
}()
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Eventually(dialed).Should(BeClosed())
server.Close()
Eventually(done).Should(BeClosed())
})
It("Dial only returns after the handshake is complete", func() {
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID) packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{}) dialed := make(chan struct{})
go func() { go func() {
@ -161,9 +111,7 @@ var _ = Describe("Client", func() {
Expect(s).ToNot(BeNil()) Expect(s).ToNot(BeNil())
close(dialed) close(dialed)
}() }()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} close(sess.handshakeChan)
Consistently(dialed).ShouldNot(BeClosed())
close(sess.handshakeComplete)
Eventually(dialed).Should(BeClosed()) Eventually(dialed).Should(BeClosed())
}) })
@ -249,22 +197,7 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done) close(done)
}() }()
sess.handshakeChan <- handshakeEvent{err: testErr} sess.handshakeChan <- testErr
Eventually(done).Should(BeClosed())
})
It("returns an error that occurs while waiting for the handshake to complete", func() {
testErr := errors.New("late handshake error")
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr))
close(done)
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
sess.handshakeComplete <- testErr
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -309,7 +242,7 @@ var _ = Describe("Client", func() {
) (packetHandler, error) { ) (packetHandler, error) {
return nil, testErr return nil, testErr
} }
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config) _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -335,7 +268,7 @@ var _ = Describe("Client", func() {
Expect(newVersion).ToNot(Equal(cl.version)) Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion)) Expect(config.Versions).To(ContainElement(newVersion))
sessionChan := make(chan *mockSession) sessionChan := make(chan *mockSession)
handshakeChan := make(chan handshakeEvent) handshakeChan := make(chan error)
newClientSession = func( newClientSession = func(
_ connection, _ connection,
_ string, _ string,
@ -386,7 +319,7 @@ var _ = Describe("Client", func() {
Expect(negotiatedVersions).To(ContainElement(newVersion)) Expect(negotiatedVersions).To(ContainElement(newVersion))
Expect(initialVersion).To(Equal(actualInitialVersion)) Expect(initialVersion).To(Equal(actualInitialVersion))
handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} close(handshakeChan)
Eventually(established).Should(BeClosed()) Eventually(established).Should(BeClosed())
}) })

View file

@ -2,14 +2,12 @@ package self_test
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"time" "time"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
@ -111,18 +109,6 @@ var _ = Describe("Handshake RTT tests", func() {
expectDurationInRTTs(4) expectDurationInRTTs(4)
}) })
// 1 RTT for verifying the source address
// 1 RTT to become secure
// TODO (marten-seemann): enable this test (see #625)
PIt("is secure after 2 RTTs", func() {
utils.SetLogLevel(utils.LogLevelDebug)
runServerAndProxy()
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
fmt.Println("#### is non fw secure ###")
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() { It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool { serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return true return true

View file

@ -130,13 +130,6 @@ type Session interface {
Context() context.Context Context() context.Context
} }
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
// The communication is encrypted, but not yet forward secure.
type NonFWSession interface {
Session
WaitUntilHandshakeComplete() error
}
// Config contains all configuration data needed for a QUIC server or client. // Config contains all configuration data needed for a QUIC server or client.
type Config struct { type Config struct {
// The QUIC versions that can be negotiated. // The QUIC versions that can be negotiated.

View file

@ -51,8 +51,8 @@ type cryptoSetupClient struct {
secureAEAD crypto.AEAD secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel handshakeEvent chan<- struct{}
params *TransportParameters params *TransportParameters
} }
@ -74,7 +74,7 @@ func NewCryptoSetupClient(
tlsConfig *tls.Config, tlsConfig *tls.Config,
params *TransportParameters, params *TransportParameters,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
@ -93,7 +93,7 @@ func NewCryptoSetupClient(
keyExchange: getEphermalKEX, keyExchange: getEphermalKEX,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
paramsChan: paramsChan, paramsChan: paramsChan,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
initialVersion: initialVersion, initialVersion: initialVersion,
negotiatedVersions: negotiatedVersions, negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte), divNonceChan: make(chan []byte),
@ -159,8 +159,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
} }
// blocks until the session has received the parameters // blocks until the session has received the parameters
h.paramsChan <- *params h.paramsChan <- *params
h.aeadChanged <- protocol.EncryptionForwardSecure h.handshakeEvent <- struct{}{}
close(h.aeadChanged) close(h.handshakeEvent)
default: default:
return qerr.InvalidCryptoMessageType return qerr.InvalidCryptoMessageType
} }
@ -496,10 +496,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if err != nil { if err != nil {
return err return err
} }
h.handshakeEvent <- struct{}{}
h.aeadChanged <- protocol.EncryptionSecure
} }
return nil return nil
} }

View file

@ -79,7 +79,7 @@ var _ = Describe("Client Crypto Setup", func() {
stream *mockStream stream *mockStream
keyDerivationCalledWith *keyDerivationValues keyDerivationCalledWith *keyDerivationValues
shloMap map[Tag][]byte shloMap map[Tag][]byte
aeadChanged chan protocol.EncryptionLevel handshakeEvent chan struct{}
paramsChan chan TransportParameters paramsChan chan TransportParameters
) )
@ -108,7 +108,7 @@ var _ = Describe("Client Crypto Setup", func() {
version := protocol.Version39 version := protocol.Version39
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking // use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1) paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2) handshakeEvent = make(chan struct{}, 2)
csInt, err := NewCryptoSetupClient( csInt, err := NewCryptoSetupClient(
stream, stream,
"hostname", "hostname",
@ -117,7 +117,7 @@ var _ = Describe("Client Crypto Setup", func() {
nil, nil,
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
paramsChan, paramsChan,
aeadChanged, handshakeEvent,
protocol.Version39, protocol.Version39,
nil, nil,
) )
@ -385,22 +385,22 @@ var _ = Describe("Client Crypto Setup", func() {
cs.receivedSecurePacket = false cs.receivedSecurePacket = false
_, err := cs.handleSHLOMessage(shloMap) _, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")))
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("rejects SHLOs without a PUBS", func() { It("rejects SHLOs without a PUBS", func() {
delete(shloMap, TagPUBS) delete(shloMap, TagPUBS)
_, err := cs.handleSHLOMessage(shloMap) _, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")))
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("rejects SHLOs without a version list", func() { It("rejects SHLOs without a version list", func() {
delete(shloMap, TagVER) delete(shloMap, TagVER)
_, err := cs.handleSHLOMessage(shloMap) _, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")))
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("accepts a SHLO after a version negotiation", func() { It("accepts a SHLO after a version negotiation", func() {
@ -435,7 +435,7 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(params.IdleTimeout).To(Equal(13 * time.Second)) Expect(params.IdleTimeout).To(Equal(13 * time.Second))
}) })
It("closes the aeadChanged when receiving an SHLO", func() { It("closes the handshakeEvent chan when receiving an SHLO", func() {
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -444,8 +444,8 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done) close(done)
}() }()
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure))) Eventually(handshakeEvent).Should(Receive())
Eventually(aeadChanged).Should(BeClosed()) Eventually(handshakeEvent).Should(BeClosed())
// make the go routine return // make the go routine return
stream.close() stream.close()
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -652,9 +652,9 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert)) Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert))
Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce)) Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce))
Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient)) Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("uses the server nonce, if the server sent one", func() { It("uses the server nonce, if the server sent one", func() {
@ -664,24 +664,24 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.secureAEAD).ToNot(BeNil())
Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...))) Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...)))
Expect(aeadChanged).To(Receive()) Expect(handshakeEvent).To(Receive())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() { It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() {
err := cs.maybeUpgradeCrypto() err := cs.maybeUpgradeCrypto()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).To(BeNil()) Expect(cs.secureAEAD).To(BeNil())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
cs.serverVerified = true cs.serverVerified = true
// make sure we really had all necessary values before, and only serverVerified was missing // make sure we really had all necessary values before, and only serverVerified was missing
err = cs.maybeUpgradeCrypto() err = cs.maybeUpgradeCrypto()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("tries to escalate before reading a handshake message", func() { It("tries to escalate before reading a handshake message", func() {
@ -694,10 +694,10 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
close(done) close(done)
}() }()
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) Eventually(handshakeEvent).Should(Receive())
Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
// make the go routine return // make the go routine return
stream.close() stream.close()
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -715,10 +715,10 @@ var _ = Describe("Client Crypto Setup", func() {
cs.serverVerified = true cs.serverVerified = true
Expect(cs.secureAEAD).To(BeNil()) Expect(cs.secureAEAD).To(BeNil())
cs.SetDiversificationNonce([]byte("div")) cs.SetDiversificationNonce([]byte("div"))
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) Eventually(handshakeEvent).Should(Receive())
Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).ToNot(Receive()) Expect(handshakeEvent).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
// make the go routine return // make the go routine return
stream.close() stream.close()
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())

View file

@ -42,7 +42,7 @@ type cryptoSetupServer struct {
receivedParams bool receivedParams bool
paramsChan chan<- TransportParameters paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel handshakeEvent chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction keyExchange KeyExchangeFunction
@ -76,7 +76,7 @@ func NewCryptoSetup(
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool, acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
) (CryptoSetup, error) { ) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil { if err != nil {
@ -96,7 +96,7 @@ func NewCryptoSetup(
acceptSTKCallback: acceptSTK, acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}), sentSHLO: make(chan struct{}),
paramsChan: paramsChan, paramsChan: paramsChan,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
}, nil }, nil
} }
@ -182,7 +182,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
if _, err := h.cryptoStream.Write(reply); err != nil { if _, err := h.cryptoStream.Write(reply); err != nil {
return false, err return false, err
} }
h.aeadChanged <- protocol.EncryptionForwardSecure h.handshakeEvent <- struct{}{}
close(h.sentSHLO) close(h.sentSHLO)
return true, nil return true, nil
} }
@ -206,9 +206,9 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
if err == nil { if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.receivedForwardSecurePacket = true h.receivedForwardSecurePacket = true
// wait until protocol.EncryptionForwardSecure was sent on the aeadChan // wait for the send on the handshakeEvent chan
<-h.sentSHLO <-h.sentSHLO
close(h.aeadChanged) close(h.handshakeEvent)
} }
return res, protocol.EncryptionForwardSecure, nil return res, protocol.EncryptionForwardSecure, nil
} }
@ -396,8 +396,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
if err != nil { if err != nil {
return nil, err return nil, err
} }
h.handshakeEvent <- struct{}{}
h.aeadChanged <- protocol.EncryptionSecure
// Generate a new curve instance to derive the forward secure key // Generate a new curve instance to derive the forward secure key
var fsNonce bytes.Buffer var fsNonce bytes.Buffer

View file

@ -124,7 +124,7 @@ var _ = Describe("Server Crypto Setup", func() {
cs *cryptoSetupServer cs *cryptoSetupServer
stream *mockStream stream *mockStream
paramsChan chan TransportParameters paramsChan chan TransportParameters
aeadChanged chan protocol.EncryptionLevel handshakeEvent chan struct{}
nonce32 []byte nonce32 []byte
versionTag []byte versionTag []byte
validSTK []byte validSTK []byte
@ -146,7 +146,7 @@ var _ = Describe("Server Crypto Setup", func() {
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking // use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1) paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2) handshakeEvent = make(chan struct{}, 2)
stream = newMockStream() stream = newMockStream()
kex = &mockKEX{} kex = &mockKEX{}
signer = &mockSigner{} signer = &mockSigner{}
@ -170,7 +170,7 @@ var _ = Describe("Server Crypto Setup", func() {
supportedVersions, supportedVersions,
nil, nil,
paramsChan, paramsChan,
aeadChanged, handshakeEvent,
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer) cs = csInt.(*cryptoSetupServer)
@ -343,10 +343,10 @@ var _ = Describe("Server Crypto Setup", func() {
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO")) Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("rejects client nonces that have the wrong length", func() { It("rejects client nonces that have the wrong length", func() {
@ -377,9 +377,9 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(aeadChanged).ToNot(BeClosed()) Expect(handshakeEvent).ToNot(BeClosed())
}) })
It("recognizes inchoate CHLOs missing SCID", func() { It("recognizes inchoate CHLOs missing SCID", func() {
@ -535,7 +535,7 @@ var _ = Describe("Server Crypto Setup", func() {
TagKEXS: kexs, TagKEXS: kexs,
}) })
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(handshakeEvent).To(Receive()) // for the switch to secure
close(cs.sentSHLO) close(cs.sentSHLO)
} }
@ -657,7 +657,7 @@ var _ = Describe("Server Crypto Setup", func() {
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{}) cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{}) _, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(BeClosed()) Expect(handshakeEvent).To(BeClosed())
}) })
}) })

View file

@ -26,9 +26,9 @@ type cryptoSetupTLS struct {
nullAEAD crypto.AEAD nullAEAD crypto.AEAD
aead crypto.AEAD aead crypto.AEAD
tls MintTLS tls MintTLS
cryptoStream *CryptoStreamConn cryptoStream *CryptoStreamConn
aeadChanged chan<- protocol.EncryptionLevel handshakeEvent chan<- struct{}
} }
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
@ -36,16 +36,16 @@ func NewCryptoSetupTLSServer(
tls MintTLS, tls MintTLS,
cryptoStream *CryptoStreamConn, cryptoStream *CryptoStreamConn,
nullAEAD crypto.AEAD, nullAEAD crypto.AEAD,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
version protocol.VersionNumber, version protocol.VersionNumber,
) CryptoSetup { ) CryptoSetup {
return &cryptoSetupTLS{ return &cryptoSetupTLS{
tls: tls, tls: tls,
cryptoStream: cryptoStream, cryptoStream: cryptoStream,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
} }
} }
@ -54,7 +54,7 @@ func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter, cryptoStream io.ReadWriter,
connID protocol.ConnectionID, connID protocol.ConnectionID,
hostname string, hostname string,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
tls MintTLS, tls MintTLS,
version protocol.VersionNumber, version protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
@ -64,11 +64,11 @@ func NewCryptoSetupTLSClient(
} }
return &cryptoSetupTLS{ return &cryptoSetupTLS{
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
tls: tls, tls: tls,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
}, nil }, nil
} }
@ -102,9 +102,8 @@ handshakeLoop:
h.aead = aead h.aead = aead
h.mutex.Unlock() h.mutex.Unlock()
// signal to the outside world that the handshake completed h.handshakeEvent <- struct{}{}
h.aeadChanged <- protocol.EncryptionForwardSecure close(h.handshakeEvent)
close(h.aeadChanged)
return nil return nil
} }

View file

@ -20,17 +20,17 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
var _ = Describe("TLS Crypto Setup", func() { var _ = Describe("TLS Crypto Setup", func() {
var ( var (
cs *cryptoSetupTLS cs *cryptoSetupTLS
aeadChanged chan protocol.EncryptionLevel handshakeEvent chan struct{}
) )
BeforeEach(func() { BeforeEach(func() {
aeadChanged = make(chan protocol.EncryptionLevel, 2) handshakeEvent = make(chan struct{}, 2)
cs = NewCryptoSetupTLSServer( cs = NewCryptoSetupTLSServer(
nil, nil,
NewCryptoStreamConn(nil), NewCryptoStreamConn(nil),
nil, // AEAD nil, // AEAD
aeadChanged, handshakeEvent,
protocol.VersionTLS, protocol.VersionTLS,
).(*cryptoSetupTLS) ).(*cryptoSetupTLS)
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl) cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
@ -51,8 +51,8 @@ var _ = Describe("TLS Crypto Setup", func() {
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) Expect(handshakeEvent).To(Receive())
Expect(aeadChanged).To(BeClosed()) Expect(handshakeEvent).To(BeClosed())
}) })
It("handshakes until it is connected", func() { It("handshakes until it is connected", func() {
@ -63,7 +63,7 @@ var _ = Describe("TLS Crypto Setup", func() {
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive()) Expect(handshakeEvent).To(Receive())
}) })
Context("escalating crypto", func() { Context("escalating crypto", func() {
@ -180,17 +180,17 @@ var _ = Describe("TLS Crypto Setup", func() {
var _ = Describe("TLS Crypto Setup, for the client", func() { var _ = Describe("TLS Crypto Setup, for the client", func() {
var ( var (
cs *cryptoSetupTLS cs *cryptoSetupTLS
aeadChanged chan protocol.EncryptionLevel handshakeEvent chan struct{}
) )
BeforeEach(func() { BeforeEach(func() {
aeadChanged = make(chan protocol.EncryptionLevel, 2) handshakeEvent = make(chan struct{})
csInt, err := NewCryptoSetupTLSClient( csInt, err := NewCryptoSetupTLSClient(
nil, nil,
0, 0,
"quic.clemente.io", "quic.clemente.io",
aeadChanged, handshakeEvent,
nil, // mintTLS nil, // mintTLS
protocol.VersionTLS, protocol.VersionTLS,
) )

View file

@ -20,7 +20,7 @@ import (
type packetHandler interface { type packetHandler interface {
Session Session
getCryptoStream() cryptoStreamI getCryptoStream() cryptoStreamI
handshakeStatus() <-chan handshakeEvent handshakeStatus() <-chan error
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber GetVersion() protocol.VersionNumber
run() error run() error
@ -391,14 +391,8 @@ func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.C
}() }()
go func() { go func() {
for { if err := <-session.handshakeStatus(); err != nil {
ev := <-session.handshakeStatus() return
if ev.err != nil {
return
}
if ev.encLevel == protocol.EncryptionForwardSecure {
break
}
} }
s.sessionQueue <- session s.sessionQueue <- session
}() }()

View file

@ -22,14 +22,13 @@ import (
) )
type mockSession struct { type mockSession struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
packetCount int packetCount int
closed bool closed bool
closeReason error closeReason error
closedRemote bool closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan handshakeEvent handshakeChan chan error
handshakeComplete chan error // for WaitUntilHandshakeComplete
} }
func (s *mockSession) handlePacket(*receivedPacket) { func (s *mockSession) handlePacket(*receivedPacket) {
@ -40,9 +39,6 @@ func (s *mockSession) run() error {
<-s.stopRunLoop <-s.stopRunLoop
return s.closeReason return s.closeReason
} }
func (s *mockSession) WaitUntilHandshakeComplete() error {
return <-s.handshakeComplete
}
func (s *mockSession) Close(e error) error { func (s *mockSession) Close(e error) error {
if s.closed { if s.closed {
return nil return nil
@ -61,17 +57,16 @@ func (s *mockSession) closeRemote(e error) {
func (s *mockSession) OpenStream() (Stream, error) { func (s *mockSession) OpenStream() (Stream, error) {
return &stream{}, nil return &stream{}, nil
} }
func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") }
func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") }
func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") }
func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") }
func (*mockSession) Context() context.Context { panic("not implemented") } func (*mockSession) Context() context.Context { panic("not implemented") }
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") } func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
var _ Session = &mockSession{} var _ Session = &mockSession{}
var _ NonFWSession = &mockSession{}
func newMockSession( func newMockSession(
_ connection, _ connection,
@ -82,10 +77,9 @@ func newMockSession(
_ *Config, _ *Config,
) (packetHandler, error) { ) (packetHandler, error) {
s := mockSession{ s := mockSession{
connectionID: connectionID, connectionID: connectionID,
handshakeChan: make(chan handshakeEvent), handshakeChan: make(chan error),
handshakeComplete: make(chan error), stopRunLoop: make(chan struct{}),
stopRunLoop: make(chan struct{}),
} }
return &s, nil return &s, nil
} }
@ -155,9 +149,8 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[connID].(*mockSession) sess := serv.sessions[connID].(*mockSession)
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Consistently(func() Session { return acceptedSess }).Should(BeNil()) Consistently(func() Session { return acceptedSess }).Should(BeNil())
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionForwardSecure} close(sess.handshakeChan)
Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
close(done) close(done)
}, 0.5) }, 0.5)
@ -173,7 +166,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[connID].(*mockSession) sess := serv.sessions[connID].(*mockSession)
sess.handshakeChan <- handshakeEvent{err: errors.New("handshake failed")} sess.handshakeChan <- errors.New("handshake failed")
Consistently(func() bool { return accepted }).Should(BeFalse()) Consistently(func() bool { return accepted }).Should(BeFalse())
close(done) close(done)
}) })

View file

@ -39,11 +39,6 @@ var (
newCryptoSetupClient = handshake.NewCryptoSetupClient newCryptoSetupClient = handshake.NewCryptoSetupClient
) )
type handshakeEvent struct {
encLevel protocol.EncryptionLevel
err error
}
type closeError struct { type closeError struct {
err error err error
remote bool remote bool
@ -90,17 +85,14 @@ type session struct {
// this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them // this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them
paramsChan <-chan handshake.TransportParameters paramsChan <-chan handshake.TransportParameters
// this channel is passed to the CryptoSetup and receives the current encryption level // the handshakeEvent channel is passed to the CryptoSetup.
// it is closed as soon as the handshake is complete // It receives when it makes sense to try decrypting undecryptable packets.
aeadChanged <-chan protocol.EncryptionLevel handshakeEvent <-chan struct{}
// handshakeChan is returned by handshakeStatus.
// It receives any error that might occur during the handshake.
// It is closed when the handshake is complete.
handshakeChan chan error
handshakeComplete bool handshakeComplete bool
// will be closed as soon as the handshake completes, and receive any error that might occur until then
// it is used to block WaitUntilHandshakeComplete()
handshakeCompleteChan chan error
// handshakeChan receives handshake events and is closed as soon the handshake completes
// the receiving end of this channel is passed to the creator of the session
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error
handshakeChan chan handshakeEvent
lastRcvdPacketNumber protocol.PacketNumber lastRcvdPacketNumber protocol.PacketNumber
// Used to calculate the next packet number from the truncated wire // Used to calculate the next packet number from the truncated wire
@ -131,15 +123,15 @@ func newSession(
config *Config, config *Config,
) (packetHandler, error) { ) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters) paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
config: config, config: config,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
paramsChan: paramsChan, paramsChan: paramsChan,
} }
s.preSetup() s.preSetup()
transportParams := &handshake.TransportParameters{ transportParams := &handshake.TransportParameters{
@ -158,7 +150,7 @@ func newSession(
s.config.Versions, s.config.Versions,
s.config.AcceptCookie, s.config.AcceptCookie,
paramsChan, paramsChan,
aeadChanged, handshakeEvent,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -179,15 +171,15 @@ var newClientSession = func(
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
) (packetHandler, error) { ) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters) paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
config: config, config: config,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
paramsChan: paramsChan, paramsChan: paramsChan,
} }
s.preSetup() s.preSetup()
transportParams := &handshake.TransportParameters{ transportParams := &handshake.TransportParameters{
@ -205,7 +197,7 @@ var newClientSession = func(
tlsConf, tlsConf,
transportParams, transportParams,
paramsChan, paramsChan,
aeadChanged, handshakeEvent,
initialVersion, initialVersion,
negotiatedVersions, negotiatedVersions,
) )
@ -227,21 +219,21 @@ func newTLSServerSession(
peerParams *handshake.TransportParameters, peerParams *handshake.TransportParameters,
v protocol.VersionNumber, v protocol.VersionNumber,
) (packetHandler, error) { ) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
config: config, config: config,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
} }
s.preSetup() s.preSetup()
s.cryptoSetup = handshake.NewCryptoSetupTLSServer( s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
tls, tls,
cryptoStreamConn, cryptoStreamConn,
nullAEAD, nullAEAD,
aeadChanged, handshakeEvent,
v, v,
) )
if err := s.postSetup(initialPacketNumber); err != nil { if err := s.postSetup(initialPacketNumber); err != nil {
@ -264,15 +256,15 @@ var newTLSClientSession = func(
paramsChan <-chan handshake.TransportParameters, paramsChan <-chan handshake.TransportParameters,
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
) (packetHandler, error) { ) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2) handshakeEvent := make(chan struct{}, 1)
s := &session{ s := &session{
conn: conn, conn: conn,
config: config, config: config,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
paramsChan: paramsChan, paramsChan: paramsChan,
} }
s.preSetup() s.preSetup()
tls.SetCryptoStream(s.cryptoStream) tls.SetCryptoStream(s.cryptoStream)
@ -280,7 +272,7 @@ var newTLSClientSession = func(
s.cryptoStream, s.cryptoStream,
s.connectionID, s.connectionID,
hostname, hostname,
aeadChanged, handshakeEvent,
tls, tls,
v, v,
) )
@ -302,8 +294,7 @@ func (s *session) preSetup() {
} }
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.handshakeChan = make(chan handshakeEvent, 3) s.handshakeChan = make(chan error, 1)
s.handshakeCompleteChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
@ -343,7 +334,7 @@ func (s *session) run() error {
}() }()
var closeErr closeError var closeErr closeError
aeadChanged := s.aeadChanged handshakeEvent := s.handshakeEvent
runLoop: runLoop:
for { for {
@ -381,16 +372,14 @@ runLoop:
putPacketBuffer(p.header.Raw) putPacketBuffer(p.header.Raw)
case p := <-s.paramsChan: case p := <-s.paramsChan:
s.processTransportParameters(&p) s.processTransportParameters(&p)
case l, ok := <-aeadChanged: case _, ok := <-handshakeEvent:
if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
s.handshakeComplete = true s.handshakeComplete = true
aeadChanged = nil // prevent this case from ever being selected again handshakeEvent = nil // prevent this case from ever being selected again
s.sentPacketHandler.SetHandshakeComplete() s.sentPacketHandler.SetHandshakeComplete()
close(s.handshakeChan) close(s.handshakeChan)
close(s.handshakeCompleteChan)
} else { } else {
s.tryDecryptingQueuedPackets() s.tryDecryptingQueuedPackets()
s.handshakeChan <- handshakeEvent{encLevel: l}
} }
} }
@ -428,8 +417,7 @@ runLoop:
// only send the error the handshakeChan when the handshake is not completed yet // only send the error the handshakeChan when the handshake is not completed yet
// otherwise this chan will already be closed // otherwise this chan will already be closed
if !s.handshakeComplete { if !s.handshakeComplete {
s.handshakeCompleteChan <- closeErr.err s.handshakeChan <- closeErr.err
s.handshakeChan <- handshakeEvent{err: closeErr.err}
} }
s.handleCloseError(closeErr) s.handleCloseError(closeErr)
return closeErr.err return closeErr.err
@ -878,10 +866,6 @@ func (s *session) OpenStreamSync() (Stream, error) {
return s.streamsMap.OpenStreamSync() return s.streamsMap.OpenStreamSync()
} }
func (s *session) WaitUntilHandshakeComplete() error {
return <-s.handshakeCompleteChan
}
func (s *session) newStream(id protocol.StreamID) streamI { func (s *session) newStream(id protocol.StreamID) streamI {
var initialSendWindow protocol.ByteCount var initialSendWindow protocol.ByteCount
if s.peerParams != nil { if s.peerParams != nil {
@ -970,7 +954,7 @@ func (s *session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr() return s.conn.RemoteAddr()
} }
func (s *session) handshakeStatus() <-chan handshakeEvent { func (s *session) handshakeStatus() <-chan error {
return s.handshakeChan return s.handshakeChan
} }

View file

@ -78,11 +78,11 @@ func areSessionsRunning() bool {
var _ = Describe("Session", func() { var _ = Describe("Session", func() {
var ( var (
sess *session sess *session
scfg *handshake.ServerConfig scfg *handshake.ServerConfig
mconn *mockConnection mconn *mockConnection
cryptoSetup *mockCryptoSetup cryptoSetup *mockCryptoSetup
aeadChanged chan<- protocol.EncryptionLevel handshakeChan chan<- struct{}
) )
BeforeEach(func() { BeforeEach(func() {
@ -99,9 +99,9 @@ var _ = Describe("Session", func() {
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
_ func(net.Addr, *Cookie) bool, _ func(net.Addr, *Cookie) bool,
_ chan<- handshake.TransportParameters, _ chan<- handshake.TransportParameters,
aeadChangedP chan<- protocol.EncryptionLevel, handshakeChanP chan<- struct{},
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, error) {
aeadChanged = aeadChangedP handshakeChan = handshakeChanP
return cryptoSetup, nil return cryptoSetup, nil
} }
@ -149,7 +149,7 @@ var _ = Describe("Session", func() {
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
cookieFunc func(net.Addr, *Cookie) bool, cookieFunc func(net.Addr, *Cookie) bool,
_ chan<- handshake.TransportParameters, _ chan<- handshake.TransportParameters,
_ chan<- protocol.EncryptionLevel, _ chan<- struct{},
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, error) {
cookieVerify = cookieFunc cookieVerify = cookieFunc
return cryptoSetup, nil return cryptoSetup, nil
@ -516,61 +516,6 @@ var _ = Describe("Session", func() {
Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242))) Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242)))
}) })
Context("waiting until the handshake completes", func() {
It("waits until the handshake is complete", func() {
go func() {
defer GinkgoRecover()
sess.run()
}()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
err := sess.WaitUntilHandshakeComplete()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
aeadChanged <- protocol.EncryptionForwardSecure
Consistently(done).ShouldNot(BeClosed())
close(aeadChanged)
Eventually(done).Should(BeClosed())
Expect(sess.Close(nil)).To(Succeed())
})
It("errors if the handshake fails", func(done Done) {
testErr := errors.New("crypto error")
sess.cryptoSetup = &mockCryptoSetup{handleErr: testErr}
go sess.run()
err := sess.WaitUntilHandshakeComplete()
Expect(err).To(MatchError(testErr))
close(done)
}, 0.5)
It("returns when Close is called", func(done Done) {
testErr := errors.New("close error")
go sess.run()
var waitReturned bool
go func() {
defer GinkgoRecover()
err := sess.WaitUntilHandshakeComplete()
Expect(err).To(MatchError(testErr))
waitReturned = true
}()
sess.Close(testErr)
Eventually(func() bool { return waitReturned }).Should(BeTrue())
close(done)
})
It("doesn't wait if the handshake is already completed", func(done Done) {
go sess.run()
close(aeadChanged)
err := sess.WaitUntilHandshakeComplete()
Expect(err).ToNot(HaveOccurred())
Expect(sess.Close(nil)).To(Succeed())
close(done)
})
})
Context("accepting streams", func() { Context("accepting streams", func() {
BeforeEach(func() { BeforeEach(func() {
// don't use the mock here // don't use the mock here
@ -1362,46 +1307,48 @@ var _ = Describe("Session", func() {
}) })
}) })
It("send a handshake event on the handshakeChan when the AEAD changes to secure", func(done Done) { It("doesn't do anything when the crypto setup says to decrypt undecryptable packets", func() {
go sess.run() done := make(chan struct{})
aeadChanged <- protocol.EncryptionSecure go func() {
Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionSecure})) defer GinkgoRecover()
err := sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
}()
handshakeChan <- struct{}{}
Consistently(sess.handshakeStatus()).ShouldNot(Receive())
// make sure the go routine returns
Expect(sess.Close(nil)).To(Succeed()) Expect(sess.Close(nil)).To(Succeed())
close(done) Eventually(done).Should(BeClosed())
}) })
It("send a handshake event on the handshakeChan when the AEAD changes to forward-secure", func(done Done) { It("closes the handshakeChan when the handshake completes", func() {
go sess.run() done := make(chan struct{})
aeadChanged <- protocol.EncryptionForwardSecure go func() {
Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionForwardSecure})) defer GinkgoRecover()
Expect(sess.Close(nil)).To(Succeed()) err := sess.run()
close(done) Expect(err).ToNot(HaveOccurred())
}) close(done)
}()
It("closes the handshakeChan when the handshake completes", func(done Done) { close(handshakeChan)
go sess.run()
close(aeadChanged)
Eventually(sess.handshakeStatus()).Should(BeClosed()) Eventually(sess.handshakeStatus()).Should(BeClosed())
// make sure the go routine returns
Expect(sess.Close(nil)).To(Succeed()) Expect(sess.Close(nil)).To(Succeed())
close(done) Eventually(done).Should(BeClosed())
}) })
It("passes errors to the handshakeChan", func(done Done) { It("passes errors to the handshakeChan", func() {
testErr := errors.New("handshake error") testErr := errors.New("handshake error")
go sess.run() done := make(chan struct{})
Expect(sess.Close(nil)).To(Succeed()) go func() {
Expect(sess.handshakeStatus()).To(Receive(&handshakeEvent{err: testErr})) defer GinkgoRecover()
close(done) err := sess.run()
}) Expect(err).To(MatchError(testErr))
close(done)
It("does not block if an error occurs", func(done Done) { }()
// this test basically tests that the handshakeChan has a capacity of 3 sess.Close(testErr)
// The session needs to run (and close) properly, even if no one is receiving from the handshakeChan Expect(sess.handshakeStatus()).To(Receive(Equal(testErr)))
go sess.run() Eventually(done).Should(BeClosed())
aeadChanged <- protocol.EncryptionSecure
aeadChanged <- protocol.EncryptionForwardSecure
Expect(sess.Close(nil)).To(Succeed())
close(done)
}) })
It("process transport parameters received from the peer", func() { It("process transport parameters received from the peer", func() {
@ -1503,7 +1450,7 @@ var _ = Describe("Session", func() {
It("closes the session due to the idle timeout after handshake", func() { It("closes the session due to the idle timeout after handshake", func() {
sess.config.IdleTimeout = 0 sess.config.IdleTimeout = 0
close(aeadChanged) close(handshakeChan)
errChan := make(chan error) errChan := make(chan error)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -1619,9 +1566,9 @@ var _ = Describe("Session", func() {
var _ = Describe("Client Session", func() { var _ = Describe("Client Session", func() {
var ( var (
sess *session sess *session
mconn *mockConnection mconn *mockConnection
aeadChanged chan<- protocol.EncryptionLevel handshakeChan chan<- struct{}
cryptoSetup *mockCryptoSetup cryptoSetup *mockCryptoSetup
) )
@ -1638,11 +1585,11 @@ var _ = Describe("Client Session", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ chan<- handshake.TransportParameters, _ chan<- handshake.TransportParameters,
aeadChangedP chan<- protocol.EncryptionLevel, handshakeChanP chan<- struct{},
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, error) {
aeadChanged = aeadChangedP handshakeChan = handshakeChanP
return cryptoSetup, nil return cryptoSetup, nil
} }
@ -1674,10 +1621,13 @@ var _ = Describe("Client Session", func() {
sess.unpacker = &mockUnpacker{} sess.unpacker = &mockUnpacker{}
}) })
It("passes the diversification nonce to the cryptoSetup", func() { It("passes the diversification nonce to the crypto setup", func() {
done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess.run() err := sess.run()
Expect(err).ToNot(HaveOccurred())
close(done)
}() }()
hdr.PacketNumber = 5 hdr.PacketNumber = 5
hdr.DiversificationNonce = []byte("foobar") hdr.DiversificationNonce = []byte("foobar")
@ -1685,16 +1635,7 @@ var _ = Describe("Client Session", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(func() []byte { return cryptoSetup.divNonce }).Should(Equal(hdr.DiversificationNonce)) Eventually(func() []byte { return cryptoSetup.divNonce }).Should(Equal(hdr.DiversificationNonce))
Expect(sess.Close(nil)).To(Succeed()) Expect(sess.Close(nil)).To(Succeed())
Eventually(done).Should(BeClosed())
}) })
}) })
It("does not block if an error occurs", func(done Done) {
// this test basically tests that the handshakeChan has a capacity of 3
// The session needs to run (and close) properly, even if no one is receiving from the handshakeChan
go sess.run()
aeadChanged <- protocol.EncryptionSecure
aeadChanged <- protocol.EncryptionForwardSecure
Expect(sess.Close(nil)).To(Succeed())
close(done)
})
}) })