use separate chans to signal handshake events and handshake completion

This commit is contained in:
Marten Seemann 2018-10-02 18:25:16 -07:00
parent 5fc2e12038
commit 5102294991
8 changed files with 149 additions and 113 deletions

View file

@ -50,8 +50,9 @@ type cryptoSetupClient struct {
secureAEAD crypto.AEAD secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{} handshakeEvent chan<- struct{}
handshakeComplete chan<- struct{}
params *TransportParameters params *TransportParameters
@ -75,6 +76,7 @@ func NewCryptoSetupClient(
params *TransportParameters, params *TransportParameters,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{}, handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
logger utils.Logger, logger utils.Logger,
@ -85,17 +87,18 @@ func NewCryptoSetupClient(
} }
divNonceChan := make(chan struct{}) divNonceChan := make(chan struct{})
cs := &cryptoSetupClient{ cs := &cryptoSetupClient{
cryptoStream: cryptoStream, cryptoStream: cryptoStream,
hostname: tlsConf.ServerName, hostname: tlsConf.ServerName,
connID: connID, connID: connID,
version: version, version: version,
certManager: crypto.NewCertManager(tlsConf), certManager: crypto.NewCertManager(tlsConf),
params: params, params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyDerivation: crypto.DeriveQuicCryptoAESKeys,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
paramsChan: paramsChan, paramsChan: paramsChan,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
initialVersion: initialVersion, handshakeComplete: handshakeComplete,
initialVersion: initialVersion,
// The server might have sent greased versions in the Version Negotiation packet. // The server might have sent greased versions in the Version Negotiation packet.
// We need strip those from the list, since they won't be included in the handshake tag. // We need strip those from the list, since they won't be included in the handshake tag.
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions), negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
@ -158,7 +161,7 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
// blocks until the session has received the parameters // blocks until the session has received the parameters
h.paramsChan <- *params h.paramsChan <- *params
h.handshakeEvent <- struct{}{} h.handshakeEvent <- struct{}{}
close(h.handshakeEvent) close(h.handshakeComplete)
default: default:
return qerr.InvalidCryptoMessageType return qerr.InvalidCryptoMessageType
} }

View file

