From fd2d20a46a801e18289a99fe02625f6835c42969 Mon Sep 17 00:00:00 2001 From: Haruue Date: Sat, 24 Aug 2024 00:27:57 +0800 Subject: [PATCH] feat: local cert loader & sni guard --- app/cmd/server.go | 48 +++++--- app/internal/utils/certloader.go | 189 +++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 16 deletions(-) create mode 100644 app/internal/utils/certloader.go diff --git a/app/cmd/server.go b/app/cmd/server.go index b45fb15..3d37c09 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -83,8 +83,9 @@ type serverConfigObfs struct { } type serverConfigTLS struct { - Cert string `mapstructure:"cert"` - Key string `mapstructure:"key"` + Cert string `mapstructure:"cert"` + Key string `mapstructure:"key"` + SNIGuard string `mapstructure:"sniGuard"` // "disable", "dns-san", "strict" } type serverConfigACME struct { @@ -290,31 +291,46 @@ func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error { if c.TLS != nil && c.ACME != nil { return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")} } + // SNI guard + var sniGuard utils.SNIGuardFunc + switch strings.ToLower(c.TLS.SNIGuard) { + case "", "dns-san": + sniGuard = utils.SNIGuardDNSSAN + case "strict": + sniGuard = utils.SNIGuardStrict + case "disable": + sniGuard = nil + default: + return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")} + } if c.TLS != nil { // Local TLS cert if c.TLS.Cert == "" || c.TLS.Key == "" { return configError{Field: "tls", Err: errors.New("empty cert or key path")} } + certLoader := &utils.LocalCertificateLoader{ + CertFile: c.TLS.Cert, + KeyFile: c.TLS.Key, + SNIGuard: sniGuard, + } // Try loading the cert-key pair here to catch errors early // (e.g. invalid files or insufficient permissions) - certPEMBlock, err := os.ReadFile(c.TLS.Cert) + err := certLoader.InitializeCache() if err != nil { - return configError{Field: "tls.cert", Err: err} - } - keyPEMBlock, err := os.ReadFile(c.TLS.Key) - if err != nil { - return configError{Field: "tls.key", Err: err} - } - _, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock) - if err != nil { - return configError{Field: "tls", Err: fmt.Errorf("invalid cert-key pair: %w", err)} + var pathErr *os.PathError + if errors.As(err, &pathErr) { + if pathErr.Path == c.TLS.Cert { + return configError{Field: "tls.cert", Err: pathErr} + } + if pathErr.Path == c.TLS.Key { + return configError{Field: "tls.key", Err: pathErr} + } + } + return configError{Field: "tls", Err: err} } // Use GetCertificate instead of Certificates so that // users can update the cert without restarting the server. - hyConfig.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key) - return &cert, err - } + hyConfig.TLSConfig.GetCertificate = certLoader.GetCertificate } else { // ACME dataDir := c.ACME.Dir diff --git a/app/internal/utils/certloader.go b/app/internal/utils/certloader.go new file mode 100644 index 0000000..390548e --- /dev/null +++ b/app/internal/utils/certloader.go @@ -0,0 +1,189 @@ +package utils + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "os" + "strings" + "sync" + "time" +) + +type LocalCertificateLoader struct { + CertFile string + KeyFile string + SNIGuard SNIGuardFunc + + lock sync.RWMutex + cache *localCertificateCache +} + +type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error + +// localCertificateCache holds the certificate and its mod times. +// this struct is designed to be read-only. +// +// to update the cache, use LocalCertificateLoader.makeCache and +// update the LocalCertificateLoader.cache field. +type localCertificateCache struct { + certificate *tls.Certificate + certModTime time.Time + keyModTime time.Time +} + +func (l *LocalCertificateLoader) InitializeCache() error { + cache, err := l.makeCache() + if err != nil { + return err + } + + l.lock.Lock() + defer l.lock.Unlock() + l.cache = cache + return nil +} + +func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := l.getCertificateWithCache() + if err != nil { + return nil, err + } + + if l.SNIGuard == nil { + return cert, nil + } + err = l.SNIGuard(info, cert) + if err != nil { + return nil, err + } + + return cert, nil +} + +func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) { + if fi, ferr := os.Stat(l.CertFile); ferr != nil { + err = fmt.Errorf("failed to stat certificate file: %w", ferr) + return + } else { + certModTime = fi.ModTime() + } + if fi, ferr := os.Stat(l.KeyFile); ferr != nil { + err = fmt.Errorf("failed to stat key file: %w", ferr) + return + } else { + keyModTime = fi.ModTime() + } + return +} + +func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) { + c := &localCertificateCache{} + + c.certModTime, c.keyModTime, err = l.checkModTime() + if err != nil { + return + } + + cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile) + if err != nil { + return + } + c.certificate = &cert + c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return + } + + cache = c + return +} + +func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) { + l.lock.RLock() + cache := l.cache + l.lock.RUnlock() + + certModTime, keyModTime, terr := l.checkModTime() + if terr != nil { + if cache != nil { + // use cache when file is temporarily unavailable + return cache.certificate, nil + } + return nil, terr + } + + if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) { + // cache is up-to-date + return cache.certificate, nil + } + + if cache != nil { + if !l.lock.TryLock() { + // another goroutine is updating the cache + return cache.certificate, nil + } + } else { + l.lock.Lock() + } + defer l.lock.Unlock() + + newCache, err := l.makeCache() + if err != nil { + if cache != nil { + // use cache when loading failed + return cache.certificate, nil + } + return nil, err + } + + l.cache = newCache + return newCache.certificate, nil +} + +// getNameFromClientHello returns a normalized form of hello.ServerName. +// If hello.ServerName is empty (i.e. client did not use SNI), then the +// associated connection's local address is used to extract an IP address. +// +// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838 +func getNameFromClientHello(hello *tls.ClientHelloInfo) string { + normalizedName := func(serverName string) string { + return strings.ToLower(strings.TrimSpace(serverName)) + } + localIPFromConn := func(c net.Conn) string { + if c == nil { + return "" + } + localAddr := c.LocalAddr().String() + ip, _, err := net.SplitHostPort(localAddr) + if err != nil { + ip = localAddr + } + if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 { + ip = ip[:scopeIDStart] + } + return ip + } + + if name := normalizedName(hello.ServerName); name != "" { + return name + } + return localIPFromConn(hello.Conn) +} + +func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error { + if len(cert.Leaf.DNSNames) == 0 { + return nil + } + return SNIGuardStrict(info, cert) +} + +func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error { + hostname := getNameFromClientHello(info) + err := cert.Leaf.VerifyHostname(hostname) + if err != nil { + return fmt.Errorf("sni guard: %w", err) + } + return nil +}