diff --git a/msgpipeline/config.go b/msgpipeline/config.go index aba789a..ba7d067 100644 --- a/msgpipeline/config.go +++ b/msgpipeline/config.go @@ -34,21 +34,23 @@ func parseMsgPipelineRootCfg(globals map[string]interface{}, nodes []config.Node return msgpipelineCfg{}, config.NodeErr(&node, "empty checks block") } - var err error - cfg.globalChecks, err = parseChecksGroup(globals, node.Children) + globalChecks, err := parseChecksGroup(globals, node.Children) if err != nil { return msgpipelineCfg{}, err } + + cfg.globalChecks = append(cfg.globalChecks, globalChecks...) case "modify": if len(node.Children) == 0 { return msgpipelineCfg{}, config.NodeErr(&node, "empty modifiers block") } - var err error - cfg.globalModifiers, err = parseModifiersGroup(globals, node.Children) + globalModifiers, err := parseModifiersGroup(globals, node.Children) if err != nil { return msgpipelineCfg{}, err } + + cfg.globalModifiers.Modifiers = append(cfg.globalModifiers.Modifiers, globalModifiers.Modifiers...) case "source": srcBlock, err := parseMsgPipelineSrcCfg(globals, node.Children) if err != nil { @@ -123,21 +125,23 @@ func parseMsgPipelineSrcCfg(globals map[string]interface{}, nodes []config.Node) return sourceBlock{}, config.NodeErr(&node, "empty checks block") } - var err error - src.checks, err = parseChecksGroup(globals, node.Children) + checks, err := parseChecksGroup(globals, node.Children) if err != nil { return sourceBlock{}, err } + + src.checks = append(src.checks, checks...) case "modify": if len(node.Children) == 0 { return sourceBlock{}, config.NodeErr(&node, "empty modifiers block") } - var err error - src.modifiers, err = parseModifiersGroup(globals, node.Children) + modifiers, err := parseModifiersGroup(globals, node.Children) if err != nil { return sourceBlock{}, err } + + src.modifiers.Modifiers = append(src.modifiers.Modifiers, modifiers.Modifiers...) case "destination": rcptBlock, err := parseMsgPipelineRcptCfg(globals, node.Children) if err != nil { @@ -195,21 +199,23 @@ func parseMsgPipelineRcptCfg(globals map[string]interface{}, nodes []config.Node return nil, config.NodeErr(&node, "empty checks block") } - var err error - rcpt.checks, err = parseChecksGroup(globals, node.Children) + checks, err := parseChecksGroup(globals, node.Children) if err != nil { return nil, err } + + rcpt.checks = append(rcpt.checks, checks...) case "modify": if len(node.Children) == 0 { return nil, config.NodeErr(&node, "empty modifiers block") } - var err error - rcpt.modifiers, err = parseModifiersGroup(globals, node.Children) + modifiers, err := parseModifiersGroup(globals, node.Children) if err != nil { return nil, err } + + rcpt.modifiers.Modifiers = append(rcpt.modifiers.Modifiers, modifiers.Modifiers...) case "deliver_to": if rcpt.rejectErr != nil { return nil, config.NodeErr(&node, "can't use 'reject' and 'deliver_to' together") diff --git a/msgpipeline/config_test.go b/msgpipeline/config_test.go index a5198e2..05858c3 100644 --- a/msgpipeline/config_test.go +++ b/msgpipeline/config_test.go @@ -252,6 +252,30 @@ func TestMsgPipelineCfg_GlobalChecks(t *testing.T) { } } +func TestMsgPipelineCfg_GlobalChecksMultiple(t *testing.T) { + str := ` + check { + test_check + } + check { + test_check + } + default_destination { + reject 500 + } + ` + + cfg, _ := parser.Read(strings.NewReader(str), "literal") + parsed, err := parseMsgPipelineRootCfg(nil, cfg) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + + if len(parsed.globalChecks) != 2 { + t.Fatalf("wrong amount of test_check's in globalChecks: %d", len(parsed.globalChecks)) + } +} + func TestMsgPipelineCfg_SourceChecks(t *testing.T) { str := ` source example.org { @@ -277,6 +301,34 @@ func TestMsgPipelineCfg_SourceChecks(t *testing.T) { } } +func TestMsgPipelineCfg_SourceChecks_Multiple(t *testing.T) { + str := ` + source example.org { + check { + test_check + } + check { + test_check + } + + reject 500 + } + default_source { + reject 500 + } + ` + + cfg, _ := parser.Read(strings.NewReader(str), "literal") + parsed, err := parseMsgPipelineRootCfg(nil, cfg) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + + if len(parsed.perSource["example.org"].checks) != 2 { + t.Fatalf("wrong amount of test_check's in source checks: %d", len(parsed.perSource["example.org"].checks)) + } +} + func TestMsgPipelineCfg_RcptChecks(t *testing.T) { str := ` destination example.org { @@ -301,3 +353,31 @@ func TestMsgPipelineCfg_RcptChecks(t *testing.T) { t.Fatalf("missing test_check in rcpt checks") } } + +func TestMsgPipelineCfg_RcptChecks_Multiple(t *testing.T) { + str := ` + destination example.org { + check { + test_check + } + check { + test_check + } + + reject 500 + } + default_destination { + reject 500 + } + ` + + cfg, _ := parser.Read(strings.NewReader(str), "literal") + parsed, err := parseMsgPipelineRootCfg(nil, cfg) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + + if len(parsed.defaultSource.perRcpt["example.org"].checks) != 2 { + t.Fatalf("wrong amount of test_check's in rcpt checks: %d", len(parsed.defaultSource.perRcpt["example.org"].checks)) + } +}