update mint

mint recently changed the signature of ExtensionList.Find.
This commit is contained in:
Marten Seemann 2017-12-22 10:42:33 +07:00
parent a6eca911de
commit 14d4e5e9ec
13 changed files with 415 additions and 236 deletions

View file

@ -24,6 +24,7 @@ type extensionHandlerClient struct {
var _ mint.AppExtensionHandler = &extensionHandlerClient{} var _ mint.AppExtensionHandler = &extensionHandlerClient{}
var _ TLSExtensionHandler = &extensionHandlerClient{} var _ TLSExtensionHandler = &extensionHandlerClient{}
// NewExtensionHandlerClient creates a new extension handler for the client.
func NewExtensionHandlerClient( func NewExtensionHandlerClient(
params *TransportParameters, params *TransportParameters,
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
@ -57,7 +58,10 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{} ext := &tlsExtensionBody{}
found := el.Find(ext) found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket { if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket {
if found { if found {

View file

@ -39,7 +39,8 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(el).To(HaveLen(1)) Expect(el).To(HaveLen(1))
ext := &tlsExtensionBody{} ext := &tlsExtensionBody{}
found := el.Find(ext) found, err := el.Find(ext)
Expect(err).ToNot(HaveOccurred())
Expect(found).To(BeTrue()) Expect(found).To(BeTrue())
chtp := &clientHelloTransportParameters{} chtp := &clientHelloTransportParameters{}
_, err = syntax.Unmarshal(ext.data, chtp) _, err = syntax.Unmarshal(ext.data, chtp)

View file

@ -24,6 +24,7 @@ type extensionHandlerServer struct {
var _ mint.AppExtensionHandler = &extensionHandlerServer{} var _ mint.AppExtensionHandler = &extensionHandlerServer{}
var _ TLSExtensionHandler = &extensionHandlerServer{} var _ TLSExtensionHandler = &extensionHandlerServer{}
// NewExtensionHandlerServer creates a new extension handler for the server
func NewExtensionHandlerServer( func NewExtensionHandlerServer(
params *TransportParameters, params *TransportParameters,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
@ -66,7 +67,10 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{} ext := &tlsExtensionBody{}
found := el.Find(ext) found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeClientHello { if hType != mint.HandshakeTypeClientHello {
if found { if found {

View file

@ -48,7 +48,8 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(el).To(HaveLen(1)) Expect(el).To(HaveLen(1))
ext := &tlsExtensionBody{} ext := &tlsExtensionBody{}
found := el.Find(ext) found, err := el.Find(ext)
Expect(err).ToNot(HaveOccurred())
Expect(found).To(BeTrue()) Expect(found).To(BeTrue())
eetp := &encryptedExtensionsTransportParameters{} eetp := &encryptedExtensionsTransportParameters{}
_, err = syntax.Unmarshal(ext.data, eetp) _, err = syntax.Unmarshal(ext.data, eetp)

View file

@ -88,7 +88,7 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
logf(logTypeHandshake, "opts: %+v", state.Opts) logf(logTypeHandshake, "opts: %+v", state.Opts)
// supported_versions, supported_groups, signature_algorithms, server_name // supported_versions, supported_groups, signature_algorithms, server_name
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}} sv := SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello, Versions: []uint16{supportedVersion}}
sni := ServerNameExtension(state.Opts.ServerName) sni := ServerNameExtension(state.Opts.ServerName)
sg := SupportedGroupsExtension{Groups: state.Caps.Groups} sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
@ -265,6 +265,7 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
} }
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2.
nextState := ClientStateWaitSH{ nextState := ClientStateWaitSH{
Caps: state.Caps, Caps: state.Caps,
Opts: state.Opts, Opts: state.Opts,
@ -323,41 +324,55 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
if alert != AlertNoAlert { if alert != AlertNoAlert {
return nil, nil, alert return nil, nil, alert
} }
if hm == nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message") if hm == nil || hm.msgType != HandshakeTypeServerHello {
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message")
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
} }
bodyGeneric, err := hm.ToBody() sh := &ServerHelloBody{}
if err != nil { if _, err := sh.Unmarshal(hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitSH] unexpected message")
return nil, nil, AlertDecodeError return nil, nil, AlertUnexpectedMessage
} }
switch body := bodyGeneric.(type) { // Common SH/HRR processing first.
case *HelloRetryRequestBody: // 1. Check that sh.version is TLS 1.2
hrr := body if sh.Version != tls12Version {
logf(logTypeHandshake, "[ClientStateWaitSH] illegal legacy version [%v]", sh.Version)
return nil, nil, AlertIllegalParameter
}
if state.helloRetryRequest != nil { // 2. Check that it responded with a valid version.
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest") supportedVersions := SupportedVersionsExtension{HandshakeType: HandshakeTypeServerHello}
return nil, nil, AlertUnexpectedMessage foundSupportedVersions, err := sh.Extensions.Find(&supportedVersions)
} if err != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] invalid supported_versions extension [%v]", err)
return nil, nil, AlertDecodeError
}
if !foundSupportedVersions {
logf(logTypeHandshake, "[ClientStateWaitSH] no supported_versions extension")
return nil, nil, AlertMissingExtension
}
if supportedVersions.Versions[0] != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] unsupported version [%x]", supportedVersions.Versions[0])
return nil, nil, AlertProtocolVersion
}
// 3. Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Check that the version sent by the server is the one we support // Now check for the sentinel.
if hrr.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
return nil, nil, AlertProtocolVersion
}
// Check that the server provided a supported ciphersuite if sh.Random == hrrRandomSentinel {
supportedCipherSuite := false // This is actually HRR.
for _, suite := range state.Caps.CipherSuites { hrr := sh
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Narrow the supported ciphersuites to the server-provided one // Narrow the supported ciphersuites to the server-provided one
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite} state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
@ -373,10 +388,14 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
// The only thing we know how to respond to in an HRR is the Cookie // The only thing we know how to respond to in an HRR is the Cookie
// extension, so if there is either no Cookie extension or anything other // extension, so if there is either no Cookie extension or anything other
// than a Cookie extension, we have to fail. // than a Cookie extension and SupportedVersions we have to fail.
serverCookie := new(CookieExtension) serverCookie := new(CookieExtension)
foundCookie := hrr.Extensions.Find(serverCookie) foundCookie, err := hrr.Extensions.Find(serverCookie)
if !foundCookie || len(hrr.Extensions) != 1 { if err != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Invalid server cookie extension [%v]", err)
return nil, nil, AlertDecodeError
}
if !foundCookie || len(hrr.Extensions) != 2 {
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
return nil, nil, AlertIllegalParameter return nil, nil, AlertIllegalParameter
} }
@ -400,131 +419,117 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
firstClientHello: firstClientHello, firstClientHello: firstClientHello,
helloRetryRequest: hm, helloRetryRequest: hm,
}, nil, AlertNoAlert }, nil, AlertNoAlert
case *ServerHelloBody:
sh := body
// Check that the version sent by the server is the one we support
if sh.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
return nil, nil, AlertProtocolVersion
}
// Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
// Do PSK or key agreement depending on extensions
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
foundPSK := sh.Extensions.Find(&serverPSK)
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
if foundPSK && (serverPSK.SelectedIdentity == 0) {
state.Params.UsingPSK = true
}
var dhSecret []byte
if foundKeyShare {
sks := serverKeyShare.Shares[0]
priv, ok := state.OfferedDH[sks.Group]
if !ok {
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
return nil, nil, AlertIllegalParameter
}
state.Params.UsingDH = true
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
}
suite := sh.CipherSuite
state.Params.CipherSuite = suite
params, ok := cipherSuiteMap[suite]
if !ok {
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
return nil, nil, AlertHandshakeFailure
}
// Start up the handshake hash
handshakeHash := params.Hash.New()
handshakeHash.Write(state.firstClientHello.Marshal())
handshakeHash.Write(state.helloRetryRequest.Marshal())
handshakeHash.Write(state.clientHello.Marshal())
handshakeHash.Write(hm.Marshal())
// Compute handshake secrets
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
var earlySecret []byte
if state.Params.UsingPSK {
if params.Hash != state.earlyHash {
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
state.earlyHash, suite, params.Hash)
}
earlySecret = state.earlySecret
} else {
earlySecret = HkdfExtract(params.Hash, zero, zero)
}
if dhSecret == nil {
dhSecret = zero
}
h0 := params.Hash.New().Sum(nil)
h2 := handshakeHash.Sum(nil)
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
nextState := ClientStateWaitEE{
Caps: state.Caps,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: params,
handshakeHash: handshakeHash,
certificates: state.Caps.Certificates,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
}
toSend := []HandshakeAction{
RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys},
}
return nextState, toSend, AlertNoAlert
} }
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType) // This is SH.
return nil, nil, AlertUnexpectedMessage // Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
// Do PSK or key agreement depending on extensions
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
foundExts, err := sh.Extensions.Parse(
[]ExtensionBody{
&serverPSK,
&serverKeyShare,
})
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err)
return nil, nil, AlertDecodeError
}
if foundExts[ExtensionTypePreSharedKey] && (serverPSK.SelectedIdentity == 0) {
state.Params.UsingPSK = true
}
var dhSecret []byte
if foundExts[ExtensionTypeKeyShare] {
sks := serverKeyShare.Shares[0]
priv, ok := state.OfferedDH[sks.Group]
if !ok {
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
return nil, nil, AlertIllegalParameter
}
state.Params.UsingDH = true
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
}
suite := sh.CipherSuite
state.Params.CipherSuite = suite
params, ok := cipherSuiteMap[suite]
if !ok {
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
return nil, nil, AlertHandshakeFailure
}
// Start up the handshake hash
handshakeHash := params.Hash.New()
handshakeHash.Write(state.firstClientHello.Marshal())
handshakeHash.Write(state.helloRetryRequest.Marshal())
handshakeHash.Write(state.clientHello.Marshal())
handshakeHash.Write(hm.Marshal())
// Compute handshake secrets
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
var earlySecret []byte
if state.Params.UsingPSK {
if params.Hash != state.earlyHash {
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
state.earlyHash, suite, params.Hash)
}
earlySecret = state.earlySecret
} else {
earlySecret = HkdfExtract(params.Hash, zero, zero)
}
if dhSecret == nil {
dhSecret = zero
}
h0 := params.Hash.New().Sum(nil)
h2 := handshakeHash.Sum(nil)
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
nextState := ClientStateWaitEE{
Caps: state.Caps,
Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: params,
handshakeHash: handshakeHash,
certificates: state.Caps.Certificates,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
}
toSend := []HandshakeAction{
RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys},
}
return nextState, toSend, AlertNoAlert
} }
type ClientStateWaitEE struct { type ClientStateWaitEE struct {
@ -557,7 +562,7 @@ func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
} }
ee := EncryptedExtensionsBody{} ee := EncryptedExtensionsBody{}
if _, err := ee.Unmarshal(hm.body); err != nil { if err := safeUnmarshal(&ee, hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -571,13 +576,22 @@ func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
} }
} }
serverALPN := ALPNExtension{} serverALPN := &ALPNExtension{}
serverEarlyData := EarlyDataExtension{} serverEarlyData := &EarlyDataExtension{}
gotALPN := ee.Extensions.Find(&serverALPN) foundExts, err := ee.Extensions.Parse(
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData) []ExtensionBody{
serverALPN,
serverEarlyData,
})
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err)
return nil, nil, AlertDecodeError
}
if gotALPN && len(serverALPN.Protocols) > 0 { state.Params.UsingEarlyData = foundExts[ExtensionTypeEarlyData]
if foundExts[ExtensionTypeALPN] && len(serverALPN.Protocols) > 0 {
state.Params.NextProto = serverALPN.Protocols[0] state.Params.NextProto = serverALPN.Protocols[0]
} }
@ -726,7 +740,7 @@ func (state ClientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
} }
cert := &CertificateBody{} cert := &CertificateBody{}
if _, err := cert.Unmarshal(hm.body); err != nil { if err := safeUnmarshal(cert, hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -783,7 +797,7 @@ func (state ClientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
} }
certVerify := CertificateVerifyBody{} certVerify := CertificateVerifyBody{}
if _, err := certVerify.Unmarshal(hm.body); err != nil { if err := safeUnmarshal(&certVerify, hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -863,7 +877,7 @@ func (state ClientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
if _, err := fin.Unmarshal(hm.body); err != nil { if err := safeUnmarshal(fin, hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -913,7 +927,11 @@ func (state ClientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
if state.Params.UsingClientAuth { if state.Params.UsingClientAuth {
// Extract constraints from certicateRequest // Extract constraints from certicateRequest
schemes := SignatureAlgorithmsExtension{} schemes := SignatureAlgorithmsExtension{}
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes) gotSchemes, err := state.serverCertificateRequest.Extensions.Find(&schemes)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING invalid signature_schemes extension [%v]", err)
return nil, nil, AlertDecodeError
}
if !gotSchemes { if !gotSchemes {
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found") logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found")
return nil, nil, AlertIllegalParameter return nil, nil, AlertIllegalParameter

View file

@ -5,9 +5,13 @@ import (
"strconv" "strconv"
) )
var ( const (
supportedVersion uint16 = 0x7f15 // draft-21 supportedVersion uint16 = 0x7f16 // draft-22
tls12Version uint16 = 0x0303
tls10Version uint16 = 0x0301
)
var (
// Flags for some minor compat issues // Flags for some minor compat issues
allowWrongVersionNumber = true allowWrongVersionNumber = true
allowPKCS1 = true allowPKCS1 = true
@ -42,6 +46,13 @@ const (
HandshakeTypeMessageHash HandshakeType = 254 HandshakeTypeMessageHash HandshakeType = 254
) )
var hrrRandomSentinel = [32]byte{
0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11,
0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e,
0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
}
// uint8 CipherSuite[2]; // uint8 CipherSuite[2];
type CipherSuite uint16 type CipherSuite uint16

View file

@ -3,7 +3,6 @@ package mint
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/bifurcation/mint/syntax" "github.com/bifurcation/mint/syntax"
) )
@ -77,14 +76,40 @@ func (el *ExtensionList) Add(src ExtensionBody) error {
return nil return nil
} }
func (el ExtensionList) Find(dst ExtensionBody) bool { func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) {
for _, ext := range el { found := make(map[ExtensionType]bool)
if ext.ExtensionType == dst.Type() {
_, err := dst.Unmarshal(ext.ExtensionData) for _, dst := range dsts {
return err == nil for _, ext := range el {
if ext.ExtensionType == dst.Type() {
if found[dst.Type()] {
return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type())
}
err := safeUnmarshal(dst, ext.ExtensionData)
if err != nil {
return nil, err
}
found[dst.Type()] = true
}
} }
} }
return false
return found, nil
}
func (el ExtensionList) Find(dst ExtensionBody) (bool, error) {
for _, ext := range el {
if ext.ExtensionType == dst.Type() {
err := safeUnmarshal(dst, ext.ExtensionData)
if err != nil {
return true, err
}
return true, nil
}
}
return false, nil
} }
// struct { // struct {
@ -529,19 +554,56 @@ func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
// ProtocolVersion versions<2..254>; // ProtocolVersion versions<2..254>;
// } SupportedVersions; // } SupportedVersions;
type SupportedVersionsExtension struct { type SupportedVersionsExtension struct {
HandshakeType HandshakeType
Versions []uint16
}
type SupportedVersionsClientHelloInner struct {
Versions []uint16 `tls:"head=1,min=2,max=254"` Versions []uint16 `tls:"head=1,min=2,max=254"`
} }
type SupportedVersionsServerHelloInner struct {
Version uint16
}
func (sv SupportedVersionsExtension) Type() ExtensionType { func (sv SupportedVersionsExtension) Type() ExtensionType {
return ExtensionTypeSupportedVersions return ExtensionTypeSupportedVersions
} }
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sv) switch sv.HandshakeType {
case HandshakeTypeClientHello:
return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions})
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]})
default:
return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
}
} }
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sv) switch sv.HandshakeType {
case HandshakeTypeClientHello:
var inner SupportedVersionsClientHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
sv.Versions = inner.Versions
return read, nil
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
var inner SupportedVersionsServerHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
sv.Versions = []uint16{inner.Version}
return read, nil
default:
return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
}
} }
// struct { // struct {

