diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 95972ac6..41afd691 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -139,7 +139,7 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) }) - It("has the right local and remote address on the ClientHelloInfo.Conn", func() { + It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() { var local, remote net.Addr done := make(chan struct{}) tlsConf := &tls.Config{ @@ -163,6 +163,30 @@ var _ = Describe("Handshake tests", func() { Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port)) }) + It("has the right local and remote address on the tls.Config.GetCertificate ClientHelloInfo.Conn", func() { + var local, remote net.Addr + done := make(chan struct{}) + tlsConf := getTLSConfig() + tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + defer close(done) + local = info.Conn.LocalAddr() + remote = info.Conn.RemoteAddr() + cert := tlsConf.Certificates[0] + return &cert, nil + } + runServer(tlsConf) + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(server.Addr()).To(Equal(local)) + Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port)) + }) + It("works with a long certificate chain", func() { runServer(getTLSConfigWithLongCertChain()) _, err := quic.DialAddr( diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 3de241bd..5b9ec8a8 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -135,6 +135,13 @@ func NewCryptoSetupServer( return gcfc(info) } } + if quicConf.TLSConfig.GetCertificate != nil { + gc := quicConf.TLSConfig.GetCertificate + quicConf.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} + return gc(info) + } + } cs.tlsConf = quicConf.TLSConfig cs.conn = qtls.QUICServer(quicConf)