@ -91,6 +91,7 @@ var _ = Describe("Client Crypto Setup", func() {
keyDerivationCalledWith *keyDerivationValues keyDerivationCalledWith *keyDerivationValues
shloMap map[Tag][]byte shloMap map[Tag][]byte
handshakeEvent chan struct{} handshakeEvent chan struct{}
handshakeComplete chan struct{}
paramsChan chan TransportParameters paramsChan chan TransportParameters
) )
@ -120,6 +121,7 @@ var _ = Describe("Client Crypto Setup", func() {
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking // use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1) paramsChan = make(chan TransportParameters, 1)
handshakeEvent = make(chan struct{}, 2) handshakeEvent = make(chan struct{}, 2)
handshakeComplete = make(chan struct{})
csInt, err := NewCryptoSetupClient( csInt, err := NewCryptoSetupClient(
stream, stream,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
@ -128,6 +130,7 @@ var _ = Describe("Client Crypto Setup", func() {
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
paramsChan, paramsChan,
handshakeEvent, handshakeEvent,
handshakeComplete,
protocol.Version39, protocol.Version39,
nil, nil,
utils.DefaultLogger, utils.DefaultLogger,
@ -445,7 +448,7 @@ var _ = Describe("Client Crypto Setup", func() {
Expect(params.IdleTimeout).To(Equal(13 * time.Second)) Expect(params.IdleTimeout).To(Equal(13 * time.Second))
}) })
It("closes the handshakeEvent chan when receiving an SHLO", func() { It("closes the handshakeComplete chan when receiving an SHLO", func() {
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -455,7 +458,7 @@ var _ = Describe("Client Crypto Setup", func() {
close(done) close(done)
}() }()
Eventually(handshakeEvent).Should(Receive()) Eventually(handshakeEvent).Should(Receive())
Eventually(handshakeEvent).Should(BeClosed()) Eventually(handshakeComplete).Should(BeClosed())
// make the go routine return // make the go routine return
stream.close() stream.close()
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())

View file

@ -42,9 +42,10 @@ type cryptoSetupServer struct {
receivedSecurePacket bool receivedSecurePacket bool
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
receivedParams bool receivedParams bool
paramsChan chan<- TransportParameters paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{} handshakeEvent chan<- struct{}
handshakeComplete chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction keyExchange KeyExchangeFunction
@ -77,6 +78,7 @@ func NewCryptoSetup(
acceptSTK func(net.Addr, *Cookie) bool, acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{}, handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
logger utils.Logger, logger utils.Logger,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
@ -99,6 +101,7 @@ func NewCryptoSetup(
sentSHLO: make(chan struct{}), sentSHLO: make(chan struct{}),
paramsChan: paramsChan, paramsChan: paramsChan,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
logger: logger, logger: logger,
}, nil }, nil
} }
@ -210,7 +213,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
h.receivedForwardSecurePacket = true h.receivedForwardSecurePacket = true
// wait for the send on the handshakeEvent chan // wait for the send on the handshakeEvent chan
<-h.sentSHLO <-h.sentSHLO
close(h.handshakeEvent) close(h.handshakeComplete)
} }
return res, protocol.EncryptionForwardSecure, nil return res, protocol.EncryptionForwardSecure, nil
} }

View file

@ -122,6 +122,7 @@ var _ = Describe("Server Crypto Setup", func() {
stream *mockStream stream *mockStream
paramsChan chan TransportParameters paramsChan chan TransportParameters
handshakeEvent chan struct{} handshakeEvent chan struct{}
handshakeComplete chan struct{}
nonce32 []byte nonce32 []byte
versionTag []byte versionTag []byte
validSTK []byte validSTK []byte
@ -144,6 +145,7 @@ var _ = Describe("Server Crypto Setup", func() {
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking // use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1) paramsChan = make(chan TransportParameters, 1)
handshakeEvent = make(chan struct{}, 2) handshakeEvent = make(chan struct{}, 2)
handshakeComplete = make(chan struct{})
stream = newMockStream() stream = newMockStream()
kex = &mockKEX{} kex = &mockKEX{}
signer = &mockSigner{} signer = &mockSigner{}
@ -169,6 +171,7 @@ var _ = Describe("Server Crypto Setup", func() {
nil, nil,
paramsChan, paramsChan,
handshakeEvent, handshakeEvent,
handshakeComplete,
utils.DefaultLogger, utils.DefaultLogger,
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -318,7 +321,7 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(handshakeEvent).To(Receive()) // for the switch to secure Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO")) Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeEvent).ToNot(BeClosed()) Expect(handshakeComplete).ToNot(BeClosed())
}) })
It("rejects client nonces that have the wrong length", func() { It("rejects client nonces that have the wrong length", func() {
@ -351,7 +354,7 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
Expect(handshakeEvent).To(Receive()) // for the switch to secure Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeEvent).ToNot(BeClosed()) Expect(handshakeComplete).ToNot(BeClosed())
}) })
It("recognizes inchoate CHLOs missing SCID", func() { It("recognizes inchoate CHLOs missing SCID", func() {
@ -629,7 +632,7 @@ var _ = Describe("Server Crypto Setup", func() {
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{}) cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{}) _, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handshakeEvent).To(BeClosed()) Expect(handshakeComplete).To(BeClosed())
}) })
}) })

View file

@ -23,9 +23,10 @@ type cryptoSetupTLS struct {
nullAEAD crypto.AEAD nullAEAD crypto.AEAD
aead crypto.AEAD aead crypto.AEAD
tls mintTLS tls mintTLS
conn *cryptoStreamConn conn *cryptoStreamConn
handshakeEvent chan<- struct{} handshakeEvent chan<- struct{}
handshakeComplete chan<- struct{}
} }
var _ CryptoSetupTLS = &cryptoSetupTLS{} var _ CryptoSetupTLS = &cryptoSetupTLS{}
@ -36,6 +37,7 @@ func NewCryptoSetupTLSServer(
connID protocol.ConnectionID, connID protocol.ConnectionID,
config *mint.Config, config *mint.Config,
handshakeEvent chan<- struct{}, handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
version protocol.VersionNumber, version protocol.VersionNumber,
) (CryptoSetupTLS, error) { ) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
@ -45,12 +47,13 @@ func NewCryptoSetupTLSServer(
conn := newCryptoStreamConn(cryptoStream) conn := newCryptoStreamConn(cryptoStream)
tls := mint.Server(conn, config) tls := mint.Server(conn, config)
return &cryptoSetupTLS{ return &cryptoSetupTLS{
tls: tls, tls: tls,
conn: conn, conn: conn,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
}, nil }, nil
} }
@ -60,6 +63,7 @@ func NewCryptoSetupTLSClient(
connID protocol.ConnectionID, connID protocol.ConnectionID,
config *mint.Config, config *mint.Config,
handshakeEvent chan<- struct{}, handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
version protocol.VersionNumber, version protocol.VersionNumber,
) (CryptoSetupTLS, error) { ) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
@ -69,12 +73,13 @@ func NewCryptoSetupTLSClient(
conn := newCryptoStreamConn(cryptoStream) conn := newCryptoStreamConn(cryptoStream)
tls := mint.Client(conn, config) tls := mint.Client(conn, config)
return &cryptoSetupTLS{ return &cryptoSetupTLS{
tls: tls, tls: tls,
conn: conn, conn: conn,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
}, nil }, nil
} }
@ -101,7 +106,7 @@ func (h *cryptoSetupTLS) HandleCryptoStream() error {
h.mutex.Unlock() h.mutex.Unlock()
h.handshakeEvent <- struct{}{} h.handshakeEvent <- struct{}{}
close(h.handshakeEvent) close(h.handshakeComplete)
return nil return nil
} }

