mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 04:27:39 +03:00
feat: local cert loader & sni guard
This commit is contained in:
parent
903666f18e
commit
fd2d20a46a
2 changed files with 221 additions and 16 deletions
|
@ -83,8 +83,9 @@ type serverConfigObfs struct {
|
|||
}
|
||||
|
||||
type serverConfigTLS struct {
|
||||
Cert string `mapstructure:"cert"`
|
||||
Key string `mapstructure:"key"`
|
||||
Cert string `mapstructure:"cert"`
|
||||
Key string `mapstructure:"key"`
|
||||
SNIGuard string `mapstructure:"sniGuard"` // "disable", "dns-san", "strict"
|
||||
}
|
||||
|
||||
type serverConfigACME struct {
|
||||
|
@ -290,31 +291,46 @@ func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error {
|
|||
if c.TLS != nil && c.ACME != nil {
|
||||
return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
|
||||
}
|
||||
// SNI guard
|
||||
var sniGuard utils.SNIGuardFunc
|
||||
switch strings.ToLower(c.TLS.SNIGuard) {
|
||||
case "", "dns-san":
|
||||
sniGuard = utils.SNIGuardDNSSAN
|
||||
case "strict":
|
||||
sniGuard = utils.SNIGuardStrict
|
||||
case "disable":
|
||||
sniGuard = nil
|
||||
default:
|
||||
return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")}
|
||||
}
|
||||
if c.TLS != nil {
|
||||
// Local TLS cert
|
||||
if c.TLS.Cert == "" || c.TLS.Key == "" {
|
||||
return configError{Field: "tls", Err: errors.New("empty cert or key path")}
|
||||
}
|
||||
certLoader := &utils.LocalCertificateLoader{
|
||||
CertFile: c.TLS.Cert,
|
||||
KeyFile: c.TLS.Key,
|
||||
SNIGuard: sniGuard,
|
||||
}
|
||||
// Try loading the cert-key pair here to catch errors early
|
||||
// (e.g. invalid files or insufficient permissions)
|
||||
certPEMBlock, err := os.ReadFile(c.TLS.Cert)
|
||||
err := certLoader.InitializeCache()
|
||||
if err != nil {
|
||||
return configError{Field: "tls.cert", Err: err}
|
||||
}
|
||||
keyPEMBlock, err := os.ReadFile(c.TLS.Key)
|
||||
if err != nil {
|
||||
return configError{Field: "tls.key", Err: err}
|
||||
}
|
||||
_, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock)
|
||||
if err != nil {
|
||||
return configError{Field: "tls", Err: fmt.Errorf("invalid cert-key pair: %w", err)}
|
||||
var pathErr *os.PathError
|
||||
if errors.As(err, &pathErr) {
|
||||
if pathErr.Path == c.TLS.Cert {
|
||||
return configError{Field: "tls.cert", Err: pathErr}
|
||||
}
|
||||
if pathErr.Path == c.TLS.Key {
|
||||
return configError{Field: "tls.key", Err: pathErr}
|
||||
}
|
||||
}
|
||||
return configError{Field: "tls", Err: err}
|
||||
}
|
||||
// Use GetCertificate instead of Certificates so that
|
||||
// users can update the cert without restarting the server.
|
||||
hyConfig.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
|
||||
return &cert, err
|
||||
}
|
||||
hyConfig.TLSConfig.GetCertificate = certLoader.GetCertificate
|
||||
} else {
|
||||
// ACME
|
||||
dataDir := c.ACME.Dir
|
||||
|
|
189
app/internal/utils/certloader.go
Normal file
189
app/internal/utils/certloader.go
Normal file
|
@ -0,0 +1,189 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LocalCertificateLoader struct {
|
||||
CertFile string
|
||||
KeyFile string
|
||||
SNIGuard SNIGuardFunc
|
||||
|
||||
lock sync.RWMutex
|
||||
cache *localCertificateCache
|
||||
}
|
||||
|
||||
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error
|
||||
|
||||
// localCertificateCache holds the certificate and its mod times.
|
||||
// this struct is designed to be read-only.
|
||||
//
|
||||
// to update the cache, use LocalCertificateLoader.makeCache and
|
||||
// update the LocalCertificateLoader.cache field.
|
||||
type localCertificateCache struct {
|
||||
certificate *tls.Certificate
|
||||
certModTime time.Time
|
||||
keyModTime time.Time
|
||||
}
|
||||
|
||||
func (l *LocalCertificateLoader) InitializeCache() error {
|
||||
cache, err := l.makeCache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
l.cache = cache
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert, err := l.getCertificateWithCache()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if l.SNIGuard == nil {
|
||||
return cert, nil
|
||||
}
|
||||
err = l.SNIGuard(info, cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
|
||||
if fi, ferr := os.Stat(l.CertFile); ferr != nil {
|
||||
err = fmt.Errorf("failed to stat certificate file: %w", ferr)
|
||||
return
|
||||
} else {
|
||||
certModTime = fi.ModTime()
|
||||
}
|
||||
if fi, ferr := os.Stat(l.KeyFile); ferr != nil {
|
||||
err = fmt.Errorf("failed to stat key file: %w", ferr)
|
||||
return
|
||||
} else {
|
||||
keyModTime = fi.ModTime()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) {
|
||||
c := &localCertificateCache{}
|
||||
|
||||
c.certModTime, c.keyModTime, err = l.checkModTime()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.certificate = &cert
|
||||
c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cache = c
|
||||
return
|
||||
}
|
||||
|
||||
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
|
||||
l.lock.RLock()
|
||||
cache := l.cache
|
||||
l.lock.RUnlock()
|
||||
|
||||
certModTime, keyModTime, terr := l.checkModTime()
|
||||
if terr != nil {
|
||||
if cache != nil {
|
||||
// use cache when file is temporarily unavailable
|
||||
return cache.certificate, nil
|
||||
}
|
||||
return nil, terr
|
||||
}
|
||||
|
||||
if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) {
|
||||
// cache is up-to-date
|
||||
return cache.certificate, nil
|
||||
}
|
||||
|
||||
if cache != nil {
|
||||
if !l.lock.TryLock() {
|
||||
// another goroutine is updating the cache
|
||||
return cache.certificate, nil
|
||||
}
|
||||
} else {
|
||||
l.lock.Lock()
|
||||
}
|
||||
defer l.lock.Unlock()
|
||||
|
||||
newCache, err := l.makeCache()
|
||||
if err != nil {
|
||||
if cache != nil {
|
||||
// use cache when loading failed
|
||||
return cache.certificate, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.cache = newCache
|
||||
return newCache.certificate, nil
|
||||
}
|
||||
|
||||
// getNameFromClientHello returns a normalized form of hello.ServerName.
|
||||
// If hello.ServerName is empty (i.e. client did not use SNI), then the
|
||||
// associated connection's local address is used to extract an IP address.
|
||||
//
|
||||
// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838
|
||||
func getNameFromClientHello(hello *tls.ClientHelloInfo) string {
|
||||
normalizedName := func(serverName string) string {
|
||||
return strings.ToLower(strings.TrimSpace(serverName))
|
||||
}
|
||||
localIPFromConn := func(c net.Conn) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
localAddr := c.LocalAddr().String()
|
||||
ip, _, err := net.SplitHostPort(localAddr)
|
||||
if err != nil {
|
||||
ip = localAddr
|
||||
}
|
||||
if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 {
|
||||
ip = ip[:scopeIDStart]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
if name := normalizedName(hello.ServerName); name != "" {
|
||||
return name
|
||||
}
|
||||
return localIPFromConn(hello.Conn)
|
||||
}
|
||||
|
||||
func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
|
||||
if len(cert.Leaf.DNSNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
return SNIGuardStrict(info, cert)
|
||||
}
|
||||
|
||||
func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
|
||||
hostname := getNameFromClientHello(info)
|
||||
err := cert.Leaf.VerifyHostname(hostname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sni guard: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue