mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
523 lines
16 KiB
Go
523 lines
16 KiB
Go
package handshake
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"errors"
|
|
"io/ioutil"
|
|
"math/big"
|
|
"time"
|
|
|
|
gomock "github.com/golang/mock/gomock"
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
"github.com/marten-seemann/qtls"
|
|
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
type chunk struct {
|
|
data []byte
|
|
encLevel protocol.EncryptionLevel
|
|
}
|
|
|
|
type stream struct {
|
|
encLevel protocol.EncryptionLevel
|
|
chunkChan chan<- chunk
|
|
}
|
|
|
|
func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream {
|
|
return &stream{
|
|
chunkChan: chunkChan,
|
|
encLevel: encLevel,
|
|
}
|
|
}
|
|
|
|
func (s *stream) Write(b []byte) (int, error) {
|
|
data := make([]byte, len(b))
|
|
copy(data, b)
|
|
select {
|
|
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
|
|
default:
|
|
panic("chunkChan too small")
|
|
}
|
|
return len(b), nil
|
|
}
|
|
|
|
var _ = Describe("Crypto Setup TLS", func() {
|
|
var clientConf *tls.Config
|
|
|
|
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
|
chunkChan := make(chan chunk, 100)
|
|
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
|
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
|
|
return chunkChan, initialStream, handshakeStream
|
|
}
|
|
|
|
BeforeEach(func() {
|
|
clientConf = &tls.Config{
|
|
ServerName: "localhost",
|
|
RootCAs: testdata.GetRootCA(),
|
|
}
|
|
})
|
|
|
|
It("creates a qtls.Config", func() {
|
|
tlsConf := &tls.Config{
|
|
ServerName: "quic.clemente.io",
|
|
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return nil, errors.New("GetCertificate")
|
|
},
|
|
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
return nil, errors.New("GetClientCertificate")
|
|
},
|
|
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return &tls.Config{ServerName: ch.ServerName}, nil
|
|
},
|
|
}
|
|
server, err := NewCryptoSetupServer(
|
|
&bytes.Buffer{},
|
|
&bytes.Buffer{},
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
NewMockHandshakeRunner(mockCtrl),
|
|
tlsConf,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
qtlsConf := server.(*cryptoSetup).tlsConf
|
|
Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName))
|
|
_, getCertificateErr := qtlsConf.GetCertificate(nil)
|
|
Expect(getCertificateErr).To(MatchError("GetCertificate"))
|
|
_, getClientCertificateErr := qtlsConf.GetClientCertificate(nil)
|
|
Expect(getClientCertificateErr).To(MatchError("GetClientCertificate"))
|
|
cconf, err := qtlsConf.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "foo.bar"})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(cconf.ServerName).To(Equal("foo.bar"))
|
|
Expect(cconf.AlternativeRecordLayer).ToNot(BeNil())
|
|
Expect(cconf.GetExtensions).ToNot(BeNil())
|
|
Expect(cconf.ReceivedExtensions).ToNot(BeNil())
|
|
})
|
|
|
|
It("returns Handshake() when an error occurs", func() {
|
|
sErrChan := make(chan error, 1)
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
server, err := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
runner,
|
|
testdata.GetTLSConfig(),
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
Expect(sErrChan).To(Receive(MatchError("CRYPTO_ERROR: local error: tls: unexpected message")))
|
|
close(done)
|
|
}()
|
|
|
|
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
|
handledMessage := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.HandleMessage(fakeCH, protocol.EncryptionInitial)
|
|
close(handledMessage)
|
|
}()
|
|
Eventually(handledMessage).Should(BeClosed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("returns Handshake() when a message is received at the wrong encryption level", func() {
|
|
sErrChan := make(chan error, 1)
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
|
server, err := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
runner,
|
|
testdata.GetTLSConfig(),
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
var err error
|
|
Expect(sErrChan).To(Receive(&err))
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
|
|
Expect(err.Error()).To(ContainSubstring("expected handshake message ClientHello to have encryption level Initial, has Handshake"))
|
|
close(done)
|
|
}()
|
|
|
|
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
|
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("returns Handshake() when handling a message fails", func() {
|
|
sErrChan := make(chan error, 1)
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
|
server, err := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
runner,
|
|
testdata.GetTLSConfig(),
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
var err error
|
|
Expect(sErrChan).To(Receive(&err))
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
|
|
Expect(err.Error()).To(ContainSubstring("unexpected handshake message"))
|
|
close(done)
|
|
}()
|
|
|
|
fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...)
|
|
server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("returns Handshake() when it is closed", func() {
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
server, err := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
NewMockHandshakeRunner(mockCtrl),
|
|
testdata.GetTLSConfig(),
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
close(done)
|
|
}()
|
|
Expect(server.Close()).To(Succeed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
Context("doing the handshake", func() {
|
|
generateCert := func() tls.Certificate {
|
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{},
|
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
|
BasicConstraintsValid: true,
|
|
}
|
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
return tls.Certificate{
|
|
PrivateKey: priv,
|
|
Certificate: [][]byte{certDER},
|
|
}
|
|
}
|
|
|
|
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
|
|
server CryptoSetup, sChunkChan <-chan chunk) {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
for {
|
|
select {
|
|
case c := <-cChunkChan:
|
|
server.HandleMessage(c.data, c.encLevel)
|
|
case c := <-sChunkChan:
|
|
client.HandleMessage(c.data, c.encLevel)
|
|
case <-done: // handshake complete
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
close(done)
|
|
}()
|
|
|
|
client.RunHandshake()
|
|
Eventually(done).Should(BeClosed())
|
|
}
|
|
|
|
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
|
var cHandshakeComplete bool
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cErrChan := make(chan error, 1)
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
|
|
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
|
|
client, _, err := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
cRunner,
|
|
clientConf,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
var sHandshakeComplete bool
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
sErrChan := make(chan error, 1)
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
|
|
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
|
|
var token [16]byte
|
|
server, err := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{StatelessResetToken: &token},
|
|
sRunner,
|
|
serverConf,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
var cErr, sErr error
|
|
select {
|
|
case sErr = <-sErrChan:
|
|
default:
|
|
Expect(sHandshakeComplete).To(BeTrue())
|
|
}
|
|
select {
|
|
case cErr = <-cErrChan:
|
|
default:
|
|
Expect(cHandshakeComplete).To(BeTrue())
|
|
}
|
|
return cErr, sErr
|
|
}
|
|
|
|
It("handshakes", func() {
|
|
serverConf := testdata.GetTLSConfig()
|
|
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("performs a HelloRetryRequst", func() {
|
|
serverConf := testdata.GetTLSConfig()
|
|
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
|
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("handshakes with client auth", func() {
|
|
clientConf.Certificates = []tls.Certificate{generateCert()}
|
|
serverConf := testdata.GetTLSConfig()
|
|
serverConf.ClientAuth = qtls.RequireAnyClientCert
|
|
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("signals when it has written the ClientHello", func() {
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
client, chChan, err := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
runner,
|
|
&tls.Config{InsecureSkipVerify: true},
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
client.RunHandshake()
|
|
close(done)
|
|
}()
|
|
var ch chunk
|
|
Eventually(cChunkChan).Should(Receive(&ch))
|
|
Eventually(chChan).Should(BeClosed())
|
|
// make sure the whole ClientHello was written
|
|
Expect(len(ch.data)).To(BeNumerically(">=", 4))
|
|
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
|
|
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
|
|
Expect(len(ch.data) - 4).To(Equal(length))
|
|
|
|
// make the go routine return
|
|
runner.EXPECT().OnError(gomock.Any())
|
|
client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("receives transport parameters", func() {
|
|
var cTransportParametersRcvd, sTransportParametersRcvd []byte
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b })
|
|
cRunner.EXPECT().OnHandshakeComplete()
|
|
client, _, err := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
cTransportParameters,
|
|
cRunner,
|
|
clientConf,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
var token [16]byte
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b })
|
|
sRunner.EXPECT().OnHandshakeComplete()
|
|
sTransportParameters := &TransportParameters{
|
|
IdleTimeout: 0x1337 * time.Second,
|
|
StatelessResetToken: &token,
|
|
}
|
|
server, err := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
sTransportParameters,
|
|
sRunner,
|
|
testdata.GetTLSConfig(),
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
close(done)
|
|
}()
|
|
Eventually(done).Should(BeClosed())
|
|
Expect(cTransportParametersRcvd).ToNot(BeNil())
|
|
clTP := &TransportParameters{}
|
|
Expect(clTP.Unmarshal(cTransportParametersRcvd, protocol.PerspectiveClient)).To(Succeed())
|
|
Expect(clTP.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout))
|
|
Expect(sTransportParametersRcvd).ToNot(BeNil())
|
|
srvTP := &TransportParameters{}
|
|
Expect(srvTP.Unmarshal(sTransportParametersRcvd, protocol.PerspectiveServer)).To(Succeed())
|
|
Expect(srvTP.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout))
|
|
})
|
|
|
|
It("errors when handling the NewSessionTicket fails", func() {
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
cRunner.EXPECT().OnHandshakeComplete()
|
|
client, _, err := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
cRunner,
|
|
clientConf,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
sRunner.EXPECT().OnHandshakeComplete()
|
|
server, err := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
ioutil.Discard,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
sRunner,
|
|
testdata.GetTLSConfig(),
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
close(done)
|
|
}()
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
// inject an invalid session ticket
|
|
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
})
|
|
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
|
client.HandleMessage(b, protocol.Encryption1RTT)
|
|
})
|
|
})
|
|
})
|