mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
use a struct to pass callbacks from the session to the crypto setup
This commit is contained in:
parent
743868159f
commit
4fd6a7cc99
6 changed files with 143 additions and 66 deletions
|
@ -73,10 +73,8 @@ type cryptoSetup struct {
|
||||||
messageChan chan []byte
|
messageChan chan []byte
|
||||||
|
|
||||||
paramsChan <-chan []byte
|
paramsChan <-chan []byte
|
||||||
handleParamsCallback func([]byte)
|
|
||||||
|
|
||||||
dropKeyCallback func(protocol.EncryptionLevel)
|
runner handshakeRunner
|
||||||
closeCallback func(error)
|
|
||||||
|
|
||||||
alertChan chan uint8
|
alertChan chan uint8
|
||||||
// HandleData() sends errors on the messageErrChan
|
// HandleData() sends errors on the messageErrChan
|
||||||
|
@ -131,9 +129,7 @@ func NewCryptoSetupClient(
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
remoteAddr net.Addr,
|
remoteAddr net.Addr,
|
||||||
tp *TransportParameters,
|
tp *TransportParameters,
|
||||||
handleParams func([]byte),
|
runner handshakeRunner,
|
||||||
dropKeys func(protocol.EncryptionLevel),
|
|
||||||
close func(error),
|
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||||
|
@ -143,9 +139,7 @@ func NewCryptoSetupClient(
|
||||||
oneRTTStream,
|
oneRTTStream,
|
||||||
connID,
|
connID,
|
||||||
tp,
|
tp,
|
||||||
handleParams,
|
runner,
|
||||||
dropKeys,
|
|
||||||
close,
|
|
||||||
tlsConf,
|
tlsConf,
|
||||||
logger,
|
logger,
|
||||||
protocol.PerspectiveClient,
|
protocol.PerspectiveClient,
|
||||||
|
@ -165,9 +159,7 @@ func NewCryptoSetupServer(
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
remoteAddr net.Addr,
|
remoteAddr net.Addr,
|
||||||
tp *TransportParameters,
|
tp *TransportParameters,
|
||||||
handleParams func([]byte),
|
runner handshakeRunner,
|
||||||
dropKeys func(protocol.EncryptionLevel),
|
|
||||||
close func(error),
|
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
) (CryptoSetup, error) {
|
) (CryptoSetup, error) {
|
||||||
|
@ -177,9 +169,7 @@ func NewCryptoSetupServer(
|
||||||
oneRTTStream,
|
oneRTTStream,
|
||||||
connID,
|
connID,
|
||||||
tp,
|
tp,
|
||||||
handleParams,
|
runner,
|
||||||
dropKeys,
|
|
||||||
close,
|
|
||||||
tlsConf,
|
tlsConf,
|
||||||
logger,
|
logger,
|
||||||
protocol.PerspectiveServer,
|
protocol.PerspectiveServer,
|
||||||
|
@ -197,9 +187,7 @@ func newCryptoSetup(
|
||||||
oneRTTStream io.Writer,
|
oneRTTStream io.Writer,
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
tp *TransportParameters,
|
tp *TransportParameters,
|
||||||
handleParams func([]byte),
|
runner handshakeRunner,
|
||||||
dropKeys func(protocol.EncryptionLevel),
|
|
||||||
close func(error),
|
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
perspective protocol.Perspective,
|
perspective protocol.Perspective,
|
||||||
|
@ -217,9 +205,7 @@ func newCryptoSetup(
|
||||||
oneRTTStream: oneRTTStream,
|
oneRTTStream: oneRTTStream,
|
||||||
readEncLevel: protocol.EncryptionInitial,
|
readEncLevel: protocol.EncryptionInitial,
|
||||||
writeEncLevel: protocol.EncryptionInitial,
|
writeEncLevel: protocol.EncryptionInitial,
|
||||||
handleParamsCallback: handleParams,
|
runner: runner,
|
||||||
dropKeyCallback: dropKeys,
|
|
||||||
closeCallback: close,
|
|
||||||
paramsChan: extHandler.TransportParameters(),
|
paramsChan: extHandler.TransportParameters(),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
perspective: perspective,
|
perspective: perspective,
|
||||||
|
@ -254,7 +240,7 @@ func (h *cryptoSetup) Received1RTTAck() {
|
||||||
if h.initialOpener != nil {
|
if h.initialOpener != nil {
|
||||||
h.initialOpener = nil
|
h.initialOpener = nil
|
||||||
h.initialSealer = nil
|
h.initialSealer = nil
|
||||||
h.dropKeyCallback(protocol.EncryptionInitial)
|
h.runner.DropKeys(protocol.EncryptionInitial)
|
||||||
h.logger.Debugf("Dropping Initial keys.")
|
h.logger.Debugf("Dropping Initial keys.")
|
||||||
}
|
}
|
||||||
// drop handshake keys
|
// drop handshake keys
|
||||||
|
@ -262,7 +248,7 @@ func (h *cryptoSetup) Received1RTTAck() {
|
||||||
h.handshakeOpener = nil
|
h.handshakeOpener = nil
|
||||||
h.handshakeSealer = nil
|
h.handshakeSealer = nil
|
||||||
h.logger.Debugf("Dropping Handshake keys.")
|
h.logger.Debugf("Dropping Handshake keys.")
|
||||||
h.dropKeyCallback(protocol.EncryptionHandshake)
|
h.runner.DropKeys(protocol.EncryptionHandshake)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,14 +272,14 @@ func (h *cryptoSetup) RunHandshake() {
|
||||||
close(h.messageChan)
|
close(h.messageChan)
|
||||||
case alert := <-h.alertChan:
|
case alert := <-h.alertChan:
|
||||||
handshakeErr := <-handshakeErrChan
|
handshakeErr := <-handshakeErrChan
|
||||||
h.closeCallback(qerr.CryptoError(alert, handshakeErr.Error()))
|
h.runner.OnError(qerr.CryptoError(alert, handshakeErr.Error()))
|
||||||
case err := <-h.messageErrChan:
|
case err := <-h.messageErrChan:
|
||||||
// If the handshake errored because of an error that occurred during HandleData(),
|
// If the handshake errored because of an error that occurred during HandleData(),
|
||||||
// that error message will be more useful than the error message generated by Handshake().
|
// that error message will be more useful than the error message generated by Handshake().
|
||||||
// Close the message chan that qtls is receiving messages from.
|
// Close the message chan that qtls is receiving messages from.
|
||||||
// This will make qtls.Handshake() return.
|
// This will make qtls.Handshake() return.
|
||||||
close(h.messageChan)
|
close(h.messageChan)
|
||||||
h.closeCallback(err)
|
h.runner.OnError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,7 +344,7 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
|
||||||
h.logger.Debugf("Sending HelloRetryRequest")
|
h.logger.Debugf("Sending HelloRetryRequest")
|
||||||
return false
|
return false
|
||||||
case data := <-h.paramsChan:
|
case data := <-h.paramsChan:
|
||||||
h.handleParamsCallback(data)
|
h.runner.OnReceivedParams(data)
|
||||||
case <-h.handshakeDone:
|
case <-h.handshakeDone:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -423,7 +409,7 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
|
||||||
case typeEncryptedExtensions:
|
case typeEncryptedExtensions:
|
||||||
select {
|
select {
|
||||||
case data := <-h.paramsChan:
|
case data := <-h.paramsChan:
|
||||||
h.handleParamsCallback(data)
|
h.runner.OnReceivedParams(data)
|
||||||
case <-h.handshakeDone:
|
case <-h.handshakeDone:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"math/big"
|
"math/big"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
|
@ -87,9 +88,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
func([]byte) {},
|
NewMockHandshakeRunner(mockCtrl),
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(err error) { Fail("error callback called") },
|
|
||||||
tlsConf,
|
tlsConf,
|
||||||
utils.DefaultLogger.WithPrefix("server"),
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
)
|
)
|
||||||
|
@ -110,6 +109,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
|
||||||
It("returns Handshake() when an error occurs", func() {
|
It("returns Handshake() when an error occurs", func() {
|
||||||
sErrChan := make(chan error, 1)
|
sErrChan := make(chan error, 1)
|
||||||
|
runner := NewMockHandshakeRunner(mockCtrl)
|
||||||
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||||
_, sInitialStream, sHandshakeStream := initStreams()
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
server, err := NewCryptoSetupServer(
|
server, err := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
|
@ -118,9 +119,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
func([]byte) {},
|
runner,
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(e error) { sErrChan <- e },
|
|
||||||
testdata.GetTLSConfig(),
|
testdata.GetTLSConfig(),
|
||||||
utils.DefaultLogger.WithPrefix("server"),
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
)
|
)
|
||||||
|
@ -148,6 +147,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
It("returns Handshake() when a message is received at the wrong encryption level", func() {
|
It("returns Handshake() when a message is received at the wrong encryption level", func() {
|
||||||
sErrChan := make(chan error, 1)
|
sErrChan := make(chan error, 1)
|
||||||
_, sInitialStream, sHandshakeStream := initStreams()
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
|
runner := NewMockHandshakeRunner(mockCtrl)
|
||||||
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||||
server, err := NewCryptoSetupServer(
|
server, err := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
|
@ -155,9 +156,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
func([]byte) {},
|
runner,
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(e error) { sErrChan <- e },
|
|
||||||
testdata.GetTLSConfig(),
|
testdata.GetTLSConfig(),
|
||||||
utils.DefaultLogger.WithPrefix("server"),
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
)
|
)
|
||||||
|
@ -185,6 +184,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
It("returns Handshake() when handling a message fails", func() {
|
It("returns Handshake() when handling a message fails", func() {
|
||||||
sErrChan := make(chan error, 1)
|
sErrChan := make(chan error, 1)
|
||||||
_, sInitialStream, sHandshakeStream := initStreams()
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
|
runner := NewMockHandshakeRunner(mockCtrl)
|
||||||
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||||
server, err := NewCryptoSetupServer(
|
server, err := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
|
@ -192,9 +193,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
func([]byte) {},
|
runner,
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(e error) { sErrChan <- e },
|
|
||||||
testdata.GetTLSConfig(),
|
testdata.GetTLSConfig(),
|
||||||
utils.DefaultLogger.WithPrefix("server"),
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
)
|
)
|
||||||
|
@ -220,7 +219,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns Handshake() when it is closed", func() {
|
It("returns Handshake() when it is closed", func() {
|
||||||
sErrChan := make(chan error, 1)
|
|
||||||
_, sInitialStream, sHandshakeStream := initStreams()
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
server, err := NewCryptoSetupServer(
|
server, err := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
|
@ -229,9 +227,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
func([]byte) {},
|
NewMockHandshakeRunner(mockCtrl),
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(e error) { sErrChan <- e },
|
|
||||||
testdata.GetTLSConfig(),
|
testdata.GetTLSConfig(),
|
||||||
utils.DefaultLogger.WithPrefix("server"),
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
)
|
)
|
||||||
|
@ -241,7 +237,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
server.RunHandshake()
|
server.RunHandshake()
|
||||||
Consistently(sErrChan).ShouldNot(Receive())
|
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
Expect(server.Close()).To(Succeed())
|
Expect(server.Close()).To(Succeed())
|
||||||
|
@ -298,6 +293,9 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
||||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
cErrChan := make(chan error, 1)
|
cErrChan := make(chan error, 1)
|
||||||
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
|
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
|
||||||
client, _, err := NewCryptoSetupClient(
|
client, _, err := NewCryptoSetupClient(
|
||||||
cInitialStream,
|
cInitialStream,
|
||||||
cHandshakeStream,
|
cHandshakeStream,
|
||||||
|
@ -305,9 +303,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
func([]byte) {},
|
cRunner,
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(e error) { cErrChan <- e },
|
|
||||||
clientConf,
|
clientConf,
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
)
|
)
|
||||||
|
@ -315,6 +311,9 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
|
||||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||||
sErrChan := make(chan error, 1)
|
sErrChan := make(chan error, 1)
|
||||||
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
|
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
|
||||||
var token [16]byte
|
var token [16]byte
|
||||||
server, err := NewCryptoSetupServer(
|
server, err := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
|
@ -323,9 +322,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{StatelessResetToken: &token},
|
&TransportParameters{StatelessResetToken: &token},
|
||||||
func([]byte) {},
|
sRunner,
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(e error) { sErrChan <- e },
|
|
||||||
serverConf,
|
serverConf,
|
||||||
utils.DefaultLogger.WithPrefix("server"),
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
)
|
)
|
||||||
|
@ -369,6 +366,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("signals when it has written the ClientHello", func() {
|
It("signals when it has written the ClientHello", func() {
|
||||||
|
runner := NewMockHandshakeRunner(mockCtrl)
|
||||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
client, chChan, err := NewCryptoSetupClient(
|
client, chChan, err := NewCryptoSetupClient(
|
||||||
cInitialStream,
|
cInitialStream,
|
||||||
|
@ -377,9 +375,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
func([]byte) {},
|
runner,
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(error) {},
|
|
||||||
&tls.Config{InsecureSkipVerify: true},
|
&tls.Config{InsecureSkipVerify: true},
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
)
|
)
|
||||||
|
@ -401,6 +397,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
Expect(len(ch.data) - 4).To(Equal(length))
|
Expect(len(ch.data) - 4).To(Equal(length))
|
||||||
|
|
||||||
// make the go routine return
|
// make the go routine return
|
||||||
|
runner.EXPECT().OnError(gomock.Any())
|
||||||
client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
|
client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -409,6 +406,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
var cTransportParametersRcvd, sTransportParametersRcvd []byte
|
var cTransportParametersRcvd, sTransportParametersRcvd []byte
|
||||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
|
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
|
||||||
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
|
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b })
|
||||||
client, _, err := NewCryptoSetupClient(
|
client, _, err := NewCryptoSetupClient(
|
||||||
cInitialStream,
|
cInitialStream,
|
||||||
cHandshakeStream,
|
cHandshakeStream,
|
||||||
|
@ -416,9 +415,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
cTransportParameters,
|
cTransportParameters,
|
||||||
func(p []byte) { sTransportParametersRcvd = p },
|
cRunner,
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(error) { Fail("error callback called") },
|
|
||||||
clientConf,
|
clientConf,
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
)
|
)
|
||||||
|
@ -426,6 +423,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
|
||||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||||
var token [16]byte
|
var token [16]byte
|
||||||
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
|
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b })
|
||||||
sTransportParameters := &TransportParameters{
|
sTransportParameters := &TransportParameters{
|
||||||
IdleTimeout: 0x1337 * time.Second,
|
IdleTimeout: 0x1337 * time.Second,
|
||||||
StatelessResetToken: &token,
|
StatelessResetToken: &token,
|
||||||
|
@ -437,9 +436,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
sTransportParameters,
|
sTransportParameters,
|
||||||
func(p []byte) { cTransportParametersRcvd = p },
|
sRunner,
|
||||||
func(protocol.EncryptionLevel) {},
|
|
||||||
func(error) { Fail("error callback called") },
|
|
||||||
testdata.GetTLSConfig(),
|
testdata.GetTLSConfig(),
|
||||||
utils.DefaultLogger.WithPrefix("server"),
|
utils.DefaultLogger.WithPrefix("server"),
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,6 +28,12 @@ type tlsExtensionHandler interface {
|
||||||
TransportParameters() <-chan []byte
|
TransportParameters() <-chan []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type handshakeRunner interface {
|
||||||
|
OnReceivedParams([]byte)
|
||||||
|
OnError(error)
|
||||||
|
DropKeys(protocol.EncryptionLevel)
|
||||||
|
}
|
||||||
|
|
||||||
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
||||||
type CryptoSetup interface {
|
type CryptoSetup interface {
|
||||||
RunHandshake()
|
RunHandshake()
|
||||||
|
|
71
internal/handshake/mock_handshake_runner_test.go
Normal file
71
internal/handshake/mock_handshake_runner_test.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: HandshakeRunner)
|
||||||
|
|
||||||
|
// Package handshake is a generated GoMock package.
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockHandshakeRunner is a mock of HandshakeRunner interface
|
||||||
|
type MockHandshakeRunner struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockHandshakeRunnerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockHandshakeRunnerMockRecorder is the mock recorder for MockHandshakeRunner
|
||||||
|
type MockHandshakeRunnerMockRecorder struct {
|
||||||
|
mock *MockHandshakeRunner
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockHandshakeRunner creates a new mock instance
|
||||||
|
func NewMockHandshakeRunner(ctrl *gomock.Controller) *MockHandshakeRunner {
|
||||||
|
mock := &MockHandshakeRunner{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockHandshakeRunnerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockHandshakeRunner) EXPECT() *MockHandshakeRunnerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropKeys mocks base method
|
||||||
|
func (m *MockHandshakeRunner) DropKeys(arg0 protocol.EncryptionLevel) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "DropKeys", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropKeys indicates an expected call of DropKeys
|
||||||
|
func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnError mocks base method
|
||||||
|
func (m *MockHandshakeRunner) OnError(arg0 error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "OnError", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnError indicates an expected call of OnError
|
||||||
|
func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnReceivedParams mocks base method
|
||||||
|
func (m *MockHandshakeRunner) OnReceivedParams(arg0 []byte) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "OnReceivedParams", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnReceivedParams indicates an expected call of OnReceivedParams
|
||||||
|
func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0)
|
||||||
|
}
|
3
internal/handshake/mockgen.go
Normal file
3
internal/handshake/mockgen.go
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
//go:generate sh -c "../mockgen_internal.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner"
|
26
session.go
26
session.go
|
@ -71,6 +71,16 @@ func (p *receivedPacket) Clone() *receivedPacket {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type handshakeRunner struct {
|
||||||
|
onReceivedParams func([]byte)
|
||||||
|
onError func(error)
|
||||||
|
dropKeys func(protocol.EncryptionLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *handshakeRunner) OnReceivedParams(b []byte) { r.onReceivedParams(b) }
|
||||||
|
func (r *handshakeRunner) OnError(e error) { r.onError(e) }
|
||||||
|
func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) }
|
||||||
|
|
||||||
type closeError struct {
|
type closeError struct {
|
||||||
err error
|
err error
|
||||||
remote bool
|
remote bool
|
||||||
|
@ -198,9 +208,11 @@ var newSession = func(
|
||||||
clientDestConnID,
|
clientDestConnID,
|
||||||
conn.RemoteAddr(),
|
conn.RemoteAddr(),
|
||||||
params,
|
params,
|
||||||
s.processTransportParameters,
|
&handshakeRunner{
|
||||||
s.dropEncryptionLevel,
|
onReceivedParams: s.processTransportParameters,
|
||||||
s.closeLocal,
|
onError: s.closeLocal,
|
||||||
|
dropKeys: s.dropEncryptionLevel,
|
||||||
|
},
|
||||||
tlsConf,
|
tlsConf,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
@ -268,9 +280,11 @@ var newClientSession = func(
|
||||||
s.destConnID,
|
s.destConnID,
|
||||||
conn.RemoteAddr(),
|
conn.RemoteAddr(),
|
||||||
params,
|
params,
|
||||||
s.processTransportParameters,
|
&handshakeRunner{
|
||||||
s.dropEncryptionLevel,
|
onReceivedParams: s.processTransportParameters,
|
||||||
s.closeLocal,
|
onError: s.closeLocal,
|
||||||
|
dropKeys: s.dropEncryptionLevel,
|
||||||
|
},
|
||||||
tlsConf,
|
tlsConf,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue