mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
use separate chans to signal handshake events and handshake completion
This commit is contained in:
parent
5fc2e12038
commit
5102294991
8 changed files with 149 additions and 113 deletions
|
@ -50,8 +50,9 @@ type cryptoSetupClient struct {
|
||||||
secureAEAD crypto.AEAD
|
secureAEAD crypto.AEAD
|
||||||
forwardSecureAEAD crypto.AEAD
|
forwardSecureAEAD crypto.AEAD
|
||||||
|
|
||||||
paramsChan chan<- TransportParameters
|
paramsChan chan<- TransportParameters
|
||||||
handshakeEvent chan<- struct{}
|
handshakeEvent chan<- struct{}
|
||||||
|
handshakeComplete chan<- struct{}
|
||||||
|
|
||||||
params *TransportParameters
|
params *TransportParameters
|
||||||
|
|
||||||
|
@ -75,6 +76,7 @@ func NewCryptoSetupClient(
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
paramsChan chan<- TransportParameters,
|
paramsChan chan<- TransportParameters,
|
||||||
handshakeEvent chan<- struct{},
|
handshakeEvent chan<- struct{},
|
||||||
|
handshakeComplete chan<- struct{},
|
||||||
initialVersion protocol.VersionNumber,
|
initialVersion protocol.VersionNumber,
|
||||||
negotiatedVersions []protocol.VersionNumber,
|
negotiatedVersions []protocol.VersionNumber,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
|
@ -85,17 +87,18 @@ func NewCryptoSetupClient(
|
||||||
}
|
}
|
||||||
divNonceChan := make(chan struct{})
|
divNonceChan := make(chan struct{})
|
||||||
cs := &cryptoSetupClient{
|
cs := &cryptoSetupClient{
|
||||||
cryptoStream: cryptoStream,
|
cryptoStream: cryptoStream,
|
||||||
hostname: tlsConf.ServerName,
|
hostname: tlsConf.ServerName,
|
||||||
connID: connID,
|
connID: connID,
|
||||||
version: version,
|
version: version,
|
||||||
certManager: crypto.NewCertManager(tlsConf),
|
certManager: crypto.NewCertManager(tlsConf),
|
||||||
params: params,
|
params: params,
|
||||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||||
nullAEAD: nullAEAD,
|
nullAEAD: nullAEAD,
|
||||||
paramsChan: paramsChan,
|
paramsChan: paramsChan,
|
||||||
handshakeEvent: handshakeEvent,
|
handshakeEvent: handshakeEvent,
|
||||||
initialVersion: initialVersion,
|
handshakeComplete: handshakeComplete,
|
||||||
|
initialVersion: initialVersion,
|
||||||
// The server might have sent greased versions in the Version Negotiation packet.
|
// The server might have sent greased versions in the Version Negotiation packet.
|
||||||
// We need strip those from the list, since they won't be included in the handshake tag.
|
// We need strip those from the list, since they won't be included in the handshake tag.
|
||||||
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
|
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
|
||||||
|
@ -158,7 +161,7 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||||
// blocks until the session has received the parameters
|
// blocks until the session has received the parameters
|
||||||
h.paramsChan <- *params
|
h.paramsChan <- *params
|
||||||
h.handshakeEvent <- struct{}{}
|
h.handshakeEvent <- struct{}{}
|
||||||
close(h.handshakeEvent)
|
close(h.handshakeComplete)
|
||||||
default:
|
default:
|
||||||
return qerr.InvalidCryptoMessageType
|
return qerr.InvalidCryptoMessageType
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,6 +91,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
keyDerivationCalledWith *keyDerivationValues
|
keyDerivationCalledWith *keyDerivationValues
|
||||||
shloMap map[Tag][]byte
|
shloMap map[Tag][]byte
|
||||||
handshakeEvent chan struct{}
|
handshakeEvent chan struct{}
|
||||||
|
handshakeComplete chan struct{}
|
||||||
paramsChan chan TransportParameters
|
paramsChan chan TransportParameters
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -120,6 +121,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
|
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
|
||||||
paramsChan = make(chan TransportParameters, 1)
|
paramsChan = make(chan TransportParameters, 1)
|
||||||
handshakeEvent = make(chan struct{}, 2)
|
handshakeEvent = make(chan struct{}, 2)
|
||||||
|
handshakeComplete = make(chan struct{})
|
||||||
csInt, err := NewCryptoSetupClient(
|
csInt, err := NewCryptoSetupClient(
|
||||||
stream,
|
stream,
|
||||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
|
@ -128,6 +130,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
||||||
paramsChan,
|
paramsChan,
|
||||||
handshakeEvent,
|
handshakeEvent,
|
||||||
|
handshakeComplete,
|
||||||
protocol.Version39,
|
protocol.Version39,
|
||||||
nil,
|
nil,
|
||||||
utils.DefaultLogger,
|
utils.DefaultLogger,
|
||||||
|
@ -445,7 +448,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
Expect(params.IdleTimeout).To(Equal(13 * time.Second))
|
Expect(params.IdleTimeout).To(Equal(13 * time.Second))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("closes the handshakeEvent chan when receiving an SHLO", func() {
|
It("closes the handshakeComplete chan when receiving an SHLO", func() {
|
||||||
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
|
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -455,7 +458,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
Eventually(handshakeEvent).Should(Receive())
|
Eventually(handshakeEvent).Should(Receive())
|
||||||
Eventually(handshakeEvent).Should(BeClosed())
|
Eventually(handshakeComplete).Should(BeClosed())
|
||||||
// make the go routine return
|
// make the go routine return
|
||||||
stream.close()
|
stream.close()
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
|
|
@ -42,9 +42,10 @@ type cryptoSetupServer struct {
|
||||||
receivedSecurePacket bool
|
receivedSecurePacket bool
|
||||||
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
|
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
|
||||||
|
|
||||||
receivedParams bool
|
receivedParams bool
|
||||||
paramsChan chan<- TransportParameters
|
paramsChan chan<- TransportParameters
|
||||||
handshakeEvent chan<- struct{}
|
handshakeEvent chan<- struct{}
|
||||||
|
handshakeComplete chan<- struct{}
|
||||||
|
|
||||||
keyDerivation QuicCryptoKeyDerivationFunction
|
keyDerivation QuicCryptoKeyDerivationFunction
|
||||||
keyExchange KeyExchangeFunction
|
keyExchange KeyExchangeFunction
|
||||||
|
@ -77,6 +78,7 @@ func NewCryptoSetup(
|
||||||
acceptSTK func(net.Addr, *Cookie) bool,
|
acceptSTK func(net.Addr, *Cookie) bool,
|
||||||
paramsChan chan<- TransportParameters,
|
paramsChan chan<- TransportParameters,
|
||||||
handshakeEvent chan<- struct{},
|
handshakeEvent chan<- struct{},
|
||||||
|
handshakeComplete chan<- struct{},
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
) (CryptoSetup, error) {
|
) (CryptoSetup, error) {
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||||
|
@ -99,6 +101,7 @@ func NewCryptoSetup(
|
||||||
sentSHLO: make(chan struct{}),
|
sentSHLO: make(chan struct{}),
|
||||||
paramsChan: paramsChan,
|
paramsChan: paramsChan,
|
||||||
handshakeEvent: handshakeEvent,
|
handshakeEvent: handshakeEvent,
|
||||||
|
handshakeComplete: handshakeComplete,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -210,7 +213,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
|
||||||
h.receivedForwardSecurePacket = true
|
h.receivedForwardSecurePacket = true
|
||||||
// wait for the send on the handshakeEvent chan
|
// wait for the send on the handshakeEvent chan
|
||||||
<-h.sentSHLO
|
<-h.sentSHLO
|
||||||
close(h.handshakeEvent)
|
close(h.handshakeComplete)
|
||||||
}
|
}
|
||||||
return res, protocol.EncryptionForwardSecure, nil
|
return res, protocol.EncryptionForwardSecure, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,6 +122,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
stream *mockStream
|
stream *mockStream
|
||||||
paramsChan chan TransportParameters
|
paramsChan chan TransportParameters
|
||||||
handshakeEvent chan struct{}
|
handshakeEvent chan struct{}
|
||||||
|
handshakeComplete chan struct{}
|
||||||
nonce32 []byte
|
nonce32 []byte
|
||||||
versionTag []byte
|
versionTag []byte
|
||||||
validSTK []byte
|
validSTK []byte
|
||||||
|
@ -144,6 +145,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
|
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
|
||||||
paramsChan = make(chan TransportParameters, 1)
|
paramsChan = make(chan TransportParameters, 1)
|
||||||
handshakeEvent = make(chan struct{}, 2)
|
handshakeEvent = make(chan struct{}, 2)
|
||||||
|
handshakeComplete = make(chan struct{})
|
||||||
stream = newMockStream()
|
stream = newMockStream()
|
||||||
kex = &mockKEX{}
|
kex = &mockKEX{}
|
||||||
signer = &mockSigner{}
|
signer = &mockSigner{}
|
||||||
|
@ -169,6 +171,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
nil,
|
nil,
|
||||||
paramsChan,
|
paramsChan,
|
||||||
handshakeEvent,
|
handshakeEvent,
|
||||||
|
handshakeComplete,
|
||||||
utils.DefaultLogger,
|
utils.DefaultLogger,
|
||||||
)
|
)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
@ -318,7 +321,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
|
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
|
||||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
||||||
Expect(handshakeEvent).ToNot(BeClosed())
|
Expect(handshakeComplete).ToNot(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects client nonces that have the wrong length", func() {
|
It("rejects client nonces that have the wrong length", func() {
|
||||||
|
@ -351,7 +354,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
|
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
|
||||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
||||||
Expect(handshakeEvent).ToNot(BeClosed())
|
Expect(handshakeComplete).ToNot(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("recognizes inchoate CHLOs missing SCID", func() {
|
It("recognizes inchoate CHLOs missing SCID", func() {
|
||||||
|
@ -629,7 +632,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
|
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
|
||||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
|
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(handshakeEvent).To(BeClosed())
|
Expect(handshakeComplete).To(BeClosed())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -23,9 +23,10 @@ type cryptoSetupTLS struct {
|
||||||
nullAEAD crypto.AEAD
|
nullAEAD crypto.AEAD
|
||||||
aead crypto.AEAD
|
aead crypto.AEAD
|
||||||
|
|
||||||
tls mintTLS
|
tls mintTLS
|
||||||
conn *cryptoStreamConn
|
conn *cryptoStreamConn
|
||||||
handshakeEvent chan<- struct{}
|
handshakeEvent chan<- struct{}
|
||||||
|
handshakeComplete chan<- struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ CryptoSetupTLS = &cryptoSetupTLS{}
|
var _ CryptoSetupTLS = &cryptoSetupTLS{}
|
||||||
|
@ -36,6 +37,7 @@ func NewCryptoSetupTLSServer(
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
config *mint.Config,
|
config *mint.Config,
|
||||||
handshakeEvent chan<- struct{},
|
handshakeEvent chan<- struct{},
|
||||||
|
handshakeComplete chan<- struct{},
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) (CryptoSetupTLS, error) {
|
) (CryptoSetupTLS, error) {
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||||
|
@ -45,12 +47,13 @@ func NewCryptoSetupTLSServer(
|
||||||
conn := newCryptoStreamConn(cryptoStream)
|
conn := newCryptoStreamConn(cryptoStream)
|
||||||
tls := mint.Server(conn, config)
|
tls := mint.Server(conn, config)
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
tls: tls,
|
tls: tls,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
nullAEAD: nullAEAD,
|
nullAEAD: nullAEAD,
|
||||||
perspective: protocol.PerspectiveServer,
|
perspective: protocol.PerspectiveServer,
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
handshakeEvent: handshakeEvent,
|
handshakeEvent: handshakeEvent,
|
||||||
|
handshakeComplete: handshakeComplete,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,6 +63,7 @@ func NewCryptoSetupTLSClient(
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
config *mint.Config,
|
config *mint.Config,
|
||||||
handshakeEvent chan<- struct{},
|
handshakeEvent chan<- struct{},
|
||||||
|
handshakeComplete chan<- struct{},
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) (CryptoSetupTLS, error) {
|
) (CryptoSetupTLS, error) {
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||||
|
@ -69,12 +73,13 @@ func NewCryptoSetupTLSClient(
|
||||||
conn := newCryptoStreamConn(cryptoStream)
|
conn := newCryptoStreamConn(cryptoStream)
|
||||||
tls := mint.Client(conn, config)
|
tls := mint.Client(conn, config)
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
tls: tls,
|
tls: tls,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
perspective: protocol.PerspectiveClient,
|
perspective: protocol.PerspectiveClient,
|
||||||
nullAEAD: nullAEAD,
|
nullAEAD: nullAEAD,
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
handshakeEvent: handshakeEvent,
|
handshakeEvent: handshakeEvent,
|
||||||
|
handshakeComplete: handshakeComplete,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +106,7 @@ func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
|
|
||||||
h.handshakeEvent <- struct{}{}
|
h.handshakeEvent <- struct{}{}
|
||||||
close(h.handshakeEvent)
|
close(h.handshakeComplete)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,17 +20,20 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
|
||||||
|
|
||||||
var _ = Describe("TLS Crypto Setup", func() {
|
var _ = Describe("TLS Crypto Setup", func() {
|
||||||
var (
|
var (
|
||||||
cs *cryptoSetupTLS
|
cs *cryptoSetupTLS
|
||||||
handshakeEvent chan struct{}
|
handshakeEvent chan struct{}
|
||||||
|
handshakeComplete chan struct{}
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
handshakeEvent = make(chan struct{}, 2)
|
handshakeEvent = make(chan struct{}, 2)
|
||||||
|
handshakeComplete = make(chan struct{})
|
||||||
css, err := NewCryptoSetupTLSServer(
|
css, err := NewCryptoSetupTLSServer(
|
||||||
newCryptoStreamConn(bytes.NewBuffer([]byte{})),
|
newCryptoStreamConn(bytes.NewBuffer([]byte{})),
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
&mint.Config{},
|
&mint.Config{},
|
||||||
handshakeEvent,
|
handshakeEvent,
|
||||||
|
handshakeComplete,
|
||||||
protocol.VersionTLS,
|
protocol.VersionTLS,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -54,7 +57,7 @@ var _ = Describe("TLS Crypto Setup", func() {
|
||||||
err := cs.HandleCryptoStream()
|
err := cs.HandleCryptoStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(handshakeEvent).To(Receive())
|
Expect(handshakeEvent).To(Receive())
|
||||||
Expect(handshakeEvent).To(BeClosed())
|
Expect(handshakeComplete).To(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("handshakes until it is connected", func() {
|
It("handshakes until it is connected", func() {
|
||||||
|
|
117
session.go
117
session.go
|
@ -119,8 +119,9 @@ type session struct {
|
||||||
paramsChan <-chan handshake.TransportParameters
|
paramsChan <-chan handshake.TransportParameters
|
||||||
// the handshakeEvent channel is passed to the CryptoSetup.
|
// the handshakeEvent channel is passed to the CryptoSetup.
|
||||||
// It receives when it makes sense to try decrypting undecryptable packets.
|
// It receives when it makes sense to try decrypting undecryptable packets.
|
||||||
handshakeEvent <-chan struct{}
|
handshakeEvent <-chan struct{}
|
||||||
handshakeComplete bool
|
handshakeCompleteChan <-chan struct{} // is closed when the handshake completes
|
||||||
|
handshakeComplete bool
|
||||||
|
|
||||||
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
|
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
|
||||||
receivedFirstForwardSecurePacket bool
|
receivedFirstForwardSecurePacket bool
|
||||||
|
@ -162,17 +163,19 @@ func newSession(
|
||||||
logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID)
|
logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID)
|
||||||
paramsChan := make(chan handshake.TransportParameters)
|
paramsChan := make(chan handshake.TransportParameters)
|
||||||
handshakeEvent := make(chan struct{}, 1)
|
handshakeEvent := make(chan struct{}, 1)
|
||||||
|
handshakeCompleteChan := make(chan struct{})
|
||||||
s := &session{
|
s := &session{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
sessionRunner: sessionRunner,
|
sessionRunner: sessionRunner,
|
||||||
srcConnID: srcConnID,
|
srcConnID: srcConnID,
|
||||||
destConnID: destConnID,
|
destConnID: destConnID,
|
||||||
perspective: protocol.PerspectiveServer,
|
perspective: protocol.PerspectiveServer,
|
||||||
version: v,
|
version: v,
|
||||||
config: config,
|
config: config,
|
||||||
handshakeEvent: handshakeEvent,
|
handshakeEvent: handshakeEvent,
|
||||||
paramsChan: paramsChan,
|
handshakeCompleteChan: handshakeCompleteChan,
|
||||||
logger: logger,
|
paramsChan: paramsChan,
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
s.preSetup()
|
s.preSetup()
|
||||||
transportParams := &handshake.TransportParameters{
|
transportParams := &handshake.TransportParameters{
|
||||||
|
@ -197,6 +200,7 @@ func newSession(
|
||||||
s.config.AcceptCookie,
|
s.config.AcceptCookie,
|
||||||
paramsChan,
|
paramsChan,
|
||||||
handshakeEvent,
|
handshakeEvent,
|
||||||
|
handshakeCompleteChan,
|
||||||
s.logger,
|
s.logger,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -238,17 +242,19 @@ var newClientSession = func(
|
||||||
logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID)
|
logger.Debugf("Creating new session. Destination Connection ID: %s, Source Connection ID: %s", destConnID, srcConnID)
|
||||||
paramsChan := make(chan handshake.TransportParameters)
|
paramsChan := make(chan handshake.TransportParameters)
|
||||||
handshakeEvent := make(chan struct{}, 1)
|
handshakeEvent := make(chan struct{}, 1)
|
||||||
|
handshakeCompleteChan := make(chan struct{})
|
||||||
s := &session{
|
s := &session{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
sessionRunner: sessionRunner,
|
sessionRunner: sessionRunner,
|
||||||
srcConnID: srcConnID,
|
srcConnID: srcConnID,
|
||||||
destConnID: destConnID,
|
destConnID: destConnID,
|
||||||
perspective: protocol.PerspectiveClient,
|
perspective: protocol.PerspectiveClient,
|
||||||
version: v,
|
version: v,
|
||||||
config: config,
|
config: config,
|
||||||
handshakeEvent: handshakeEvent,
|
handshakeEvent: handshakeEvent,
|
||||||
paramsChan: paramsChan,
|
handshakeCompleteChan: handshakeCompleteChan,
|
||||||
logger: logger,
|
paramsChan: paramsChan,
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
s.preSetup()
|
s.preSetup()
|
||||||
transportParams := &handshake.TransportParameters{
|
transportParams := &handshake.TransportParameters{
|
||||||
|
@ -266,6 +272,7 @@ var newClientSession = func(
|
||||||
transportParams,
|
transportParams,
|
||||||
paramsChan,
|
paramsChan,
|
||||||
handshakeEvent,
|
handshakeEvent,
|
||||||
|
handshakeCompleteChan,
|
||||||
initialVersion,
|
initialVersion,
|
||||||
negotiatedVersions,
|
negotiatedVersions,
|
||||||
s.logger,
|
s.logger,
|
||||||
|
@ -307,16 +314,18 @@ func newTLSServerSession(
|
||||||
v protocol.VersionNumber,
|
v protocol.VersionNumber,
|
||||||
) (quicSession, error) {
|
) (quicSession, error) {
|
||||||
handshakeEvent := make(chan struct{}, 1)
|
handshakeEvent := make(chan struct{}, 1)
|
||||||
|
handshakeCompleteChan := make(chan struct{})
|
||||||
s := &session{
|
s := &session{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
sessionRunner: runner,
|
sessionRunner: runner,
|
||||||
config: config,
|
config: config,
|
||||||
srcConnID: srcConnID,
|
srcConnID: srcConnID,
|
||||||
destConnID: destConnID,
|
destConnID: destConnID,
|
||||||
perspective: protocol.PerspectiveServer,
|
perspective: protocol.PerspectiveServer,
|
||||||
version: v,
|
version: v,
|
||||||
handshakeEvent: handshakeEvent,
|
handshakeEvent: handshakeEvent,
|
||||||
logger: logger,
|
handshakeCompleteChan: handshakeCompleteChan,
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
s.preSetup()
|
s.preSetup()
|
||||||
cs, err := handshake.NewCryptoSetupTLSServer(
|
cs, err := handshake.NewCryptoSetupTLSServer(
|
||||||
|
@ -324,6 +333,7 @@ func newTLSServerSession(
|
||||||
origConnID,
|
origConnID,
|
||||||
mintConf,
|
mintConf,
|
||||||
handshakeEvent,
|
handshakeEvent,
|
||||||
|
handshakeCompleteChan,
|
||||||
v,
|
v,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -370,17 +380,19 @@ var newTLSClientSession = func(
|
||||||
v protocol.VersionNumber,
|
v protocol.VersionNumber,
|
||||||
) (quicSession, error) {
|
) (quicSession, error) {
|
||||||
handshakeEvent := make(chan struct{}, 1)
|
handshakeEvent := make(chan struct{}, 1)
|
||||||
|
handshakeCompleteChan := make(chan struct{})
|
||||||
s := &session{
|
s := &session{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
sessionRunner: runner,
|
sessionRunner: runner,
|
||||||
config: conf,
|
config: conf,
|
||||||
srcConnID: srcConnID,
|
srcConnID: srcConnID,
|
||||||
destConnID: destConnID,
|
destConnID: destConnID,
|
||||||
perspective: protocol.PerspectiveClient,
|
perspective: protocol.PerspectiveClient,
|
||||||
version: v,
|
version: v,
|
||||||
handshakeEvent: handshakeEvent,
|
handshakeEvent: handshakeEvent,
|
||||||
paramsChan: paramsChan,
|
handshakeCompleteChan: handshakeCompleteChan,
|
||||||
logger: logger,
|
paramsChan: paramsChan,
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
s.preSetup()
|
s.preSetup()
|
||||||
cs, err := handshake.NewCryptoSetupTLSClient(
|
cs, err := handshake.NewCryptoSetupTLSClient(
|
||||||
|
@ -388,6 +400,7 @@ var newTLSClientSession = func(
|
||||||
s.destConnID,
|
s.destConnID,
|
||||||
mintConf,
|
mintConf,
|
||||||
handshakeEvent,
|
handshakeEvent,
|
||||||
|
handshakeCompleteChan,
|
||||||
v,
|
v,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -463,9 +476,10 @@ runLoop:
|
||||||
select {
|
select {
|
||||||
case closeErr = <-s.closeChan:
|
case closeErr = <-s.closeChan:
|
||||||
break runLoop
|
break runLoop
|
||||||
case _, ok := <-s.handshakeEvent:
|
case <-s.handshakeEvent:
|
||||||
// when the handshake is completed, the channel will be closed
|
s.tryDecryptingQueuedPackets()
|
||||||
s.handleHandshakeEvent(!ok)
|
case <-s.handshakeCompleteChan:
|
||||||
|
s.handleHandshakeComplete()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -497,9 +511,10 @@ runLoop:
|
||||||
case p := <-s.paramsChan:
|
case p := <-s.paramsChan:
|
||||||
s.processTransportParameters(&p)
|
s.processTransportParameters(&p)
|
||||||
continue
|
continue
|
||||||
case _, ok := <-s.handshakeEvent:
|
case <-s.handshakeEvent:
|
||||||
// when the handshake is completed, the channel will be closed
|
s.tryDecryptingQueuedPackets()
|
||||||
s.handleHandshakeEvent(!ok)
|
case <-s.handshakeCompleteChan:
|
||||||
|
s.handleHandshakeComplete()
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
@ -590,13 +605,9 @@ func (s *session) maybeResetTimer() {
|
||||||
s.timer.Reset(deadline)
|
s.timer.Reset(deadline)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handleHandshakeEvent(completed bool) {
|
func (s *session) handleHandshakeComplete() {
|
||||||
if !completed {
|
|
||||||
s.tryDecryptingQueuedPackets()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.handshakeComplete = true
|
s.handshakeComplete = true
|
||||||
s.handshakeEvent = nil // prevent this case from ever being selected again
|
s.handshakeCompleteChan = nil // prevent this case from ever being selected again
|
||||||
s.sessionRunner.onHandshakeComplete(s)
|
s.sessionRunner.onHandshakeComplete(s)
|
||||||
|
|
||||||
// In gQUIC, the server completes the handshake first (after sending the SHLO).
|
// In gQUIC, the server completes the handshake first (after sending the SHLO).
|
||||||
|
|
|
@ -93,14 +93,15 @@ func areSessionsRunning() bool {
|
||||||
|
|
||||||
var _ = Describe("Session", func() {
|
var _ = Describe("Session", func() {
|
||||||
var (
|
var (
|
||||||
sess *session
|
sess *session
|
||||||
sessionRunner *MockSessionRunner
|
sessionRunner *MockSessionRunner
|
||||||
scfg *handshake.ServerConfig
|
scfg *handshake.ServerConfig
|
||||||
mconn *mockConnection
|
mconn *mockConnection
|
||||||
cryptoSetup *mockCryptoSetup
|
cryptoSetup *mockCryptoSetup
|
||||||
streamManager *MockStreamManager
|
streamManager *MockStreamManager
|
||||||
packer *MockPacker
|
packer *MockPacker
|
||||||
handshakeChan chan<- struct{}
|
handshakeChan chan<- struct{}
|
||||||
|
handshakeCompleteChan chan<- struct{}
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -119,9 +120,11 @@ var _ = Describe("Session", func() {
|
||||||
_ func(net.Addr, *Cookie) bool,
|
_ func(net.Addr, *Cookie) bool,
|
||||||
_ chan<- handshake.TransportParameters,
|
_ chan<- handshake.TransportParameters,
|
||||||
handshakeChanP chan<- struct{},
|
handshakeChanP chan<- struct{},
|
||||||
|
handshakeCompleteChanP chan<- struct{},
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
) (handshake.CryptoSetup, error) {
|
) (handshake.CryptoSetup, error) {
|
||||||
handshakeChan = handshakeChanP
|
handshakeChan = handshakeChanP
|
||||||
|
handshakeCompleteChan = handshakeCompleteChanP
|
||||||
return cryptoSetup, nil
|
return cryptoSetup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,6 +181,7 @@ var _ = Describe("Session", func() {
|
||||||
cookieFunc func(net.Addr, *Cookie) bool,
|
cookieFunc func(net.Addr, *Cookie) bool,
|
||||||
_ chan<- handshake.TransportParameters,
|
_ chan<- handshake.TransportParameters,
|
||||||
_ chan<- struct{},
|
_ chan<- struct{},
|
||||||
|
_ chan<- struct{},
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
) (handshake.CryptoSetup, error) {
|
) (handshake.CryptoSetup, error) {
|
||||||
cookieVerify = cookieFunc
|
cookieVerify = cookieFunc
|
||||||
|
@ -1255,7 +1259,7 @@ var _ = Describe("Session", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("calls the onHandshakeComplete callback when the handshake completes", func() {
|
It("calls the onHandshakeComplete callback when the handshake completes", func() {
|
||||||
close(handshakeChan)
|
close(handshakeCompleteChan)
|
||||||
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
|
sessionRunner.EXPECT().onHandshakeComplete(gomock.Any())
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -1469,7 +1473,7 @@ var _ = Describe("Session", func() {
|
||||||
return &packedPacket{}, nil
|
return &packedPacket{}, nil
|
||||||
})
|
})
|
||||||
sess.config.IdleTimeout = 0
|
sess.config.IdleTimeout = 0
|
||||||
close(handshakeChan)
|
close(handshakeCompleteChan)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -1578,11 +1582,11 @@ var _ = Describe("Session", func() {
|
||||||
|
|
||||||
var _ = Describe("Client Session", func() {
|
var _ = Describe("Client Session", func() {
|
||||||
var (
|
var (
|
||||||
sess *session
|
sess *session
|
||||||
sessionRunner *MockSessionRunner
|
sessionRunner *MockSessionRunner
|
||||||
packer *MockPacker
|
packer *MockPacker
|
||||||
mconn *mockConnection
|
mconn *mockConnection
|
||||||
handshakeChan chan<- struct{}
|
handshakeCompleteChan chan<- struct{}
|
||||||
|
|
||||||
cryptoSetup *mockCryptoSetup
|
cryptoSetup *mockCryptoSetup
|
||||||
)
|
)
|
||||||
|
@ -1598,12 +1602,13 @@ var _ = Describe("Client Session", func() {
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TransportParameters,
|
_ *handshake.TransportParameters,
|
||||||
_ chan<- handshake.TransportParameters,
|
_ chan<- handshake.TransportParameters,
|
||||||
handshakeChanP chan<- struct{},
|
_ chan<- struct{},
|
||||||
|
handshakeCompleteChanP chan<- struct{},
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
_ []protocol.VersionNumber,
|
_ []protocol.VersionNumber,
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
) (handshake.CryptoSetup, error) {
|
) (handshake.CryptoSetup, error) {
|
||||||
handshakeChan = handshakeChanP
|
handshakeCompleteChan = handshakeCompleteChanP
|
||||||
return cryptoSetup, nil
|
return cryptoSetup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1641,7 +1646,7 @@ var _ = Describe("Client Session", func() {
|
||||||
}),
|
}),
|
||||||
packer.EXPECT().PackPacket().AnyTimes(),
|
packer.EXPECT().PackPacket().AnyTimes(),
|
||||||
)
|
)
|
||||||
close(handshakeChan)
|
close(handshakeCompleteChan)
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
sess.run()
|
sess.run()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue