From f49196c6e8966038e6378d9414be9067093a0e13 Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Fri, 21 Feb 2025 18:07:41 +0100 Subject: [PATCH] xTransport: avoid updating the host->IP map in multiple goroutines When a goroutine is updating an IP, keep serving the previous IP to other goroutines. --- dnscrypt-proxy/xtransport.go | 37 +++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/dnscrypt-proxy/xtransport.go b/dnscrypt-proxy/xtransport.go index 80402746..0bd1b93c 100644 --- a/dnscrypt-proxy/xtransport.go +++ b/dnscrypt-proxy/xtransport.go @@ -40,8 +40,9 @@ const ( ) type CachedIPItem struct { - ip net.IP - expiration *time.Time + ip net.IP + expiration *time.Time + updating_until *time.Time } type CachedIPs struct { @@ -105,7 +106,7 @@ func ParseIP(ipStr string) net.IP { // If ttl < 0, never expire // Otherwise, ttl is set to max(ttl, MinResolverIPTTL) func (xTransport *XTransport) saveCachedIP(host string, ip net.IP, ttl time.Duration) { - item := &CachedIPItem{ip: ip, expiration: nil} + item := &CachedIPItem{ip: ip, expiration: nil, updating_until: nil} if ttl >= 0 { if ttl < MinResolverIPTTL { ttl = MinResolverIPTTL @@ -118,8 +119,21 @@ func (xTransport *XTransport) saveCachedIP(host string, ip net.IP, ttl time.Dura xTransport.cachedIPs.Unlock() } -func (xTransport *XTransport) loadCachedIP(host string) (ip net.IP, expired bool) { - ip, expired = nil, false +// Mark an entry as being updated +func (xTransport *XTransport) markUpdatingCachedIP(host string) { + xTransport.cachedIPs.Lock() + item, ok := xTransport.cachedIPs.cache[host] + if ok { + now := time.Now() + until := now.Add(xTransport.timeout) + item.updating_until = &until + xTransport.cachedIPs.cache[host] = item + } + xTransport.cachedIPs.Unlock() +} + +func (xTransport *XTransport) loadCachedIP(host string) (ip net.IP, expired bool, updating bool) { + ip, expired, updating = nil, false, false xTransport.cachedIPs.RLock() item, ok := xTransport.cachedIPs.cache[host] xTransport.cachedIPs.RUnlock() @@ -130,6 +144,9 @@ func (xTransport *XTransport) loadCachedIP(host string) (ip net.IP, expired bool expiration := item.expiration if expiration != nil && time.Until(*expiration) < 0 { expired = true + if item.updating_until != nil && time.Until(*item.updating_until) > 0 { + updating = true + } } return } @@ -153,7 +170,7 @@ func (xTransport *XTransport) rebuildTransport() { ipOnly := host // resolveAndUpdateCache() is always called in `Fetch()` before the `Dial()` // method is used, so that a cached entry must be present at this point. - cachedIP, _ := xTransport.loadCachedIP(host) + cachedIP, _, _ := xTransport.loadCachedIP(host) if cachedIP != nil { if ipv4 := cachedIP.To4(); ipv4 != nil { ipOnly = ipv4.String() @@ -263,7 +280,7 @@ func (xTransport *XTransport) rebuildTransport() { dlog.Debugf("Dialing for H3: [%v]", addrStr) host, port := ExtractHostAndPort(addrStr, stamps.DefaultPort) ipOnly := host - cachedIP, _ := xTransport.loadCachedIP(host) + cachedIP, _, _ := xTransport.loadCachedIP(host) network := "udp4" if cachedIP != nil { if ipv4 := cachedIP.To4(); ipv4 != nil { @@ -402,10 +419,12 @@ func (xTransport *XTransport) resolveAndUpdateCache(host string) error { if ParseIP(host) != nil { return nil } - cachedIP, expired := xTransport.loadCachedIP(host) - if cachedIP != nil && !expired { + cachedIP, expired, updating := xTransport.loadCachedIP(host) + if cachedIP != nil && (!expired || updating) { return nil } + xTransport.markUpdatingCachedIP(host) + var foundIP net.IP var ttl time.Duration var err error