feat: local cert loader & sni guard

This commit is contained in:
Haruue 2024-08-24 00:27:57 +08:00
parent 903666f18e
commit fd2d20a46a
No known key found for this signature in database
GPG key ID: F6083B28CBCBC148
2 changed files with 221 additions and 16 deletions

View file

@ -83,8 +83,9 @@ type serverConfigObfs struct {
}
type serverConfigTLS struct {
Cert string `mapstructure:"cert"`
Key string `mapstructure:"key"`
Cert string `mapstructure:"cert"`
Key string `mapstructure:"key"`
SNIGuard string `mapstructure:"sniGuard"` // "disable", "dns-san", "strict"
}
type serverConfigACME struct {
@ -290,31 +291,46 @@ func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error {
if c.TLS != nil && c.ACME != nil {
return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
}
// SNI guard
var sniGuard utils.SNIGuardFunc
switch strings.ToLower(c.TLS.SNIGuard) {
case "", "dns-san":
sniGuard = utils.SNIGuardDNSSAN
case "strict":
sniGuard = utils.SNIGuardStrict
case "disable":
sniGuard = nil
default:
return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")}
}
if c.TLS != nil {
// Local TLS cert
if c.TLS.Cert == "" || c.TLS.Key == "" {
return configError{Field: "tls", Err: errors.New("empty cert or key path")}
}
certLoader := &utils.LocalCertificateLoader{
CertFile: c.TLS.Cert,
KeyFile: c.TLS.Key,
SNIGuard: sniGuard,
}
// Try loading the cert-key pair here to catch errors early
// (e.g. invalid files or insufficient permissions)
certPEMBlock, err := os.ReadFile(c.TLS.Cert)
err := certLoader.InitializeCache()
if err != nil {
return configError{Field: "tls.cert", Err: err}
}
keyPEMBlock, err := os.ReadFile(c.TLS.Key)
if err != nil {
return configError{Field: "tls.key", Err: err}
}
_, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return configError{Field: "tls", Err: fmt.Errorf("invalid cert-key pair: %w", err)}
var pathErr *os.PathError
if errors.As(err, &pathErr) {
if pathErr.Path == c.TLS.Cert {
return configError{Field: "tls.cert", Err: pathErr}
}
if pathErr.Path == c.TLS.Key {
return configError{Field: "tls.key", Err: pathErr}
}
}
return configError{Field: "tls", Err: err}
}
// Use GetCertificate instead of Certificates so that
// users can update the cert without restarting the server.
hyConfig.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
return &cert, err
}
hyConfig.TLSConfig.GetCertificate = certLoader.GetCertificate
} else {
// ACME
dataDir := c.ACME.Dir

View file

@ -0,0 +1,189 @@
package utils
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"strings"
"sync"
"time"
)
type LocalCertificateLoader struct {
CertFile string
KeyFile string
SNIGuard SNIGuardFunc
lock sync.RWMutex
cache *localCertificateCache
}
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error
// localCertificateCache holds the certificate and its mod times.
// this struct is designed to be read-only.
//
// to update the cache, use LocalCertificateLoader.makeCache and
// update the LocalCertificateLoader.cache field.
type localCertificateCache struct {
certificate *tls.Certificate
certModTime time.Time
keyModTime time.Time
}
func (l *LocalCertificateLoader) InitializeCache() error {
cache, err := l.makeCache()
if err != nil {
return err
}
l.lock.Lock()
defer l.lock.Unlock()
l.cache = cache
return nil
}
func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := l.getCertificateWithCache()
if err != nil {
return nil, err
}
if l.SNIGuard == nil {
return cert, nil
}
err = l.SNIGuard(info, cert)
if err != nil {
return nil, err
}
return cert, nil
}
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
if fi, ferr := os.Stat(l.CertFile); ferr != nil {
err = fmt.Errorf("failed to stat certificate file: %w", ferr)
return
} else {
certModTime = fi.ModTime()
}
if fi, ferr := os.Stat(l.KeyFile); ferr != nil {
err = fmt.Errorf("failed to stat key file: %w", ferr)
return
} else {
keyModTime = fi.ModTime()
}
return
}
func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) {
c := &localCertificateCache{}
c.certModTime, c.keyModTime, err = l.checkModTime()
if err != nil {
return
}
cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile)
if err != nil {
return
}
c.certificate = &cert
c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return
}
cache = c
return
}
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
l.lock.RLock()
cache := l.cache
l.lock.RUnlock()
certModTime, keyModTime, terr := l.checkModTime()
if terr != nil {
if cache != nil {
// use cache when file is temporarily unavailable
return cache.certificate, nil
}
return nil, terr
}
if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) {
// cache is up-to-date
return cache.certificate, nil
}
if cache != nil {
if !l.lock.TryLock() {
// another goroutine is updating the cache
return cache.certificate, nil
}
} else {
l.lock.Lock()
}
defer l.lock.Unlock()
newCache, err := l.makeCache()
if err != nil {
if cache != nil {
// use cache when loading failed
return cache.certificate, nil
}
return nil, err
}
l.cache = newCache
return newCache.certificate, nil
}
// getNameFromClientHello returns a normalized form of hello.ServerName.
// If hello.ServerName is empty (i.e. client did not use SNI), then the
// associated connection's local address is used to extract an IP address.
//
// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838
func getNameFromClientHello(hello *tls.ClientHelloInfo) string {
normalizedName := func(serverName string) string {
return strings.ToLower(strings.TrimSpace(serverName))
}
localIPFromConn := func(c net.Conn) string {
if c == nil {
return ""
}
localAddr := c.LocalAddr().String()
ip, _, err := net.SplitHostPort(localAddr)
if err != nil {
ip = localAddr
}
if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 {
ip = ip[:scopeIDStart]
}
return ip
}
if name := normalizedName(hello.ServerName); name != "" {
return name
}
return localIPFromConn(hello.Conn)
}
func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
if len(cert.Leaf.DNSNames) == 0 {
return nil
}
return SNIGuardStrict(info, cert)
}
func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
hostname := getNameFromClientHello(info)
err := cert.Leaf.VerifyHostname(hostname)
if err != nil {
return fmt.Errorf("sni guard: %w", err)
}
return nil
}