View file

@ -99,7 +99,7 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
} }
_, err := body.Unmarshal(hm.body) err := safeUnmarshal(body, hm.body)
return body, err return body, err
} }
@ -324,8 +324,6 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
return nil, err return nil, err
} }
for { for {
// TODO(ekr@rtfm.com): Discard partial DTLS frames.
logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder))
if h.frame.needed() > 0 { if h.frame.needed() > 0 {
logf(logTypeVerbose, "Trying to read a new record") logf(logTypeVerbose, "Trying to read a new record")
@ -481,3 +479,19 @@ func decodeUint(in []byte, size int) (uint64, []byte) {
} }
return val, in[size:] return val, in[size:]
} }
type marshalledPDU interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
func safeUnmarshal(pdu marshalledPDU, data []byte) error {
read, err := pdu.Unmarshal(data)
if err != nil {
return err
}
if len(data) != read {
return fmt.Errorf("Invalid encoding: Extra data not consumed")
}
return nil
}

View file

@ -26,11 +26,10 @@ type HandshakeMessageBody interface {
// } ClientHello; // } ClientHello;
type ClientHelloBody struct { type ClientHelloBody struct {
// Omitted: clientVersion // Omitted: clientVersion
// Omitted: legacySessionID Random [32]byte
// Omitted: legacyCompressionMethods LegacySessionID []byte
Random [32]byte CipherSuites []CipherSuite
CipherSuites []CipherSuite Extensions ExtensionList
Extensions ExtensionList
} }
type clientHelloBodyInner struct { type clientHelloBodyInner struct {
@ -48,7 +47,7 @@ func (ch ClientHelloBody) Type() HandshakeType {
func (ch ClientHelloBody) Marshal() ([]byte, error) { func (ch ClientHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(clientHelloBodyInner{ return syntax.Marshal(clientHelloBodyInner{
LegacyVersion: 0x0303, LegacyVersion: tls12Version,
Random: ch.Random, Random: ch.Random,
LegacySessionID: []byte{}, LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites, CipherSuites: ch.CipherSuites,
@ -65,7 +64,7 @@ func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
} }
// We are strict about these things because we only support 1.3 // We are strict about these things because we only support 1.3
if inner.LegacyVersion != 0x0303 { if inner.LegacyVersion != tls12Version {
return 0, fmt.Errorf("tls.clienthello: Incorrect version number") return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
} }
@ -74,6 +73,7 @@ func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
} }
ch.Random = inner.Random ch.Random = inner.Random
ch.LegacySessionID = inner.LegacySessionID
ch.CipherSuites = inner.CipherSuites ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions ch.Extensions = inner.Extensions
return read, nil return read, nil
@ -144,16 +144,20 @@ func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
} }
// struct { // struct {
// ProtocolVersion version; // ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
// Random random; // Random random;
// opaque legacy_session_id_echo<0..32>;
// CipherSuite cipher_suite; // CipherSuite cipher_suite;
// Extension extensions<0..2^16-1>; // uint8 legacy_compression_method = 0;
// Extension extensions<6..2^16-1>;
// } ServerHello; // } ServerHello;
type ServerHelloBody struct { type ServerHelloBody struct {
Version uint16 Version uint16
Random [32]byte Random [32]byte
CipherSuite CipherSuite LegacySessionID []byte `tls:"head=1,max=32"`
Extensions ExtensionList `tls:"head=2"` CipherSuite CipherSuite
LegacyCompressionMethod uint8
Extensions ExtensionList `tls:"head=2"`
} }
func (sh ServerHelloBody) Type() HandshakeType { func (sh ServerHelloBody) Type() HandshakeType {

View file

@ -23,7 +23,7 @@ func (err DecryptError) Error() string {
// struct { // struct {
// ContentType type; // ContentType type;
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */ // ProtocolVersion record_version [0301 for CH, 0303 for others]
// uint16 length; // uint16 length;
// opaque fragment[TLSPlaintext.length]; // opaque fragment[TLSPlaintext.length];
// } TLSPlaintext; // } TLSPlaintext;
@ -45,6 +45,7 @@ type cipherState struct {
type RecordLayer struct { type RecordLayer struct {
sync.Mutex sync.Mutex
version uint16 // The current version number
conn io.ReadWriter // The underlying connection conn io.ReadWriter // The underlying connection
frame *frameReader // The buffered frame reader frame *frameReader // The buffered frame reader
nextData []byte // The next record to send nextData []byte // The next record to send
@ -92,6 +93,7 @@ func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer {
r.conn = conn r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{false}) r.frame = newFrameReader(recordLayerFrameDetails{false})
r.cipher = newCipherStateNull() r.cipher = newCipherStateNull()
r.version = tls10Version
return &r return &r
} }
@ -104,6 +106,10 @@ func NewRecordLayerDTLS(conn io.ReadWriter) *RecordLayer {
return &r return &r
} }
func (r *RecordLayer) SetVersion(v uint16) {
r.version = v
}
func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error { func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error {
cipher, err := newCipherStateAead(epoch, factory, key, iv) cipher, err := newCipherStateAead(epoch, factory, key, iv)
if err != nil { if err != nil {
@ -349,7 +355,9 @@ func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherSta
length := len(pt.fragment) length := len(pt.fragment)
var header []byte var header []byte
if !r.datagram { if !r.datagram {
header = []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)} header = []byte{byte(pt.contentType),
byte(r.version >> 8), byte(r.version & 0xff),
byte(length >> 8), byte(length)}
} else { } else {
// TODO(ekr@rtfm.com): Double check version // TODO(ekr@rtfm.com): Double check version
seq := cipher.seq seq := cipher.seq

View file

@ -93,7 +93,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
} }
ch := &ClientHelloBody{} ch := &ClientHelloBody{}
if _, err := ch.Unmarshal(hm.body); err != nil { if err := safeUnmarshal(ch, hm.body); err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err) logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -101,7 +101,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
clientHello := hm clientHello := hm
connParams := ConnectionParameters{} connParams := ConnectionParameters{}
supportedVersions := new(SupportedVersionsExtension) supportedVersions := &SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello}
serverName := new(ServerNameExtension) serverName := new(ServerNameExtension)
supportedGroups := new(SupportedGroupsExtension) supportedGroups := new(SupportedGroupsExtension)
signatureAlgorithms := new(SignatureAlgorithmsExtension) signatureAlgorithms := new(SignatureAlgorithmsExtension)
@ -121,26 +121,34 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
} }
} }
gotSupportedVersions := ch.Extensions.Find(supportedVersions) foundExts, err := ch.Extensions.Parse(
gotServerName := ch.Extensions.Find(serverName) []ExtensionBody{
gotSupportedGroups := ch.Extensions.Find(supportedGroups) supportedVersions,
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms) serverName,
gotEarlyData := ch.Extensions.Find(clientEarlyData) supportedGroups,
ch.Extensions.Find(clientKeyShares) signatureAlgorithms,
ch.Extensions.Find(clientPSK) clientEarlyData,
ch.Extensions.Find(clientALPN) clientKeyShares,
ch.Extensions.Find(clientPSKModes) clientPSK,
ch.Extensions.Find(clientCookie) clientALPN,
clientPSKModes,
clientCookie,
})
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error parsing extensions [%v]", err)
return nil, nil, AlertDecodeError
}
clientSentCookie := len(clientCookie.Cookie) > 0 clientSentCookie := len(clientCookie.Cookie) > 0
if gotServerName { if foundExts[ExtensionTypeServerName] {
connParams.ServerName = string(*serverName) connParams.ServerName = string(*serverName)
} }
// If the client didn't send supportedVersions or doesn't support 1.3, // If the client didn't send supportedVersions or doesn't support 1.3,
// then we're done here. // then we're done here.
if !gotSupportedVersions { if !foundExts[ExtensionTypeSupportedVersions] {
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions") logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions")
return nil, nil, AlertProtocolVersion return nil, nil, AlertProtocolVersion
} }
@ -160,7 +168,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
return nil, nil, AlertDecryptError return nil, nil, AlertDecryptError
} }
cookie := &cookie{} cookie := &cookie{}
if _, err := syntax.Unmarshal(plainCookie, cookie); err != nil { // this should never happen if rb, err := syntax.Unmarshal(plainCookie, cookie); err != nil && rb != len(plainCookie) { // this should never happen
logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error unmarshaling cookie [%v]", err)) logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error unmarshaling cookie [%v]", err))
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
@ -182,6 +190,11 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
} }
} }
if len(ch.LegacySessionID) != 0 && len(ch.LegacySessionID) != 32 {
logf(logTypeHandshake, "[ServerStateStart] invalid session ID")
return nil, nil, AlertIllegalParameter
}
// Figure out if we can do DH // Figure out if we can do DH
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups) canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
@ -196,7 +209,8 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
contextBase = append(contextBase, firstClientHello.Marshal()...) contextBase = append(contextBase, firstClientHello.Marshal()...)
// fill in the cookie sent by the client. Needed to calculate the correct hash // fill in the cookie sent by the client. Needed to calculate the correct hash
cookieExt := &CookieExtension{Cookie: clientCookie.Cookie} cookieExt := &CookieExtension{Cookie: clientCookie.Cookie}
hrr, err := state.generateHRR(params.Suite, cookieExt) hrr, err := state.generateHRR(params.Suite,
ch.LegacySessionID, cookieExt)
if err != nil { if err != nil {
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
@ -224,7 +238,6 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes) connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
// Select a ciphersuite // Select a ciphersuite
var err error
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites) connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err) logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
@ -285,7 +298,8 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
// Ignoring errors because everything here is newly constructed, so there // Ignoring errors because everything here is newly constructed, so there
// shouldn't be marshal errors // shouldn't be marshal errors
if shouldSendHRR || clientSentCookie { if shouldSendHRR || clientSentCookie {
helloRetryRequest, err = state.generateHRR(connParams.CipherSuite, cookieExt) helloRetryRequest, err = state.generateHRR(connParams.CipherSuite,
ch.LegacySessionID, cookieExt)
if err != nil { if err != nil {
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
@ -316,9 +330,10 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
psk = nil psk = nil
// If we're not using a PSK mode, then we need to have certain extensions // If we're not using a PSK mode, then we need to have certain extensions
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms { if !(foundExts[ExtensionTypeServerName] &&
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)", foundExts[ExtensionTypeSupportedGroups] &&
gotServerName, gotSupportedGroups, gotSignatureAlgorithms) foundExts[ExtensionTypeSignatureAlgorithms]) {
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v)", foundExts)
return nil, nil, AlertMissingExtension return nil, nil, AlertMissingExtension
} }
@ -338,8 +353,8 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
// Figure out if we're going to do early data // Figure out if we're going to do early data
var clientEarlyTrafficSecret []byte var clientEarlyTrafficSecret []byte
connParams.ClientSendingEarlyData = gotEarlyData connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData]
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData) connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Caps.AllowEarlyData)
if connParams.UsingEarlyData { if connParams.UsingEarlyData {
h := params.Hash.New() h := params.Hash.New()
h.Write(clientHello.Marshal()) h.Write(clientHello.Marshal())
@ -358,6 +373,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
} }
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2.
return ServerStateNegotiated{ return ServerStateNegotiated{
Caps: state.Caps, Caps: state.Caps,
Params: connParams, Params: connParams,
@ -369,6 +385,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
selectedPSK: selectedPSK, selectedPSK: selectedPSK,
cert: cert, cert: cert,
certScheme: certScheme, certScheme: certScheme,
legacySessionId: ch.LegacySessionID,
clientEarlyTrafficSecret: clientEarlyTrafficSecret, clientEarlyTrafficSecret: clientEarlyTrafficSecret,
firstClientHello: firstClientHello, firstClientHello: firstClientHello,
@ -377,12 +394,27 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}, nil, AlertNoAlert }, nil, AlertNoAlert
} }
func (state *ServerStateStart) generateHRR(cs CipherSuite, cookieExt *CookieExtension) (*HandshakeMessage, error) { func (state *ServerStateStart) generateHRR(cs CipherSuite, legacySessionId []byte,
cookieExt *CookieExtension) (*HandshakeMessage, error) {
var helloRetryRequest *HandshakeMessage var helloRetryRequest *HandshakeMessage
hrr := &HelloRetryRequestBody{ hrr := &ServerHelloBody{
Version: supportedVersion, Version: tls12Version,
CipherSuite: cs, Random: hrrRandomSentinel,
CipherSuite: cs,
LegacySessionID: legacySessionId,
LegacyCompressionMethod: 0,
} }
sv := &SupportedVersionsExtension{
HandshakeType: HandshakeTypeServerHello,
Versions: []uint16{supportedVersion},
}
if err := hrr.Extensions.Add(sv); err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error adding SupportedVersion [%v]", err)
return nil, err
}
if err := hrr.Extensions.Add(cookieExt); err != nil { if err := hrr.Extensions.Add(cookieExt); err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error adding CookieExtension [%v]", err) logf(logTypeHandshake, "[ServerStateStart] Error adding CookieExtension [%v]", err)
return nil, err return nil, err
@ -415,10 +447,10 @@ type ServerStateNegotiated struct {
selectedPSK int selectedPSK int
cert *Certificate cert *Certificate
certScheme SignatureScheme certScheme SignatureScheme
legacySessionId []byte
firstClientHello *HandshakeMessage firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage helloRetryRequest *HandshakeMessage
clientHello *HandshakeMessage clientHello *HandshakeMessage
} }
var _ HandshakeState = &ServerStateNegotiated{} var _ HandshakeState = &ServerStateNegotiated{}
@ -430,13 +462,24 @@ func (state ServerStateNegotiated) State() State {
func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
// Create the ServerHello // Create the ServerHello
sh := &ServerHelloBody{ sh := &ServerHelloBody{
Version: supportedVersion, Version: tls12Version,
CipherSuite: state.Params.CipherSuite, CipherSuite: state.Params.CipherSuite,
LegacySessionID: state.legacySessionId,
LegacyCompressionMethod: 0,
} }
if _, err := prng.Read(sh.Random[:]); err != nil { if _, err := prng.Read(sh.Random[:]); err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err) logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
err := sh.Extensions.Add(&SupportedVersionsExtension{
HandshakeType: HandshakeTypeServerHello,
Versions: []uint16{supportedVersion},
})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported_versions extension [%v]", err)
return nil, nil, AlertInternalError
}
if state.Params.UsingDH { if state.Params.UsingDH {
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension") logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension")
err := sh.Extensions.Add(&KeyShareExtension{ err := sh.Extensions.Add(&KeyShareExtension{
@ -845,7 +888,7 @@ func (state ServerStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
} }
cert := &CertificateBody{} cert := &CertificateBody{}
if _, err := cert.Unmarshal(hm.body); err != nil { if err := safeUnmarshal(cert, hm.body); err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -921,7 +964,7 @@ func (state ServerStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
} }
certVerify := &CertificateVerifyBody{} certVerify := &CertificateVerifyBody{}
if _, err := certVerify.Unmarshal(hm.body); err != nil { if err := safeUnmarshal(certVerify, hm.body); err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -995,7 +1038,7 @@ func (state ServerStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
} }
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()} fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
if _, err := fin.Unmarshal(hm.body); err != nil { if err := safeUnmarshal(fin, hm.body); err != nil {
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err) logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }

View file

@ -95,6 +95,15 @@ type HandshakeContext struct {
hIn, hOut *HandshakeLayer hIn, hOut *HandshakeLayer
} }
func (hc *HandshakeContext) SetVersion(version uint16) {
if hc.hIn.conn != nil {
hc.hIn.conn.SetVersion(version)
}
if hc.hOut.conn != nil {
hc.hOut.conn.SetVersion(version)
}
}
// StateConnected is symmetric between client and server // StateConnected is symmetric between client and server
type StateConnected struct { type StateConnected struct {
Params ConnectionParameters Params ConnectionParameters

6
vendor/vendor.json vendored
View file

@ -3,10 +3,10 @@
"ignore": "test", "ignore": "test",
"package": [ "package": [
{ {
"checksumSHA1": "7wknjQuGJ8gr8VsM8cTLgu1M8vw=", "checksumSHA1": "nxj6lkDUEZ81SO0lP8YUhm+4BAM=",
"path": "github.com/bifurcation/mint", "path": "github.com/bifurcation/mint",
"revision": "a544bfbca6a083ce9ddeb2c5f570cb240837355a", "revision": "f699e8d03646cb8e6e15410ced7bff37fcf8dddd",
"revisionTime": "2017-12-09T20:11:46Z" "revisionTime": "2017-12-21T19:05:27Z"
}, },
{ {
"checksumSHA1": "PZNcjO1c9gV/LZzppwpVRl6+QAY=", "checksumSHA1": "PZNcjO1c9gV/LZzppwpVRl6+QAY=",