Merge pull request #1 from lifenjoiner/pr2537

Add `cert_refresh_concurrency`
This commit is contained in:
Xiaotong Liu 2023-12-11 21:17:54 +08:00 committed by GitHub
commit e75fa68301
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 11 deletions

View file

@ -42,6 +42,7 @@ type Config struct {
Timeout int `toml:"timeout"` Timeout int `toml:"timeout"`
KeepAlive int `toml:"keepalive"` KeepAlive int `toml:"keepalive"`
Proxy string `toml:"proxy"` Proxy string `toml:"proxy"`
CertRefreshConcurrency int `toml:"cert_refresh_concurrency"`
CertRefreshDelay int `toml:"cert_refresh_delay"` CertRefreshDelay int `toml:"cert_refresh_delay"`
CertIgnoreTimestamp bool `toml:"cert_ignore_timestamp"` CertIgnoreTimestamp bool `toml:"cert_ignore_timestamp"`
EphemeralKeys bool `toml:"dnscrypt_ephemeral_keys"` EphemeralKeys bool `toml:"dnscrypt_ephemeral_keys"`
@ -116,6 +117,7 @@ func newConfig() Config {
LocalDoH: LocalDoHConfig{Path: "/dns-query"}, LocalDoH: LocalDoHConfig{Path: "/dns-query"},
Timeout: 5000, Timeout: 5000,
KeepAlive: 5, KeepAlive: 5,
CertRefreshConcurrency: 10,
CertRefreshDelay: 240, CertRefreshDelay: 240,
HTTP3: false, HTTP3: false,
CertIgnoreTimestamp: false, CertIgnoreTimestamp: false,
@ -437,6 +439,7 @@ func ConfigLoad(proxy *Proxy, flags *ConfigFlags) error {
if config.ForceTCP { if config.ForceTCP {
proxy.mainProto = "tcp" proxy.mainProto = "tcp"
} }
proxy.certRefreshConcurrency = Max(1, config.CertRefreshConcurrency)
proxy.certRefreshDelay = time.Duration(Max(60, config.CertRefreshDelay)) * time.Minute proxy.certRefreshDelay = time.Duration(Max(60, config.CertRefreshDelay)) * time.Minute
proxy.certRefreshDelayAfterFailure = time.Duration(10 * time.Second) proxy.certRefreshDelayAfterFailure = time.Duration(10 * time.Second)
proxy.certIgnoreTimestamp = config.CertIgnoreTimestamp proxy.certIgnoreTimestamp = config.CertIgnoreTimestamp

View file

@ -183,6 +183,12 @@ keepalive = 30
# use_syslog = true # use_syslog = true
## The maximum concurrency to reload certificates from the resolvers.
## Default is 10.
# cert_refresh_concurrency = 10
## Delay, in minutes, after which certificates are reloaded ## Delay, in minutes, after which certificates are reloaded
cert_refresh_delay = 240 cert_refresh_delay = 240

View file

@ -74,6 +74,7 @@ type Proxy struct {
certRefreshDelayAfterFailure time.Duration certRefreshDelayAfterFailure time.Duration
timeout time.Duration timeout time.Duration
certRefreshDelay time.Duration certRefreshDelay time.Duration
certRefreshConcurrency int
cacheSize int cacheSize int
logMaxBackups int logMaxBackups int
logMaxAge int logMaxAge int

View file

@ -228,24 +228,23 @@ func (serversInfo *ServersInfo) refresh(proxy *Proxy) (int, error) {
copy(registeredServers, serversInfo.registeredServers) copy(registeredServers, serversInfo.registeredServers)
serversInfo.RUnlock() serversInfo.RUnlock()
liveServers := 0 liveServers := 0
countChannel := make(chan struct{}, proxy.certRefreshConcurrency)
waitChannel := make(chan struct{})
var err error var err error
// simultaneously refresh all servers
wg := sync.WaitGroup{}
wg.Add(len(registeredServers))
for i := range registeredServers { for i := range registeredServers {
go func(rs *RegisteredServer) { countChannel <- struct{}{}
if err = serversInfo.refreshServer(proxy, rs.name, rs.stamp); err == nil { go func(registeredServer *RegisteredServer) {
serversInfo.Lock() if err = serversInfo.refreshServer(proxy, registeredServer.name, registeredServer.stamp); err == nil {
liveServers++ liveServers++
proxy.xTransport.internalResolverReady = true proxy.xTransport.internalResolverReady = true
serversInfo.Unlock()
} }
wg.Done() <-countChannel
if len(countChannel) == 0 {
close(waitChannel)
}
}(&registeredServers[i]) }(&registeredServers[i])
} }
wg.Wait() <-waitChannel
serversInfo.Lock() serversInfo.Lock()
sort.SliceStable(serversInfo.inner, func(i, j int) bool { sort.SliceStable(serversInfo.inner, func(i, j int) bool {
return serversInfo.inner[i].initialRtt < serversInfo.inner[j].initialRtt return serversInfo.inner[i].initialRtt < serversInfo.inner[j].initialRtt