mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 04:27:39 +03:00
chore: replace rwlock with atomic pointer
This commit is contained in:
parent
fd2d20a46a
commit
57a48a674b
1 changed files with 23 additions and 17 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -16,8 +17,8 @@ type LocalCertificateLoader struct {
|
|||
KeyFile string
|
||||
SNIGuard SNIGuardFunc
|
||||
|
||||
lock sync.RWMutex
|
||||
cache *localCertificateCache
|
||||
lock sync.Mutex
|
||||
cache atomic.Pointer[localCertificateCache]
|
||||
}
|
||||
|
||||
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error
|
||||
|
@ -34,14 +35,15 @@ type localCertificateCache struct {
|
|||
}
|
||||
|
||||
func (l *LocalCertificateLoader) InitializeCache() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
cache, err := l.makeCache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
l.cache = cache
|
||||
l.cache.Store(cache)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -63,18 +65,19 @@ func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls
|
|||
}
|
||||
|
||||
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
|
||||
if fi, ferr := os.Stat(l.CertFile); ferr != nil {
|
||||
err = fmt.Errorf("failed to stat certificate file: %w", ferr)
|
||||
fi, err := os.Stat(l.CertFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to stat certificate file: %w", err)
|
||||
return
|
||||
} else {
|
||||
certModTime = fi.ModTime()
|
||||
}
|
||||
if fi, ferr := os.Stat(l.KeyFile); ferr != nil {
|
||||
err = fmt.Errorf("failed to stat key file: %w", ferr)
|
||||
certModTime = fi.ModTime()
|
||||
|
||||
fi, err = os.Stat(l.KeyFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to stat key file: %w", err)
|
||||
return
|
||||
} else {
|
||||
keyModTime = fi.ModTime()
|
||||
}
|
||||
keyModTime = fi.ModTime()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -101,9 +104,7 @@ func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err
|
|||
}
|
||||
|
||||
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
|
||||
l.lock.RLock()
|
||||
cache := l.cache
|
||||
l.lock.RUnlock()
|
||||
cache := l.cache.Load()
|
||||
|
||||
certModTime, keyModTime, terr := l.checkModTime()
|
||||
if terr != nil {
|
||||
|
@ -129,6 +130,11 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er
|
|||
}
|
||||
defer l.lock.Unlock()
|
||||
|
||||
if l.cache.Load() != cache {
|
||||
// another goroutine updated the cache
|
||||
return l.cache.Load().certificate, nil
|
||||
}
|
||||
|
||||
newCache, err := l.makeCache()
|
||||
if err != nil {
|
||||
if cache != nil {
|
||||
|
@ -138,7 +144,7 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er
|
|||
return nil, err
|
||||
}
|
||||
|
||||
l.cache = newCache
|
||||
l.cache.Store(newCache)
|
||||
return newCache.certificate, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue