target/remote: Fix race conditions in tests initialization

This commit is contained in:
fox.cpp 2019-12-29 19:27:56 +03:00
parent bafedd5792
commit 1b661f6eab
No known key found for this signature in database
GPG key ID: E76D97CCEDE90B6C
2 changed files with 45 additions and 14 deletions

View file

@ -62,6 +62,7 @@ func testSTSPolicy(t *testing.T, zones map[string]mockdns.Zone, mtastsGet func(c
}
p.mtastsGet = mtastsGet
p.log = testutils.Logger(t, "remote/mtasts")
p.StartUpdater()
return p
}
@ -79,6 +80,7 @@ func testSTSPreload(t *testing.T, download FuncPreloadList) *stsPreloadPolicy {
t.Fatal(err)
}
p.log = testutils.Logger(t, "remote/preload")
p.StartUpdater()
return p
}

View file

@ -117,8 +117,7 @@ type (
func NewMTASTSPolicy(r dns.Resolver, debug bool, cfg *config.Map) (*mtastsPolicy, error) {
c := &mtastsPolicy{
updaterStop: make(chan struct{}),
log: log.Logger{Name: "remote/mtasts", Debug: debug},
log: log.Logger{Name: "remote/mtasts", Debug: debug},
}
var (
@ -145,12 +144,20 @@ func NewMTASTSPolicy(r dns.Resolver, debug bool, cfg *config.Map) (*mtastsPolicy
c.cache.Resolver = dns.DefaultResolver()
c.mtastsGet = c.cache.Get
go c.updater()
return c, nil
}
// StartUpdater starts a goroutine to update MTA-STS cache periodically until
// Close is called.
//
// It can be called only once per mtastsPolicy instance.
func (c *mtastsPolicy) StartUpdater() {
c.updaterStop = make(chan struct{})
go c.updater()
}
func (c *mtastsPolicy) updater() {
// Always update cache on start-up since we may have been down for some
// time.
c.log.Debugln("updating MTA-STS cache...")
@ -183,8 +190,11 @@ func (c *mtastsPolicy) Start(msgMeta *module.MsgMetadata) DeliveryPolicy {
}
func (c *mtastsPolicy) Close() error {
c.updaterStop <- struct{}{}
<-c.updaterStop
if c.updaterStop != nil {
c.updaterStop <- struct{}{}
<-c.updaterStop
c.updaterStop = nil
}
return nil
}
@ -291,14 +301,13 @@ type (
func NewSTSPreloadPolicy(debug bool, client *http.Client, listDownload FuncPreloadList, cfg *config.Map) (*stsPreloadPolicy, error) {
p := &stsPreloadPolicy{
log: log.Logger{Name: "remote/preload", Debug: debug},
updaterStop: make(chan struct{}),
client: client,
listDownload: preload.Download,
}
var sourcePath string
cfg.String("source", false, false, "eff", &sourcePath)
cfg.Bool("enforceTesting", false, true, &p.enforceTesting)
cfg.Bool("enforce_testing", false, true, &p.enforceTesting)
if _, err := cfg.Process(); err != nil {
return nil, err
}
@ -311,8 +320,6 @@ func NewSTSPreloadPolicy(debug bool, client *http.Client, listDownload FuncPrelo
p.sourcePath = sourcePath
go p.updater()
return p, nil
}
@ -368,6 +375,15 @@ func (p *stsPreloadPolicy) load(client *http.Client, listDownload FuncPreloadLis
return l, nil
}
// StartUpdater starts a goroutine to update the used list periodically until Close is
// called.
//
// It can be called only once per stsPreloadPolicy instance.
func (p *stsPreloadPolicy) StartUpdater() {
p.updaterStop = make(chan struct{})
go p.updater()
}
func (p *stsPreloadPolicy) updater() {
for {
updateDelay := time.Until(time.Time(p.l.Expires)) - preloadUpdateGrace
@ -527,8 +543,11 @@ func (p *preloadDelivery) CheckConn(ctx context.Context, mxLevel MXLevel, tlsLev
}
func (p *stsPreloadPolicy) Close() error {
p.updaterStop <- struct{}{}
<-p.updaterStop
if p.updaterStop != nil {
p.updaterStop <- struct{}{}
<-p.updaterStop
p.updaterStop = nil
}
return nil
}
@ -798,14 +817,24 @@ func (rt *Target) defaultPolicy(globals map[string]interface{}, name string) fun
)
switch name {
case "mtasts":
policy, err = NewMTASTSPolicy(rt.resolver, log.DefaultLogger.Debug, config.NewMap(globals, &config.Node{}))
mtastsPol, err := NewMTASTSPolicy(rt.resolver, log.DefaultLogger.Debug, config.NewMap(globals, &config.Node{}))
if err != nil {
return nil, err
}
mtastsPol.StartUpdater()
policy = mtastsPol
case "dane":
policy = NewDANEPolicy(rt.extResolver, log.DefaultLogger.Debug)
case "dnssec":
policy = &dnssecPolicy{}
case "sts_preload":
policy, err = NewSTSPreloadPolicy(log.DefaultLogger.Debug, http.DefaultClient, preload.Download,
preloadPolicy, err := NewSTSPreloadPolicy(log.DefaultLogger.Debug, http.DefaultClient, preload.Download,
config.NewMap(globals, &config.Node{}))
if err != nil {
return nil, err
}
preloadPolicy.StartUpdater()
policy = preloadPolicy
case "local":
policy, err = NewLocalPolicy(config.NewMap(globals, &config.Node{}))
default: