chore: replace rwlock with atomic pointer

This commit is contained in:
Haruue 2024-08-24 10:37:08 +08:00
parent fd2d20a46a
commit 57a48a674b
No known key found for this signature in database
GPG key ID: F6083B28CBCBC148

View file

@ -8,6 +8,7 @@ import (
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
@ -16,8 +17,8 @@ type LocalCertificateLoader struct {
KeyFile string
SNIGuard SNIGuardFunc
lock sync.RWMutex
cache *localCertificateCache
lock sync.Mutex
cache atomic.Pointer[localCertificateCache]
}
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error
@ -34,14 +35,15 @@ type localCertificateCache struct {
}
func (l *LocalCertificateLoader) InitializeCache() error {
l.lock.Lock()
defer l.lock.Unlock()
cache, err := l.makeCache()
if err != nil {
return err
}
l.lock.Lock()
defer l.lock.Unlock()
l.cache = cache
l.cache.Store(cache)
return nil
}
@ -63,18 +65,19 @@ func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls
}
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
if fi, ferr := os.Stat(l.CertFile); ferr != nil {
err = fmt.Errorf("failed to stat certificate file: %w", ferr)
fi, err := os.Stat(l.CertFile)
if err != nil {
err = fmt.Errorf("failed to stat certificate file: %w", err)
return
} else {
certModTime = fi.ModTime()
}
if fi, ferr := os.Stat(l.KeyFile); ferr != nil {
err = fmt.Errorf("failed to stat key file: %w", ferr)
certModTime = fi.ModTime()
fi, err = os.Stat(l.KeyFile)
if err != nil {
err = fmt.Errorf("failed to stat key file: %w", err)
return
} else {
keyModTime = fi.ModTime()
}
keyModTime = fi.ModTime()
return
}
@ -101,9 +104,7 @@ func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err
}
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
l.lock.RLock()
cache := l.cache
l.lock.RUnlock()
cache := l.cache.Load()
certModTime, keyModTime, terr := l.checkModTime()
if terr != nil {
@ -129,6 +130,11 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er
}
defer l.lock.Unlock()
if l.cache.Load() != cache {
// another goroutine updated the cache
return l.cache.Load().certificate, nil
}
newCache, err := l.makeCache()
if err != nil {
if cache != nil {
@ -138,7 +144,7 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er
return nil, err
}
l.cache = newCache
l.cache.Store(newCache)
return newCache.certificate, nil
}