View file

@ -20,17 +20,20 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
var _ = Describe("TLS Crypto Setup", func() { var _ = Describe("TLS Crypto Setup", func() {
var ( var (
cs *cryptoSetupTLS cs *cryptoSetupTLS
handshakeEvent chan struct{} handshakeEvent chan struct{}
handshakeComplete chan struct{}
) )
BeforeEach(func() { BeforeEach(func() {
handshakeEvent = make(chan struct{}, 2) handshakeEvent = make(chan struct{}, 2)
handshakeComplete = make(chan struct{})
css, err := NewCryptoSetupTLSServer( css, err := NewCryptoSetupTLSServer(
newCryptoStreamConn(bytes.NewBuffer([]byte{})), newCryptoStreamConn(bytes.NewBuffer([]byte{})),
protocol.ConnectionID{}, protocol.ConnectionID{},
&mint.Config{}, &mint.Config{},
handshakeEvent, handshakeEvent,
handshakeComplete,
protocol.VersionTLS, protocol.VersionTLS,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -54,7 +57,7 @@ var _ = Describe("TLS Crypto Setup", func() {
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handshakeEvent).To(Receive()) Expect(handshakeEvent).To(Receive())
Expect(handshakeEvent).To(BeClosed()) Expect(handshakeComplete).To(BeClosed())
}) })
It("handshakes until it is connected", func() { It("handshakes until it is connected", func() {

View file

@ -119,8 +119,9 @@ type session struct {
paramsChan <-chan handshake.TransportParameters paramsChan <-chan handshake.TransportParameters
// the handshakeEvent channel is passed to the CryptoSetup. // the handshakeEvent channel is passed to the CryptoSetup.
// It receives when it makes sense to try decrypting undecryptable packets. // It receives when it makes sense to try decrypting undecryptable packets.
handshakeEvent <-chan struct{} handshakeEvent <-chan struct{}
handshakeComplete bool handshakeCompleteChan <-chan struct{} // is closed when the handshake completes
handshakeComplete bool
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
receivedFirstForwardSecurePacket bool receivedFirstForwardSecurePacket bool
@ -162,17 +163,19 @@ func newSession(
logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID) logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID)
paramsChan := make(chan handshake.TransportParameters) paramsChan := make(chan handshake.TransportParameters)
handshakeEvent := make(chan struct{}, 1) handshakeEvent := make(chan struct{}, 1)
handshakeCompleteChan := make(chan struct{})
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: sessionRunner, sessionRunner: sessionRunner,
srcConnID: srcConnID, srcConnID: srcConnID,
destConnID: destConnID, destConnID: destConnID,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
config: config, config: config,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
paramsChan: paramsChan, handshakeCompleteChan: handshakeCompleteChan,
logger: logger, paramsChan: paramsChan,
logger: logger,
} }
s.preSetup() s.preSetup()
transportParams := &handshake.TransportParameters{ transportParams := &handshake.TransportParameters{
@ -197,6 +200,7 @@ func newSession(
s.config.AcceptCookie, s.config.AcceptCookie,
paramsChan, paramsChan,
handshakeEvent, handshakeEvent,
handshakeCompleteChan,
s.logger, s.logger,
) )
if err != nil { if err != nil {
@ -238,17 +242,19 @@ var newClientSession = func(
logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID) logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID)
paramsChan := make(chan handshake.TransportParameters) paramsChan := make(chan handshake.TransportParameters)
handshakeEvent := make(chan struct{}, 1) handshakeEvent := make(chan struct{}, 1)
handshakeCompleteChan := make(chan struct{})
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: sessionRunner, sessionRunner: sessionRunner,
srcConnID: srcConnID, srcConnID: srcConnID,
destConnID: destConnID, destConnID: destConnID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
config: config, config: config,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
paramsChan: paramsChan, handshakeCompleteChan: handshakeCompleteChan,
logger: logger, paramsChan: paramsChan,
logger: logger,
} }
s.preSetup() s.preSetup()
transportParams := &handshake.TransportParameters{ transportParams := &handshake.TransportParameters{
@ -266,6 +272,7 @@ var newClientSession = func(
transportParams, transportParams,
paramsChan, paramsChan,
handshakeEvent, handshakeEvent,
handshakeCompleteChan,
initialVersion, initialVersion,
negotiatedVersions, negotiatedVersions,
s.logger, s.logger,
@ -307,16 +314,18 @@ func newTLSServerSession(
v protocol.VersionNumber, v protocol.VersionNumber,
) (quicSession, error) { ) (quicSession, error) {
handshakeEvent := make(chan struct{}, 1) handshakeEvent := make(chan struct{}, 1)
handshakeCompleteChan := make(chan struct{})
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: runner, sessionRunner: runner,
config: config, config: config,
srcConnID: srcConnID, srcConnID: srcConnID,
destConnID: destConnID, destConnID: destConnID,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
logger: logger, handshakeCompleteChan: handshakeCompleteChan,
logger: logger,
} }
s.preSetup() s.preSetup()
cs, err := handshake.NewCryptoSetupTLSServer( cs, err := handshake.NewCryptoSetupTLSServer(
@ -324,6 +333,7 @@ func newTLSServerSession(
origConnID, origConnID,
mintConf, mintConf,
handshakeEvent, handshakeEvent,
handshakeCompleteChan,
v, v,
) )
if err != nil { if err != nil {
@ -370,17 +380,19 @@ var newTLSClientSession = func(
v protocol.VersionNumber, v protocol.VersionNumber,
) (quicSession, error) { ) (quicSession, error) {
handshakeEvent := make(chan struct{}, 1) handshakeEvent := make(chan struct{}, 1)
handshakeCompleteChan := make(chan struct{})
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: runner, sessionRunner: runner,
config: conf, config: conf,
srcConnID: srcConnID, srcConnID: srcConnID,
destConnID: destConnID, destConnID: destConnID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
paramsChan: paramsChan, handshakeCompleteChan: handshakeCompleteChan,
logger: logger, paramsChan: paramsChan,
logger: logger,
} }
s.preSetup() s.preSetup()
cs, err := handshake.NewCryptoSetupTLSClient( cs, err := handshake.NewCryptoSetupTLSClient(
@ -388,6 +400,7 @@ var newTLSClientSession = func(
s.destConnID, s.destConnID,
mintConf, mintConf,
handshakeEvent, handshakeEvent,
handshakeCompleteChan,
v, v,
) )
if err != nil { if err != nil {
@ -463,9 +476,10 @@ runLoop:
select { select {
case closeErr = <-s.closeChan: case closeErr = <-s.closeChan:
break runLoop break runLoop
case _, ok := <-s.handshakeEvent: case <-s.handshakeEvent:
// when the handshake is completed, the channel will be closed s.tryDecryptingQueuedPackets()
s.handleHandshakeEvent(!ok) case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
default: default:
} }
@ -497,9 +511,10 @@ runLoop:
case p := <-s.paramsChan: case p := <-s.paramsChan:
s.processTransportParameters(&p) s.processTransportParameters(&p)
continue continue
case _, ok := <-s.handshakeEvent: case <-s.handshakeEvent:
// when the handshake is completed, the channel will be closed s.tryDecryptingQueuedPackets()
s.handleHandshakeEvent(!ok) case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
} }
now := time.Now() now := time.Now()
@ -590,13 +605,9 @@ func (s *session) maybeResetTimer() {
s.timer.Reset(deadline) s.timer.Reset(deadline)
} }
func (s *session) handleHandshakeEvent(completed bool) { func (s *session) handleHandshakeComplete() {
if !completed {
s.tryDecryptingQueuedPackets()
return
}
s.handshakeComplete = true s.handshakeComplete = true
s.handshakeEvent = nil // prevent this case from ever being selected again s.handshakeCompleteChan = nil // prevent this case from ever being selected again
s.sessionRunner.onHandshakeComplete(s) s.sessionRunner.onHandshakeComplete(s)
// In gQUIC, the server completes the handshake first (after sending the SHLO). // In gQUIC, the server completes the handshake first (after sending the SHLO).

View file

@ -93,14 +93,15 @@ func areSessionsRunning() bool {
var _ = Describe("Session", func() { var _ = Describe("Session", func() {
var ( var (
sess *session sess *session
sessionRunner *MockSessionRunner sessionRunner *MockSessionRunner
scfg *handshake.ServerConfig scfg *handshake.ServerConfig
mconn *mockConnection mconn *mockConnection
cryptoSetup *mockCryptoSetup cryptoSetup *mockCryptoSetup
streamManager *MockStreamManager streamManager *MockStreamManager
packer *MockPacker packer *MockPacker
handshakeChan chan<- struct{} handshakeChan chan<- struct{}
handshakeCompleteChan chan<- struct{}
) )
BeforeEach(func() { BeforeEach(func() {
@ -119,9 +120,11 @@ var _ = Describe("Session", func() {
_ func(net.Addr, *Cookie) bool, _ func(net.Addr, *Cookie) bool,
_ chan<- handshake.TransportParameters, _ chan<- handshake.TransportParameters,
handshakeChanP chan<- struct{}, handshakeChanP chan<- struct{},
handshakeCompleteChanP chan<- struct{},
_ utils.Logger, _ utils.Logger,
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, error) {
handshakeChan = handshakeChanP handshakeChan = handshakeChanP
handshakeCompleteChan = handshakeCompleteChanP
return cryptoSetup, nil return cryptoSetup, nil
} }
@ -178,6 +181,7 @@ var _ = Describe("Session", func() {
cookieFunc func(net.Addr, *Cookie) bool, cookieFunc func(net.Addr, *Cookie) bool,
_ chan<- handshake.TransportParameters, _ chan<- handshake.TransportParameters,
_ chan<- struct{}, _ chan<- struct{},
_ chan<- struct{},
_ utils.Logger, _ utils.Logger,
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, error) {
cookieVerify = cookieFunc cookieVerify = cookieFunc
@ -1255,7 +1259,7 @@ var _ = Describe("Session", func() {
}) })
It("calls the onHandshakeComplete callback when the handshake completes", func() { It("calls the onHandshakeComplete callback when the handshake completes", func() {
close(handshakeChan) close(handshakeCompleteChan)
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -1469,7 +1473,7 @@ var _ = Describe("Session", func() {
return &packedPacket{}, nil return &packedPacket{}, nil
}) })
sess.config.IdleTimeout = 0 sess.config.IdleTimeout = 0
close(handshakeChan) close(handshakeCompleteChan)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -1578,11 +1582,11 @@ var _ = Describe("Session", func() {
var _ = Describe("Client Session", func() { var _ = Describe("Client Session", func() {
var ( var (
sess *session sess *session
sessionRunner *MockSessionRunner sessionRunner *MockSessionRunner
packer *MockPacker packer *MockPacker
mconn *mockConnection mconn *mockConnection
handshakeChan chan<- struct{} handshakeCompleteChan chan<- struct{}
cryptoSetup *mockCryptoSetup cryptoSetup *mockCryptoSetup
) )
@ -1598,12 +1602,13 @@ var _ = Describe("Client Session", func() {
_ *tls.Config, _ *tls.Config,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ chan<- handshake.TransportParameters, _ chan<- handshake.TransportParameters,
handshakeChanP chan<- struct{}, _ chan<- struct{},
handshakeCompleteChanP chan<- struct{},
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, error) {
handshakeChan = handshakeChanP handshakeCompleteChan = handshakeCompleteChanP
return cryptoSetup, nil return cryptoSetup, nil
} }
@ -1641,7 +1646,7 @@ var _ = Describe("Client Session", func() {
}), }),
packer.EXPECT().PackPacket().AnyTimes(), packer.EXPECT().PackPacket().AnyTimes(),
) )
close(handshakeChan) close(handshakeCompleteChan)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess.run() sess.run()