mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
add tls.ClientHelloInfo.Conn for recursive GetConfigForClient calls (#4016)
This commit is contained in:
parent
f8dd2f2c8f
commit
fba8d784a8
3 changed files with 118 additions and 59 deletions
|
@ -141,13 +141,20 @@ var _ = Describe("Handshake tests", func() {
|
||||||
|
|
||||||
It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
|
It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
|
||||||
var local, remote net.Addr
|
var local, remote net.Addr
|
||||||
|
var local2, remote2 net.Addr
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
tlsConf := &tls.Config{
|
tlsConf := &tls.Config{
|
||||||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
defer close(done)
|
|
||||||
local = info.Conn.LocalAddr()
|
local = info.Conn.LocalAddr()
|
||||||
remote = info.Conn.RemoteAddr()
|
remote = info.Conn.RemoteAddr()
|
||||||
return getTLSConfig(), nil
|
conf := getTLSConfig()
|
||||||
|
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
defer close(done)
|
||||||
|
local2 = info.Conn.LocalAddr()
|
||||||
|
remote2 = info.Conn.RemoteAddr()
|
||||||
|
return &(conf.Certificates[0]), nil
|
||||||
|
}
|
||||||
|
return conf, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
runServer(tlsConf)
|
runServer(tlsConf)
|
||||||
|
@ -161,30 +168,8 @@ var _ = Describe("Handshake tests", func() {
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
Expect(server.Addr()).To(Equal(local))
|
Expect(server.Addr()).To(Equal(local))
|
||||||
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
|
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
|
||||||
})
|
Expect(local).To(Equal(local2))
|
||||||
|
Expect(remote).To(Equal(remote2))
|
||||||
It("has the right local and remote address on the tls.Config.GetCertificate ClientHelloInfo.Conn", func() {
|
|
||||||
var local, remote net.Addr
|
|
||||||
done := make(chan struct{})
|
|
||||||
tlsConf := getTLSConfig()
|
|
||||||
tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
defer close(done)
|
|
||||||
local = info.Conn.LocalAddr()
|
|
||||||
remote = info.Conn.RemoteAddr()
|
|
||||||
cert := tlsConf.Certificates[0]
|
|
||||||
return &cert, nil
|
|
||||||
}
|
|
||||||
runServer(tlsConf)
|
|
||||||
conn, err := quic.DialAddr(
|
|
||||||
context.Background(),
|
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
|
||||||
getTLSClientConfig(),
|
|
||||||
getQuicConfig(nil),
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(server.Addr()).To(Equal(local))
|
|
||||||
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("works with a long certificate chain", func() {
|
It("works with a long certificate chain", func() {
|
||||||
|
|
|
@ -128,20 +128,7 @@ func NewCryptoSetupServer(
|
||||||
|
|
||||||
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
||||||
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT)
|
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT)
|
||||||
if quicConf.TLSConfig.GetConfigForClient != nil {
|
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
|
||||||
gcfc := quicConf.TLSConfig.GetConfigForClient
|
|
||||||
quicConf.TLSConfig.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
||||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
|
||||||
return gcfc(info)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if quicConf.TLSConfig.GetCertificate != nil {
|
|
||||||
gc := quicConf.TLSConfig.GetCertificate
|
|
||||||
quicConf.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
|
||||||
return gc(info)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cs.tlsConf = quicConf.TLSConfig
|
cs.tlsConf = quicConf.TLSConfig
|
||||||
cs.conn = qtls.QUICServer(quicConf)
|
cs.conn = qtls.QUICServer(quicConf)
|
||||||
|
@ -149,6 +136,31 @@ func NewCryptoSetupServer(
|
||||||
return cs
|
return cs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
|
||||||
|
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
|
||||||
|
// that allows the caller to get the local and the remote address.
|
||||||
|
func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
|
||||||
|
if conf.GetConfigForClient != nil {
|
||||||
|
gcfc := conf.GetConfigForClient
|
||||||
|
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||||
|
c, err := gcfc(info)
|
||||||
|
if c != nil {
|
||||||
|
// We're returning a tls.Config here, so we need to apply this recursively.
|
||||||
|
addConnToClientHelloInfo(c, localAddr, remoteAddr)
|
||||||
|
}
|
||||||
|
return c, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if conf.GetCertificate != nil {
|
||||||
|
gc := conf.GetCertificate
|
||||||
|
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||||
|
return gc(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newCryptoSetup(
|
func newCryptoSetup(
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
tp *wire.TransportParameters,
|
tp *wire.TransportParameters,
|
||||||
|
|
|
@ -30,6 +30,25 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Crypto Setup TLS", func() {
|
var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
generateCert := func() tls.Certificate {
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{},
|
||||||
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return tls.Certificate{
|
||||||
|
PrivateKey: priv,
|
||||||
|
Certificate: [][]byte{certDER},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var clientConf, serverConf *tls.Config
|
var clientConf, serverConf *tls.Config
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -87,26 +106,69 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
|
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("doing the handshake", func() {
|
Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
|
||||||
generateCert := func() tls.Certificate {
|
var (
|
||||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
|
||||||
Expect(err).ToNot(HaveOccurred())
|
remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||||
tmpl := &x509.Certificate{
|
)
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
Subject: pkix.Name{},
|
|
||||||
SignatureAlgorithm: x509.SHA256WithRSA,
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
}
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
return tls.Certificate{
|
|
||||||
PrivateKey: priv,
|
|
||||||
Certificate: [][]byte{certDER},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
It("wraps GetCertificate", func() {
|
||||||
|
var localAddr, remoteAddr net.Addr
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
localAddr = info.Conn.LocalAddr()
|
||||||
|
remoteAddr = info.Conn.RemoteAddr()
|
||||||
|
cert := generateCert()
|
||||||
|
return &cert, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addConnToClientHelloInfo(tlsConf, local, remote)
|
||||||
|
_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(localAddr).To(Equal(local))
|
||||||
|
Expect(remoteAddr).To(Equal(remote))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("wraps GetConfigForClient", func() {
|
||||||
|
var localAddr, remoteAddr net.Addr
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
localAddr = info.Conn.LocalAddr()
|
||||||
|
remoteAddr = info.Conn.RemoteAddr()
|
||||||
|
return &tls.Config{}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addConnToClientHelloInfo(tlsConf, local, remote)
|
||||||
|
_, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(localAddr).To(Equal(local))
|
||||||
|
Expect(remoteAddr).To(Equal(remote))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("wraps GetConfigForClient, recursively", func() {
|
||||||
|
var localAddr, remoteAddr net.Addr
|
||||||
|
tlsConf := &tls.Config{}
|
||||||
|
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
conf := tlsConf.Clone()
|
||||||
|
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
localAddr = info.Conn.LocalAddr()
|
||||||
|
remoteAddr = info.Conn.RemoteAddr()
|
||||||
|
cert := generateCert()
|
||||||
|
return &cert, nil
|
||||||
|
}
|
||||||
|
return conf, nil
|
||||||
|
}
|
||||||
|
addConnToClientHelloInfo(tlsConf, local, remote)
|
||||||
|
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(localAddr).To(Equal(local))
|
||||||
|
Expect(remoteAddr).To(Equal(remote))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("doing the handshake", func() {
|
||||||
newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
|
newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
|
||||||
rttStats := &utils.RTTStats{}
|
rttStats := &utils.RTTStats{}
|
||||||
rttStats.UpdateRTT(rtt, 0, time.Now())
|
rttStats.UpdateRTT(rtt, 0, time.Now())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue