From 95ab7bdc9a392b0357598dd77caae66bd99a6b0b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 5 Aug 2023 16:00:11 -0400 Subject: [PATCH] add tls.ClientHelloInfo.Conn for recursive GetConfigForClient calls (#4016) --- integrationtests/self/handshake_test.go | 37 +++------ internal/handshake/crypto_setup.go | 40 ++++++---- internal/handshake/crypto_setup_test.go | 100 +++++++++++++++++++----- 3 files changed, 118 insertions(+), 59 deletions(-) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index f23ae77c..cccff8a6 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -142,13 +142,20 @@ var _ = Describe("Handshake tests", func() { It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() { var local, remote net.Addr + var local2, remote2 net.Addr done := make(chan struct{}) tlsConf := &tls.Config{ GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { - defer close(done) local = info.Conn.LocalAddr() remote = info.Conn.RemoteAddr() - return getTLSConfig(), nil + conf := getTLSConfig() + conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + defer close(done) + local2 = info.Conn.LocalAddr() + remote2 = info.Conn.RemoteAddr() + return &(conf.Certificates[0]), nil + } + return conf, nil }, } runServer(tlsConf) @@ -162,30 +169,8 @@ var _ = Describe("Handshake tests", func() { Eventually(done).Should(BeClosed()) Expect(server.Addr()).To(Equal(local)) Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port)) - }) - - It("has the right local and remote address on the tls.Config.GetCertificate ClientHelloInfo.Conn", func() { - var local, remote net.Addr - done := make(chan struct{}) - tlsConf := getTLSConfig() - tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - defer close(done) - local = info.Conn.LocalAddr() - remote = info.Conn.RemoteAddr() - cert := tlsConf.Certificates[0] - return &cert, nil - } - runServer(tlsConf) - conn, err := quic.DialAddr( - context.Background(), - fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - getTLSClientConfig(), - getQuicConfig(nil), - ) - Expect(err).ToNot(HaveOccurred()) - Eventually(done).Should(BeClosed()) - Expect(server.Addr()).To(Equal(local)) - Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port)) + Expect(local).To(Equal(local2)) + Expect(remote).To(Equal(remote2)) }) It("works with a long certificate chain", func() { diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 7011a6fc..543e70e0 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -127,20 +127,7 @@ func NewCryptoSetupServer( quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT) - if quicConf.TLSConfig.GetConfigForClient != nil { - gcfc := quicConf.TLSConfig.GetConfigForClient - quicConf.TLSConfig.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { - info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} - return gcfc(info) - } - } - if quicConf.TLSConfig.GetCertificate != nil { - gc := quicConf.TLSConfig.GetCertificate - quicConf.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} - return gc(info) - } - } + addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr) cs.tlsConf = quicConf.TLSConfig cs.conn = qtls.QUICServer(quicConf) @@ -148,6 +135,31 @@ func NewCryptoSetupServer( return cs } +// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo. +// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn +// that allows the caller to get the local and the remote address. +func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) { + if conf.GetConfigForClient != nil { + gcfc := conf.GetConfigForClient + conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} + c, err := gcfc(info) + if c != nil { + // We're returning a tls.Config here, so we need to apply this recursively. + addConnToClientHelloInfo(c, localAddr, remoteAddr) + } + return c, err + } + } + if conf.GetCertificate != nil { + gc := conf.GetCertificate + conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} + return gc(info) + } + } +} + func newCryptoSetup( connID protocol.ConnectionID, tp *wire.TransportParameters, diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 4b210d20..8b2c5efe 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -29,6 +29,25 @@ const ( ) var _ = Describe("Crypto Setup TLS", func() { + generateCert := func() tls.Certificate { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{}, + SignatureAlgorithm: x509.SHA256WithRSA, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), // valid for an hour + BasicConstraintsValid: true, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) + Expect(err).ToNot(HaveOccurred()) + return tls.Certificate{ + PrivateKey: priv, + Certificate: [][]byte{certDER}, + } + } + var clientConf, serverConf *tls.Config BeforeEach(func() { @@ -86,26 +105,69 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) }) - Context("doing the handshake", func() { - generateCert := func() tls.Certificate { - priv, err := rsa.GenerateKey(rand.Reader, 2048) - Expect(err).ToNot(HaveOccurred()) - tmpl := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{}, - SignatureAlgorithm: x509.SHA256WithRSA, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour), // valid for an hour - BasicConstraintsValid: true, - } - certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) - Expect(err).ToNot(HaveOccurred()) - return tls.Certificate{ - PrivateKey: priv, - Certificate: [][]byte{certDER}, - } - } + Context("filling in a net.Conn in tls.ClientHelloInfo", func() { + var ( + local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} + remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + ) + It("wraps GetCertificate", func() { + var localAddr, remoteAddr net.Addr + tlsConf := &tls.Config{ + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + cert := generateCert() + return &cert, nil + }, + } + addConnToClientHelloInfo(tlsConf, local, remote) + _, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(localAddr).To(Equal(local)) + Expect(remoteAddr).To(Equal(remote)) + }) + + It("wraps GetConfigForClient", func() { + var localAddr, remoteAddr net.Addr + tlsConf := &tls.Config{ + GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + return &tls.Config{}, nil + }, + } + addConnToClientHelloInfo(tlsConf, local, remote) + _, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(localAddr).To(Equal(local)) + Expect(remoteAddr).To(Equal(remote)) + }) + + It("wraps GetConfigForClient, recursively", func() { + var localAddr, remoteAddr net.Addr + tlsConf := &tls.Config{} + tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + conf := tlsConf.Clone() + conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + cert := generateCert() + return &cert, nil + } + return conf, nil + } + addConnToClientHelloInfo(tlsConf, local, remote) + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) + Expect(err).ToNot(HaveOccurred()) + Expect(localAddr).To(Equal(local)) + Expect(remoteAddr).To(Equal(remote)) + }) + }) + + Context("doing the handshake", func() { newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { rttStats := &utils.RTTStats{} rttStats.UpdateRTT(rtt, 0, time.Now())