mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
fix error handling in the TLS crypto setup
There are two ways that an error can occur during the handshake: 1. as a return value from qtls.Handshake() 2. when new data is passed to the crypto setup via HandleData() We need to make sure that the RunHandshake() as well as HandleData() both return if an error occurs at any step during the handshake.
This commit is contained in:
parent
82508f1562
commit
2dbc29a5bd
6 changed files with 375 additions and 242 deletions
|
@ -48,25 +48,6 @@ func (s *stream) Write(b []byte) (int, error) {
|
|||
}
|
||||
|
||||
var _ = Describe("Crypto Setup TLS", func() {
|
||||
generateCert := func() tls.Certificate {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{},
|
||||
SignatureAlgorithm: x509.SHA256WithRSA,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return tls.Certificate{
|
||||
PrivateKey: priv,
|
||||
Certificate: [][]byte{certDER},
|
||||
}
|
||||
}
|
||||
|
||||
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
||||
chunkChan := make(chan chunk, 100)
|
||||
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
||||
|
@ -74,172 +55,16 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
return chunkChan, initialStream, handshakeStream
|
||||
}
|
||||
|
||||
handshake := func(
|
||||
client CryptoSetupTLS,
|
||||
cChunkChan <-chan chunk,
|
||||
server CryptoSetupTLS,
|
||||
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
select {
|
||||
case c := <-cChunkChan:
|
||||
err := server.HandleData(c.data, c.encLevel)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
case c := <-sChunkChan:
|
||||
err := client.HandleData(c.data, c.encLevel)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
case <-done: // handshake complete
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
serverErrChan := make(chan error)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serverErrChan <- server.RunHandshake()
|
||||
}()
|
||||
|
||||
clientErr := client.RunHandshake()
|
||||
var serverErr error
|
||||
Eventually(serverErrChan).Should(Receive(&serverErr))
|
||||
return clientErr, serverErr
|
||||
}
|
||||
|
||||
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
client, _, err := NewCryptoSetupTLSClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
clientConf,
|
||||
protocol.VersionTLS,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
It("returns Handshake() when an error occurs", func() {
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
&TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
serverConf,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return handshake(client, cChunkChan, server, sChunkChan)
|
||||
}
|
||||
|
||||
It("handshakes", func() {
|
||||
clientConf := &tls.Config{ServerName: "quic.clemente.io"}
|
||||
serverConf := testdata.GetTLSConfig()
|
||||
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
||||
Expect(clientErr).ToNot(HaveOccurred())
|
||||
Expect(serverErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("handshakes with client auth", func() {
|
||||
clientConf := &tls.Config{
|
||||
ServerName: "quic.clemente.io",
|
||||
Certificates: []tls.Certificate{generateCert()},
|
||||
}
|
||||
serverConf := testdata.GetTLSConfig()
|
||||
serverConf.ClientAuth = qtls.RequireAnyClientCert
|
||||
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
||||
Expect(clientErr).ToNot(HaveOccurred())
|
||||
Expect(serverErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("signals when it has written the ClientHello", func() {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
client, chChan, err := NewCryptoSetupTLSClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
protocol.VersionTLS,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
client.RunHandshake()
|
||||
close(done)
|
||||
}()
|
||||
var ch chunk
|
||||
Eventually(cChunkChan).Should(Receive(&ch))
|
||||
Eventually(chChan).Should(BeClosed())
|
||||
// make sure the whole ClientHello was written
|
||||
Expect(len(ch.data)).To(BeNumerically(">=", 4))
|
||||
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
|
||||
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
|
||||
Expect(len(ch.data) - 4).To(Equal(length))
|
||||
|
||||
// make the go routine return
|
||||
client.HandleData([]byte{1, 0, 0, 1, 0}, protocol.EncryptionInitial)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("receives transport parameters", func() {
|
||||
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
|
||||
client, _, err := NewCryptoSetupTLSClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
cTransportParameters,
|
||||
func(p *TransportParameters) { sTransportParametersRcvd = p },
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
&tls.Config{ServerName: "quic.clemente.io"},
|
||||
protocol.VersionTLS,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sTransportParameters := &TransportParameters{
|
||||
IdleTimeout: 0x1337 * time.Second,
|
||||
StatelessResetToken: bytes.Repeat([]byte{42}, 16),
|
||||
}
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
sTransportParameters,
|
||||
func(p *TransportParameters) { cTransportParametersRcvd = p },
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
testdata.GetTLSConfig(),
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
|
@ -251,15 +76,253 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan)
|
||||
Expect(clientErr).ToNot(HaveOccurred())
|
||||
Expect(serverErr).ToNot(HaveOccurred())
|
||||
err := server.RunHandshake()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("received unexpected handshake message"))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
||||
server.HandleData(fakeCH, protocol.EncryptionInitial)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(cTransportParametersRcvd).ToNot(BeNil())
|
||||
Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout))
|
||||
Expect(sTransportParametersRcvd).ToNot(BeNil())
|
||||
Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout))
|
||||
})
|
||||
|
||||
It("returns Handshake() when handling a message fails", func() {
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
testdata.GetTLSConfig(),
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := server.RunHandshake()
|
||||
Expect(err).To(MatchError("expected handshake message ClientHello to have encryption level Initial, has Handshake"))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
||||
server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("doing the handshake", func() {
|
||||
generateCert := func() tls.Certificate {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{},
|
||||
SignatureAlgorithm: x509.SHA256WithRSA,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return tls.Certificate{
|
||||
PrivateKey: priv,
|
||||
Certificate: [][]byte{certDER},
|
||||
}
|
||||
}
|
||||
|
||||
handshake := func(
|
||||
client CryptoSetupTLS,
|
||||
cChunkChan <-chan chunk,
|
||||
server CryptoSetupTLS,
|
||||
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
select {
|
||||
case c := <-cChunkChan:
|
||||
server.HandleData(c.data, c.encLevel)
|
||||
case c := <-sChunkChan:
|
||||
client.HandleData(c.data, c.encLevel)
|
||||
case <-done: // handshake complete
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
serverErrChan := make(chan error)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serverErrChan <- server.RunHandshake()
|
||||
}()
|
||||
|
||||
clientErr := client.RunHandshake()
|
||||
var serverErr error
|
||||
Eventually(serverErrChan).Should(Receive(&serverErr))
|
||||
return clientErr, serverErr
|
||||
}
|
||||
|
||||
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
client, _, err := NewCryptoSetupTLSClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
clientConf,
|
||||
protocol.VersionTLS,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
&TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
serverConf,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return handshake(client, cChunkChan, server, sChunkChan)
|
||||
}
|
||||
|
||||
It("handshakes", func() {
|
||||
clientConf := &tls.Config{ServerName: "quic.clemente.io"}
|
||||
serverConf := testdata.GetTLSConfig()
|
||||
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
||||
Expect(clientErr).ToNot(HaveOccurred())
|
||||
Expect(serverErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("handshakes with client auth", func() {
|
||||
clientConf := &tls.Config{
|
||||
ServerName: "quic.clemente.io",
|
||||
Certificates: []tls.Certificate{generateCert()},
|
||||
}
|
||||
serverConf := testdata.GetTLSConfig()
|
||||
serverConf.ClientAuth = qtls.RequireAnyClientCert
|
||||
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
||||
Expect(clientErr).ToNot(HaveOccurred())
|
||||
Expect(serverErr).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("signals when it has written the ClientHello", func() {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
client, chChan, err := NewCryptoSetupTLSClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
protocol.VersionTLS,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
client.RunHandshake()
|
||||
close(done)
|
||||
}()
|
||||
var ch chunk
|
||||
Eventually(cChunkChan).Should(Receive(&ch))
|
||||
Eventually(chChan).Should(BeClosed())
|
||||
// make sure the whole ClientHello was written
|
||||
Expect(len(ch.data)).To(BeNumerically(">=", 4))
|
||||
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
|
||||
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
|
||||
Expect(len(ch.data) - 4).To(Equal(length))
|
||||
|
||||
// make the go routine return
|
||||
client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("receives transport parameters", func() {
|
||||
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
|
||||
client, _, err := NewCryptoSetupTLSClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
cTransportParameters,
|
||||
func(p *TransportParameters) { sTransportParametersRcvd = p },
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
&tls.Config{ServerName: "quic.clemente.io"},
|
||||
protocol.VersionTLS,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sTransportParameters := &TransportParameters{
|
||||
IdleTimeout: 0x1337 * time.Second,
|
||||
StatelessResetToken: bytes.Repeat([]byte{42}, 16),
|
||||
}
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
sTransportParameters,
|
||||
func(p *TransportParameters) { cTransportParametersRcvd = p },
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
testdata.GetTLSConfig(),
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan)
|
||||
Expect(clientErr).ToNot(HaveOccurred())
|
||||
Expect(serverErr).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(cTransportParametersRcvd).ToNot(BeNil())
|
||||
Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout))
|
||||
Expect(sTransportParametersRcvd).ToNot(BeNil())
|
||||
Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue