chore: replace rwlock with atomic pointer

This commit is contained in:
Haruue 2024-08-24 10:37:08 +08:00
parent fd2d20a46a
commit 57a48a674b
No known key found for this signature in database
GPG key ID: F6083B28CBCBC148

View file

@ -8,6 +8,7 @@ import (
"os" "os"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@ -16,8 +17,8 @@ type LocalCertificateLoader struct {
KeyFile string KeyFile string
SNIGuard SNIGuardFunc SNIGuard SNIGuardFunc
lock sync.RWMutex lock sync.Mutex
cache *localCertificateCache cache atomic.Pointer[localCertificateCache]
} }
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error
@ -34,14 +35,15 @@ type localCertificateCache struct {
} }
func (l *LocalCertificateLoader) InitializeCache() error { func (l *LocalCertificateLoader) InitializeCache() error {
l.lock.Lock()
defer l.lock.Unlock()
cache, err := l.makeCache() cache, err := l.makeCache()
if err != nil { if err != nil {
return err return err
} }
l.lock.Lock() l.cache.Store(cache)
defer l.lock.Unlock()
l.cache = cache
return nil return nil
} }
@ -63,18 +65,19 @@ func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls
} }
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) { func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
if fi, ferr := os.Stat(l.CertFile); ferr != nil { fi, err := os.Stat(l.CertFile)
err = fmt.Errorf("failed to stat certificate file: %w", ferr) if err != nil {
err = fmt.Errorf("failed to stat certificate file: %w", err)
return return
} else {
certModTime = fi.ModTime()
} }
if fi, ferr := os.Stat(l.KeyFile); ferr != nil { certModTime = fi.ModTime()
err = fmt.Errorf("failed to stat key file: %w", ferr)
fi, err = os.Stat(l.KeyFile)
if err != nil {
err = fmt.Errorf("failed to stat key file: %w", err)
return return
} else {
keyModTime = fi.ModTime()
} }
keyModTime = fi.ModTime()
return return
} }
@ -101,9 +104,7 @@ func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err
} }
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) { func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
l.lock.RLock() cache := l.cache.Load()
cache := l.cache
l.lock.RUnlock()
certModTime, keyModTime, terr := l.checkModTime() certModTime, keyModTime, terr := l.checkModTime()
if terr != nil { if terr != nil {
@ -129,6 +130,11 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er
} }
defer l.lock.Unlock() defer l.lock.Unlock()
if l.cache.Load() != cache {
// another goroutine updated the cache
return l.cache.Load().certificate, nil
}
newCache, err := l.makeCache() newCache, err := l.makeCache()
if err != nil { if err != nil {
if cache != nil { if cache != nil {
@ -138,7 +144,7 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er
return nil, err return nil, err
} }
l.cache = newCache l.cache.Store(newCache)
return newCache.certificate, nil return newCache.certificate, nil
} }