mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-04 13:07:39 +03:00
since Go 1.23, cert.Leaf will be populated after loaded. see doc of tls.LoadX509KeyPair for details
198 lines
4.4 KiB
Go
198 lines
4.4 KiB
Go
package utils
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
type LocalCertificateLoader struct {
|
|
CertFile string
|
|
KeyFile string
|
|
SNIGuard SNIGuardFunc
|
|
|
|
lock sync.Mutex
|
|
cache atomic.Pointer[localCertificateCache]
|
|
}
|
|
|
|
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error
|
|
|
|
// localCertificateCache holds the certificate and its mod times.
|
|
// this struct is designed to be read-only.
|
|
//
|
|
// to update the cache, use LocalCertificateLoader.makeCache and
|
|
// update the LocalCertificateLoader.cache field.
|
|
type localCertificateCache struct {
|
|
certificate *tls.Certificate
|
|
certModTime time.Time
|
|
keyModTime time.Time
|
|
}
|
|
|
|
func (l *LocalCertificateLoader) InitializeCache() error {
|
|
l.lock.Lock()
|
|
defer l.lock.Unlock()
|
|
|
|
cache, err := l.makeCache()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
l.cache.Store(cache)
|
|
return nil
|
|
}
|
|
|
|
func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
cert, err := l.getCertificateWithCache()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if l.SNIGuard == nil {
|
|
return cert, nil
|
|
}
|
|
err = l.SNIGuard(info, cert)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return cert, nil
|
|
}
|
|
|
|
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
|
|
fi, err := os.Stat(l.CertFile)
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to stat certificate file: %w", err)
|
|
return
|
|
}
|
|
certModTime = fi.ModTime()
|
|
|
|
fi, err = os.Stat(l.KeyFile)
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to stat key file: %w", err)
|
|
return
|
|
}
|
|
keyModTime = fi.ModTime()
|
|
return
|
|
}
|
|
|
|
func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) {
|
|
c := &localCertificateCache{}
|
|
|
|
c.certModTime, c.keyModTime, err = l.checkModTime()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile)
|
|
if err != nil {
|
|
return
|
|
}
|
|
c.certificate = &cert
|
|
if c.certificate.Leaf == nil {
|
|
// certificate.Leaf was left nil by tls.LoadX509KeyPair before Go 1.23
|
|
c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
cache = c
|
|
return
|
|
}
|
|
|
|
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
|
|
cache := l.cache.Load()
|
|
|
|
certModTime, keyModTime, terr := l.checkModTime()
|
|
if terr != nil {
|
|
if cache != nil {
|
|
// use cache when file is temporarily unavailable
|
|
return cache.certificate, nil
|
|
}
|
|
return nil, terr
|
|
}
|
|
|
|
if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) {
|
|
// cache is up-to-date
|
|
return cache.certificate, nil
|
|
}
|
|
|
|
if cache != nil {
|
|
if !l.lock.TryLock() {
|
|
// another goroutine is updating the cache
|
|
return cache.certificate, nil
|
|
}
|
|
} else {
|
|
l.lock.Lock()
|
|
}
|
|
defer l.lock.Unlock()
|
|
|
|
if l.cache.Load() != cache {
|
|
// another goroutine updated the cache
|
|
return l.cache.Load().certificate, nil
|
|
}
|
|
|
|
newCache, err := l.makeCache()
|
|
if err != nil {
|
|
if cache != nil {
|
|
// use cache when loading failed
|
|
return cache.certificate, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
l.cache.Store(newCache)
|
|
return newCache.certificate, nil
|
|
}
|
|
|
|
// getNameFromClientHello returns a normalized form of hello.ServerName.
|
|
// If hello.ServerName is empty (i.e. client did not use SNI), then the
|
|
// associated connection's local address is used to extract an IP address.
|
|
//
|
|
// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838
|
|
func getNameFromClientHello(hello *tls.ClientHelloInfo) string {
|
|
normalizedName := func(serverName string) string {
|
|
return strings.ToLower(strings.TrimSpace(serverName))
|
|
}
|
|
localIPFromConn := func(c net.Conn) string {
|
|
if c == nil {
|
|
return ""
|
|
}
|
|
localAddr := c.LocalAddr().String()
|
|
ip, _, err := net.SplitHostPort(localAddr)
|
|
if err != nil {
|
|
ip = localAddr
|
|
}
|
|
if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 {
|
|
ip = ip[:scopeIDStart]
|
|
}
|
|
return ip
|
|
}
|
|
|
|
if name := normalizedName(hello.ServerName); name != "" {
|
|
return name
|
|
}
|
|
return localIPFromConn(hello.Conn)
|
|
}
|
|
|
|
func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
|
|
if len(cert.Leaf.DNSNames) == 0 {
|
|
return nil
|
|
}
|
|
return SNIGuardStrict(info, cert)
|
|
}
|
|
|
|
func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
|
|
hostname := getNameFromClientHello(info)
|
|
err := cert.Leaf.VerifyHostname(hostname)
|
|
if err != nil {
|
|
return fmt.Errorf("sni guard: %w", err)
|
|
}
|
|
return nil
|
|
}
|