use separate chans to signal handshake events and handshake completion

This commit is contained in:
Marten Seemann 2018-10-02 18:25:16 -07:00
parent 5fc2e12038
commit 5102294991
8 changed files with 149 additions and 113 deletions

View file

@ -50,8 +50,9 @@ type cryptoSetupClient struct {
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
handshakeComplete chan<- struct{}
params *TransportParameters
@ -75,6 +76,7 @@ func NewCryptoSetupClient(
params *TransportParameters,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
logger utils.Logger,
@ -85,17 +87,18 @@ func NewCryptoSetupClient(
}
divNonceChan := make(chan struct{})
cs := &cryptoSetupClient{
cryptoStream: cryptoStream,
hostname: tlsConf.ServerName,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConf),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
initialVersion: initialVersion,
cryptoStream: cryptoStream,
hostname: tlsConf.ServerName,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConf),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
initialVersion: initialVersion,
// The server might have sent greased versions in the Version Negotiation packet.
// We need strip those from the list, since they won't be included in the handshake tag.
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
@ -158,7 +161,7 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
// blocks until the session has received the parameters
h.paramsChan <- *params
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
close(h.handshakeComplete)
default:
return qerr.InvalidCryptoMessageType
}

View file

@ -91,6 +91,7 @@ var _ = Describe("Client Crypto Setup", func() {
keyDerivationCalledWith *keyDerivationValues
shloMap map[Tag][]byte
handshakeEvent chan struct{}
handshakeComplete chan struct{}
paramsChan chan TransportParameters
)
@ -120,6 +121,7 @@ var _ = Describe("Client Crypto Setup", func() {
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1)
handshakeEvent = make(chan struct{}, 2)
handshakeComplete = make(chan struct{})
csInt, err := NewCryptoSetupClient(
stream,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
@ -128,6 +130,7 @@ var _ = Describe("Client Crypto Setup", func() {
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
paramsChan,
handshakeEvent,
handshakeComplete,
protocol.Version39,
nil,
utils.DefaultLogger,
@ -445,7 +448,7 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(params.IdleTimeout).To(Equal(13 * time.Second))
})
It("closes the handshakeEvent chan when receiving an SHLO", func() {
It("closes the handshakeComplete chan when receiving an SHLO", func() {
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
done := make(chan struct{})
go func() {
@ -455,7 +458,7 @@ var _ = Describe("Client Crypto Setup", func() {
close(done)
}()
Eventually(handshakeEvent).Should(Receive())
Eventually(handshakeEvent).Should(BeClosed())
Eventually(handshakeComplete).Should(BeClosed())
// make the go routine return
stream.close()
Eventually(done).Should(BeClosed())

View file

@ -42,9 +42,10 @@ type cryptoSetupServer struct {
receivedSecurePacket bool
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
receivedParams bool
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
receivedParams bool
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
handshakeComplete chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
@ -77,6 +78,7 @@ func NewCryptoSetup(
acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
@ -99,6 +101,7 @@ func NewCryptoSetup(
sentSHLO: make(chan struct{}),
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
logger: logger,
}, nil
}
@ -210,7 +213,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
h.receivedForwardSecurePacket = true
// wait for the send on the handshakeEvent chan
<-h.sentSHLO
close(h.handshakeEvent)
close(h.handshakeComplete)
}
return res, protocol.EncryptionForwardSecure, nil
}

View file

@ -122,6 +122,7 @@ var _ = Describe("Server Crypto Setup", func() {
stream *mockStream
paramsChan chan TransportParameters
handshakeEvent chan struct{}
handshakeComplete chan struct{}
nonce32 []byte
versionTag []byte
validSTK []byte
@ -144,6 +145,7 @@ var _ = Describe("Server Crypto Setup", func() {
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1)
handshakeEvent = make(chan struct{}, 2)
handshakeComplete = make(chan struct{})
stream = newMockStream()
kex = &mockKEX{}
signer = &mockSigner{}
@ -169,6 +171,7 @@ var _ = Describe("Server Crypto Setup", func() {
nil,
paramsChan,
handshakeEvent,
handshakeComplete,
utils.DefaultLogger,
)
Expect(err).NotTo(HaveOccurred())
@ -318,7 +321,7 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeEvent).ToNot(BeClosed())
Expect(handshakeComplete).ToNot(BeClosed())
})
It("rejects client nonces that have the wrong length", func() {
@ -351,7 +354,7 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeEvent).ToNot(BeClosed())
Expect(handshakeComplete).ToNot(BeClosed())
})
It("recognizes inchoate CHLOs missing SCID", func() {
@ -629,7 +632,7 @@ var _ = Describe("Server Crypto Setup", func() {
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(handshakeEvent).To(BeClosed())
Expect(handshakeComplete).To(BeClosed())
})
})

View file

@ -23,9 +23,10 @@ type cryptoSetupTLS struct {
nullAEAD crypto.AEAD
aead crypto.AEAD
tls mintTLS
conn *cryptoStreamConn
handshakeEvent chan<- struct{}
tls mintTLS
conn *cryptoStreamConn
handshakeEvent chan<- struct{}
handshakeComplete chan<- struct{}
}
var _ CryptoSetupTLS = &cryptoSetupTLS{}
@ -36,6 +37,7 @@ func NewCryptoSetupTLSServer(
connID protocol.ConnectionID,
config *mint.Config,
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
version protocol.VersionNumber,
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
@ -45,12 +47,13 @@ func NewCryptoSetupTLSServer(
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Server(conn, config)
return &cryptoSetupTLS{
tls: tls,
conn: conn,
nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
tls: tls,
conn: conn,
nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
}, nil
}
@ -60,6 +63,7 @@ func NewCryptoSetupTLSClient(
connID protocol.ConnectionID,
config *mint.Config,
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
version protocol.VersionNumber,
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
@ -69,12 +73,13 @@ func NewCryptoSetupTLSClient(
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Client(conn, config)
return &cryptoSetupTLS{
tls: tls,
conn: conn,
perspective: protocol.PerspectiveClient,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
tls: tls,
conn: conn,
perspective: protocol.PerspectiveClient,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
}, nil
}
@ -101,7 +106,7 @@ func (h *cryptoSetupTLS) HandleCryptoStream() error {
h.mutex.Unlock()
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
close(h.handshakeComplete)
return nil
}

View file

@ -20,17 +20,20 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
var _ = Describe("TLS Crypto Setup", func() {
var (
cs *cryptoSetupTLS
handshakeEvent chan struct{}
cs *cryptoSetupTLS
handshakeEvent chan struct{}
handshakeComplete chan struct{}
)
BeforeEach(func() {
handshakeEvent = make(chan struct{}, 2)
handshakeComplete = make(chan struct{})
css, err := NewCryptoSetupTLSServer(
newCryptoStreamConn(bytes.NewBuffer([]byte{})),
protocol.ConnectionID{},
&mint.Config{},
handshakeEvent,
handshakeComplete,
protocol.VersionTLS,
)
Expect(err).ToNot(HaveOccurred())
@ -54,7 +57,7 @@ var _ = Describe("TLS Crypto Setup", func() {
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
Expect(handshakeEvent).To(Receive())
Expect(handshakeEvent).To(BeClosed())
Expect(handshakeComplete).To(BeClosed())
})
It("handshakes until it is connected", func() {

View file

@ -119,8 +119,9 @@ type session struct {
paramsChan <-chan handshake.TransportParameters
// the handshakeEvent channel is passed to the CryptoSetup.
// It receives when it makes sense to try decrypting undecryptable packets.
handshakeEvent <-chan struct{}
handshakeComplete bool
handshakeEvent <-chan struct{}
handshakeCompleteChan <-chan struct{} // is closed when the handshake completes
handshakeComplete bool
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
receivedFirstForwardSecurePacket bool
@ -162,17 +163,19 @@ func newSession(
logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID)
paramsChan := make(chan handshake.TransportParameters)
handshakeEvent := make(chan struct{}, 1)
handshakeCompleteChan := make(chan struct{})
s := &session{
conn: conn,
sessionRunner: sessionRunner,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveServer,
version: v,
config: config,
handshakeEvent: handshakeEvent,
paramsChan: paramsChan,
logger: logger,
conn: conn,
sessionRunner: sessionRunner,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveServer,
version: v,
config: config,
handshakeEvent: handshakeEvent,
handshakeCompleteChan: handshakeCompleteChan,
paramsChan: paramsChan,
logger: logger,
}
s.preSetup()
transportParams := &handshake.TransportParameters{
@ -197,6 +200,7 @@ func newSession(
s.config.AcceptCookie,
paramsChan,
handshakeEvent,
handshakeCompleteChan,
s.logger,
)
if err != nil {
@ -238,17 +242,19 @@ var newClientSession = func(
logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID)
paramsChan := make(chan handshake.TransportParameters)
handshakeEvent := make(chan struct{}, 1)
handshakeCompleteChan := make(chan struct{})
s := &session{
conn: conn,
sessionRunner: sessionRunner,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveClient,
version: v,
config: config,
handshakeEvent: handshakeEvent,
paramsChan: paramsChan,
logger: logger,
conn: conn,
sessionRunner: sessionRunner,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveClient,
version: v,
config: config,
handshakeEvent: handshakeEvent,
handshakeCompleteChan: handshakeCompleteChan,
paramsChan: paramsChan,
logger: logger,
}
s.preSetup()
transportParams := &handshake.TransportParameters{
@ -266,6 +272,7 @@ var newClientSession = func(
transportParams,
paramsChan,
handshakeEvent,
handshakeCompleteChan,
initialVersion,
negotiatedVersions,
s.logger,
@ -307,16 +314,18 @@ func newTLSServerSession(
v protocol.VersionNumber,
) (quicSession, error) {
handshakeEvent := make(chan struct{}, 1)
handshakeCompleteChan := make(chan struct{})
s := &session{
conn: conn,
sessionRunner: runner,
config: config,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveServer,
version: v,
handshakeEvent: handshakeEvent,
logger: logger,
conn: conn,
sessionRunner: runner,
config: config,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveServer,
version: v,
handshakeEvent: handshakeEvent,
handshakeCompleteChan: handshakeCompleteChan,
logger: logger,
}
s.preSetup()
cs, err := handshake.NewCryptoSetupTLSServer(
@ -324,6 +333,7 @@ func newTLSServerSession(
origConnID,
mintConf,
handshakeEvent,
handshakeCompleteChan,
v,
)
if err != nil {
@ -370,17 +380,19 @@ var newTLSClientSession = func(
v protocol.VersionNumber,
) (quicSession, error) {
handshakeEvent := make(chan struct{}, 1)
handshakeCompleteChan := make(chan struct{})
s := &session{
conn: conn,
sessionRunner: runner,
config: conf,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveClient,
version: v,
handshakeEvent: handshakeEvent,
paramsChan: paramsChan,
logger: logger,
conn: conn,
sessionRunner: runner,
config: conf,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveClient,
version: v,
handshakeEvent: handshakeEvent,
handshakeCompleteChan: handshakeCompleteChan,
paramsChan: paramsChan,
logger: logger,
}
s.preSetup()
cs, err := handshake.NewCryptoSetupTLSClient(
@ -388,6 +400,7 @@ var newTLSClientSession = func(
s.destConnID,
mintConf,
handshakeEvent,
handshakeCompleteChan,
v,
)
if err != nil {
@ -463,9 +476,10 @@ runLoop:
select {
case closeErr = <-s.closeChan:
break runLoop
case _, ok := <-s.handshakeEvent:
// when the handshake is completed, the channel will be closed
s.handleHandshakeEvent(!ok)
case <-s.handshakeEvent:
s.tryDecryptingQueuedPackets()
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
default:
}
@ -497,9 +511,10 @@ runLoop:
case p := <-s.paramsChan:
s.processTransportParameters(&p)
continue
case _, ok := <-s.handshakeEvent:
// when the handshake is completed, the channel will be closed
s.handleHandshakeEvent(!ok)
case <-s.handshakeEvent:
s.tryDecryptingQueuedPackets()
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
}
now := time.Now()
@ -590,13 +605,9 @@ func (s *session) maybeResetTimer() {
s.timer.Reset(deadline)
}
func (s *session) handleHandshakeEvent(completed bool) {
if !completed {
s.tryDecryptingQueuedPackets()
return
}
func (s *session) handleHandshakeComplete() {
s.handshakeComplete = true
s.handshakeEvent = nil // prevent this case from ever being selected again
s.handshakeCompleteChan = nil // prevent this case from ever being selected again
s.sessionRunner.onHandshakeComplete(s)
// In gQUIC, the server completes the handshake first (after sending the SHLO).

View file

@ -93,14 +93,15 @@ func areSessionsRunning() bool {
var _ = Describe("Session", func() {
var (
sess *session
sessionRunner *MockSessionRunner
scfg *handshake.ServerConfig
mconn *mockConnection
cryptoSetup *mockCryptoSetup
streamManager *MockStreamManager
packer *MockPacker
handshakeChan chan<- struct{}
sess *session
sessionRunner *MockSessionRunner
scfg *handshake.ServerConfig
mconn *mockConnection
cryptoSetup *mockCryptoSetup
streamManager *MockStreamManager
packer *MockPacker
handshakeChan chan<- struct{}
handshakeCompleteChan chan<- struct{}
)
BeforeEach(func() {
@ -119,9 +120,11 @@ var _ = Describe("Session", func() {
_ func(net.Addr, *Cookie) bool,
_ chan<- handshake.TransportParameters,
handshakeChanP chan<- struct{},
handshakeCompleteChanP chan<- struct{},
_ utils.Logger,
) (handshake.CryptoSetup, error) {
handshakeChan = handshakeChanP
handshakeCompleteChan = handshakeCompleteChanP
return cryptoSetup, nil
}
@ -178,6 +181,7 @@ var _ = Describe("Session", func() {
cookieFunc func(net.Addr, *Cookie) bool,
_ chan<- handshake.TransportParameters,
_ chan<- struct{},
_ chan<- struct{},
_ utils.Logger,
) (handshake.CryptoSetup, error) {
cookieVerify = cookieFunc
@ -1255,7 +1259,7 @@ var _ = Describe("Session", func() {
})
It("calls the onHandshakeComplete callback when the handshake completes", func() {
close(handshakeChan)
close(handshakeCompleteChan)
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
go func() {
defer GinkgoRecover()
@ -1469,7 +1473,7 @@ var _ = Describe("Session", func() {
return &packedPacket{}, nil
})
sess.config.IdleTimeout = 0
close(handshakeChan)
close(handshakeCompleteChan)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -1578,11 +1582,11 @@ var _ = Describe("Session", func() {
var _ = Describe("Client Session", func() {
var (
sess *session
sessionRunner *MockSessionRunner
packer *MockPacker
mconn *mockConnection
handshakeChan chan<- struct{}
sess *session
sessionRunner *MockSessionRunner
packer *MockPacker
mconn *mockConnection
handshakeCompleteChan chan<- struct{}
cryptoSetup *mockCryptoSetup
)
@ -1598,12 +1602,13 @@ var _ = Describe("Client Session", func() {
_ *tls.Config,
_ *handshake.TransportParameters,
_ chan<- handshake.TransportParameters,
handshakeChanP chan<- struct{},
_ chan<- struct{},
handshakeCompleteChanP chan<- struct{},
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (handshake.CryptoSetup, error) {
handshakeChan = handshakeChanP
handshakeCompleteChan = handshakeCompleteChanP
return cryptoSetup, nil
}
@ -1641,7 +1646,7 @@ var _ = Describe("Client Session", func() {
}),
packer.EXPECT().PackPacket().AnyTimes(),
)
close(handshakeChan)
close(handshakeCompleteChan)
go func() {
defer GinkgoRecover()
sess.run()