handshake: clone the tls.Config returned by GetConfigForClient (#4133)

We modify this tls.Config, so we should clone it first. Otherwise, this could
cause conflicts with how the application is using that config.
This commit is contained in:
Marten Seemann 2023-10-27 12:22:20 +07:00 committed by GitHub
parent e2622bfad8
commit d309060cde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 8 deletions

View file

@ -147,6 +147,7 @@ func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr)
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info) c, err := gcfc(info)
if c != nil { if c != nil {
c = c.Clone()
// We're returning a tls.Config here, so we need to apply this recursively. // We're returning a tls.Config here, so we need to apply this recursively.
addConnToClientHelloInfo(c, localAddr, remoteAddr) addConnToClientHelloInfo(c, localAddr, remoteAddr)
} }

View file

@ -8,6 +8,7 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"math/big" "math/big"
"net" "net"
"reflect"
"runtime" "runtime"
"strings" "strings"
"time" "time"
@ -148,15 +149,17 @@ var _ = Describe("Crypto Setup TLS", func() {
It("wraps GetConfigForClient, recursively", func() { It("wraps GetConfigForClient, recursively", func() {
var localAddr, remoteAddr net.Addr var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{} tlsConf := &tls.Config{}
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { var innerConf *tls.Config
conf := tlsConf.Clone() getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
localAddr = info.Conn.LocalAddr() localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr() remoteAddr = info.Conn.RemoteAddr()
cert := generateCert() cert := generateCert()
return &cert, nil return &cert, nil
} }
return conf, nil tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
innerConf = tlsConf.Clone()
innerConf.GetCertificate = getCert
return innerConf, nil
} }
addConnToClientHelloInfo(tlsConf, local, remote) addConnToClientHelloInfo(tlsConf, local, remote)
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
@ -165,6 +168,8 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(localAddr).To(Equal(local)) Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote)) Expect(remoteAddr).To(Equal(remote))
// make sure that the tls.Config returned by GetConfigForClient isn't modified
Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue())
}) })
}) })