hysteria/app/internal/utils/certloader.go
Haruue bcf830c29a
chore: only init cert.Leaf when not populated
since Go 1.23, cert.Leaf will be populated after loaded.

see doc of tls.LoadX509KeyPair for details
2024-08-24 13:46:25 +08:00

198 lines
4.4 KiB
Go

package utils
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
type LocalCertificateLoader struct {
CertFile string
KeyFile string
SNIGuard SNIGuardFunc
lock sync.Mutex
cache atomic.Pointer[localCertificateCache]
}
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error
// localCertificateCache holds the certificate and its mod times.
// this struct is designed to be read-only.
//
// to update the cache, use LocalCertificateLoader.makeCache and
// update the LocalCertificateLoader.cache field.
type localCertificateCache struct {
certificate *tls.Certificate
certModTime time.Time
keyModTime time.Time
}
func (l *LocalCertificateLoader) InitializeCache() error {
l.lock.Lock()
defer l.lock.Unlock()
cache, err := l.makeCache()
if err != nil {
return err
}
l.cache.Store(cache)
return nil
}
func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := l.getCertificateWithCache()
if err != nil {
return nil, err
}
if l.SNIGuard == nil {
return cert, nil
}
err = l.SNIGuard(info, cert)
if err != nil {
return nil, err
}
return cert, nil
}
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
fi, err := os.Stat(l.CertFile)
if err != nil {
err = fmt.Errorf("failed to stat certificate file: %w", err)
return
}
certModTime = fi.ModTime()
fi, err = os.Stat(l.KeyFile)
if err != nil {
err = fmt.Errorf("failed to stat key file: %w", err)
return
}
keyModTime = fi.ModTime()
return
}
func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) {
c := &localCertificateCache{}
c.certModTime, c.keyModTime, err = l.checkModTime()
if err != nil {
return
}
cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile)
if err != nil {
return
}
c.certificate = &cert
if c.certificate.Leaf == nil {
// certificate.Leaf was left nil by tls.LoadX509KeyPair before Go 1.23
c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return
}
}
cache = c
return
}
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
cache := l.cache.Load()
certModTime, keyModTime, terr := l.checkModTime()
if terr != nil {
if cache != nil {
// use cache when file is temporarily unavailable
return cache.certificate, nil
}
return nil, terr
}
if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) {
// cache is up-to-date
return cache.certificate, nil
}
if cache != nil {
if !l.lock.TryLock() {
// another goroutine is updating the cache
return cache.certificate, nil
}
} else {
l.lock.Lock()
}
defer l.lock.Unlock()
if l.cache.Load() != cache {
// another goroutine updated the cache
return l.cache.Load().certificate, nil
}
newCache, err := l.makeCache()
if err != nil {
if cache != nil {
// use cache when loading failed
return cache.certificate, nil
}
return nil, err
}
l.cache.Store(newCache)
return newCache.certificate, nil
}
// getNameFromClientHello returns a normalized form of hello.ServerName.
// If hello.ServerName is empty (i.e. client did not use SNI), then the
// associated connection's local address is used to extract an IP address.
//
// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838
func getNameFromClientHello(hello *tls.ClientHelloInfo) string {
normalizedName := func(serverName string) string {
return strings.ToLower(strings.TrimSpace(serverName))
}
localIPFromConn := func(c net.Conn) string {
if c == nil {
return ""
}
localAddr := c.LocalAddr().String()
ip, _, err := net.SplitHostPort(localAddr)
if err != nil {
ip = localAddr
}
if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 {
ip = ip[:scopeIDStart]
}
return ip
}
if name := normalizedName(hello.ServerName); name != "" {
return name
}
return localIPFromConn(hello.Conn)
}
func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
if len(cert.Leaf.DNSNames) == 0 {
return nil
}
return SNIGuardStrict(info, cert)
}
func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
hostname := getNameFromClientHello(info)
err := cert.Leaf.VerifyHostname(hostname)
if err != nil {
return fmt.Errorf("sni guard: %w", err)
}
return nil
}