diff --git a/dnscrypt-proxy/config.go b/dnscrypt-proxy/config.go index 7f784f04..ff5c1c49 100644 --- a/dnscrypt-proxy/config.go +++ b/dnscrypt-proxy/config.go @@ -894,7 +894,7 @@ func (config *Config) loadSource(proxy *Proxy, cfgSourceName string, cfgSource * cfgSource.Prefix, ) if err != nil { - if len(source.in) <= 0 { + if len(source.bin) <= 0 { dlog.Criticalf("Unable to retrieve source [%s]: [%s]", cfgSourceName, err) return err } diff --git a/dnscrypt-proxy/sources.go b/dnscrypt-proxy/sources.go index 0cd633aa..c973644c 100644 --- a/dnscrypt-proxy/sources.go +++ b/dnscrypt-proxy/sources.go @@ -3,10 +3,10 @@ package main import ( "bytes" "fmt" + "math" "math/rand" "net/url" "os" - "path/filepath" "strings" "time" @@ -31,13 +31,14 @@ const ( type Source struct { name string urls []*url.URL - format SourceFormat - in []byte + bin []byte // copy of the file content - there's something wrong in our logic, we shouldn't need to keep that in memory + sig []byte // copy of the signature minisignKey *minisign.PublicKey cacheFile string + prefix string cacheTTL, prefetchDelay time.Duration refresh time.Time - prefix string + format SourceFormat } func (source *Source) checkSignature(bin, sig []byte) (err error) { @@ -48,11 +49,10 @@ func (source *Source) checkSignature(bin, sig []byte) (err error) { return err } -// timeNow can be replaced by tests to provide a static value +// timeNow() can be replaced by tests to provide a static value var timeNow = time.Now -func (source *Source) fetchFromCache(now time.Time) (delay time.Duration, err error) { - var bin, sig []byte +func (source *Source) fetchFromCache(now time.Time) (delay time.Duration, bin []byte, sig []byte, err error) { if bin, err = os.ReadFile(source.cacheFile); err != nil { return } @@ -62,14 +62,13 @@ func (source *Source) fetchFromCache(now time.Time) (delay time.Duration, err er if err = source.checkSignature(bin, sig); err != nil { return } - source.in = bin var fi os.FileInfo if fi, err = os.Stat(source.cacheFile); err != nil { return } if elapsed := now.Sub(fi.ModTime()); elapsed < source.cacheTTL { delay = source.prefetchDelay - elapsed - dlog.Debugf("Source [%s] cache file [%s] is still fresh, next update: %v", source.name, source.cacheFile, delay) + dlog.Debugf("Source [%s] cache file [%s] is still fresh, next update in %v min", source.name, source.cacheFile, math.Round(delay.Minutes())) } else { dlog.Debugf("Source [%s] cache file [%s] needs to be refreshed", source.name, source.cacheFile) } @@ -98,25 +97,25 @@ func writeSource(f string, bin, sig []byte) (err error) { return fSig.Commit() } -func (source *Source) writeToCache(bin, sig []byte, now time.Time) { +// Update the cache file with the new data +func (source *Source) updateCache(bin, sig []byte, now time.Time) error { f := source.cacheFile - var writeErr error // an error writing cache isn't fatal - defer func() { - source.in = bin - if writeErr == nil { - return + // If the data and signature are unchanged, update the files timestamps only + if bin != nil && bytes.Equal(source.bin, bin) && sig != nil && bytes.Equal(source.sig, sig) { + dlog.Debugf("Source [%s] content and signature are unchanged", source.name) + if err := os.Chtimes(f, now, now); err != nil { + return err } - if absPath, absErr := filepath.Abs(f); absErr == nil { - f = absPath - } - dlog.Warnf("%s: %s", f, writeErr) - }() - if !bytes.Equal(source.in, bin) { - if writeErr = writeSource(f, bin, sig); writeErr != nil { - return + if err := os.Chtimes(f+".minisig", now, now); err != nil { + return err } + return nil } - writeErr = os.Chtimes(f, now, now) + if err := writeSource(f, bin, sig); err != nil { + dlog.Warnf("Source [%s] failed to update cache file [%s]: %v", source.name, f, err) + return err + } + return nil } func (source *Source) parseURLs(urls []string) { @@ -134,11 +133,10 @@ func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) { return bin, err } -func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (delay time.Duration, err error) { - if delay, err = source.fetchFromCache(now); err != nil { +func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (delay time.Duration, bin []byte, sig []byte, err error) { + if delay, bin, sig, err = source.fetchFromCache(now); err != nil { if len(source.urls) == 0 { - dlog.Errorf("Source [%s] cache file [%s] not present and no valid URL", source.name, source.cacheFile) - return + dlog.Fatalf("Source [%s] cache file [%s] not present and no valid URL", source.name, source.cacheFile) } dlog.Debugf("Source [%s] cache file [%s] not present", source.name, source.cacheFile) } @@ -148,10 +146,9 @@ func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (del }() } if len(source.urls) == 0 || delay > 0 { - return + source.sig = sig + return delay, bin, sig, nil // source is still valid } - delay = MinimumPrefetchInterval - var bin, sig []byte for _, srcURL := range source.urls { dlog.Infof("Source [%s] loading from URL [%s]", source.name, srcURL) sigURL := &url.URL{} @@ -171,11 +168,13 @@ func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (del dlog.Debugf("Source [%s] failed signature check using URL [%s]", source.name, srcURL) } if err != nil { - return + return MinimumPrefetchInterval, nil, nil, err } - source.writeToCache(bin, sig, now) - delay = source.prefetchDelay - return + if err := source.updateCache(bin, sig, now); err != nil { + return MinimumPrefetchInterval, bin, sig, err // keep using the old data + } + source.sig = sig + return source.prefetchDelay, bin, sig, nil } // NewSource loads a new source using the given cacheFile and urls, ensuring it has a valid signature @@ -211,10 +210,14 @@ func NewSource( return source, err } source.parseURLs(urls) - if _, err = source.fetchWithCache(xTransport, timeNow()); err == nil { - dlog.Noticef("Source [%s] loaded", name) + _, bin, sig, err := source.fetchWithCache(xTransport, timeNow()) + if err != nil { + return nil, err } - return + dlog.Noticef("Source [%s] loaded", name) + source.bin = bin + source.sig = sig + return source, err } // PrefetchSources downloads latest versions of given sources, ensuring they have a valid signature before caching @@ -226,13 +229,16 @@ func PrefetchSources(xTransport *XTransport, sources []*Source) time.Duration { continue } dlog.Debugf("Prefetching [%s]", source.name) - if delay, err := source.fetchWithCache(xTransport, now); err != nil { + delay, bin, sig, err := source.fetchWithCache(xTransport, now) + if err != nil { dlog.Infof("Prefetching [%s] failed: %v, will retry in %v", source.name, err, interval) - } else { - dlog.Debugf("Prefetching [%s] succeeded, next update: %v", source.name, delay) - if delay >= MinimumPrefetchInterval && (interval == MinimumPrefetchInterval || interval > delay) { - interval = delay - } + continue + } + source.bin = bin + source.sig = sig + dlog.Debugf("Prefetching [%s] succeeded, next update in %v min", source.name, math.Round(delay.Minutes())) + if delay >= MinimumPrefetchInterval && (interval == MinimumPrefetchInterval || interval > delay) { + interval = delay } } return interval @@ -254,8 +260,8 @@ func (source *Source) parseV2() ([]RegisteredServer, error) { stampErrs = append(stampErrs, stampErr) dlog.Warn(stampErr) } - in := string(source.in) - parts := strings.Split(in, "## ") + bin := string(source.bin) + parts := strings.Split(bin, "## ") if len(parts) < 2 { return registeredServers, fmt.Errorf("Invalid format for source at [%v]", source.urls) } diff --git a/dnscrypt-proxy/sources_test.go b/dnscrypt-proxy/sources_test.go index f3e41b9e..04a6eac7 100644 --- a/dnscrypt-proxy/sources_test.go +++ b/dnscrypt-proxy/sources_test.go @@ -284,9 +284,9 @@ func prepSourceTestCache(t *testing.T, d *SourceTestData, e *SourceTestExpect, s e.cache = []SourceFixture{d.fixtures[state][source], d.fixtures[state][source+".minisig"]} switch state { case TestStateCorrect: - e.Source.in, e.success = e.cache[0].content, true + e.Source.bin, e.success = e.cache[0].content, true case TestStateExpired: - e.Source.in = e.cache[0].content + e.Source.bin = e.cache[0].content case TestStatePartial, TestStatePartialSig: e.err = "signature" case TestStateMissing, TestStateMissingSig, TestStateOpenErr, TestStateOpenSigErr: @@ -339,7 +339,7 @@ func prepSourceTestDownload( switch state { case TestStateCorrect: e.cache = []SourceFixture{d.fixtures[state][source], d.fixtures[state][source+".minisig"]} - e.Source.in, e.success = e.cache[0].content, true + e.Source.bin, e.success = e.cache[0].content, true fallthrough case TestStateMissingSig, TestStatePartial, TestStatePartialSig, TestStateReadSigErr: d.reqExpect[path+".minisig"]++ @@ -477,7 +477,7 @@ func TestPrefetchSources(t *testing.T) { e.mtime = d.timeUpd s := &Source{} *s = *e.Source - s.in = nil + s.bin = nil sources = append(sources, s) expects = append(expects, e) }