mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
use a callback to signal completion of the handshake
This commit is contained in:
parent
4fd6a7cc99
commit
e361d3c5cd
6 changed files with 65 additions and 58 deletions
|
@ -267,6 +267,7 @@ func (h *cryptoSetup) RunHandshake() {
|
|||
|
||||
select {
|
||||
case <-handshakeComplete: // return when the handshake is done
|
||||
h.runner.OnHandshakeComplete()
|
||||
case <-h.closeChan:
|
||||
// wait until the Handshake() go routine has returned
|
||||
close(h.messageChan)
|
||||
|
|
|
@ -291,11 +291,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
}
|
||||
|
||||
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
||||
var cHandshakeComplete bool
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cErrChan := make(chan error, 1)
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
|
||||
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
|
||||
client, _, err := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
|
@ -309,11 +311,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var sHandshakeComplete bool
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sErrChan := make(chan error, 1)
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
|
||||
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
|
||||
var token [16]byte
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
|
@ -333,10 +337,12 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
select {
|
||||
case sErr = <-sErrChan:
|
||||
default:
|
||||
Expect(sHandshakeComplete).To(BeTrue())
|
||||
}
|
||||
select {
|
||||
case cErr = <-cErrChan:
|
||||
default:
|
||||
Expect(cHandshakeComplete).To(BeTrue())
|
||||
}
|
||||
return cErr, sErr
|
||||
}
|
||||
|
@ -408,6 +414,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b })
|
||||
cRunner.EXPECT().OnHandshakeComplete()
|
||||
client, _, err := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
|
@ -425,6 +432,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
var token [16]byte
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b })
|
||||
sRunner.EXPECT().OnHandshakeComplete()
|
||||
sTransportParameters := &TransportParameters{
|
||||
IdleTimeout: 0x1337 * time.Second,
|
||||
StatelessResetToken: &token,
|
||||
|
|
|
@ -30,6 +30,7 @@ type tlsExtensionHandler interface {
|
|||
|
||||
type handshakeRunner interface {
|
||||
OnReceivedParams([]byte)
|
||||
OnHandshakeComplete()
|
||||
OnError(error)
|
||||
DropKeys(protocol.EncryptionLevel)
|
||||
}
|
||||
|
|
|
@ -58,6 +58,18 @@ func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Cal
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0)
|
||||
}
|
||||
|
||||
// OnHandshakeComplete mocks base method
|
||||
func (m *MockHandshakeRunner) OnHandshakeComplete() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnHandshakeComplete")
|
||||
}
|
||||
|
||||
// OnHandshakeComplete indicates an expected call of OnHandshakeComplete
|
||||
func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockHandshakeRunner)(nil).OnHandshakeComplete))
|
||||
}
|
||||
|
||||
// OnReceivedParams mocks base method
|
||||
func (m *MockHandshakeRunner) OnReceivedParams(arg0 []byte) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
29
session.go
29
session.go
|
@ -72,14 +72,16 @@ func (p *receivedPacket) Clone() *receivedPacket {
|
|||
}
|
||||
|
||||
type handshakeRunner struct {
|
||||
onReceivedParams func([]byte)
|
||||
onError func(error)
|
||||
dropKeys func(protocol.EncryptionLevel)
|
||||
onReceivedParams func([]byte)
|
||||
onError func(error)
|
||||
dropKeys func(protocol.EncryptionLevel)
|
||||
onHandshakeComplete func()
|
||||
}
|
||||
|
||||
func (r *handshakeRunner) OnReceivedParams(b []byte) { r.onReceivedParams(b) }
|
||||
func (r *handshakeRunner) OnError(e error) { r.onError(e) }
|
||||
func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) }
|
||||
func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() }
|
||||
|
||||
type closeError struct {
|
||||
err error
|
||||
|
@ -209,9 +211,10 @@ var newSession = func(
|
|||
conn.RemoteAddr(),
|
||||
params,
|
||||
&handshakeRunner{
|
||||
onReceivedParams: s.processTransportParameters,
|
||||
onError: s.closeLocal,
|
||||
dropKeys: s.dropEncryptionLevel,
|
||||
onReceivedParams: s.processTransportParameters,
|
||||
onError: s.closeLocal,
|
||||
dropKeys: s.dropEncryptionLevel,
|
||||
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
|
||||
},
|
||||
tlsConf,
|
||||
logger,
|
||||
|
@ -281,9 +284,10 @@ var newClientSession = func(
|
|||
conn.RemoteAddr(),
|
||||
params,
|
||||
&handshakeRunner{
|
||||
onReceivedParams: s.processTransportParameters,
|
||||
onError: s.closeLocal,
|
||||
dropKeys: s.dropEncryptionLevel,
|
||||
onReceivedParams: s.processTransportParameters,
|
||||
onError: s.closeLocal,
|
||||
dropKeys: s.dropEncryptionLevel,
|
||||
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
|
||||
},
|
||||
tlsConf,
|
||||
logger,
|
||||
|
@ -353,11 +357,8 @@ func (s *session) postSetup() error {
|
|||
func (s *session) run() error {
|
||||
defer s.ctxCancel()
|
||||
|
||||
go func() {
|
||||
s.cryptoStreamHandler.RunHandshake()
|
||||
// If an error occurred during the handshake, the crypto setup will already have called the close callback.
|
||||
close(s.handshakeCompleteChan)
|
||||
}()
|
||||
go s.cryptoStreamHandler.RunHandshake()
|
||||
|
||||
if s.perspective == protocol.PerspectiveClient {
|
||||
select {
|
||||
case <-s.clientHelloWritten:
|
||||
|
|
|
@ -328,7 +328,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
Expect(sess.run()).To(MatchError(testErr))
|
||||
}()
|
||||
ccf := &wire.ConnectionCloseFrame{
|
||||
|
@ -347,7 +347,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
Expect(sess.run()).To(MatchError(testErr))
|
||||
}()
|
||||
ccf := &wire.ConnectionCloseFrame{
|
||||
|
@ -383,7 +383,7 @@ var _ = Describe("Session", func() {
|
|||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
runErr = sess.run()
|
||||
}()
|
||||
Eventually(areSessionsRunning).Should(BeTrue())
|
||||
|
@ -565,7 +565,7 @@ var _ = Describe("Session", func() {
|
|||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
}()
|
||||
sessionRunner.EXPECT().Retire(gomock.Any())
|
||||
|
@ -590,7 +590,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
err := sess.run()
|
||||
Expect(err).To(MatchError("PROTOCOL_VIOLATION: empty packet"))
|
||||
close(done)
|
||||
|
@ -938,7 +938,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
|
@ -965,7 +965,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
|
@ -993,7 +993,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
|
@ -1021,7 +1021,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
|
@ -1043,7 +1043,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
|
@ -1071,7 +1071,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
}()
|
||||
Consistently(mconn.written).ShouldNot(Receive())
|
||||
|
@ -1106,7 +1106,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
}()
|
||||
Eventually(mconn.written).Should(Receive())
|
||||
|
@ -1121,25 +1121,7 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
})
|
||||
|
||||
It("calls the onHandshakeComplete callback when the handshake completes", func() {
|
||||
packer.EXPECT().PackPacket().AnyTimes()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any())
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
sess.run()
|
||||
}()
|
||||
Consistently(sess.Context().Done()).ShouldNot(BeClosed())
|
||||
// make sure the go routine returns
|
||||
sessionRunner.EXPECT().Retire(gomock.Any())
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
|
||||
cryptoSetup.EXPECT().Close()
|
||||
Expect(sess.Close()).To(Succeed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends a forward-secure packet when the handshake completes", func() {
|
||||
It("sends a 1-RTT packet when the handshake completes", func() {
|
||||
done := make(chan struct{})
|
||||
gomock.InOrder(
|
||||
sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()),
|
||||
|
@ -1155,6 +1137,7 @@ var _ = Describe("Session", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
close(sess.handshakeCompleteChan)
|
||||
sess.run()
|
||||
}()
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -1171,7 +1154,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
Expect(sess.run()).To(Succeed())
|
||||
close(done)
|
||||
}()
|
||||
|
@ -1188,7 +1171,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
err := sess.run()
|
||||
Expect(err).To(MatchError(qerr.Error(0x1337, testErr.Error())))
|
||||
close(done)
|
||||
|
@ -1205,7 +1188,7 @@ var _ = Describe("Session", func() {
|
|||
It("errors if it can't unmarshal the TransportParameters", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
err := sess.run()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("transport parameter"))
|
||||
|
@ -1221,7 +1204,7 @@ var _ = Describe("Session", func() {
|
|||
It("process transport parameters received from the client", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
}()
|
||||
params := &handshake.TransportParameters{
|
||||
|
@ -1265,7 +1248,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
|
@ -1286,7 +1269,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
|
@ -1307,7 +1290,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
close(done)
|
||||
}()
|
||||
|
@ -1335,7 +1318,7 @@ var _ = Describe("Session", func() {
|
|||
cryptoSetup.EXPECT().Close()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
err := sess.run()
|
||||
nerr, ok := err.(net.Error)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
@ -1353,7 +1336,7 @@ var _ = Describe("Session", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
err := sess.run()
|
||||
nerr, ok := err.(net.Error)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
@ -1375,7 +1358,7 @@ var _ = Describe("Session", func() {
|
|||
// and not on the last network activity
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
}()
|
||||
Consistently(sess.Context().Done()).ShouldNot(BeClosed())
|
||||
|
@ -1396,6 +1379,7 @@ var _ = Describe("Session", func() {
|
|||
defer GinkgoRecover()
|
||||
sessionRunner.EXPECT().OnHandshakeComplete(sess)
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
close(sess.handshakeCompleteChan)
|
||||
err := sess.run()
|
||||
nerr, ok := err.(net.Error)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
@ -1413,7 +1397,7 @@ var _ = Describe("Session", func() {
|
|||
sess.config.IdleTimeout = 30 * time.Second
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
}()
|
||||
Consistently(sess.Context().Done()).ShouldNot(BeClosed())
|
||||
|
@ -1553,7 +1537,7 @@ var _ = Describe("Client Session", func() {
|
|||
sess.unpacker = unpacker
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }).AnyTimes()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
sess.run()
|
||||
}()
|
||||
newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}
|
||||
|
@ -1630,7 +1614,7 @@ var _ = Describe("Client Session", func() {
|
|||
It("errors if it can't unmarshal the TransportParameters", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
err := sess.run()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("transport parameter"))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue