mirror of
https://github.com/foxcpp/maddy.git
synced 2025-04-06 06:27:38 +03:00
target/remote: Fix race conditions in tests initialization
This commit is contained in:
parent
bafedd5792
commit
1b661f6eab
2 changed files with 45 additions and 14 deletions
|
@ -62,6 +62,7 @@ func testSTSPolicy(t *testing.T, zones map[string]mockdns.Zone, mtastsGet func(c
|
|||
}
|
||||
p.mtastsGet = mtastsGet
|
||||
p.log = testutils.Logger(t, "remote/mtasts")
|
||||
p.StartUpdater()
|
||||
|
||||
return p
|
||||
}
|
||||
|
@ -79,6 +80,7 @@ func testSTSPreload(t *testing.T, download FuncPreloadList) *stsPreloadPolicy {
|
|||
t.Fatal(err)
|
||||
}
|
||||
p.log = testutils.Logger(t, "remote/preload")
|
||||
p.StartUpdater()
|
||||
|
||||
return p
|
||||
}
|
||||
|
|
|
@ -117,8 +117,7 @@ type (
|
|||
|
||||
func NewMTASTSPolicy(r dns.Resolver, debug bool, cfg *config.Map) (*mtastsPolicy, error) {
|
||||
c := &mtastsPolicy{
|
||||
updaterStop: make(chan struct{}),
|
||||
log: log.Logger{Name: "remote/mtasts", Debug: debug},
|
||||
log: log.Logger{Name: "remote/mtasts", Debug: debug},
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -145,12 +144,20 @@ func NewMTASTSPolicy(r dns.Resolver, debug bool, cfg *config.Map) (*mtastsPolicy
|
|||
c.cache.Resolver = dns.DefaultResolver()
|
||||
c.mtastsGet = c.cache.Get
|
||||
|
||||
go c.updater()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// StartUpdater starts a goroutine to update MTA-STS cache periodically until
|
||||
// Close is called.
|
||||
//
|
||||
// It can be called only once per mtastsPolicy instance.
|
||||
func (c *mtastsPolicy) StartUpdater() {
|
||||
c.updaterStop = make(chan struct{})
|
||||
go c.updater()
|
||||
}
|
||||
|
||||
func (c *mtastsPolicy) updater() {
|
||||
|
||||
// Always update cache on start-up since we may have been down for some
|
||||
// time.
|
||||
c.log.Debugln("updating MTA-STS cache...")
|
||||
|
@ -183,8 +190,11 @@ func (c *mtastsPolicy) Start(msgMeta *module.MsgMetadata) DeliveryPolicy {
|
|||
}
|
||||
|
||||
func (c *mtastsPolicy) Close() error {
|
||||
c.updaterStop <- struct{}{}
|
||||
<-c.updaterStop
|
||||
if c.updaterStop != nil {
|
||||
c.updaterStop <- struct{}{}
|
||||
<-c.updaterStop
|
||||
c.updaterStop = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -291,14 +301,13 @@ type (
|
|||
func NewSTSPreloadPolicy(debug bool, client *http.Client, listDownload FuncPreloadList, cfg *config.Map) (*stsPreloadPolicy, error) {
|
||||
p := &stsPreloadPolicy{
|
||||
log: log.Logger{Name: "remote/preload", Debug: debug},
|
||||
updaterStop: make(chan struct{}),
|
||||
client: client,
|
||||
listDownload: preload.Download,
|
||||
}
|
||||
|
||||
var sourcePath string
|
||||
cfg.String("source", false, false, "eff", &sourcePath)
|
||||
cfg.Bool("enforceTesting", false, true, &p.enforceTesting)
|
||||
cfg.Bool("enforce_testing", false, true, &p.enforceTesting)
|
||||
if _, err := cfg.Process(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -311,8 +320,6 @@ func NewSTSPreloadPolicy(debug bool, client *http.Client, listDownload FuncPrelo
|
|||
|
||||
p.sourcePath = sourcePath
|
||||
|
||||
go p.updater()
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
|
@ -368,6 +375,15 @@ func (p *stsPreloadPolicy) load(client *http.Client, listDownload FuncPreloadLis
|
|||
return l, nil
|
||||
}
|
||||
|
||||
// StartUpdater starts a goroutine to update the used list periodically until Close is
|
||||
// called.
|
||||
//
|
||||
// It can be called only once per stsPreloadPolicy instance.
|
||||
func (p *stsPreloadPolicy) StartUpdater() {
|
||||
p.updaterStop = make(chan struct{})
|
||||
go p.updater()
|
||||
}
|
||||
|
||||
func (p *stsPreloadPolicy) updater() {
|
||||
for {
|
||||
updateDelay := time.Until(time.Time(p.l.Expires)) - preloadUpdateGrace
|
||||
|
@ -527,8 +543,11 @@ func (p *preloadDelivery) CheckConn(ctx context.Context, mxLevel MXLevel, tlsLev
|
|||
}
|
||||
|
||||
func (p *stsPreloadPolicy) Close() error {
|
||||
p.updaterStop <- struct{}{}
|
||||
<-p.updaterStop
|
||||
if p.updaterStop != nil {
|
||||
p.updaterStop <- struct{}{}
|
||||
<-p.updaterStop
|
||||
p.updaterStop = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -798,14 +817,24 @@ func (rt *Target) defaultPolicy(globals map[string]interface{}, name string) fun
|
|||
)
|
||||
switch name {
|
||||
case "mtasts":
|
||||
policy, err = NewMTASTSPolicy(rt.resolver, log.DefaultLogger.Debug, config.NewMap(globals, &config.Node{}))
|
||||
mtastsPol, err := NewMTASTSPolicy(rt.resolver, log.DefaultLogger.Debug, config.NewMap(globals, &config.Node{}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mtastsPol.StartUpdater()
|
||||
policy = mtastsPol
|
||||
case "dane":
|
||||
policy = NewDANEPolicy(rt.extResolver, log.DefaultLogger.Debug)
|
||||
case "dnssec":
|
||||
policy = &dnssecPolicy{}
|
||||
case "sts_preload":
|
||||
policy, err = NewSTSPreloadPolicy(log.DefaultLogger.Debug, http.DefaultClient, preload.Download,
|
||||
preloadPolicy, err := NewSTSPreloadPolicy(log.DefaultLogger.Debug, http.DefaultClient, preload.Download,
|
||||
config.NewMap(globals, &config.Node{}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
preloadPolicy.StartUpdater()
|
||||
policy = preloadPolicy
|
||||
case "local":
|
||||
policy, err = NewLocalPolicy(config.NewMap(globals, &config.Node{}))
|
||||
default:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue