diff --git a/internal/handshake/tls_extension_handler_client.go b/internal/handshake/tls_extension_handler_client.go index f2dfa958..20d2d06b 100644 --- a/internal/handshake/tls_extension_handler_client.go +++ b/internal/handshake/tls_extension_handler_client.go @@ -24,6 +24,7 @@ type extensionHandlerClient struct { var _ mint.AppExtensionHandler = &extensionHandlerClient{} var _ TLSExtensionHandler = &extensionHandlerClient{} +// NewExtensionHandlerClient creates a new extension handler for the client. func NewExtensionHandlerClient( params *TransportParameters, initialVersion protocol.VersionNumber, @@ -57,7 +58,10 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { ext := &tlsExtensionBody{} - found := el.Find(ext) + found, err := el.Find(ext) + if err != nil { + return err + } if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket { if found { diff --git a/internal/handshake/tls_extension_handler_client_test.go b/internal/handshake/tls_extension_handler_client_test.go index 52822b85..05cfae5e 100644 --- a/internal/handshake/tls_extension_handler_client_test.go +++ b/internal/handshake/tls_extension_handler_client_test.go @@ -39,7 +39,8 @@ var _ = Describe("TLS Extension Handler, for the client", func() { Expect(err).ToNot(HaveOccurred()) Expect(el).To(HaveLen(1)) ext := &tlsExtensionBody{} - found := el.Find(ext) + found, err := el.Find(ext) + Expect(err).ToNot(HaveOccurred()) Expect(found).To(BeTrue()) chtp := &clientHelloTransportParameters{} _, err = syntax.Unmarshal(ext.data, chtp) diff --git a/internal/handshake/tls_extension_handler_server.go b/internal/handshake/tls_extension_handler_server.go index 941f5115..313751c9 100644 --- a/internal/handshake/tls_extension_handler_server.go +++ b/internal/handshake/tls_extension_handler_server.go @@ -24,6 +24,7 @@ type extensionHandlerServer struct { var _ mint.AppExtensionHandler = &extensionHandlerServer{} var _ TLSExtensionHandler = &extensionHandlerServer{} +// NewExtensionHandlerServer creates a new extension handler for the server func NewExtensionHandlerServer( params *TransportParameters, supportedVersions []protocol.VersionNumber, @@ -66,7 +67,10 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { ext := &tlsExtensionBody{} - found := el.Find(ext) + found, err := el.Find(ext) + if err != nil { + return err + } if hType != mint.HandshakeTypeClientHello { if found { diff --git a/internal/handshake/tls_extension_handler_server_test.go b/internal/handshake/tls_extension_handler_server_test.go index ceab29b1..8bb8d102 100644 --- a/internal/handshake/tls_extension_handler_server_test.go +++ b/internal/handshake/tls_extension_handler_server_test.go @@ -48,7 +48,8 @@ var _ = Describe("TLS Extension Handler, for the server", func() { Expect(err).ToNot(HaveOccurred()) Expect(el).To(HaveLen(1)) ext := &tlsExtensionBody{} - found := el.Find(ext) + found, err := el.Find(ext) + Expect(err).ToNot(HaveOccurred()) Expect(found).To(BeTrue()) eetp := &encryptedExtensionsTransportParameters{} _, err = syntax.Unmarshal(ext.data, eetp) diff --git a/vendor/github.com/bifurcation/mint/client-state-machine.go b/vendor/github.com/bifurcation/mint/client-state-machine.go index 0ab8a41a..b72c584d 100644 --- a/vendor/github.com/bifurcation/mint/client-state-machine.go +++ b/vendor/github.com/bifurcation/mint/client-state-machine.go @@ -88,7 +88,7 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ logf(logTypeHandshake, "opts: %+v", state.Opts) // supported_versions, supported_groups, signature_algorithms, server_name - sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}} + sv := SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello, Versions: []uint16{supportedVersion}} sni := ServerNameExtension(state.Opts.ServerName) sg := SupportedGroupsExtension{Groups: state.Caps.Groups} sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} @@ -265,6 +265,7 @@ func (state ClientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ } logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") + state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. nextState := ClientStateWaitSH{ Caps: state.Caps, Opts: state.Opts, @@ -323,41 +324,55 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, if alert != AlertNoAlert { return nil, nil, alert } - if hm == nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message") + + if hm == nil || hm.msgType != HandshakeTypeServerHello { + logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message") return nil, nil, AlertUnexpectedMessage } - bodyGeneric, err := hm.ToBody() - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err) - return nil, nil, AlertDecodeError + sh := &ServerHelloBody{} + if _, err := sh.Unmarshal(hm.body); err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] unexpected message") + return nil, nil, AlertUnexpectedMessage } - switch body := bodyGeneric.(type) { - case *HelloRetryRequestBody: - hrr := body + // Common SH/HRR processing first. + // 1. Check that sh.version is TLS 1.2 + if sh.Version != tls12Version { + logf(logTypeHandshake, "[ClientStateWaitSH] illegal legacy version [%v]", sh.Version) + return nil, nil, AlertIllegalParameter + } - if state.helloRetryRequest != nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest") - return nil, nil, AlertUnexpectedMessage - } + // 2. Check that it responded with a valid version. + supportedVersions := SupportedVersionsExtension{HandshakeType: HandshakeTypeServerHello} + foundSupportedVersions, err := sh.Extensions.Find(&supportedVersions) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] invalid supported_versions extension [%v]", err) + return nil, nil, AlertDecodeError + } + if !foundSupportedVersions { + logf(logTypeHandshake, "[ClientStateWaitSH] no supported_versions extension") + return nil, nil, AlertMissingExtension + } + if supportedVersions.Versions[0] != supportedVersion { + logf(logTypeHandshake, "[ClientStateWaitSH] unsupported version [%x]", supportedVersions.Versions[0]) + return nil, nil, AlertProtocolVersion + } + // 3. Check that the server provided a supported ciphersuite + supportedCipherSuite := false + for _, suite := range state.Caps.CipherSuites { + supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) + } + if !supportedCipherSuite { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) + return nil, nil, AlertHandshakeFailure + } - // Check that the version sent by the server is the one we support - if hrr.Version != supportedVersion { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version) - return nil, nil, AlertProtocolVersion - } + // Now check for the sentinel. - // Check that the server provided a supported ciphersuite - supportedCipherSuite := false - for _, suite := range state.Caps.CipherSuites { - supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite) - } - if !supportedCipherSuite { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite) - return nil, nil, AlertHandshakeFailure - } + if sh.Random == hrrRandomSentinel { + // This is actually HRR. + hrr := sh // Narrow the supported ciphersuites to the server-provided one state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite} @@ -373,10 +388,14 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, // The only thing we know how to respond to in an HRR is the Cookie // extension, so if there is either no Cookie extension or anything other - // than a Cookie extension, we have to fail. + // than a Cookie extension and SupportedVersions we have to fail. serverCookie := new(CookieExtension) - foundCookie := hrr.Extensions.Find(serverCookie) - if !foundCookie || len(hrr.Extensions) != 1 { + foundCookie, err := hrr.Extensions.Find(serverCookie) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] Invalid server cookie extension [%v]", err) + return nil, nil, AlertDecodeError + } + if !foundCookie || len(hrr.Extensions) != 2 { logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) return nil, nil, AlertIllegalParameter } @@ -400,131 +419,117 @@ func (state ClientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, firstClientHello: firstClientHello, helloRetryRequest: hm, }, nil, AlertNoAlert - - case *ServerHelloBody: - sh := body - - // Check that the version sent by the server is the one we support - if sh.Version != supportedVersion { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version) - return nil, nil, AlertProtocolVersion - } - - // Check that the server provided a supported ciphersuite - supportedCipherSuite := false - for _, suite := range state.Caps.CipherSuites { - supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) - } - if !supportedCipherSuite { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) - return nil, nil, AlertHandshakeFailure - } - - // Handle external extensions. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Do PSK or key agreement depending on extensions - serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} - serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} - - foundPSK := sh.Extensions.Find(&serverPSK) - foundKeyShare := sh.Extensions.Find(&serverKeyShare) - - if foundPSK && (serverPSK.SelectedIdentity == 0) { - state.Params.UsingPSK = true - } - - var dhSecret []byte - if foundKeyShare { - sks := serverKeyShare.Shares[0] - priv, ok := state.OfferedDH[sks.Group] - if !ok { - logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") - return nil, nil, AlertIllegalParameter - } - - state.Params.UsingDH = true - dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) - } - - suite := sh.CipherSuite - state.Params.CipherSuite = suite - - params, ok := cipherSuiteMap[suite] - if !ok { - logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) - return nil, nil, AlertHandshakeFailure - } - - // Start up the handshake hash - handshakeHash := params.Hash.New() - handshakeHash.Write(state.firstClientHello.Marshal()) - handshakeHash.Write(state.helloRetryRequest.Marshal()) - handshakeHash.Write(state.clientHello.Marshal()) - handshakeHash.Write(hm.Marshal()) - - // Compute handshake secrets - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - - var earlySecret []byte - if state.Params.UsingPSK { - if params.Hash != state.earlyHash { - logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", - state.earlyHash, suite, params.Hash) - } - - earlySecret = state.earlySecret - } else { - earlySecret = HkdfExtract(params.Hash, zero, zero) - } - - if dhSecret == nil { - dhSecret = zero - } - - h0 := params.Hash.New().Sum(nil) - h2 := handshakeHash.Sum(nil) - preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) - handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) - clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) - serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) - preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) - masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) - - logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) - logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) - logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) - logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) - logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) - - serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) - - logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") - nextState := ClientStateWaitEE{ - Caps: state.Caps, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: params, - handshakeHash: handshakeHash, - certificates: state.Caps.Certificates, - masterSecret: masterSecret, - clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, - } - toSend := []HandshakeAction{ - RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, - } - return nextState, toSend, AlertNoAlert } - logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType) - return nil, nil, AlertUnexpectedMessage + // This is SH. + // Handle external extensions. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Do PSK or key agreement depending on extensions + serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} + serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} + + foundExts, err := sh.Extensions.Parse( + []ExtensionBody{ + &serverPSK, + &serverKeyShare, + }) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err) + return nil, nil, AlertDecodeError + } + + if foundExts[ExtensionTypePreSharedKey] && (serverPSK.SelectedIdentity == 0) { + state.Params.UsingPSK = true + } + + var dhSecret []byte + if foundExts[ExtensionTypeKeyShare] { + sks := serverKeyShare.Shares[0] + priv, ok := state.OfferedDH[sks.Group] + if !ok { + logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") + return nil, nil, AlertIllegalParameter + } + + state.Params.UsingDH = true + dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) + } + + suite := sh.CipherSuite + state.Params.CipherSuite = suite + + params, ok := cipherSuiteMap[suite] + if !ok { + logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) + return nil, nil, AlertHandshakeFailure + } + + // Start up the handshake hash + handshakeHash := params.Hash.New() + handshakeHash.Write(state.firstClientHello.Marshal()) + handshakeHash.Write(state.helloRetryRequest.Marshal()) + handshakeHash.Write(state.clientHello.Marshal()) + handshakeHash.Write(hm.Marshal()) + + // Compute handshake secrets + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + var earlySecret []byte + if state.Params.UsingPSK { + if params.Hash != state.earlyHash { + logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", + state.earlyHash, suite, params.Hash) + } + + earlySecret = state.earlySecret + } else { + earlySecret = HkdfExtract(params.Hash, zero, zero) + } + + if dhSecret == nil { + dhSecret = zero + } + + h0 := params.Hash.New().Sum(nil) + h2 := handshakeHash.Sum(nil) + preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) + handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) + clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) + serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) + preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) + masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) + + logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) + logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) + logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) + logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) + logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) + + serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) + + logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") + nextState := ClientStateWaitEE{ + Caps: state.Caps, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: params, + handshakeHash: handshakeHash, + certificates: state.Caps.Certificates, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, + } + toSend := []HandshakeAction{ + RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, + } + return nextState, toSend, AlertNoAlert } type ClientStateWaitEE struct { @@ -557,7 +562,7 @@ func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, } ee := EncryptedExtensionsBody{} - if _, err := ee.Unmarshal(hm.body); err != nil { + if err := safeUnmarshal(&ee, hm.body); err != nil { logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) return nil, nil, AlertDecodeError } @@ -571,13 +576,22 @@ func (state ClientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, } } - serverALPN := ALPNExtension{} - serverEarlyData := EarlyDataExtension{} + serverALPN := &ALPNExtension{} + serverEarlyData := &EarlyDataExtension{} - gotALPN := ee.Extensions.Find(&serverALPN) - state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData) + foundExts, err := ee.Extensions.Parse( + []ExtensionBody{ + serverALPN, + serverEarlyData, + }) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err) + return nil, nil, AlertDecodeError + } - if gotALPN && len(serverALPN.Protocols) > 0 { + state.Params.UsingEarlyData = foundExts[ExtensionTypeEarlyData] + + if foundExts[ExtensionTypeALPN] && len(serverALPN.Protocols) > 0 { state.Params.NextProto = serverALPN.Protocols[0] } @@ -726,7 +740,7 @@ func (state ClientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState } cert := &CertificateBody{} - if _, err := cert.Unmarshal(hm.body); err != nil { + if err := safeUnmarshal(cert, hm.body); err != nil { logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) return nil, nil, AlertDecodeError } @@ -783,7 +797,7 @@ func (state ClientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, } certVerify := CertificateVerifyBody{} - if _, err := certVerify.Unmarshal(hm.body); err != nil { + if err := safeUnmarshal(&certVerify, hm.body); err != nil { logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) return nil, nil, AlertDecodeError } @@ -863,7 +877,7 @@ func (state ClientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} - if _, err := fin.Unmarshal(hm.body); err != nil { + if err := safeUnmarshal(fin, hm.body); err != nil { logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) return nil, nil, AlertDecodeError } @@ -913,7 +927,11 @@ func (state ClientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS if state.Params.UsingClientAuth { // Extract constraints from certicateRequest schemes := SignatureAlgorithmsExtension{} - gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes) + gotSchemes, err := state.serverCertificateRequest.Extensions.Find(&schemes) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING invalid signature_schemes extension [%v]", err) + return nil, nil, AlertDecodeError + } if !gotSchemes { logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found") return nil, nil, AlertIllegalParameter diff --git a/vendor/github.com/bifurcation/mint/common.go b/vendor/github.com/bifurcation/mint/common.go index 8c97de98..3dc44f5e 100644 --- a/vendor/github.com/bifurcation/mint/common.go +++ b/vendor/github.com/bifurcation/mint/common.go @@ -5,9 +5,13 @@ import ( "strconv" ) -var ( - supportedVersion uint16 = 0x7f15 // draft-21 +const ( + supportedVersion uint16 = 0x7f16 // draft-22 + tls12Version uint16 = 0x0303 + tls10Version uint16 = 0x0301 +) +var ( // Flags for some minor compat issues allowWrongVersionNumber = true allowPKCS1 = true @@ -42,6 +46,13 @@ const ( HandshakeTypeMessageHash HandshakeType = 254 ) +var hrrRandomSentinel = [32]byte{ + 0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, + 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91, + 0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, + 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c, +} + // uint8 CipherSuite[2]; type CipherSuite uint16 diff --git a/vendor/github.com/bifurcation/mint/extensions.go b/vendor/github.com/bifurcation/mint/extensions.go index f239e16f..07cb16c6 100644 --- a/vendor/github.com/bifurcation/mint/extensions.go +++ b/vendor/github.com/bifurcation/mint/extensions.go @@ -3,7 +3,6 @@ package mint import ( "bytes" "fmt" - "github.com/bifurcation/mint/syntax" ) @@ -77,14 +76,40 @@ func (el *ExtensionList) Add(src ExtensionBody) error { return nil } -func (el ExtensionList) Find(dst ExtensionBody) bool { - for _, ext := range el { - if ext.ExtensionType == dst.Type() { - _, err := dst.Unmarshal(ext.ExtensionData) - return err == nil +func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) { + found := make(map[ExtensionType]bool) + + for _, dst := range dsts { + for _, ext := range el { + if ext.ExtensionType == dst.Type() { + if found[dst.Type()] { + return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type()) + } + + err := safeUnmarshal(dst, ext.ExtensionData) + if err != nil { + return nil, err + } + + found[dst.Type()] = true + } } } - return false + + return found, nil +} + +func (el ExtensionList) Find(dst ExtensionBody) (bool, error) { + for _, ext := range el { + if ext.ExtensionType == dst.Type() { + err := safeUnmarshal(dst, ext.ExtensionData) + if err != nil { + return true, err + } + return true, nil + } + } + return false, nil } // struct { @@ -529,19 +554,56 @@ func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) { // ProtocolVersion versions<2..254>; // } SupportedVersions; type SupportedVersionsExtension struct { + HandshakeType HandshakeType + Versions []uint16 +} + +type SupportedVersionsClientHelloInner struct { Versions []uint16 `tls:"head=1,min=2,max=254"` } +type SupportedVersionsServerHelloInner struct { + Version uint16 +} + func (sv SupportedVersionsExtension) Type() ExtensionType { return ExtensionTypeSupportedVersions } func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { - return syntax.Marshal(sv) + switch sv.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions}) + case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest: + return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]}) + default: + return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed") + } } func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, sv) + switch sv.HandshakeType { + case HandshakeTypeClientHello: + var inner SupportedVersionsClientHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + sv.Versions = inner.Versions + return read, nil + + case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest: + var inner SupportedVersionsServerHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + sv.Versions = []uint16{inner.Version} + return read, nil + + default: + return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed") + } } // struct { diff --git a/vendor/github.com/bifurcation/mint/handshake-layer.go b/vendor/github.com/bifurcation/mint/handshake-layer.go index 583b95eb..48d086b6 100644 --- a/vendor/github.com/bifurcation/mint/handshake-layer.go +++ b/vendor/github.com/bifurcation/mint/handshake-layer.go @@ -99,7 +99,7 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") } - _, err := body.Unmarshal(hm.body) + err := safeUnmarshal(body, hm.body) return body, err } @@ -324,8 +324,6 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { return nil, err } for { - - // TODO(ekr@rtfm.com): Discard partial DTLS frames. logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) if h.frame.needed() > 0 { logf(logTypeVerbose, "Trying to read a new record") @@ -481,3 +479,19 @@ func decodeUint(in []byte, size int) (uint64, []byte) { } return val, in[size:] } + +type marshalledPDU interface { + Marshal() ([]byte, error) + Unmarshal(data []byte) (int, error) +} + +func safeUnmarshal(pdu marshalledPDU, data []byte) error { + read, err := pdu.Unmarshal(data) + if err != nil { + return err + } + if len(data) != read { + return fmt.Errorf("Invalid encoding: Extra data not consumed") + } + return nil +} diff --git a/vendor/github.com/bifurcation/mint/handshake-messages.go b/vendor/github.com/bifurcation/mint/handshake-messages.go index 789a6547..b29faa37 100644 --- a/vendor/github.com/bifurcation/mint/handshake-messages.go +++ b/vendor/github.com/bifurcation/mint/handshake-messages.go @@ -26,11 +26,10 @@ type HandshakeMessageBody interface { // } ClientHello; type ClientHelloBody struct { // Omitted: clientVersion - // Omitted: legacySessionID - // Omitted: legacyCompressionMethods - Random [32]byte - CipherSuites []CipherSuite - Extensions ExtensionList + Random [32]byte + LegacySessionID []byte + CipherSuites []CipherSuite + Extensions ExtensionList } type clientHelloBodyInner struct { @@ -48,7 +47,7 @@ func (ch ClientHelloBody) Type() HandshakeType { func (ch ClientHelloBody) Marshal() ([]byte, error) { return syntax.Marshal(clientHelloBodyInner{ - LegacyVersion: 0x0303, + LegacyVersion: tls12Version, Random: ch.Random, LegacySessionID: []byte{}, CipherSuites: ch.CipherSuites, @@ -65,7 +64,7 @@ func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) { } // We are strict about these things because we only support 1.3 - if inner.LegacyVersion != 0x0303 { + if inner.LegacyVersion != tls12Version { return 0, fmt.Errorf("tls.clienthello: Incorrect version number") } @@ -74,6 +73,7 @@ func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) { } ch.Random = inner.Random + ch.LegacySessionID = inner.LegacySessionID ch.CipherSuites = inner.CipherSuites ch.Extensions = inner.Extensions return read, nil @@ -144,16 +144,20 @@ func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) { } // struct { -// ProtocolVersion version; +// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ // Random random; +// opaque legacy_session_id_echo<0..32>; // CipherSuite cipher_suite; -// Extension extensions<0..2^16-1>; +// uint8 legacy_compression_method = 0; +// Extension extensions<6..2^16-1>; // } ServerHello; type ServerHelloBody struct { - Version uint16 - Random [32]byte - CipherSuite CipherSuite - Extensions ExtensionList `tls:"head=2"` + Version uint16 + Random [32]byte + LegacySessionID []byte `tls:"head=1,max=32"` + CipherSuite CipherSuite + LegacyCompressionMethod uint8 + Extensions ExtensionList `tls:"head=2"` } func (sh ServerHelloBody) Type() HandshakeType { diff --git a/vendor/github.com/bifurcation/mint/record-layer.go b/vendor/github.com/bifurcation/mint/record-layer.go index 32a393ae..ee9fe43a 100644 --- a/vendor/github.com/bifurcation/mint/record-layer.go +++ b/vendor/github.com/bifurcation/mint/record-layer.go @@ -23,7 +23,7 @@ func (err DecryptError) Error() string { // struct { // ContentType type; -// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */ +// ProtocolVersion record_version [0301 for CH, 0303 for others] // uint16 length; // opaque fragment[TLSPlaintext.length]; // } TLSPlaintext; @@ -45,6 +45,7 @@ type cipherState struct { type RecordLayer struct { sync.Mutex + version uint16 // The current version number conn io.ReadWriter // The underlying connection frame *frameReader // The buffered frame reader nextData []byte // The next record to send @@ -92,6 +93,7 @@ func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer { r.conn = conn r.frame = newFrameReader(recordLayerFrameDetails{false}) r.cipher = newCipherStateNull() + r.version = tls10Version return &r } @@ -104,6 +106,10 @@ func NewRecordLayerDTLS(conn io.ReadWriter) *RecordLayer { return &r } +func (r *RecordLayer) SetVersion(v uint16) { + r.version = v +} + func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error { cipher, err := newCipherStateAead(epoch, factory, key, iv) if err != nil { @@ -349,7 +355,9 @@ func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherSta length := len(pt.fragment) var header []byte if !r.datagram { - header = []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)} + header = []byte{byte(pt.contentType), + byte(r.version >> 8), byte(r.version & 0xff), + byte(length >> 8), byte(length)} } else { // TODO(ekr@rtfm.com): Double check version seq := cipher.seq diff --git a/vendor/github.com/bifurcation/mint/server-state-machine.go b/vendor/github.com/bifurcation/mint/server-state-machine.go index a0f28287..a57cdfca 100644 --- a/vendor/github.com/bifurcation/mint/server-state-machine.go +++ b/vendor/github.com/bifurcation/mint/server-state-machine.go @@ -93,7 +93,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ } ch := &ClientHelloBody{} - if _, err := ch.Unmarshal(hm.body); err != nil { + if err := safeUnmarshal(ch, hm.body); err != nil { logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err) return nil, nil, AlertDecodeError } @@ -101,7 +101,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ clientHello := hm connParams := ConnectionParameters{} - supportedVersions := new(SupportedVersionsExtension) + supportedVersions := &SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello} serverName := new(ServerNameExtension) supportedGroups := new(SupportedGroupsExtension) signatureAlgorithms := new(SignatureAlgorithmsExtension) @@ -121,26 +121,34 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ } } - gotSupportedVersions := ch.Extensions.Find(supportedVersions) - gotServerName := ch.Extensions.Find(serverName) - gotSupportedGroups := ch.Extensions.Find(supportedGroups) - gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms) - gotEarlyData := ch.Extensions.Find(clientEarlyData) - ch.Extensions.Find(clientKeyShares) - ch.Extensions.Find(clientPSK) - ch.Extensions.Find(clientALPN) - ch.Extensions.Find(clientPSKModes) - ch.Extensions.Find(clientCookie) + foundExts, err := ch.Extensions.Parse( + []ExtensionBody{ + supportedVersions, + serverName, + supportedGroups, + signatureAlgorithms, + clientEarlyData, + clientKeyShares, + clientPSK, + clientALPN, + clientPSKModes, + clientCookie, + }) + + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error parsing extensions [%v]", err) + return nil, nil, AlertDecodeError + } clientSentCookie := len(clientCookie.Cookie) > 0 - if gotServerName { + if foundExts[ExtensionTypeServerName] { connParams.ServerName = string(*serverName) } // If the client didn't send supportedVersions or doesn't support 1.3, // then we're done here. - if !gotSupportedVersions { + if !foundExts[ExtensionTypeSupportedVersions] { logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions") return nil, nil, AlertProtocolVersion } @@ -160,7 +168,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertDecryptError } cookie := &cookie{} - if _, err := syntax.Unmarshal(plainCookie, cookie); err != nil { // this should never happen + if rb, err := syntax.Unmarshal(plainCookie, cookie); err != nil && rb != len(plainCookie) { // this should never happen logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error unmarshaling cookie [%v]", err)) return nil, nil, AlertInternalError } @@ -182,6 +190,11 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ } } + if len(ch.LegacySessionID) != 0 && len(ch.LegacySessionID) != 32 { + logf(logTypeHandshake, "[ServerStateStart] invalid session ID") + return nil, nil, AlertIllegalParameter + } + // Figure out if we can do DH canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups) @@ -196,7 +209,8 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ contextBase = append(contextBase, firstClientHello.Marshal()...) // fill in the cookie sent by the client. Needed to calculate the correct hash cookieExt := &CookieExtension{Cookie: clientCookie.Cookie} - hrr, err := state.generateHRR(params.Suite, cookieExt) + hrr, err := state.generateHRR(params.Suite, + ch.LegacySessionID, cookieExt) if err != nil { return nil, nil, AlertInternalError } @@ -224,7 +238,6 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes) // Select a ciphersuite - var err error connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites) if err != nil { logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err) @@ -285,7 +298,8 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ // Ignoring errors because everything here is newly constructed, so there // shouldn't be marshal errors if shouldSendHRR || clientSentCookie { - helloRetryRequest, err = state.generateHRR(connParams.CipherSuite, cookieExt) + helloRetryRequest, err = state.generateHRR(connParams.CipherSuite, + ch.LegacySessionID, cookieExt) if err != nil { return nil, nil, AlertInternalError } @@ -316,9 +330,10 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ psk = nil // If we're not using a PSK mode, then we need to have certain extensions - if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms { - logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)", - gotServerName, gotSupportedGroups, gotSignatureAlgorithms) + if !(foundExts[ExtensionTypeServerName] && + foundExts[ExtensionTypeSupportedGroups] && + foundExts[ExtensionTypeSignatureAlgorithms]) { + logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v)", foundExts) return nil, nil, AlertMissingExtension } @@ -338,8 +353,8 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ // Figure out if we're going to do early data var clientEarlyTrafficSecret []byte - connParams.ClientSendingEarlyData = gotEarlyData - connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData) + connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData] + connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Caps.AllowEarlyData) if connParams.UsingEarlyData { h := params.Hash.New() h.Write(clientHello.Marshal()) @@ -358,6 +373,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ } logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") + state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. return ServerStateNegotiated{ Caps: state.Caps, Params: connParams, @@ -369,6 +385,7 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ selectedPSK: selectedPSK, cert: cert, certScheme: certScheme, + legacySessionId: ch.LegacySessionID, clientEarlyTrafficSecret: clientEarlyTrafficSecret, firstClientHello: firstClientHello, @@ -377,12 +394,27 @@ func (state ServerStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ }, nil, AlertNoAlert } -func (state *ServerStateStart) generateHRR(cs CipherSuite, cookieExt *CookieExtension) (*HandshakeMessage, error) { +func (state *ServerStateStart) generateHRR(cs CipherSuite, legacySessionId []byte, + cookieExt *CookieExtension) (*HandshakeMessage, error) { var helloRetryRequest *HandshakeMessage - hrr := &HelloRetryRequestBody{ - Version: supportedVersion, - CipherSuite: cs, + hrr := &ServerHelloBody{ + Version: tls12Version, + Random: hrrRandomSentinel, + CipherSuite: cs, + LegacySessionID: legacySessionId, + LegacyCompressionMethod: 0, } + + sv := &SupportedVersionsExtension{ + HandshakeType: HandshakeTypeServerHello, + Versions: []uint16{supportedVersion}, + } + + if err := hrr.Extensions.Add(sv); err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error adding SupportedVersion [%v]", err) + return nil, err + } + if err := hrr.Extensions.Add(cookieExt); err != nil { logf(logTypeHandshake, "[ServerStateStart] Error adding CookieExtension [%v]", err) return nil, err @@ -415,10 +447,10 @@ type ServerStateNegotiated struct { selectedPSK int cert *Certificate certScheme SignatureScheme - - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage - clientHello *HandshakeMessage + legacySessionId []byte + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage + clientHello *HandshakeMessage } var _ HandshakeState = &ServerStateNegotiated{} @@ -430,13 +462,24 @@ func (state ServerStateNegotiated) State() State { func (state ServerStateNegotiated) Next(_ handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { // Create the ServerHello sh := &ServerHelloBody{ - Version: supportedVersion, - CipherSuite: state.Params.CipherSuite, + Version: tls12Version, + CipherSuite: state.Params.CipherSuite, + LegacySessionID: state.legacySessionId, + LegacyCompressionMethod: 0, } if _, err := prng.Read(sh.Random[:]); err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err) return nil, nil, AlertInternalError } + + err := sh.Extensions.Add(&SupportedVersionsExtension{ + HandshakeType: HandshakeTypeServerHello, + Versions: []uint16{supportedVersion}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported_versions extension [%v]", err) + return nil, nil, AlertInternalError + } if state.Params.UsingDH { logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension") err := sh.Extensions.Add(&KeyShareExtension{ @@ -845,7 +888,7 @@ func (state ServerStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState } cert := &CertificateBody{} - if _, err := cert.Unmarshal(hm.body); err != nil { + if err := safeUnmarshal(cert, hm.body); err != nil { logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") return nil, nil, AlertDecodeError } @@ -921,7 +964,7 @@ func (state ServerStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, } certVerify := &CertificateVerifyBody{} - if _, err := certVerify.Unmarshal(hm.body); err != nil { + if err := safeUnmarshal(certVerify, hm.body); err != nil { logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) return nil, nil, AlertDecodeError } @@ -995,7 +1038,7 @@ func (state ServerStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS } fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()} - if _, err := fin.Unmarshal(hm.body); err != nil { + if err := safeUnmarshal(fin, hm.body); err != nil { logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err) return nil, nil, AlertDecodeError } diff --git a/vendor/github.com/bifurcation/mint/state-machine.go b/vendor/github.com/bifurcation/mint/state-machine.go index cd51dc03..f3d3593b 100644 --- a/vendor/github.com/bifurcation/mint/state-machine.go +++ b/vendor/github.com/bifurcation/mint/state-machine.go @@ -95,6 +95,15 @@ type HandshakeContext struct { hIn, hOut *HandshakeLayer } +func (hc *HandshakeContext) SetVersion(version uint16) { + if hc.hIn.conn != nil { + hc.hIn.conn.SetVersion(version) + } + if hc.hOut.conn != nil { + hc.hOut.conn.SetVersion(version) + } +} + // StateConnected is symmetric between client and server type StateConnected struct { Params ConnectionParameters diff --git a/vendor/vendor.json b/vendor/vendor.json index 37bb0499..96b2c9c3 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -3,10 +3,10 @@ "ignore": "test", "package": [ { - "checksumSHA1": "7wknjQuGJ8gr8VsM8cTLgu1M8vw=", + "checksumSHA1": "nxj6lkDUEZ81SO0lP8YUhm+4BAM=", "path": "github.com/bifurcation/mint", - "revision": "a544bfbca6a083ce9ddeb2c5f570cb240837355a", - "revisionTime": "2017-12-09T20:11:46Z" + "revision": "f699e8d03646cb8e6e15410ced7bff37fcf8dddd", + "revisionTime": "2017-12-21T19:05:27Z" }, { "checksumSHA1": "PZNcjO1c9gV/LZzppwpVRl6+QAY=",