update mint, and adapt to the new use of the extension handler

This commit is contained in:
Marten Seemann 2018-02-01 10:57:02 +08:00
parent 65eaf8131d
commit 1cc209e4fb
17 changed files with 362 additions and 340 deletions

View file

@ -166,13 +166,6 @@ func (c *client) dialGQUIC() error {
}
func (c *client) dialTLS() error {
csc := handshake.NewCryptoStreamConn(nil)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
mintConf.ServerName = c.hostname
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
@ -180,11 +173,17 @@ func (c *client) dialTLS() error {
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
}
eh := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
if err := c.tls.SetExtensionHandler(eh); err != nil {
csc := handshake.NewCryptoStreamConn(nil)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
mintConf.ExtensionHandler = extHandler
mintConf.ServerName = c.hostname
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
go c.listen()
@ -193,7 +192,7 @@ func (c *client) dialTLS() error {
return err
}
utils.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
if err := c.establishSecureConnection(); err != nil {

View file

@ -33,7 +33,6 @@ type MintTLS interface {
ConnectionState() mint.ConnectionState
SetCryptoStream(io.ReadWriter)
SetExtensionHandler(mint.AppExtensionHandler) error
}
// CryptoSetup is a crypto setup

View file

@ -94,18 +94,6 @@ func (mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0)
}
// SetExtensionHandler mocks base method
func (m *MockMintTLS) SetExtensionHandler(arg0 mint.AppExtensionHandler) error {
ret := m.ctrl.Call(m, "SetExtensionHandler", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetExtensionHandler indicates an expected call of SetExtensionHandler
func (mr *MockMintTLSMockRecorder) SetExtensionHandler(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetExtensionHandler", reflect.TypeOf((*MockMintTLS)(nil).SetExtensionHandler), arg0)
}
// State mocks base method
func (m *MockMintTLS) State() mint.State {
ret := m.ctrl.Call(m, "State")

View file

@ -43,7 +43,7 @@ func newMintController(
}
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
return mc.conn.State().CipherSuite
return mc.conn.ConnectionState().CipherSuite
}
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
@ -55,21 +55,17 @@ func (mc *mintController) Handshake() mint.Alert {
}
func (mc *mintController) State() mint.State {
return mc.conn.State().HandshakeState
return mc.conn.ConnectionState().HandshakeState
}
func (mc *mintController) ConnectionState() mint.ConnectionState {
return mc.conn.State()
return mc.conn.ConnectionState()
}
func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
mc.csc.SetStream(stream)
}
func (mc *mintController) SetExtensionHandler(h mint.AppExtensionHandler) error {
return mc.conn.SetExtensionHandler(h)
}
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
mconf := &mint.Config{
NonBlocking: true,

View file

@ -89,14 +89,10 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []
// will be set to s.newMintConn by the constructor
func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
conn := mint.Server(bc, s.mintConf)
extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v)
if err := conn.SetExtensionHandler(extHandler); err != nil {
return nil, nil, err
}
tls := newMintController(bc, s.mintConf, protocol.PerspectiveServer)
tls.SetExtensionHandler(extHandler)
return tls, extHandler.GetPeerParams(), nil
conf := s.mintConf.Clone()
conf.ExtensionHandler = extHandler
return newMintController(bc, conf, protocol.PerspectiveServer), extHandler.GetPeerParams(), nil
}
func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error {

View file

@ -3,6 +3,7 @@ package mint
import (
"bytes"
"crypto"
"crypto/x509"
"hash"
"time"
)
@ -50,7 +51,7 @@ import (
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
type ClientStateStart struct {
Caps Capabilities
Config *Config
Opts ConnectionOptions
Params ConnectionParameters
@ -71,9 +72,9 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
offeredDH := map[NamedGroup][]byte{}
ks := KeyShareExtension{
HandshakeType: HandshakeTypeClientHello,
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
Shares: make([]KeyShareEntry, len(state.Config.Groups)),
}
for i, group := range state.Caps.Groups {
for i, group := range state.Config.Groups {
pub, priv, err := newKeyShare(group)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
@ -90,8 +91,8 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
// supported_versions, supported_groups, signature_algorithms, server_name
sv := SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello, Versions: []uint16{supportedVersion}}
sni := ServerNameExtension(state.Opts.ServerName)
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
sg := SupportedGroupsExtension{Groups: state.Config.Groups}
sa := SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes}
state.Params.ServerName = state.Opts.ServerName
@ -103,7 +104,8 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
// Construct base ClientHello
ch := &ClientHelloBody{
CipherSuites: state.Caps.CipherSuites,
LegacyVersion: wireVersion(state.hsCtx.hIn),
CipherSuites: state.Config.CipherSuites,
}
_, err := prng.Read(ch.Random[:])
if err != nil {
@ -135,8 +137,8 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
if state.Config.ExtensionHandler != nil {
err := state.Config.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
@ -152,7 +154,7 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
var earlySecret []byte
var clientEarlyTrafficKeys keySet
var clientHello *HandshakeMessage
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
if key, ok := state.Config.PSKs.Get(state.Opts.ServerName); ok {
offeredPSK = key
// Narrow ciphersuites to ones that match PSK hash
@ -182,11 +184,11 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}
// Signal supported PSK key exchange modes
if len(state.Caps.PSKModes) == 0 {
if len(state.Config.PSKModes) == 0 {
logf(logTypeHandshake, "PSK selected, but no PSKModes")
return nil, nil, AlertInternalError
}
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
kem := &PSKKeyExchangeModesExtension{KEModes: state.Config.PSKModes}
err = ch.Extensions.Add(kem)
if err != nil {
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
@ -267,7 +269,7 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2.
nextState := ClientStateWaitSH{
Caps: state.Caps,
Config: state.Config,
Opts: state.Opts,
Params: state.Params,
hsCtx: state.hsCtx,
@ -297,7 +299,7 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}
type ClientStateWaitSH struct {
Caps Capabilities
Config *Config
Opts ConnectionOptions
Params ConnectionParameters
hsCtx HandshakeContext
@ -360,7 +362,7 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
}
// 3. Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
for _, suite := range state.Config.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
}
if !supportedCipherSuite {
@ -375,11 +377,11 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
hrr := sh
// Narrow the supported ciphersuites to the server-provided one
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
state.Config.CipherSuites = []CipherSuite{hrr.CipherSuite}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if state.Config.ExtensionHandler != nil {
err := state.Config.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
@ -412,7 +414,7 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
return ClientStateStart{
Caps: state.Caps,
Config: state.Config,
Opts: state.Opts,
hsCtx: state.hsCtx,
cookie: serverCookie.Cookie,
@ -423,8 +425,8 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
// This is SH.
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
if state.Config.ExtensionHandler != nil {
err := state.Config.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
@ -516,12 +518,11 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
nextState := ClientStateWaitEE{
Caps: state.Caps,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: params,
handshakeHash: handshakeHash,
certificates: state.Caps.Certificates,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
@ -533,13 +534,11 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
}
type ClientStateWaitEE struct {
Caps Capabilities
AuthCertificate func(chain []CertificateEntry) error
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
@ -568,8 +567,8 @@ func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if state.Config.ExtensionHandler != nil {
err := state.Config.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
@ -604,7 +603,7 @@ func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
certificates: state.Config.Certificates,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
@ -614,12 +613,11 @@ func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
nextState := ClientStateWaitCertCR{
AuthCertificate: state.AuthCertificate,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
@ -628,12 +626,11 @@ func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
}
type ClientStateWaitCertCR struct {
AuthCertificate func(chain []CertificateEntry) error
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
@ -667,12 +664,11 @@ func (state ClientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeSta
case *CertificateBody:
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{
AuthCertificate: state.AuthCertificate,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: body,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
@ -691,12 +687,11 @@ func (state ClientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeSta
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
nextState := ClientStateWaitCert{
AuthCertificate: state.AuthCertificate,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificateRequest: body,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
@ -709,13 +704,12 @@ func (state ClientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeSta
}
type ClientStateWaitCert struct {
AuthCertificate func(chain []CertificateEntry) error
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
@ -749,12 +743,11 @@ func (state ClientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{
AuthCertificate: state.AuthCertificate,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: cert,
serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret,
@ -765,13 +758,12 @@ func (state ClientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
}
type ClientStateWaitCV struct {
AuthCertificate func(chain []CertificateEntry) error
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificate *CertificateBody
serverCertificateRequest *CertificateRequestBody
@ -811,14 +803,41 @@ func (state ClientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
return nil, nil, AlertHandshakeFailure
}
if state.AuthCertificate != nil {
err := state.AuthCertificate(state.serverCertificate.CertificateList)
certs := make([]*x509.Certificate, len(state.serverCertificate.CertificateList))
rawCerts := make([][]byte, len(state.serverCertificate.CertificateList))
for i, certEntry := range state.serverCertificate.CertificateList {
certs[i] = certEntry.CertData
rawCerts[i] = certEntry.CertData.Raw
}
var verifiedChains [][]*x509.Certificate
if !state.Config.InsecureSkipVerify {
opts := x509.VerifyOptions{
Roots: state.Config.RootCAs,
CurrentTime: state.Config.time(),
DNSName: state.Config.ServerName,
Intermediates: x509.NewCertPool(),
}
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
var err error
verifiedChains, err = certs[0].Verify(opts)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err)
return nil, nil, AlertBadCertificate
}
}
if state.Config.VerifyPeerCertificate != nil {
if err := state.Config.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate: %s", err)
return nil, nil, AlertBadCertificate
}
} else {
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
}
state.handshakeHash.Write(hm.Marshal())
@ -829,11 +848,13 @@ func (state ClientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
certificates: state.Config.Certificates,
serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
peerCertificates: certs,
verifiedChains: verifiedChains,
}
return nextState, nil, AlertNoAlert
}
@ -846,6 +867,8 @@ type ClientStateWaitFinished struct {
certificates []*Certificate
serverCertificateRequest *CertificateRequestBody
peerCertificates []*x509.Certificate
verifiedChains [][]*x509.Certificate
masterSecret []byte
clientHandshakeTrafficSecret []byte
@ -1032,6 +1055,8 @@ func (state ClientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
peerCertificates: state.peerCertificates,
verifiedChains: state.verifiedChains,
}
return nextState, toSend, AlertNoAlert
}

View file

@ -9,6 +9,7 @@ const (
supportedVersion uint16 = 0x7f16 // draft-22
tls12Version uint16 = 0x0303
tls10Version uint16 = 0x0301
dtls12WireVersion uint16 = 0xfefd
)
var (

View file

@ -89,11 +89,39 @@ type Config struct {
// If non-blocking mode is used, and cookies are required, this field has to be set.
// In blocking mode, a default cookie protector is used, if this is unused.
CookieProtector CookieProtector
// The ExtensionHandler is used to add custom extensions.
ExtensionHandler AppExtensionHandler
RequireClientAuth bool
// Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses time.Now.
Time func() time.Time
// RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
RootCAs *x509.CertPool
// InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name.
// If InsecureSkipVerify is true, TLS accepts any certificate
// presented by the server and any host name in that certificate.
// In this mode, TLS is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
InsecureSkipVerify bool
// Shared fields
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
// VerifyPeerCertificate, if not nil, is called after normal
// certificate verification by either a TLS client or server. It
// receives the raw ASN.1 certificates provided by the peer and also
// any verified chains that normal processing found. If it returns a
// non-nil error, the handshake is aborted and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. If normal verification is disabled by
// setting InsecureSkipVerify then this callback will be considered but
// the verifiedChains argument will always be nil.
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
@ -125,10 +153,14 @@ func (c *Config) Clone() *Config {
RequireCookie: c.RequireCookie,
CookieHandler: c.CookieHandler,
CookieProtector: c.CookieProtector,
ExtensionHandler: c.ExtensionHandler,
RequireClientAuth: c.RequireClientAuth,
Time: c.Time,
RootCAs: c.RootCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
Certificates: c.Certificates,
AuthCertificate: c.AuthCertificate,
VerifyPeerCertificate: c.VerifyPeerCertificate,
CipherSuites: c.CipherSuites,
Groups: c.Groups,
SignatureSchemes: c.SignatureSchemes,
@ -163,28 +195,6 @@ func (c *Config) Init(isClient bool) error {
if len(c.PSKModes) == 0 {
c.PSKModes = defaultPSKModes
}
// If there is no certificate, generate one
if !isClient && len(c.Certificates) == 0 {
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
priv, err := newSigningKey(RSA_PSS_SHA256)
if err != nil {
return err
}
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
if err != nil {
return err
}
c.Certificates = []*Certificate{
{
Chain: []*x509.Certificate{cert},
PrivateKey: priv,
},
}
}
return nil
}
@ -199,6 +209,14 @@ func (c *Config) ValidForClient() bool {
return len(c.ServerName) > 0
}
func (c *Config) time() time.Time {
t := c.Time
if t == nil {
t = time.Now
}
return t()
}
var (
defaultSupportedCipherSuites = []CipherSuite{
TLS_AES_128_GCM_SHA256,
@ -232,7 +250,8 @@ var (
type ConnectionState struct {
HandshakeState State
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
NextProto string // Selected ALPN proto
}
@ -255,8 +274,6 @@ type Conn struct {
readBuffer []byte
in, out *RecordLayer
hsCtx HandshakeContext
extHandler AppExtensionHandler
}
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
@ -637,22 +654,6 @@ func (c *Conn) HandshakeSetup() Alert {
return AlertInternalError
}
// Set things up
caps := Capabilities{
CipherSuites: c.config.CipherSuites,
Groups: c.config.Groups,
SignatureSchemes: c.config.SignatureSchemes,
PSKs: c.config.PSKs,
PSKModes: c.config.PSKModes,
AllowEarlyData: c.config.AllowEarlyData,
RequireCookie: c.config.RequireCookie,
CookieProtector: c.config.CookieProtector,
CookieHandler: c.config.CookieHandler,
RequireClientAuth: c.config.RequireClientAuth,
NextProtos: c.config.NextProtos,
Certificates: c.config.Certificates,
ExtensionHandler: c.extHandler,
}
opts := ConnectionOptions{
ServerName: c.config.ServerName,
NextProtos: c.config.NextProtos,
@ -660,7 +661,7 @@ func (c *Conn) HandshakeSetup() Alert {
}
if c.isClient {
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts, hsCtx: c.hsCtx}.Next(nil)
state, actions, alert = ClientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error initializing client state: %v", alert)
return alert
@ -681,13 +682,13 @@ func (c *Conn) HandshakeSetup() Alert {
return AlertInternalError
}
var err error
caps.CookieProtector, err = NewDefaultCookieProtector()
c.config.CookieProtector, err = NewDefaultCookieProtector()
if err != nil {
logf(logTypeHandshake, "Error initializing cookie source: %v", alert)
return AlertInternalError
}
}
state = ServerStateStart{Caps: caps, conn: c, hsCtx: c.hsCtx}
state = ServerStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx}
}
c.hState = state
@ -867,7 +868,7 @@ func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]b
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
}
func (c *Conn) State() ConnectionState {
func (c *Conn) ConnectionState() ConnectionState {
state := ConnectionState{
HandshakeState: c.GetHsState(),
}
@ -875,16 +876,9 @@ func (c *Conn) State() ConnectionState {
if c.handshakeComplete {
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
state.NextProto = c.state.Params.NextProto
state.VerifiedChains = c.state.verifiedChains
state.PeerCertificates = c.state.peerCertificates
}
return state
}
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
if c.hState != nil {
return fmt.Errorf("Can't set extension handler after setup")
}
c.extHandler = h
return nil
}

View file

@ -11,11 +11,9 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math/big"
"time"
"golang.org/x/crypto/curve25519"
@ -331,40 +329,6 @@ func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
}
}
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
sigAlg, ok := x509AlgMap[alg]
if !ok {
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
}
if len(name) == 0 {
return nil, fmt.Errorf("tls.selfsigned: No name provided")
}
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
if err != nil {
return nil, err
}
template := &x509.Certificate{
SerialNumber: serial,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
SignatureAlgorithm: sigAlg,
Subject: pkix.Name{CommonName: name},
DNSNames: []string{name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
if err != nil {
return nil, err
}
// It is safe to ignore the error here because we're parsing known-good data
cert, _ := x509.ParseCertificate(der)
return cert, nil
}
// XXX(rlb): Copied from crypto/x509
type ecdsaSignature struct {
R, S *big.Int

View file

@ -1,7 +1,28 @@
package mint
import (
"fmt"
)
// This file is a placeholder. DTLS-specific stuff (timer management,
// ACKs, retransmits, etc. will eventually go here.
const (
initialMtu = 1200
)
func wireVersion(h *HandshakeLayer) uint16 {
if h.datagram {
return dtls12WireVersion
}
return tls12Version
}
func dtlsConvertVersion(version uint16) uint16 {
if version == tls12Version {
return dtls12WireVersion
}
if version == tls10Version {
return 0xfeff
}
panic(fmt.Sprintf("Internal error, unexpected version=%d", version))
}

View file

@ -77,8 +77,6 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
body = new(ClientHelloBody)
case HandshakeTypeServerHello:
body = new(ServerHelloBody)
case HandshakeTypeHelloRetryRequest:
body = new(HelloRetryRequestBody)
case HandshakeTypeEncryptedExtensions:
body = new(EncryptedExtensionsBody)
case HandshakeTypeCertificate:

View file

@ -25,14 +25,14 @@ type HandshakeMessageBody interface {
// Extension extensions<0..2^16-1>;
// } ClientHello;
type ClientHelloBody struct {
// Omitted: clientVersion
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte
CipherSuites []CipherSuite
Extensions ExtensionList
}
type clientHelloBodyInner struct {
type clientHelloBodyInnerTLS struct {
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
@ -41,41 +41,86 @@ type clientHelloBodyInner struct {
Extensions []Extension `tls:"head=2"`
}
type clientHelloBodyInnerDTLS struct {
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
EmptyCookie uint8
CipherSuites []CipherSuite `tls:"head=2,min=2"`
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
Extensions []Extension `tls:"head=2"`
}
func (ch ClientHelloBody) Type() HandshakeType {
return HandshakeTypeClientHello
}
func (ch ClientHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(clientHelloBodyInner{
LegacyVersion: tls12Version,
if ch.LegacyVersion == tls12Version {
return syntax.Marshal(clientHelloBodyInnerTLS{
LegacyVersion: ch.LegacyVersion,
Random: ch.Random,
LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites,
LegacyCompressionMethods: []byte{0},
Extensions: ch.Extensions,
})
} else {
return syntax.Marshal(clientHelloBodyInnerDTLS{
LegacyVersion: ch.LegacyVersion,
Random: ch.Random,
LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites,
LegacyCompressionMethods: []byte{0},
Extensions: ch.Extensions,
})
}
}
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
var inner clientHelloBodyInner
read, err := syntax.Unmarshal(data, &inner)
var read int
var err error
// Note that this might be 0, in which case we do TLS. That
// makes the tests easier.
if ch.LegacyVersion != dtls12WireVersion {
var inner clientHelloBodyInnerTLS
read, err = syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
// We are strict about these things because we only support 1.3
if inner.LegacyVersion != tls12Version {
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
}
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
}
ch.LegacyVersion = inner.LegacyVersion
ch.Random = inner.Random
ch.LegacySessionID = inner.LegacySessionID
ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions
} else {
var inner clientHelloBodyInnerDTLS
read, err = syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if inner.EmptyCookie != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid cookie")
}
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
}
ch.LegacyVersion = inner.LegacyVersion
ch.Random = inner.Random
ch.LegacySessionID = inner.LegacySessionID
ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions
}
return read, nil
}
@ -120,29 +165,6 @@ func (ch ClientHelloBody) Truncated() ([]byte, error) {
return chData[:chLen-binderLen], nil
}
// struct {
// ProtocolVersion server_version;
// CipherSuite cipher_suite;
// Extension extensions<2..2^16-1>;
// } HelloRetryRequest;
type HelloRetryRequestBody struct {
Version uint16
CipherSuite CipherSuite
Extensions ExtensionList `tls:"head=2,min=2"`
}
func (hrr HelloRetryRequestBody) Type() HandshakeType {
return HandshakeTypeHelloRetryRequest
}
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
return syntax.Marshal(hrr)
}
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, hrr)
}
// struct {
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
// Random random;

View file

@ -119,6 +119,15 @@ func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []b
return nil
}
func (c *cipherState) formatSeq(datagram bool) []byte {
seq := append([]byte{}, c.seq...)
if datagram {
seq[0] = byte(c.epoch >> 8)
seq[1] = byte(c.epoch & 0xff)
}
return seq
}
func (c *cipherState) computeNonce(seq []byte) []byte {
nonce := make([]byte, len(c.iv))
copy(nonce, c.iv)
@ -143,9 +152,9 @@ func (c *cipherState) incrementSequenceNumber() {
if i < 0 {
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bothec.
// Not likely enough to bother.
// TODO(ekr@rtfm.com): Check for DTLS here
// because the limit is soonec.
// because the limit is sooner.
panic("TLS: sequence number wraparound")
}
}
@ -157,7 +166,8 @@ func (c *cipherState) overhead() int {
return c.cipher.Overhead()
}
func (r *RecordLayer) encrypt(cipher *cipherState, pt *TLSPlaintext, padLen int) *TLSPlaintext {
func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext, padLen int) *TLSPlaintext {
logf(logTypeIO, "Encrypt seq=[%x]", seq)
// Expand the fragment to hold contentType, padding, and overhead
originalLen := len(pt.fragment)
plaintextLen := originalLen + 1 + padLen
@ -165,6 +175,7 @@ func (r *RecordLayer) encrypt(cipher *cipherState, pt *TLSPlaintext, padLen int)
// Assemble the revised plaintext
out := &TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: make([]byte, ciphertextLen),
}
@ -176,11 +187,12 @@ func (r *RecordLayer) encrypt(cipher *cipherState, pt *TLSPlaintext, padLen int)
// Encrypt the fragment
payload := out.fragment[:plaintextLen]
cipher.cipher.Seal(payload[:0], cipher.computeNonce(cipher.seq), payload, nil)
cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil)
return out
}
func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, error) {
logf(logTypeIO, "Decrypt seq=[%x]", seq)
if len(pt.fragment) < r.cipher.overhead() {
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
return nil, 0, DecryptError(msg)
@ -312,6 +324,8 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
if r.datagram {
seq = header[3:11]
}
// TODO(ekr@rtfm.com): Handle the wrong epoch.
// TODO(ekr@rtfm.com): Handle duplicates.
logf(logTypeIO, "RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), seq, pt.contentType, pt.fragment)
pt, _, err = r.decrypt(pt, seq)
if err != nil {
@ -341,9 +355,11 @@ func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error
}
func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
seq := cipher.formatSeq(r.datagram)
if cipher.cipher != nil {
logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
pt = r.encrypt(cipher, pt, padLen)
pt = r.encrypt(cipher, seq, pt, padLen)
} else if padLen > 0 {
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
}
@ -354,16 +370,17 @@ func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherSta
length := len(pt.fragment)
var header []byte
if !r.datagram {
header = []byte{byte(pt.contentType),
byte(r.version >> 8), byte(r.version & 0xff),
byte(length >> 8), byte(length)}
} else {
// TODO(ekr@rtfm.com): Double check version
seq := cipher.seq
header = []byte{byte(pt.contentType), 0xfe, 0xff,
0x00, 0x00, // TODO(ekr@rtfm.com): double-check epoch
seq[2], seq[3], seq[4], seq[5], seq[6], seq[7],
version := dtlsConvertVersion(r.version)
header = []byte{byte(pt.contentType),
byte(version >> 8), byte(version & 0xff),
seq[0], seq[1], seq[2], seq[3],
seq[4], seq[5], seq[6], seq[7],
byte(length >> 8), byte(length)}
}
record := append(header, pt.fragment...)

View file

@ -2,6 +2,7 @@ package mint
import (
"bytes"
"crypto/x509"
"fmt"
"hash"
"reflect"
@ -71,7 +72,7 @@ type cookie struct {
}
type ServerStateStart struct {
Caps Capabilities
Config *Config
conn *Conn
hsCtx HandshakeContext
}
@ -92,12 +93,18 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
return nil, nil, AlertUnexpectedMessage
}
ch := &ClientHelloBody{}
ch := &ClientHelloBody{LegacyVersion: wireVersion(state.hsCtx.hIn)}
if err := safeUnmarshal(ch, hm.body); err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
// We are strict about these things because we only support 1.3
if ch.LegacyVersion != wireVersion(state.hsCtx.hIn) {
logf(logTypeHandshake, "[ServerStateStart] Invalid version number: %v", ch.LegacyVersion)
return nil, nil, AlertDecodeError
}
clientHello := hm
connParams := ConnectionParameters{}
@ -113,8 +120,8 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
clientCookie := new(CookieExtension)
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
if state.Config.ExtensionHandler != nil {
err := state.Config.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
@ -162,7 +169,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
var firstClientHello *HandshakeMessage
var initialCipherSuite CipherSuiteParams // the cipher suite that was negotiated when sending the HelloRetryRequest
if clientSentCookie {
plainCookie, err := state.Caps.CookieProtector.DecodeToken(clientCookie.Cookie)
plainCookie, err := state.Config.CookieProtector.DecodeToken(clientCookie.Cookie)
if err != nil {
logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error decoding token [%v]", err))
return nil, nil, AlertDecryptError
@ -178,7 +185,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
body: cookie.ClientHelloHash,
}
// have the application validate its part of the cookie
if state.Caps.CookieHandler != nil && !state.Caps.CookieHandler.Validate(state.conn, cookie.ApplicationCookie) {
if state.Config.CookieHandler != nil && !state.Config.CookieHandler.Validate(state.conn, cookie.ApplicationCookie) {
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch")
return nil, nil, AlertAccessDenied
}
@ -196,7 +203,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}
// Figure out if we can do DH
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Config.Groups)
// Figure out if we can do PSK
var canDoPSK bool
@ -223,7 +230,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}
context := append(contextBase, chTrunc...)
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs)
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Config.PSKs)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
return nil, nil, AlertInternalError
@ -238,7 +245,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
// Select a ciphersuite
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Config.CipherSuites)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
return nil, nil, AlertHandshakeFailure
@ -249,7 +256,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}
var helloRetryRequest *HandshakeMessage
if state.Caps.RequireCookie {
if state.Config.RequireCookie {
// Send a cookie if required
// NB: Need to do this here because it's after ciphersuite selection, which
// has to be after PSK selection.
@ -257,11 +264,11 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
var cookieExt *CookieExtension
if !clientSentCookie { // this is the first ClientHello that we receive
var appCookie []byte
if state.Caps.CookieHandler == nil { // if Config.RequireCookie is set, but no CookieHandler was provided, we definitely need to send a cookie
if state.Config.CookieHandler == nil { // if Config.RequireCookie is set, but no CookieHandler was provided, we definitely need to send a cookie
shouldSendHRR = true
} else { // if the CookieHandler was set, we just send a cookie when the application provides one
var err error
appCookie, err = state.Caps.CookieHandler.Generate(state.conn)
appCookie, err = state.Config.CookieHandler.Generate(state.conn)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err)
return nil, nil, AlertInternalError
@ -281,7 +288,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
logf(logTypeHandshake, "[ServerStateStart] Error marshalling cookie [%v]", err)
return nil, nil, AlertInternalError
}
cookieData, err := state.Caps.CookieProtector.NewToken(plainCookie)
cookieData, err := state.Config.CookieProtector.NewToken(plainCookie)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error encoding cookie [%v]", err)
return nil, nil, AlertInternalError
@ -340,7 +347,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
// Select a certificate
name := string(*serverName)
var err error
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates)
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Config.Certificates)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err)
return nil, nil, AlertAccessDenied
@ -354,7 +361,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
// Figure out if we're going to do early data
var clientEarlyTrafficSecret []byte
connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData]
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Caps.AllowEarlyData)
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData)
if connParams.UsingEarlyData {
h := params.Hash.New()
h.Write(clientHello.Marshal())
@ -366,7 +373,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}
// Select a next protocol
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos)
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Config.NextProtos)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err)
return nil, nil, AlertNoApplicationProtocol
@ -375,7 +382,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2.
return ServerStateNegotiated{
Caps: state.Caps,
Config: state.Config,
Params: connParams,
hsCtx: state.hsCtx,
dhGroup: dhGroup,
@ -420,8 +427,8 @@ func (state *ServerStateStart) generateHRR(cs CipherSuite, legacySessionId []byt
return nil, err
}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if state.Config.ExtensionHandler != nil {
err := state.Config.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err)
return nil, err
@ -436,7 +443,7 @@ func (state *ServerStateStart) generateHRR(cs CipherSuite, legacySessionId []byt
}
type ServerStateNegotiated struct {
Caps Capabilities
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
dhGroup NamedGroup
@ -504,8 +511,8 @@ func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat
}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
if state.Config.ExtensionHandler != nil {
err := state.Config.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
@ -585,8 +592,8 @@ func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat
ee := &EncryptedExtensionsBody{eeList}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if state.Config.ExtensionHandler != nil {
err := state.Config.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
@ -610,13 +617,13 @@ func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat
// Authenticate with a certificate if required
if !state.Params.UsingPSK {
// Send a CertificateRequest message if we want client auth
if state.Caps.RequireClientAuth {
if state.Config.RequireClientAuth {
state.Params.UsingClientAuth = true
// XXX: We don't support sending any constraints besides a list of
// supported signature algorithms
cr := &CertificateRequestBody{}
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes}
err := cr.Extensions.Add(schemes)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err)
@ -711,7 +718,7 @@ func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]")
nextState := ServerStateWaitEOED{
AuthCertificate: state.Caps.AuthCertificate,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: params,
@ -735,7 +742,7 @@ func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat
ReadPastEarlyData{},
}...)
waitFlight2 := ServerStateWaitFlight2{
AuthCertificate: state.Caps.AuthCertificate,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: params,
@ -750,7 +757,7 @@ func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat
}
type ServerStateWaitEOED struct {
AuthCertificate func(chain []CertificateEntry) error
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
cryptoParams CipherSuiteParams
@ -792,7 +799,7 @@ func (state ServerStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState
RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys},
}
waitFlight2 := ServerStateWaitFlight2{
AuthCertificate: state.AuthCertificate,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams,
@ -807,7 +814,7 @@ func (state ServerStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState
}
type ServerStateWaitFlight2 struct {
AuthCertificate func(chain []CertificateEntry) error
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
cryptoParams CipherSuiteParams
@ -829,7 +836,7 @@ func (state ServerStateWaitFlight2) Next(_ handshakeMessageReader) (HandshakeSta
if state.Params.UsingClientAuth {
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]")
nextState := ServerStateWaitCert{
AuthCertificate: state.AuthCertificate,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams,
@ -859,7 +866,7 @@ func (state ServerStateWaitFlight2) Next(_ handshakeMessageReader) (HandshakeSta
}
type ServerStateWaitCert struct {
AuthCertificate func(chain []CertificateEntry) error
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
cryptoParams CipherSuiteParams
@ -915,7 +922,7 @@ func (state ServerStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]")
nextState := ServerStateWaitCV{
AuthCertificate: state.AuthCertificate,
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams,
@ -931,7 +938,7 @@ func (state ServerStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
}
type ServerStateWaitCV struct {
AuthCertificate func(chain []CertificateEntry) error
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
cryptoParams CipherSuiteParams
@ -969,6 +976,13 @@ func (state ServerStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
return nil, nil, AlertDecodeError
}
rawCerts := make([][]byte, len(state.clientCertificate.CertificateList))
certs := make([]*x509.Certificate, len(state.clientCertificate.CertificateList))
for i, certEntry := range state.clientCertificate.CertificateList {
certs[i] = certEntry.CertData
rawCerts[i] = certEntry.CertData.Raw
}
// Verify client signature over handshake hash
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
@ -979,14 +993,12 @@ func (state ServerStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
return nil, nil, AlertHandshakeFailure
}
if state.AuthCertificate != nil {
err := state.AuthCertificate(state.clientCertificate.CertificateList)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate")
if state.Config.VerifyPeerCertificate != nil {
// TODO(#171): pass in the verified chains, once we support different client auth types
if err := state.Config.VerifyPeerCertificate(rawCerts, nil); err != nil {
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate: %s", err)
return nil, nil, AlertBadCertificate
}
} else {
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate")
}
// If it passes, record the certificateVerify in the transcript hash
@ -1003,6 +1015,8 @@ func (state ServerStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
peerCertificates: certs,
verifiedChains: nil, // TODO(#171): set this value
}
return nextState, nil, AlertNoAlert
}
@ -1014,6 +1028,8 @@ type ServerStateWaitFinished struct {
masterSecret []byte
clientHandshakeTrafficSecret []byte
peerCertificates []*x509.Certificate
verifiedChains [][]*x509.Certificate
handshakeHash hash.Hash
clientTrafficSecret []byte
@ -1076,6 +1092,8 @@ func (state ServerStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
peerCertificates: state.peerCertificates,
verifiedChains: state.verifiedChains,
}
toSend := []HandshakeAction{
RekeyIn{epoch: EpochApplicationData, KeySet: clientTrafficKeys},

View file

@ -1,6 +1,7 @@
package mint
import (
"crypto/x509"
"time"
)
@ -44,30 +45,6 @@ type AppExtensionHandler interface {
Receive(hs HandshakeType, el *ExtensionList) error
}
// Capabilities objects represent the capabilities of a TLS client or server,
// as an input to TLS negotiation
type Capabilities struct {
// For both client and server
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
PSKs PreSharedKeyCache
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
ExtensionHandler AppExtensionHandler
UseDTLS bool
// For client
PSKModes []PSKKeyExchangeMode
// For server
NextProtos []string
AllowEarlyData bool
RequireCookie bool
CookieProtector CookieProtector
CookieHandler CookieHandler
RequireClientAuth bool
}
// ConnectionOptions objects represent per-connection settings for a client
// initiating a connection
type ConnectionOptions struct {
@ -114,6 +91,8 @@ type StateConnected struct {
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
peerCertificates []*x509.Certificate
verifiedChains [][]*x509.Certificate
}
var _ HandshakeState = &StateConnected{}

View file

@ -93,6 +93,7 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
if config != nil && config.NonBlocking {
return nil, errors.New("dialing not possible in non-blocking mode")
}
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
@ -127,16 +128,20 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
if config == nil {
config = &Config{}
} else {
config = config.Clone()
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := config.Clone()
c.ServerName = hostname
config = c
config.ServerName = hostname
}
// Set up DTLS as needed.
config.UseDTLS = (network == "udp")
conn := Client(rawConn, config)
if timeout == 0 {

6
vendor/vendor.json vendored
View file

@ -3,10 +3,10 @@
"ignore": "test",
"package": [
{
"checksumSHA1": "nxj6lkDUEZ81SO0lP8YUhm+4BAM=",
"checksumSHA1": "8nsuzBKJY9oVNM/NPU0kbY/YpDw=",
"path": "github.com/bifurcation/mint",
"revision": "f699e8d03646cb8e6e15410ced7bff37fcf8dddd",
"revisionTime": "2017-12-21T19:05:27Z"
"revision": "a00905133cda39e4ac20677ae9332dce13e785eb",
"revisionTime": "2018-02-01T00:42:34Z"
},
{
"checksumSHA1": "PZNcjO1c9gV/LZzppwpVRl6+QAY=",