add tls.ClientHelloInfo.Conn for recursive GetConfigForClient calls (#4016)

This commit is contained in:
Marten Seemann 2023-08-05 16:00:11 -04:00 committed by Gaukas Wang
parent f8dd2f2c8f
commit fba8d784a8
3 changed files with 118 additions and 59 deletions

View file

@ -141,13 +141,20 @@ var _ = Describe("Handshake tests", func() {
It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
var local, remote net.Addr
var local2, remote2 net.Addr
done := make(chan struct{})
tlsConf := &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
defer close(done)
local = info.Conn.LocalAddr()
remote = info.Conn.RemoteAddr()
return getTLSConfig(), nil
conf := getTLSConfig()
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
defer close(done)
local2 = info.Conn.LocalAddr()
remote2 = info.Conn.RemoteAddr()
return &(conf.Certificates[0]), nil
}
return conf, nil
},
}
runServer(tlsConf)
@ -161,30 +168,8 @@ var _ = Describe("Handshake tests", func() {
Eventually(done).Should(BeClosed())
Expect(server.Addr()).To(Equal(local))
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
})
It("has the right local and remote address on the tls.Config.GetCertificate ClientHelloInfo.Conn", func() {
var local, remote net.Addr
done := make(chan struct{})
tlsConf := getTLSConfig()
tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
defer close(done)
local = info.Conn.LocalAddr()
remote = info.Conn.RemoteAddr()
cert := tlsConf.Certificates[0]
return &cert, nil
}
runServer(tlsConf)
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(server.Addr()).To(Equal(local))
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
Expect(local).To(Equal(local2))
Expect(remote).To(Equal(remote2))
})
It("works with a long certificate chain", func() {