Restructure code tree

Root package now contains only initialization code and 'dummy' module.

Each module now got its own package. Module packages are grouped by
their main purpose (storage/, target/, auth/, etc). Shared code is
placed in these "group" packages.

Parser for module references in config is moved into config/module.

Code shared by tests (mock modules, etc) is placed in testutils.
This commit is contained in:
fox.cpp 2019-09-08 15:46:46 +03:00
parent d4d807d6c7
commit 35c3b1c792
No known key found for this signature in database
GPG key ID: E76D97CCEDE90B6C
51 changed files with 961 additions and 2223 deletions

3
address/doc.go Normal file
View file

@ -0,0 +1,3 @@
// Package address provides utilities for parsing
// and validation of RFC 2821 addresses.
package address

24
address/split.go Normal file
View file

@ -0,0 +1,24 @@
package address
import (
"fmt"
"strings"
)
func Split(addr string) (mailbox, domain string, err error) {
parts := strings.Split(addr, "@")
switch len(parts) {
case 1:
if strings.EqualFold(parts[0], "postmaster") {
return parts[0], "", nil
}
return "", "", fmt.Errorf("malformed address")
case 2:
if len(parts[0]) == 0 || len(parts[1]) == 0 {
return "", "", fmt.Errorf("malformed address")
}
return parts[0], parts[1], nil
default:
return "", "", fmt.Errorf("malformed address")
}
}

View file

@ -1,4 +1,4 @@
package maddy
package address
import (
"strings"
@ -9,7 +9,8 @@ Rules for validation are subset of rules listed here:
https://emailregex.com/email-validation-summary/
*/
func validAddress(addr string) bool {
// Valid checks whether ths string is valid as a email address.
func Valid(addr string) bool {
if len(addr) > 320 { // RFC 3696 says it's 320, not 255.
return false
}
@ -18,19 +19,20 @@ func validAddress(addr string) bool {
case 1:
return strings.EqualFold(addr, "postmaster")
case 2:
return validMailboxName(parts[0]) && validDomain(parts[1])
return ValidMailboxName(parts[0]) && ValidDomain(parts[1])
default:
return false
}
}
// validMailboxName checks whether the specified string is a valid mailbox-name
// ValidMailboxName checks whether the specified string is a valid mailbox-name
// element of e-mail address (left part of it, before at-sign).
func validMailboxName(mbox string) bool {
func ValidMailboxName(mbox string) bool {
return true // TODO
}
func validDomain(domain string) bool {
// ValidDomain checks whether the specified string is a valid DNS domain.
func ValidDomain(domain string) bool {
if len(domain) > 255 {
return false
}

View file

@ -1,8 +1,8 @@
package maddy
package auth
import "strings"
func checkDomainAuth(username string, perDomain bool, allowedDomains []string) (loginName string, allowed bool) {
func CheckDomainAuth(username string, perDomain bool, allowedDomains []string) (loginName string, allowed bool) {
var accountName, domain string
if perDomain {
parts := strings.Split(username, "@")

View file

@ -1,4 +1,4 @@
package maddy
package auth
import (
"fmt"
@ -59,7 +59,7 @@ func TestCheckDomainAuth(t *testing.T) {
for _, case_ := range cases {
case_ := case_
t.Run(fmt.Sprintf("%+v", case_), func(t *testing.T) {
loginName, allowed := checkDomainAuth(case_.rawUsername, case_.perDomain, case_.allowedDomains)
loginName, allowed := CheckDomainAuth(case_.rawUsername, case_.perDomain, case_.allowedDomains)
if case_.loginName != "" && !allowed {
t.Fatalf("Unexpected authentication fail")
}

View file

@ -1,4 +1,4 @@
package maddy
package external
import (
"errors"
@ -9,6 +9,7 @@ import (
"path/filepath"
"strings"
"github.com/foxcpp/maddy/auth"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
@ -75,9 +76,9 @@ func (ea *ExternalAuth) Init(cfg *config.Map) error {
}
}
ea.helperPath = filepath.Join(LibexecDirectory(cfg.Globals), helperName)
ea.helperPath = filepath.Join(config.LibexecDirectory(cfg.Globals), helperName)
if _, err := os.Stat(ea.helperPath); err != nil {
return fmt.Errorf("no %s authentication support, %s is not found in %s and no custom path is set", ea.modName, LibexecDirectory(cfg.Globals), helperName)
return fmt.Errorf("no %s authentication support, %s is not found in %s and no custom path is set", ea.modName, config.LibexecDirectory(cfg.Globals), helperName)
}
ea.Log.Debugln("using helper:", ea.helperPath)
@ -86,7 +87,7 @@ func (ea *ExternalAuth) Init(cfg *config.Map) error {
}
func (ea *ExternalAuth) CheckPlain(username, password string) bool {
accountName, ok := checkDomainAuth(username, ea.perDomain, ea.domains)
accountName, ok := auth.CheckDomainAuth(username, ea.perDomain, ea.domains)
if !ok {
return false
}

62
check/action.go Normal file
View file

@ -0,0 +1,62 @@
package check
import (
"strconv"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/module"
)
type FailAction struct {
Quarantine bool
Reject bool
ScoreAdjust int
}
func failActionDirective(m *config.Map, node *config.Node) (interface{}, error) {
if len(node.Args) == 0 {
return nil, m.MatchErr("expected at least 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
switch node.Args[0] {
case "Reject", "Quarantine", "ignore":
if len(node.Args) > 1 {
return nil, m.MatchErr("too many arguments")
}
return FailAction{
Reject: node.Args[0] == "Reject",
Quarantine: node.Args[0] == "Quarantine",
}, nil
case "score":
if len(node.Args) != 2 {
return nil, m.MatchErr("expected 2 arguments")
}
scoreAdj, err := strconv.Atoi(node.Args[1])
if err != nil {
return nil, m.MatchErr("%v", err)
}
return FailAction{
ScoreAdjust: scoreAdj,
}, nil
default:
return nil, m.MatchErr("invalid action")
}
}
// Apply merges the result of check execution with action configuration specified
// in the check configuration.
func (cfa FailAction) Apply(originalRes module.CheckResult) module.CheckResult {
if originalRes.RejectErr == nil {
return originalRes
}
originalRes.Quarantine = cfa.Quarantine
originalRes.ScoreAdjust = int32(cfa.ScoreAdjust)
if !cfa.Reject {
originalRes.RejectErr = nil
}
return originalRes
}

View file

@ -1,4 +1,4 @@
package maddy
package check
import (
"context"
@ -13,31 +13,31 @@ import (
"golang.org/x/sync/errgroup"
)
// CheckGroup is a type that wraps a group of checks and runs them in parallel.
// Group is a type that wraps a group of checks and runs them in parallel.
//
// It implements module.Check interface.
type CheckGroup struct {
checks []module.Check
type Group struct {
Checks []module.Check
}
func (cg *CheckGroup) NewMessage(msgMeta *module.MsgMetadata) (module.CheckState, error) {
states := make([]module.CheckState, 0, len(cg.checks))
for _, check := range cg.checks {
func (g *Group) NewMessage(msgMeta *module.MsgMetadata) (module.CheckState, error) {
states := make([]module.CheckState, 0, len(g.Checks))
for _, check := range g.Checks {
state, err := check.NewMessage(msgMeta)
if err != nil {
return nil, err
}
states = append(states, state)
}
return &checkGroupState{msgMeta, states}, nil
return &groupState{msgMeta, states}, nil
}
type checkGroupState struct {
type groupState struct {
msgMeta *module.MsgMetadata
states []module.CheckState
}
func (cgs *checkGroupState) runAndMergeResults(ctx context.Context, runner func(context.Context, module.CheckState) module.CheckResult) module.CheckResult {
func (gs *groupState) runAndMergeResults(ctx context.Context, runner func(context.Context, module.CheckState) module.CheckResult) module.CheckResult {
var (
checkScore int32
quarantineFlag atomicbool.AtomicBool
@ -48,7 +48,7 @@ func (cgs *checkGroupState) runAndMergeResults(ctx context.Context, runner func(
)
syncGroup, childCtx := errgroup.WithContext(ctx)
for _, state := range cgs.states {
for _, state := range gs.states {
state := state
syncGroup.Go(func() error {
subCheckRes := runner(childCtx, state)
@ -90,32 +90,32 @@ func (cgs *checkGroupState) runAndMergeResults(ctx context.Context, runner func(
}
}
func (cgs *checkGroupState) CheckConnection(ctx context.Context) module.CheckResult {
return cgs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
func (gs *groupState) CheckConnection(ctx context.Context) module.CheckResult {
return gs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
return state.CheckConnection(childCtx)
})
}
func (cgs *checkGroupState) CheckSender(ctx context.Context, from string) module.CheckResult {
return cgs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
func (gs *groupState) CheckSender(ctx context.Context, from string) module.CheckResult {
return gs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
return state.CheckSender(childCtx, from)
})
}
func (cgs *checkGroupState) CheckRcpt(ctx context.Context, to string) module.CheckResult {
return cgs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
func (gs *groupState) CheckRcpt(ctx context.Context, to string) module.CheckResult {
return gs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
return state.CheckRcpt(childCtx, to)
})
}
func (cgs *checkGroupState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
return cgs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
func (gs *groupState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
return gs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
return state.CheckBody(childCtx, header, body)
})
}
func (cgs *checkGroupState) Close() error {
for _, state := range cgs.states {
func (gs *groupState) Close() error {
for _, state := range gs.states {
state.Close()
}
return nil

View file

@ -1,4 +1,4 @@
package maddy
package dkim
import (
"io"

View file

@ -1,4 +1,4 @@
package maddy
package dns
import (
"context"
@ -6,11 +6,13 @@ import (
"strings"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/address"
"github.com/foxcpp/maddy/check"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
)
func requireMatchingRDNS(ctx StatelessCheckContext) module.CheckResult {
func requireMatchingRDNS(ctx check.StatelessCheckContext) module.CheckResult {
tcpAddr, ok := ctx.MsgMeta.SrcAddr.(*net.TCPAddr)
if !ok {
log.Debugf("non TCP/IP source (%v), skipped", ctx.MsgMeta.SrcAddr)
@ -47,8 +49,8 @@ func requireMatchingRDNS(ctx StatelessCheckContext) module.CheckResult {
}
}
func requireMXRecord(ctx StatelessCheckContext, mailFrom string) module.CheckResult {
_, domain, err := splitAddress(mailFrom)
func requireMXRecord(ctx check.StatelessCheckContext, mailFrom string) module.CheckResult {
_, domain, err := address.Split(mailFrom)
if err != nil {
return module.CheckResult{RejectErr: err}
}
@ -95,7 +97,7 @@ func requireMXRecord(ctx StatelessCheckContext, mailFrom string) module.CheckRes
return module.CheckResult{}
}
func requireMatchingEHLO(ctx StatelessCheckContext) module.CheckResult {
func requireMatchingEHLO(ctx check.StatelessCheckContext) module.CheckResult {
tcpAddr, ok := ctx.MsgMeta.SrcAddr.(*net.TCPAddr)
if !ok {
ctx.Logger.Debugf("not TCP/IP source (%v), skipped", ctx.MsgMeta.SrcAddr)
@ -132,10 +134,10 @@ func requireMatchingEHLO(ctx StatelessCheckContext) module.CheckResult {
}
func init() {
RegisterStatelessCheck("require_matching_rdns", checkFailAction{quarantine: true},
check.RegisterStatelessCheck("require_matching_rdns", check.FailAction{Quarantine: true},
requireMatchingRDNS, nil, nil, nil)
RegisterStatelessCheck("require_mx_record", checkFailAction{quarantine: true},
check.RegisterStatelessCheck("require_mx_record", check.FailAction{Quarantine: true},
nil, requireMXRecord, nil, nil)
RegisterStatelessCheck("require_matching_ehlo", checkFailAction{quarantine: true},
check.RegisterStatelessCheck("require_matching_ehlo", check.FailAction{Quarantine: true},
requireMatchingEHLO, nil, nil, nil)
}

View file

@ -1,21 +1,22 @@
package maddy
package check
import (
"context"
"net"
"time"
"github.com/emersion/go-message/textproto"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/dns"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
"github.com/foxcpp/maddy/target"
)
type (
StatelessCheckContext struct {
// Resolver that should be used by the check for DNS queries.
Resolver Resolver
Resolver dns.Resolver
MsgMeta *module.MsgMetadata
@ -37,13 +38,13 @@ type (
type statelessCheck struct {
modName string
instName string
resolver Resolver
resolver dns.Resolver
logger log.Logger
// One used by Init if config option is not passed by a user.
defaultFailAction checkFailAction
defaultFailAction FailAction
// The actual fail action that should be applied.
failAction checkFailAction
failAction FailAction
okScore int
connCheck FuncConnCheck
@ -57,21 +58,6 @@ type statelessCheckState struct {
msgMeta *module.MsgMetadata
}
func deliveryLogger(l log.Logger, msgMeta *module.MsgMetadata) log.Logger {
out := l.Out
if out == nil {
out = log.DefaultLogger.Out
}
return log.Logger{
Out: func(t time.Time, debug bool, str string) {
out(t, debug, str+" (msg ID = "+msgMeta.ID+")")
},
Name: l.Name,
Debug: l.Debug,
}
}
func (s statelessCheckState) CheckConnection(ctx context.Context) module.CheckResult {
if s.c.connCheck == nil {
return module.CheckResult{}
@ -81,9 +67,9 @@ func (s statelessCheckState) CheckConnection(ctx context.Context) module.CheckRe
Resolver: s.c.resolver,
MsgMeta: s.msgMeta,
CancelCtx: ctx,
Logger: deliveryLogger(s.c.logger, s.msgMeta),
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
})
return s.c.failAction.apply(originalRes)
return s.c.failAction.Apply(originalRes)
}
func (s statelessCheckState) CheckSender(ctx context.Context, mailFrom string) module.CheckResult {
@ -95,9 +81,9 @@ func (s statelessCheckState) CheckSender(ctx context.Context, mailFrom string) m
Resolver: s.c.resolver,
MsgMeta: s.msgMeta,
CancelCtx: ctx,
Logger: deliveryLogger(s.c.logger, s.msgMeta),
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
}, mailFrom)
return s.c.failAction.apply(originalRes)
return s.c.failAction.Apply(originalRes)
}
func (s statelessCheckState) CheckRcpt(ctx context.Context, rcptTo string) module.CheckResult {
@ -109,9 +95,9 @@ func (s statelessCheckState) CheckRcpt(ctx context.Context, rcptTo string) modul
Resolver: s.c.resolver,
MsgMeta: s.msgMeta,
CancelCtx: ctx,
Logger: deliveryLogger(s.c.logger, s.msgMeta),
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
}, rcptTo)
return s.c.failAction.apply(originalRes)
return s.c.failAction.Apply(originalRes)
}
func (s statelessCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
@ -123,9 +109,9 @@ func (s statelessCheckState) CheckBody(ctx context.Context, header textproto.Hea
Resolver: s.c.resolver,
MsgMeta: s.msgMeta,
CancelCtx: ctx,
Logger: deliveryLogger(s.c.logger, s.msgMeta),
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
}, header, body)
return s.c.failAction.apply(originalRes)
return s.c.failAction.Apply(originalRes)
}
func (s statelessCheckState) Close() error {
@ -144,7 +130,7 @@ func (c *statelessCheck) Init(m *config.Map) error {
m.Custom("fail_action", false, false,
func() (interface{}, error) {
return c.defaultFailAction, nil
}, checkFailActionDirective, &c.failAction)
}, failActionDirective, &c.failAction)
_, err := m.Process()
return err
}
@ -165,10 +151,10 @@ func (c *statelessCheck) InstanceName() string {
//
// Note about CheckResult that is returned by the functions:
// StatelessCheck supports different action types based on the user configuration, but the particular check
// code doesn't need to know about it. It should assume that it is always "reject" and hence it should
// code doesn't need to know about it. It should assume that it is always "Reject" and hence it should
// populate RejectErr field of the result object with the relevant error description. Fields ScoreAdjust and
// Quarantine will be ignored.
func RegisterStatelessCheck(name string, defaultFailAction checkFailAction, connCheck FuncConnCheck, senderCheck FuncSenderCheck, rcptCheck FuncRcptCheck, bodyCheck FuncBodyCheck) {
func RegisterStatelessCheck(name string, defaultFailAction FailAction, connCheck FuncConnCheck, senderCheck FuncSenderCheck, rcptCheck FuncRcptCheck, bodyCheck FuncBodyCheck) {
module.Register(name, func(modName, instName string, aliases []string) (module.Module, error) {
return &statelessCheck{
modName: modName,

View file

@ -1,65 +0,0 @@
package maddy
import (
"context"
"github.com/emersion/go-message/textproto"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/module"
)
type testCheck struct {
connRes module.CheckResult
senderRes module.CheckResult
rcptRes module.CheckResult
bodyRes module.CheckResult
}
func (tc *testCheck) NewMessage(msgMeta *module.MsgMetadata) (module.CheckState, error) {
return &testCheckState{msgMeta, tc}, nil
}
func (tc *testCheck) Init(*config.Map) error {
return nil
}
func (tc *testCheck) Name() string {
return "test_check"
}
func (tc *testCheck) InstanceName() string {
return "test_check"
}
type testCheckState struct {
msgMeta *module.MsgMetadata
check *testCheck
}
func (tcs *testCheckState) CheckConnection(ctx context.Context) module.CheckResult {
return tcs.check.connRes
}
func (tcs *testCheckState) CheckSender(ctx context.Context, from string) module.CheckResult {
return tcs.check.senderRes
}
func (tcs *testCheckState) CheckRcpt(ctx context.Context, to string) module.CheckResult {
return tcs.check.rcptRes
}
func (tcs *testCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
return tcs.check.bodyRes
}
func (tcs *testCheckState) Close() error {
return nil
}
func init() {
module.Register("test_check", func(modName, instanceName string, aliases []string) (module.Module, error) {
return &testCheck{}, nil
})
module.RegisterInstance(&testCheck{}, nil)
}

View file

@ -5,7 +5,7 @@ import (
"fmt"
"os"
"github.com/foxcpp/maddy/shadow"
"github.com/foxcpp/maddy/auth/shadow"
)
func main() {

View file

@ -14,7 +14,7 @@ import (
)
func main() {
configPath := flag.String("config", filepath.Join(maddy.ConfigDirectory(), "maddy.conf"), "path to configuration file")
configPath := flag.String("config", filepath.Join(config.ConfigDirectory(), "maddy.conf"), "path to configuration file")
debugFlag := flag.Bool("debug", false, "enable debug logging early")
profileEndpoint := flag.String("debug.pprof", "", "enable live profiler HTTP endpoint and listen on the specified endpoint")
blockProfileRate := flag.Int("debug.blockprofrate", 0, "set blocking profile rate")

193
config.go
View file

@ -4,208 +4,15 @@ import (
"errors"
"fmt"
"os"
"reflect"
"strconv"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
)
/*
Config matchers for module interfaces.
*/
// createInlineModule is a helper function for config matchers that can create inline modules.
func createInlineModule(modName, instName string, aliases []string) (module.Module, error) {
newMod := module.Get(modName)
if newMod == nil {
return nil, fmt.Errorf("unknown module: %s", modName)
}
log.Debugln("module create", modName, instName, "(inline)")
return newMod(modName, instName, aliases)
}
// initInlineModule constructs "faked" config tree and passes it to module
// Init function to make it look like it is defined at top-level.
//
// args must contain at least one argument, otherwise initInlineModule panics.
func initInlineModule(modObj module.Module, globals map[string]interface{}, block *config.Node) error {
log.Debugln("module init", modObj.Name(), modObj.InstanceName(), "(inline)")
return modObj.Init(config.NewMap(globals, block))
}
// moduleFromNode does all work to create or get existing module object with a certain type.
// It is not used by top-level module definitions, only for references from other
// modules configuration blocks.
//
// inlineCfg should contain configuration directives for inline declarations.
// args should contain values that are used to create module.
// It should be either module name + instance name or just module name. Further extensions
// may add other string arguments (currently, they can be accessed by module instances
// as aliases argument to constructor).
//
// It checks using reflection whether it is possible to store a module object into modObj
// pointer (e.g. it implements all necessary interfaces) and stores it if everything is fine.
// If module object doesn't implement necessary module interfaces - error is returned.
// If modObj is not a pointer, moduleFromNode panics.
func moduleFromNode(args []string, inlineCfg *config.Node, globals map[string]interface{}, moduleIface interface{}) error {
// single argument
// - instance name of an existing module
// single argument + block
// - module name, inline definition
// two+ arguments + block
// - module name and instance name, inline definition
// two+ arguments, no block
// - module name and instance name, inline definition, empty config block
if len(args) == 0 {
return config.NodeErr(inlineCfg, "at least one argument is required")
}
var modObj module.Module
var err error
if inlineCfg.Children != nil || len(args) > 1 {
modName := args[0]
modAliases := args[1:]
instName := ""
if len(args) >= 2 {
modAliases = args[2:]
instName = args[1]
}
modObj, err = createInlineModule(modName, instName, modAliases)
} else {
if len(args) != 1 {
return config.NodeErr(inlineCfg, "exactly one argument is to use existing config block")
}
modObj, err = module.GetInstance(args[0])
}
if err != nil {
return config.NodeErr(inlineCfg, "%v", err)
}
// NOTE: This will panic if moduleIface is not a pointer.
modIfaceType := reflect.TypeOf(moduleIface).Elem()
modObjType := reflect.TypeOf(modObj)
if !modObjType.Implements(modIfaceType) && !modObjType.AssignableTo(modIfaceType) {
return config.NodeErr(inlineCfg, "module %s (%s) doesn't implement %v interface", modObj.Name(), modObj.InstanceName(), modIfaceType)
}
reflect.ValueOf(moduleIface).Elem().Set(reflect.ValueOf(modObj))
if inlineCfg.Children != nil {
if err := initInlineModule(modObj, globals, inlineCfg); err != nil {
return err
}
}
return nil
}
// deliveryDirective is a callback for use in config.Map.Custom.
//
// It does all work necessary to create a module instance from the config
// directive with the following structure:
// directive_name mod_name [inst_name] [{
// inline_mod_config
// }]
//
// Note that if used configuration structure lacks directive_name before mod_name - this function
// should not be used (call deliveryTarget directly).
func deliveryDirective(m *config.Map, node *config.Node) (interface{}, error) {
return deliveryTarget(m.Globals, node.Args, node)
}
func deliveryTarget(globals map[string]interface{}, args []string, block *config.Node) (module.DeliveryTarget, error) {
var target module.DeliveryTarget
if err := moduleFromNode(args, block, globals, &target); err != nil {
return nil, err
}
return target, nil
}
func messageCheck(globals map[string]interface{}, args []string, block *config.Node) (module.Check, error) {
var check module.Check
if err := moduleFromNode(args, block, globals, &check); err != nil {
return nil, err
}
return check, nil
}
func authDirective(m *config.Map, node *config.Node) (interface{}, error) {
var provider module.AuthProvider
if err := moduleFromNode(node.Args, node, m.Globals, &provider); err != nil {
return nil, err
}
return provider, nil
}
func storageDirective(m *config.Map, node *config.Node) (interface{}, error) {
var backend module.Storage
if err := moduleFromNode(node.Args, node, m.Globals, &backend); err != nil {
return nil, err
}
return backend, nil
}
type checkFailAction struct {
quarantine bool
reject bool
scoreAdjust int
}
func checkFailActionDirective(m *config.Map, node *config.Node) (interface{}, error) {
if len(node.Args) == 0 {
return nil, m.MatchErr("expected at least 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
switch node.Args[0] {
case "reject", "quarantine", "ignore":
if len(node.Args) > 1 {
return nil, m.MatchErr("too many arguments")
}
return checkFailAction{
reject: node.Args[0] == "reject",
quarantine: node.Args[0] == "quarantine",
}, nil
case "score":
if len(node.Args) != 2 {
return nil, m.MatchErr("expected 2 arguments")
}
scoreAdj, err := strconv.Atoi(node.Args[1])
if err != nil {
return nil, m.MatchErr("%v", err)
}
return checkFailAction{
scoreAdjust: scoreAdj,
}, nil
default:
return nil, m.MatchErr("invalid action")
}
}
// apply merges the result of check execution with action configuration specified
// in the check configuration.
func (cfa checkFailAction) apply(originalRes module.CheckResult) module.CheckResult {
if originalRes.RejectErr == nil {
return originalRes
}
originalRes.Quarantine = cfa.quarantine
originalRes.ScoreAdjust = int32(cfa.scoreAdjust)
if !cfa.reject {
originalRes.RejectErr = nil
}
return originalRes
}
func logOutput(m *config.Map, node *config.Node) (interface{}, error) {
if len(node.Args) == 0 {
return nil, m.MatchErr("expected at least 1 argument")

View file

@ -1,4 +1,4 @@
package maddy
package config
import (
"fmt"
@ -86,9 +86,9 @@ func (a Address) IsTLS() bool {
return a.Scheme == "imaps" || a.Scheme == "smtps"
}
// standardizeAddress parses an address string into a structured format with separate
// StandardizeAddress parses an address string into a structured format with separate
// scheme, host, port, and path portions, as well as the original input string.
func standardizeAddress(str string) (Address, error) {
func StandardizeAddress(str string) (Address, error) {
input := str
// Split input into components (prepend with // to assert host by default)

View file

@ -1,4 +1,4 @@
package maddy
package config
import (
"net"
@ -13,7 +13,7 @@ func TestStandardizeAddress(t *testing.T) {
{Original: "smtp://0.0.0.0:10025", Scheme: "smtp", Host: "0.0.0.0", Port: "10025"},
{Original: "smtp://[::]:10025", Scheme: "smtp", Host: "::", Port: "10025"},
} {
actual, err := standardizeAddress(expected.Original)
actual, err := StandardizeAddress(expected.Original)
if err != nil {
t.Error("Unexpected failure:", err)
return

View file

@ -1,4 +1,4 @@
package maddy
package config
import (
"os"

14
config/module/auth.go Normal file
View file

@ -0,0 +1,14 @@
package modconfig
import (
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/module"
)
func AuthDirective(m *config.Map, node *config.Node) (interface{}, error) {
var provider module.AuthProvider
if err := ModuleFromNode(node.Args, node, m.Globals, &provider); err != nil {
return nil, err
}
return provider, nil
}

14
config/module/check.go Normal file
View file

@ -0,0 +1,14 @@
package modconfig
import (
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/module"
)
func MessageCheck(globals map[string]interface{}, args []string, block *config.Node) (module.Check, error) {
var check module.Check
if err := ModuleFromNode(args, block, globals, &check); err != nil {
return nil, err
}
return check, nil
}

28
config/module/delivery.go Normal file
View file

@ -0,0 +1,28 @@
package modconfig
import (
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/module"
)
// deliveryDirective is a callback for use in config.Map.Custom.
//
// It does all work necessary to create a module instance from the config
// directive with the following structure:
// directive_name mod_name [inst_name] [{
// inline_mod_config
// }]
//
// Note that if used configuration structure lacks directive_name before mod_name - this function
// should not be used (call DeliveryTarget directly).
func DeliveryDirective(m *config.Map, node *config.Node) (interface{}, error) {
return DeliveryTarget(m.Globals, node.Args, node)
}
func DeliveryTarget(globals map[string]interface{}, args []string, block *config.Node) (module.DeliveryTarget, error) {
var target module.DeliveryTarget
if err := ModuleFromNode(args, block, globals, &target); err != nil {
return nil, err
}
return target, nil
}

108
config/module/modconfig.go Normal file
View file

@ -0,0 +1,108 @@
// Package modconfig provides matchers for config.Map that query
// modules registry and parse inline module definitions.
//
// They should be used instead of manual querying when there is need to
// reference a module instance in the configuration.
//
// See ModuleFromNode documentation for explanation of what is 'args'
// for some functions (DeliveryTarget).
package modconfig
import (
"fmt"
"reflect"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
)
// createInlineModule is a helper function for config matchers that can create inline modules.
func createInlineModule(modName, instName string, aliases []string) (module.Module, error) {
newMod := module.Get(modName)
if newMod == nil {
return nil, fmt.Errorf("unknown module: %s", modName)
}
log.Debugln("module create", modName, instName, "(inline)")
return newMod(modName, instName, aliases)
}
// initInlineModule constructs "faked" config tree and passes it to module
// Init function to make it look like it is defined at top-level.
//
// args must contain at least one argument, otherwise initInlineModule panics.
func initInlineModule(modObj module.Module, globals map[string]interface{}, block *config.Node) error {
log.Debugln("module init", modObj.Name(), modObj.InstanceName(), "(inline)")
return modObj.Init(config.NewMap(globals, block))
}
// ModuleFromNode does all work to create or get existing module object with a certain type.
// It is not used by top-level module definitions, only for references from other
// modules configuration blocks.
//
// inlineCfg should contain configuration directives for inline declarations.
// args should contain values that are used to create module.
// It should be either module name + instance name or just module name. Further extensions
// may add other string arguments (currently, they can be accessed by module instances
// as aliases argument to constructor).
//
// It checks using reflection whether it is possible to store a module object into modObj
// pointer (e.g. it implements all necessary interfaces) and stores it if everything is fine.
// If module object doesn't implement necessary module interfaces - error is returned.
// If modObj is not a pointer, ModuleFromNode panics.
func ModuleFromNode(args []string, inlineCfg *config.Node, globals map[string]interface{}, moduleIface interface{}) error {
// single argument
// - instance name of an existing module
// single argument + block
// - module name, inline definition
// two+ arguments + block
// - module name and instance name, inline definition
// two+ arguments, no block
// - module name and instance name, inline definition, empty config block
if len(args) == 0 {
return config.NodeErr(inlineCfg, "at least one argument is required")
}
var modObj module.Module
var err error
if inlineCfg.Children != nil || len(args) > 1 {
modName := args[0]
modAliases := args[1:]
instName := ""
if len(args) >= 2 {
modAliases = args[2:]
instName = args[1]
}
modObj, err = createInlineModule(modName, instName, modAliases)
} else {
if len(args) != 1 {
return config.NodeErr(inlineCfg, "exactly one argument is to use existing config block")
}
modObj, err = module.GetInstance(args[0])
}
if err != nil {
return config.NodeErr(inlineCfg, "%v", err)
}
// NOTE: This will panic if moduleIface is not a pointer.
modIfaceType := reflect.TypeOf(moduleIface).Elem()
modObjType := reflect.TypeOf(modObj)
if !modObjType.Implements(modIfaceType) && !modObjType.AssignableTo(modIfaceType) {
return config.NodeErr(inlineCfg, "module %s (%s) doesn't implement %v interface", modObj.Name(), modObj.InstanceName(), modIfaceType)
}
reflect.ValueOf(moduleIface).Elem().Set(reflect.ValueOf(modObj))
if inlineCfg.Children != nil {
if err := initInlineModule(modObj, globals, inlineCfg); err != nil {
return err
}
}
return nil
}

14
config/module/storage.go Normal file
View file

@ -0,0 +1,14 @@
package modconfig
import (
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/module"
)
func StorageDirective(m *config.Map, node *config.Node) (interface{}, error) {
var backend module.Storage
if err := ModuleFromNode(node.Args, node, m.Globals, &backend); err != nil {
return nil, err
}
return backend, nil
}

View file

@ -1,23 +1,25 @@
package maddy
package dispatcher
import (
"testing"
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-msgauth/authres"
"github.com/foxcpp/maddy/check"
"github.com/foxcpp/maddy/module"
"github.com/foxcpp/maddy/testutils"
)
func TestDispatcher_NoScoresChecked(t *testing.T) {
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
}, testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}, testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -26,31 +28,31 @@ func TestDispatcher_NoScoresChecked(t *testing.T) {
},
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
// No rejectScore or quarantineScore.
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
}
if target.messages[0].msgMeta.Quarantine {
if target.Messages[0].MsgMeta.Quarantine {
t.Fatalf("message is quarantined when it shouldn't")
}
}
func TestDispatcher_RejectScore(t *testing.T) {
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
}, testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}, testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}
rejectScore := 10
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -60,30 +62,30 @@ func TestDispatcher_RejectScore(t *testing.T) {
},
rejectScore: &rejectScore,
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
// Should be rejected.
if _, err := doTestDeliveryErr(t, &d, "whatever@whatever", []string{"whatever@whatever"}); err == nil {
if _, err := testutils.DoTestDeliveryErr(t, &d, "whatever@whatever", []string{"whatever@whatever"}); err == nil {
t.Fatalf("expected an error")
}
if len(target.messages) != 0 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 0, len(target.messages))
if len(target.Messages) != 0 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 0, len(target.Messages))
}
}
func TestDispatcher_RejectScore_notEnough(t *testing.T) {
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
}, testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}, testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}
rejectScore := 15
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -93,29 +95,29 @@ func TestDispatcher_RejectScore_notEnough(t *testing.T) {
},
rejectScore: &rejectScore,
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
}
if target.messages[0].msgMeta.Quarantine {
if target.Messages[0].MsgMeta.Quarantine {
t.Fatalf("message is quarantined when it shouldn't")
}
}
func TestDispatcher_Quarantine(t *testing.T) {
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{},
}, testCheck{
bodyRes: module.CheckResult{Quarantine: true},
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{},
}, testutils.Check{
BodyRes: module.CheckResult{Quarantine: true},
}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -124,31 +126,31 @@ func TestDispatcher_Quarantine(t *testing.T) {
},
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
// Should be quarantined.
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
}
if !target.messages[0].msgMeta.Quarantine {
if !target.Messages[0].MsgMeta.Quarantine {
t.Fatalf("message is not quarantined when it should")
}
}
func TestDispatcher_QuarantineScore(t *testing.T) {
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
}, testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}, testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}
quarantineScore := 10
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -158,31 +160,31 @@ func TestDispatcher_QuarantineScore(t *testing.T) {
},
quarantineScore: &quarantineScore,
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
// Should be quarantined.
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
}
if !target.messages[0].msgMeta.Quarantine {
if !target.Messages[0].MsgMeta.Quarantine {
t.Fatalf("message is not quarantined when it should")
}
}
func TestDispatcher_QuarantineScore_notEnough(t *testing.T) {
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
}, testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}, testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}
quarantineScore := 15
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -192,32 +194,32 @@ func TestDispatcher_QuarantineScore_notEnough(t *testing.T) {
},
quarantineScore: &quarantineScore,
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
// Should be quarantined.
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
}
if target.messages[0].msgMeta.Quarantine {
if target.Messages[0].MsgMeta.Quarantine {
t.Fatalf("message is quarantined when it shouldn't")
}
}
func TestDispatcher_BothScores_Quarantined(t *testing.T) {
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
}, testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}, testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}
quarantineScore := 10
rejectScore := 15
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -228,32 +230,32 @@ func TestDispatcher_BothScores_Quarantined(t *testing.T) {
quarantineScore: &quarantineScore,
rejectScore: &rejectScore,
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
// Should be quarantined.
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
}
if !target.messages[0].msgMeta.Quarantine {
if !target.Messages[0].MsgMeta.Quarantine {
t.Fatalf("message is not quarantined when it should")
}
}
func TestDispatcher_BothScores_Rejected(t *testing.T) {
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
}, testCheck{
bodyRes: module.CheckResult{ScoreAdjust: 5},
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}, testutils.Check{
BodyRes: module.CheckResult{ScoreAdjust: 5},
}
quarantineScore := 5
rejectScore := 10
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -264,23 +266,23 @@ func TestDispatcher_BothScores_Rejected(t *testing.T) {
quarantineScore: &quarantineScore,
rejectScore: &rejectScore,
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
// Should be quarantined.
if _, err := doTestDeliveryErr(t, &d, "whatever@whatever", []string{"whatever@whatever"}); err == nil {
if _, err := testutils.DoTestDeliveryErr(t, &d, "whatever@whatever", []string{"whatever@whatever"}); err == nil {
t.Fatalf("message not rejected")
}
if len(target.messages) != 0 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 0, len(target.messages))
if len(target.Messages) != 0 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 0, len(target.Messages))
}
}
func TestDispatcher_AuthResults(t *testing.T) {
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{
AuthResult: []authres.Result{
&authres.SPFResult{
Value: authres.ResultFail,
@ -289,8 +291,8 @@ func TestDispatcher_AuthResults(t *testing.T) {
},
},
},
}, testCheck{
bodyRes: module.CheckResult{
}, testutils.Check{
BodyRes: module.CheckResult{
AuthResult: []authres.Result{
&authres.SPFResult{
Value: authres.ResultFail,
@ -302,7 +304,7 @@ func TestDispatcher_AuthResults(t *testing.T) {
}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -311,17 +313,17 @@ func TestDispatcher_AuthResults(t *testing.T) {
},
},
},
hostname: "TEST-HOST",
Log: testLogger(t, "dispatcher"),
Hostname: "TEST-HOST",
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
}
authRes := target.messages[0].header.Get("Authentication-Results")
authRes := target.Messages[0].Header.Get("Authentication-Results")
id, parsed, err := authres.Parse(authRes)
if err != nil {
t.Fatalf("failed to parse results")
@ -362,19 +364,19 @@ func TestDispatcher_Headers(t *testing.T) {
hdr2 := textproto.Header{}
hdr2.Add("HDR2", "2")
target := testTarget{}
check1, check2 := testCheck{
bodyRes: module.CheckResult{
target := testutils.Target{}
check1, check2 := testutils.Check{
BodyRes: module.CheckResult{
Header: hdr1,
},
}, testCheck{
bodyRes: module.CheckResult{
}, testutils.Check{
BodyRes: module.CheckResult{
Header: hdr2,
},
}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
perSource: map[string]sourceBlock{},
defaultSource: sourceBlock{
perRcpt: map[string]*rcptBlock{},
@ -383,20 +385,20 @@ func TestDispatcher_Headers(t *testing.T) {
},
},
},
hostname: "TEST-HOST",
Log: testLogger(t, "dispatcher"),
Hostname: "TEST-HOST",
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
}
if target.messages[0].header.Get("HDR1") != "1" {
t.Fatalf("wrong HDR1 value, want %s, got %s", "1", target.messages[0].header.Get("HDR1"))
if target.Messages[0].Header.Get("HDR1") != "1" {
t.Fatalf("wrong HDR1 value, want %s, got %s", "1", target.Messages[0].Header.Get("HDR1"))
}
if target.messages[0].header.Get("HDR2") != "2" {
t.Fatalf("wrong HDR2 value, want %s, got %s", "1", target.messages[0].header.Get("HDR2"))
if target.Messages[0].Header.Get("HDR2") != "2" {
t.Fatalf("wrong HDR2 value, want %s, got %s", "1", target.Messages[0].Header.Get("HDR2"))
}
}

View file

@ -1,4 +1,4 @@
package maddy
package dispatcher
import (
"fmt"
@ -6,11 +6,14 @@ import (
"strings"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/address"
"github.com/foxcpp/maddy/check"
"github.com/foxcpp/maddy/config"
modconfig "github.com/foxcpp/maddy/config/module"
)
type dispatcherCfg struct {
globalChecks CheckGroup
globalChecks check.Group
perSource map[string]sourceBlock
defaultSource sourceBlock
@ -197,7 +200,7 @@ func parseDispatcherRcptCfg(globals map[string]interface{}, nodes []config.Node)
if len(node.Args) == 0 {
return nil, config.NodeErr(&node, "required at least one argument")
}
mod, err := deliveryTarget(globals, node.Args, &node)
mod, err := modconfig.DeliveryTarget(globals, node.Args, &node)
if err != nil {
return nil, err
}
@ -276,19 +279,19 @@ func parseEnhancedCode(s string) (smtp.EnhancedCode, error) {
return code, nil
}
func parseChecksGroup(globals map[string]interface{}, nodes []config.Node) (CheckGroup, error) {
checks := CheckGroup{}
func parseChecksGroup(globals map[string]interface{}, nodes []config.Node) (check.Group, error) {
checks := check.Group{}
for _, child := range nodes {
check, err := messageCheck(globals, append([]string{child.Name}, child.Args...), &child)
msgCheck, err := modconfig.MessageCheck(globals, append([]string{child.Name}, child.Args...), &child)
if err != nil {
return CheckGroup{}, err
return check.Group{}, err
}
checks.checks = append(checks.checks, check)
checks.Checks = append(checks.Checks, msgCheck)
}
return checks, nil
}
func validDispatchRule(rule string) bool {
return validDomain(rule) || validAddress(rule)
return address.ValidDomain(rule) || address.Valid(rule)
}

View file

@ -1,4 +1,4 @@
package maddy
package dispatcher
import (
"reflect"
@ -244,7 +244,7 @@ func TestDispatcherCfg_GlobalChecks(t *testing.T) {
t.Fatalf("unexpected parse error: %v", err)
}
if len(parsed.globalChecks.checks) == 0 {
if len(parsed.globalChecks.Checks) == 0 {
t.Fatalf("missing test_check in globalChecks")
}
}
@ -269,7 +269,7 @@ func TestDispatcherCfg_SourceChecks(t *testing.T) {
t.Fatalf("unexpected parse error: %v", err)
}
if len(parsed.perSource["example.org"].checks.checks) == 0 {
if len(parsed.perSource["example.org"].checks.Checks) == 0 {
t.Fatalf("missing test_check in source checks")
}
}
@ -294,7 +294,7 @@ func TestDispatcherCfg_RcptChecks(t *testing.T) {
t.Fatalf("unexpected parse error: %v", err)
}
if len(parsed.defaultSource.perRcpt["example.org"].checks.checks) == 0 {
if len(parsed.defaultSource.perRcpt["example.org"].checks.Checks) == 0 {
t.Fatalf("missing test_check in rcpt checks")
}
}

View file

@ -1,4 +1,4 @@
package maddy
package dispatcher
import (
"context"
@ -8,10 +8,13 @@ import (
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-msgauth/authres"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/address"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/check"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
"github.com/foxcpp/maddy/target"
)
// Dispatcher is a object that is responsible for selecting delivery targets
@ -23,20 +26,20 @@ import (
// source (Submission, SMTP, JMAP modules) implementation.
type Dispatcher struct {
dispatcherCfg
hostname string
Hostname string
Log log.Logger
}
type sourceBlock struct {
checks CheckGroup
checks check.Group
rejectErr error
perRcpt map[string]*rcptBlock
defaultRcpt *rcptBlock
}
type rcptBlock struct {
checks CheckGroup
checks check.Group
rejectErr error
targets []module.DeliveryTarget
}
@ -48,26 +51,8 @@ func NewDispatcher(globals map[string]interface{}, cfg []config.Node) (*Dispatch
}, err
}
func splitAddress(addr string) (mailbox, domain string, err error) {
parts := strings.Split(addr, "@")
switch len(parts) {
case 1:
if strings.EqualFold(parts[0], "postmaster") {
return parts[0], "", nil
}
return "", "", fmt.Errorf("malformed address")
case 2:
if len(parts[0]) == 0 || len(parts[1]) == 0 {
return "", "", fmt.Errorf("malformed address")
}
return parts[0], parts[1], nil
default:
return "", "", fmt.Errorf("malformed address")
}
}
func (d *Dispatcher) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
dl := deliveryLogger(d.Log, msgMeta)
dl := target.DeliveryLogger(d.Log, msgMeta)
dd := dispatcherDelivery{
d: d,
rcptChecksState: make(map[*rcptBlock]module.CheckState),
@ -101,7 +86,7 @@ func (d *Dispatcher) Start(msgMeta *module.MsgMetadata, mailFrom string) (module
sourceBlock, ok := d.perSource[strings.ToLower(mailFrom)]
if !ok {
// Then try domain-only.
_, domain, err := splitAddress(mailFrom)
_, domain, err := address.Split(mailFrom)
if err != nil {
return nil, &smtp.SMTPError{
Code: 501,
@ -176,7 +161,7 @@ func (dd *dispatcherDelivery) AddRcpt(to string) error {
rcptBlock, ok := dd.sourceBlock.perRcpt[strings.ToLower(to)]
if !ok {
// Then try domain-only.
_, domain, err := splitAddress(to)
_, domain, err := address.Split(to)
if err != nil {
return &smtp.SMTPError{
Code: 501,
@ -268,7 +253,7 @@ func (dd *dispatcherDelivery) Body(header textproto.Header, body buffer.Buffer)
// After results for all checks are checked, authRes will be populated with values
// we should put into Authentication-Results header.
if len(dd.authRes) != 0 {
header.Add("Authentication-Results", authres.Format(dd.d.hostname, dd.authRes))
header.Add("Authentication-Results", authres.Format(dd.d.Hostname, dd.authRes))
}
for field := dd.header.Fields(); field.Next(); {
header.Add(field.Key(), field.Value())

View file

@ -1,4 +1,4 @@
package maddy
package dispatcher
import (
"errors"
@ -7,10 +7,11 @@ import (
"github.com/emersion/go-message/textproto"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/module"
"github.com/foxcpp/maddy/testutils"
)
func TestDispatcher_AllToTarget(t *testing.T) {
target := testTarget{}
target := testutils.Target{}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{},
@ -21,20 +22,20 @@ func TestDispatcher_AllToTarget(t *testing.T) {
},
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
}
checkTestMessage(t, &target, 0, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
testutils.CheckTestMessage(t, &target, 0, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
}
func TestDispatcher_PerSourceDomainSplit(t *testing.T) {
orgTarget, comTarget := testTarget{instName: "orgTarget"}, testTarget{instName: "comTarget"}
orgTarget, comTarget := testutils.Target{InstName: "orgTarget"}, testutils.Target{InstName: "comTarget"}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{
@ -53,25 +54,25 @@ func TestDispatcher_PerSourceDomainSplit(t *testing.T) {
},
defaultSource: sourceBlock{rejectErr: errors.New("default src block used")},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
doTestDelivery(t, &d, "sender@example.org", []string{"rcpt1@example.com", "rcpt2@example.com"})
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
testutils.DoTestDelivery(t, &d, "sender@example.org", []string{"rcpt1@example.com", "rcpt2@example.com"})
if len(comTarget.messages) != 1 {
t.Fatalf("wrong amount of messages received for comTarget, want %d, got %d", 1, len(comTarget.messages))
if len(comTarget.Messages) != 1 {
t.Fatalf("wrong amount of messages received for comTarget, want %d, got %d", 1, len(comTarget.Messages))
}
checkTestMessage(t, &comTarget, 0, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
testutils.CheckTestMessage(t, &comTarget, 0, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
if len(orgTarget.messages) != 1 {
t.Fatalf("wrong amount of messages received for orgTarget, want %d, got %d", 1, len(orgTarget.messages))
if len(orgTarget.Messages) != 1 {
t.Fatalf("wrong amount of messages received for orgTarget, want %d, got %d", 1, len(orgTarget.Messages))
}
checkTestMessage(t, &orgTarget, 0, "sender@example.org", []string{"rcpt1@example.com", "rcpt2@example.com"})
testutils.CheckTestMessage(t, &orgTarget, 0, "sender@example.org", []string{"rcpt1@example.com", "rcpt2@example.com"})
}
func TestDispatcher_PerRcptAddrSplit(t *testing.T) {
target1, target2 := testTarget{instName: "target1"}, testTarget{instName: "target2"}
target1, target2 := testutils.Target{InstName: "target1"}, testutils.Target{InstName: "target2"}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{
@ -90,25 +91,25 @@ func TestDispatcher_PerRcptAddrSplit(t *testing.T) {
},
defaultSource: sourceBlock{rejectErr: errors.New("default src block used")},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
doTestDelivery(t, &d, "sender2@example.com", []string{"rcpt@example.com"})
testutils.DoTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
testutils.DoTestDelivery(t, &d, "sender2@example.com", []string{"rcpt@example.com"})
if len(target1.messages) != 1 {
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 1, len(target1.messages))
if len(target1.Messages) != 1 {
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 1, len(target1.Messages))
}
checkTestMessage(t, &target1, 0, "sender1@example.com", []string{"rcpt@example.com"})
testutils.CheckTestMessage(t, &target1, 0, "sender1@example.com", []string{"rcpt@example.com"})
if len(target2.messages) != 1 {
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 1, len(target2.messages))
if len(target2.Messages) != 1 {
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 1, len(target2.Messages))
}
checkTestMessage(t, &target2, 0, "sender2@example.com", []string{"rcpt@example.com"})
testutils.CheckTestMessage(t, &target2, 0, "sender2@example.com", []string{"rcpt@example.com"})
}
func TestDispatcher_PerRcptDomainSplit(t *testing.T) {
target1, target2 := testTarget{instName: "target1"}, testTarget{instName: "target2"}
target1, target2 := testutils.Target{InstName: "target1"}, testutils.Target{InstName: "target2"}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{},
@ -126,27 +127,27 @@ func TestDispatcher_PerRcptDomainSplit(t *testing.T) {
},
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.org"})
doTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.org", "rcpt2@example.com"})
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.org"})
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.org", "rcpt2@example.com"})
if len(target1.messages) != 2 {
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 2, len(target1.messages))
if len(target1.Messages) != 2 {
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 2, len(target1.Messages))
}
checkTestMessage(t, &target1, 0, "sender@example.com", []string{"rcpt1@example.com"})
checkTestMessage(t, &target1, 1, "sender@example.com", []string{"rcpt2@example.com"})
testutils.CheckTestMessage(t, &target1, 0, "sender@example.com", []string{"rcpt1@example.com"})
testutils.CheckTestMessage(t, &target1, 1, "sender@example.com", []string{"rcpt2@example.com"})
if len(target2.messages) != 2 {
t.Errorf("wrong amount of messages received for target2, want %d, got %d", 2, len(target2.messages))
if len(target2.Messages) != 2 {
t.Errorf("wrong amount of messages received for target2, want %d, got %d", 2, len(target2.Messages))
}
checkTestMessage(t, &target2, 0, "sender@example.com", []string{"rcpt2@example.org"})
checkTestMessage(t, &target2, 1, "sender@example.com", []string{"rcpt1@example.org"})
testutils.CheckTestMessage(t, &target2, 0, "sender@example.com", []string{"rcpt2@example.org"})
testutils.CheckTestMessage(t, &target2, 1, "sender@example.com", []string{"rcpt1@example.org"})
}
func TestDispatcher_PerSourceAddrAndDomainSplit(t *testing.T) {
target1, target2 := testTarget{instName: "target1"}, testTarget{instName: "target2"}
target1, target2 := testutils.Target{InstName: "target1"}, testutils.Target{InstName: "target2"}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{
@ -164,25 +165,25 @@ func TestDispatcher_PerSourceAddrAndDomainSplit(t *testing.T) {
},
defaultSource: sourceBlock{rejectErr: errors.New("default src block used")},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
doTestDelivery(t, &d, "sender2@example.com", []string{"rcpt@example.com"})
testutils.DoTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
testutils.DoTestDelivery(t, &d, "sender2@example.com", []string{"rcpt@example.com"})
if len(target1.messages) != 1 {
t.Fatalf("wrong amount of messages received for target1, want %d, got %d", 1, len(target1.messages))
if len(target1.Messages) != 1 {
t.Fatalf("wrong amount of messages received for target1, want %d, got %d", 1, len(target1.Messages))
}
checkTestMessage(t, &target1, 0, "sender1@example.com", []string{"rcpt@example.com"})
testutils.CheckTestMessage(t, &target1, 0, "sender1@example.com", []string{"rcpt@example.com"})
if len(target2.messages) != 1 {
t.Fatalf("wrong amount of messages received for target2, want %d, got %d", 1, len(target2.messages))
if len(target2.Messages) != 1 {
t.Fatalf("wrong amount of messages received for target2, want %d, got %d", 1, len(target2.Messages))
}
checkTestMessage(t, &target2, 0, "sender2@example.com", []string{"rcpt@example.com"})
testutils.CheckTestMessage(t, &target2, 0, "sender2@example.com", []string{"rcpt@example.com"})
}
func TestDispatcher_PerSourceReject(t *testing.T) {
target := testTarget{}
target := testutils.Target{}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{
@ -198,10 +199,10 @@ func TestDispatcher_PerSourceReject(t *testing.T) {
},
defaultSource: sourceBlock{rejectErr: errors.New("go away")},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
testutils.DoTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
_, err := d.Start(&module.MsgMetadata{ID: "testing"}, "sender2@example.com")
if err == nil {
@ -215,7 +216,7 @@ func TestDispatcher_PerSourceReject(t *testing.T) {
}
func TestDispatcher_PerRcptReject(t *testing.T) {
target := testTarget{}
target := testutils.Target{}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{},
@ -233,7 +234,7 @@ func TestDispatcher_PerRcptReject(t *testing.T) {
},
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
delivery, err := d.Start(&module.MsgMetadata{ID: "testing"}, "sender@example.com")
@ -257,7 +258,7 @@ func TestDispatcher_PerRcptReject(t *testing.T) {
}
func TestDispatcher_PostmasterRcpt(t *testing.T) {
target := testTarget{}
target := testutils.Target{}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{},
@ -275,18 +276,18 @@ func TestDispatcher_PostmasterRcpt(t *testing.T) {
},
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "disappointed-user@example.com", []string{"postmaster"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.messages))
testutils.DoTestDelivery(t, &d, "disappointed-user@example.com", []string{"postmaster"})
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.Messages))
}
checkTestMessage(t, &target, 0, "disappointed-user@example.com", []string{"postmaster"})
testutils.CheckTestMessage(t, &target, 0, "disappointed-user@example.com", []string{"postmaster"})
}
func TestDispatcher_PostmasterSrc(t *testing.T) {
target := testTarget{}
target := testutils.Target{}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{
@ -304,18 +305,18 @@ func TestDispatcher_PostmasterSrc(t *testing.T) {
rejectErr: errors.New("go away"),
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "postmaster", []string{"disappointed-user@example.com"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.messages))
testutils.DoTestDelivery(t, &d, "postmaster", []string{"disappointed-user@example.com"})
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.Messages))
}
checkTestMessage(t, &target, 0, "postmaster", []string{"disappointed-user@example.com"})
testutils.CheckTestMessage(t, &target, 0, "postmaster", []string{"disappointed-user@example.com"})
}
func TestDispatcher_CaseInsensetiveMatch_Src(t *testing.T) {
target := testTarget{}
target := testutils.Target{}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{
@ -342,22 +343,22 @@ func TestDispatcher_CaseInsensetiveMatch_Src(t *testing.T) {
rejectErr: errors.New("go away"),
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "POSTMastER", []string{"disappointed-user@example.com"})
doTestDelivery(t, &d, "SenDeR@EXAMPLE.com", []string{"disappointed-user@example.com"})
doTestDelivery(t, &d, "sender@exAMPle.com", []string{"disappointed-user@example.com"})
if len(target.messages) != 3 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 3, len(target.messages))
testutils.DoTestDelivery(t, &d, "POSTMastER", []string{"disappointed-user@example.com"})
testutils.DoTestDelivery(t, &d, "SenDeR@EXAMPLE.com", []string{"disappointed-user@example.com"})
testutils.DoTestDelivery(t, &d, "sender@exAMPle.com", []string{"disappointed-user@example.com"})
if len(target.Messages) != 3 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 3, len(target.Messages))
}
checkTestMessage(t, &target, 0, "POSTMastER", []string{"disappointed-user@example.com"})
checkTestMessage(t, &target, 1, "SenDeR@EXAMPLE.com", []string{"disappointed-user@example.com"})
checkTestMessage(t, &target, 2, "sender@exAMPle.com", []string{"disappointed-user@example.com"})
testutils.CheckTestMessage(t, &target, 0, "POSTMastER", []string{"disappointed-user@example.com"})
testutils.CheckTestMessage(t, &target, 1, "SenDeR@EXAMPLE.com", []string{"disappointed-user@example.com"})
testutils.CheckTestMessage(t, &target, 2, "sender@exAMPle.com", []string{"disappointed-user@example.com"})
}
func TestDispatcher_CaseInsensetiveMatch_Rcpt(t *testing.T) {
target := testTarget{}
target := testutils.Target{}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{},
@ -375,22 +376,22 @@ func TestDispatcher_CaseInsensetiveMatch_Rcpt(t *testing.T) {
},
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "sender@example.com", []string{"POSTMastER"})
doTestDelivery(t, &d, "sender@example.com", []string{"SenDeR@EXAMPLE.com"})
doTestDelivery(t, &d, "sender@example.com", []string{"sender@exAMPle.com"})
if len(target.messages) != 3 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 3, len(target.messages))
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"POSTMastER"})
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"SenDeR@EXAMPLE.com"})
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"sender@exAMPle.com"})
if len(target.Messages) != 3 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 3, len(target.Messages))
}
checkTestMessage(t, &target, 0, "sender@example.com", []string{"POSTMastER"})
checkTestMessage(t, &target, 1, "sender@example.com", []string{"SenDeR@EXAMPLE.com"})
checkTestMessage(t, &target, 2, "sender@example.com", []string{"sender@exAMPle.com"})
testutils.CheckTestMessage(t, &target, 0, "sender@example.com", []string{"POSTMastER"})
testutils.CheckTestMessage(t, &target, 1, "sender@example.com", []string{"SenDeR@EXAMPLE.com"})
testutils.CheckTestMessage(t, &target, 2, "sender@example.com", []string{"sender@exAMPle.com"})
}
func TestDispatcher_MalformedSource(t *testing.T) {
target := testTarget{}
target := testutils.Target{}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{},
@ -408,7 +409,7 @@ func TestDispatcher_MalformedSource(t *testing.T) {
},
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
// Simple checks for violations that can make dispatcher misbehave.
@ -421,7 +422,7 @@ func TestDispatcher_MalformedSource(t *testing.T) {
}
func TestDispatcher_TwoRcptToOneTarget(t *testing.T) {
target := testTarget{}
target := testutils.Target{}
d := Dispatcher{
dispatcherCfg: dispatcherCfg{
perSource: map[string]sourceBlock{},
@ -436,13 +437,13 @@ func TestDispatcher_TwoRcptToOneTarget(t *testing.T) {
},
},
},
Log: testLogger(t, "dispatcher"),
Log: testutils.Logger(t, "dispatcher"),
}
doTestDelivery(t, &d, "sender@example.com", []string{"recipient@example.com", "recipient@example.org"})
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"recipient@example.com", "recipient@example.org"})
if len(target.messages) != 1 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.messages))
if len(target.Messages) != 1 {
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.Messages))
}
checkTestMessage(t, &target, 0, "sender@example.com", []string{"recipient@example.com", "recipient@example.org"})
testutils.CheckTestMessage(t, &target, 0, "sender@example.com", []string{"recipient@example.com", "recipient@example.org"})
}

16
dispatcher/msgid.go Normal file
View file

@ -0,0 +1,16 @@
package dispatcher
import (
"encoding/hex"
"math/rand"
)
// GenerateMsgID generates a string usable as MsgID field in module.MsgMeta.
//
// TODO: Find a better place for this function. 'dispatcher' package seems
// irrelevant.
func GenerateMsgID() (string, error) {
rawID := make([]byte, 32)
_, err := rand.Read(rawID)
return hex.EncodeToString(rawID), err
}

View file

@ -1,4 +1,10 @@
package maddy
// Package dns defines interfaces used by maddy modules to perform DNS
// lookups.
//
// Currently, there is only Resolver interface which is implemented
// by net.DefaultResolver. In the future, DNSSEC-enabled stub resolver
// implementation will be added here.
package dns
import (
"context"

View file

@ -1,4 +1,4 @@
package maddy
package imap
import (
"crypto/tls"
@ -9,23 +9,24 @@ import (
"sync"
"github.com/emersion/go-imap"
imapbackend "github.com/emersion/go-imap/backend"
imapserver "github.com/emersion/go-imap/server"
"github.com/emersion/go-message"
_ "github.com/emersion/go-message/charset"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
appendlimit "github.com/emersion/go-imap-appendlimit"
compress "github.com/emersion/go-imap-compress"
idle "github.com/emersion/go-imap-idle"
move "github.com/emersion/go-imap-move"
unselect "github.com/emersion/go-imap-unselect"
imapbackend "github.com/emersion/go-imap/backend"
imapserver "github.com/emersion/go-imap/server"
"github.com/emersion/go-message"
_ "github.com/emersion/go-message/charset"
"github.com/foxcpp/go-imap-sql/children"
"github.com/foxcpp/maddy/config"
modconfig "github.com/foxcpp/maddy/config/module"
"github.com/foxcpp/maddy/endpoint"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
)
type IMAPEndpoint struct {
type Endpoint struct {
name string
aliases []string
serv *imapserver.Server
@ -40,8 +41,8 @@ type IMAPEndpoint struct {
Log log.Logger
}
func NewIMAPEndpoint(_, instName string, aliases []string) (module.Module, error) {
endp := &IMAPEndpoint{
func New(_, instName string, aliases []string) (module.Module, error) {
endp := &Endpoint{
name: instName,
aliases: aliases,
Log: log.Logger{Name: "imap"},
@ -51,15 +52,15 @@ func NewIMAPEndpoint(_, instName string, aliases []string) (module.Module, error
return endp, nil
}
func (endp *IMAPEndpoint) Init(cfg *config.Map) error {
func (endp *Endpoint) Init(cfg *config.Map) error {
var (
insecureAuth bool
ioDebug bool
)
cfg.Custom("auth", false, true, nil, authDirective, &endp.Auth)
cfg.Custom("storage", false, true, nil, storageDirective, &endp.Store)
cfg.Custom("tls", true, true, nil, tlsDirective, &endp.tlsConfig)
cfg.Custom("auth", false, true, nil, modconfig.AuthDirective, &endp.Auth)
cfg.Custom("storage", false, true, nil, modconfig.StorageDirective, &endp.Store)
cfg.Custom("tls", true, true, nil, endpoint.TLSDirective, &endp.tlsConfig)
cfg.Bool("insecure_auth", false, &insecureAuth)
cfg.Bool("io_debug", false, &ioDebug)
cfg.Bool("debug", true, &endp.Log.Debug)
@ -80,9 +81,9 @@ func (endp *IMAPEndpoint) Init(cfg *config.Map) error {
}
args := append([]string{endp.name}, endp.aliases...)
addresses := make([]Address, 0, len(args))
addresses := make([]config.Address, 0, len(args))
for _, addr := range args {
saddr, err := standardizeAddress(addr)
saddr, err := config.StandardizeAddress(addr)
if err != nil {
return fmt.Errorf("imap: invalid address: %s", endp.name)
}
@ -144,19 +145,19 @@ func (endp *IMAPEndpoint) Init(cfg *config.Map) error {
return nil
}
func (endp *IMAPEndpoint) Updates() <-chan imapbackend.Update {
func (endp *Endpoint) Updates() <-chan imapbackend.Update {
return endp.updater.Updates()
}
func (endp *IMAPEndpoint) Name() string {
func (endp *Endpoint) Name() string {
return "imap"
}
func (endp *IMAPEndpoint) InstanceName() string {
func (endp *Endpoint) InstanceName() string {
return endp.name
}
func (endp *IMAPEndpoint) Close() error {
func (endp *Endpoint) Close() error {
for _, l := range endp.listeners {
l.Close()
}
@ -167,7 +168,7 @@ func (endp *IMAPEndpoint) Close() error {
return nil
}
func (endp *IMAPEndpoint) Login(connInfo *imap.ConnInfo, username, password string) (imapbackend.User, error) {
func (endp *Endpoint) Login(connInfo *imap.ConnInfo, username, password string) (imapbackend.User, error) {
if !endp.Auth.CheckPlain(username, password) {
endp.Log.Printf("authentication failed for %s (from %v)", username, connInfo.RemoteAddr)
return nil, imapbackend.ErrInvalidCredentials
@ -176,11 +177,11 @@ func (endp *IMAPEndpoint) Login(connInfo *imap.ConnInfo, username, password stri
return endp.Store.GetOrCreateUser(username)
}
func (endp *IMAPEndpoint) EnableChildrenExt() bool {
func (endp *Endpoint) EnableChildrenExt() bool {
return endp.Store.(children.Backend).EnableChildrenExt()
}
func (endp *IMAPEndpoint) enableExtensions() error {
func (endp *Endpoint) enableExtensions() error {
exts := endp.Store.IMAPExtensions()
for _, ext := range exts {
switch ext {
@ -201,7 +202,7 @@ func (endp *IMAPEndpoint) enableExtensions() error {
}
func init() {
module.Register("imap", NewIMAPEndpoint)
module.Register("imap", New)
imap.CharsetReader = message.CharsetReader
}

View file

@ -1,4 +1,4 @@
package maddy
package smtp
import (
"bufio"
@ -11,13 +11,15 @@ import (
"sync"
"time"
"encoding/hex"
"math/rand"
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/config"
modconfig "github.com/foxcpp/maddy/config/module"
"github.com/foxcpp/maddy/dispatcher"
"github.com/foxcpp/maddy/endpoint"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
)
@ -38,8 +40,8 @@ func MsgMetaLog(l log.Logger, msgMeta *module.MsgMetadata) log.Logger {
}
}
type SMTPSession struct {
endp *SMTPEndpoint
type Session struct {
endp *Endpoint
delivery module.Delivery
msgMeta *module.MsgMetadata
log log.Logger
@ -51,7 +53,7 @@ var errInternal = &smtp.SMTPError{
Message: "Internal server error",
}
func (s *SMTPSession) Reset() {
func (s *Session) Reset() {
if s.delivery != nil {
if err := s.delivery.Abort(); err != nil {
s.endp.Log.Printf("failed to abort delivery: %v", err)
@ -60,15 +62,9 @@ func (s *SMTPSession) Reset() {
}
}
func generateMsgID() (string, error) {
rawID := make([]byte, 32)
_, err := rand.Read(rawID)
return hex.EncodeToString(rawID), err
}
func (s *SMTPSession) Mail(from string) error {
func (s *Session) Mail(from string) error {
var err error
s.msgMeta.ID, err = generateMsgID()
s.msgMeta.ID, err = dispatcher.GenerateMsgID()
if err != nil {
s.endp.Log.Printf("rand.Rand error: %v", err)
return s.wrapErr(errInternal)
@ -86,7 +82,7 @@ func (s *SMTPSession) Mail(from string) error {
return nil
}
func (s *SMTPSession) Rcpt(to string) error {
func (s *Session) Rcpt(to string) error {
err := s.delivery.AddRcpt(to)
if err != nil {
s.log.Printf("recipient rejected: %v, RCPT TO = %s", err, to)
@ -94,7 +90,7 @@ func (s *SMTPSession) Rcpt(to string) error {
return s.wrapErr(err)
}
func (s *SMTPSession) Logout() error {
func (s *Session) Logout() error {
if s.delivery != nil {
if err := s.delivery.Abort(); err != nil {
s.endp.Log.Printf("failed to abort delivery: %v", err)
@ -104,7 +100,7 @@ func (s *SMTPSession) Logout() error {
return nil
}
func (s *SMTPSession) Data(r io.Reader) error {
func (s *Session) Data(r io.Reader) error {
bufr := bufio.NewReader(r)
header, err := textproto.ReadHeader(bufr)
if err != nil {
@ -143,7 +139,7 @@ func (s *SMTPSession) Data(r io.Reader) error {
return nil
}
func (s *SMTPSession) wrapErr(err error) error {
func (s *Session) wrapErr(err error) error {
if err == nil {
return nil
}
@ -158,13 +154,13 @@ func (s *SMTPSession) wrapErr(err error) error {
return fmt.Errorf("%v (msg ID = %s)", err, s.msgMeta.ID)
}
type SMTPEndpoint struct {
type Endpoint struct {
Auth module.AuthProvider
serv *smtp.Server
name string
aliases []string
listeners []net.Listener
dispatcher *Dispatcher
dispatcher *dispatcher.Dispatcher
authAlwaysRequired bool
@ -175,16 +171,16 @@ type SMTPEndpoint struct {
Log log.Logger
}
func (endp *SMTPEndpoint) Name() string {
func (endp *Endpoint) Name() string {
return "smtp"
}
func (endp *SMTPEndpoint) InstanceName() string {
func (endp *Endpoint) InstanceName() string {
return endp.name
}
func NewSMTPEndpoint(modName, instName string, aliases []string) (module.Module, error) {
endp := &SMTPEndpoint{
func New(modName, instName string, aliases []string) (module.Module, error) {
endp := &Endpoint{
name: instName,
aliases: aliases,
submission: modName == "submission",
@ -193,7 +189,7 @@ func NewSMTPEndpoint(modName, instName string, aliases []string) (module.Module,
return endp, nil
}
func (endp *SMTPEndpoint) Init(cfg *config.Map) error {
func (endp *Endpoint) Init(cfg *config.Map) error {
endp.serv = smtp.NewServer(endp)
if err := endp.setConfig(cfg); err != nil {
return err
@ -204,9 +200,9 @@ func (endp *SMTPEndpoint) Init(cfg *config.Map) error {
}
args := append([]string{endp.name}, endp.aliases...)
addresses := make([]Address, 0, len(args))
addresses := make([]config.Address, 0, len(args))
for _, addr := range args {
saddr, err := standardizeAddress(addr)
saddr, err := config.StandardizeAddress(addr)
if err != nil {
return fmt.Errorf("smtp: invalid address: %s", addr)
}
@ -232,7 +228,7 @@ func (endp *SMTPEndpoint) Init(cfg *config.Map) error {
return nil
}
func (endp *SMTPEndpoint) setConfig(cfg *config.Map) error {
func (endp *Endpoint) setConfig(cfg *config.Map) error {
var (
err error
ioDebug bool
@ -242,14 +238,14 @@ func (endp *SMTPEndpoint) setConfig(cfg *config.Map) error {
readTimeoutSecs uint
)
cfg.Custom("auth", false, false, nil, authDirective, &endp.Auth)
cfg.Custom("auth", false, false, nil, modconfig.AuthDirective, &endp.Auth)
cfg.String("hostname", true, true, "", &endp.serv.Domain)
// TODO: Parse human-readable duration values.
cfg.UInt("write_timeout", false, false, 60, &writeTimeoutSecs)
cfg.UInt("read_timeout", false, false, 600, &readTimeoutSecs)
cfg.Int("max_message_size", false, false, 32*1024*1024, &endp.serv.MaxMessageBytes)
cfg.Int("max_recipients", false, false, 20000, &endp.serv.MaxRecipients)
cfg.Custom("tls", true, true, nil, tlsDirective, &endp.serv.TLSConfig)
cfg.Custom("tls", true, true, nil, endpoint.TLSDirective, &endp.serv.TLSConfig)
cfg.Bool("insecure_auth", false, &endp.serv.AllowInsecureAuth)
cfg.Bool("io_debug", false, &ioDebug)
cfg.Bool("debug", true, &endp.Log.Debug)
@ -259,14 +255,14 @@ func (endp *SMTPEndpoint) setConfig(cfg *config.Map) error {
if err != nil {
return err
}
endp.dispatcher, err = NewDispatcher(cfg.Globals, unmatched)
endp.dispatcher, err = dispatcher.NewDispatcher(cfg.Globals, unmatched)
if err != nil {
return err
}
endp.dispatcher.hostname = endp.serv.Domain
endp.dispatcher.Hostname = endp.serv.Domain
endp.dispatcher.Log = log.Logger{Name: "smtp/dispatcher", Debug: endp.Log.Debug}
// endp.submission can be set to true by NewSMTPEndpoint, leave it on
// endp.submission can be set to true by New, leave it on
// even if directive is missing.
if submission {
endp.submission = true
@ -291,7 +287,7 @@ func (endp *SMTPEndpoint) setConfig(cfg *config.Map) error {
return nil
}
func (endp *SMTPEndpoint) setupListeners(addresses []Address) error {
func (endp *Endpoint) setupListeners(addresses []config.Address) error {
var smtpUsed, lmtpUsed bool
for _, addr := range addresses {
if addr.Scheme == "smtp" || addr.Scheme == "smtps" {
@ -341,7 +337,7 @@ func (endp *SMTPEndpoint) setupListeners(addresses []Address) error {
return nil
}
func (endp *SMTPEndpoint) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
func (endp *Endpoint) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
if endp.Auth == nil {
return nil, smtp.ErrAuthUnsupported
}
@ -354,7 +350,7 @@ func (endp *SMTPEndpoint) Login(state *smtp.ConnectionState, username, password
return endp.newSession(false, username, password, state), nil
}
func (endp *SMTPEndpoint) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
func (endp *Endpoint) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
if endp.authAlwaysRequired {
return nil, smtp.ErrAuthRequired
}
@ -362,7 +358,7 @@ func (endp *SMTPEndpoint) AnonymousLogin(state *smtp.ConnectionState) (smtp.Sess
return endp.newSession(true, "", "", state), nil
}
func (endp *SMTPEndpoint) newSession(anonymous bool, username, password string, state *smtp.ConnectionState) smtp.Session {
func (endp *Endpoint) newSession(anonymous bool, username, password string, state *smtp.ConnectionState) smtp.Session {
ctx := &module.MsgMetadata{
Anonymous: anonymous,
AuthUser: username,
@ -385,18 +381,14 @@ func (endp *SMTPEndpoint) newSession(anonymous bool, username, password string,
}
}
return &SMTPSession{
return &Session{
endp: endp,
msgMeta: ctx,
log: MsgMetaLog(endp.Log, ctx),
}
}
func sanitizeString(raw string) string {
return strings.Replace(raw, "\n", "", -1)
}
func (endp *SMTPEndpoint) Close() error {
func (endp *Endpoint) Close() error {
for _, l := range endp.listeners {
l.Close()
}
@ -406,8 +398,8 @@ func (endp *SMTPEndpoint) Close() error {
}
func init() {
module.Register("smtp", NewSMTPEndpoint)
module.Register("submission", NewSMTPEndpoint)
module.Register("smtp", New)
module.Register("submission", New)
rand.Seed(time.Now().UnixNano())
}

View file

@ -1,4 +1,4 @@
package maddy
package smtp
import (
"errors"

View file

@ -1,4 +1,4 @@
package maddy
package endpoint
import (
"crypto/ecdsa"
@ -16,7 +16,7 @@ import (
"github.com/foxcpp/maddy/log"
)
func tlsDirective(m *config.Map, node *config.Node) (interface{}, error) {
func TLSDirective(m *config.Map, node *config.Node) (interface{}, error) {
switch len(node.Args) {
case 1:
switch node.Args[0] {

View file

@ -1,3 +1,4 @@
// Package log implements minimalistic logging library.
package log
import (

View file

@ -7,8 +7,18 @@ import (
"syscall"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/endpoint"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
// Import packages for side-effect of module registration.
_ "github.com/foxcpp/maddy/auth/external"
_ "github.com/foxcpp/maddy/check/dns"
_ "github.com/foxcpp/maddy/endpoint/imap"
_ "github.com/foxcpp/maddy/endpoint/smtp"
_ "github.com/foxcpp/maddy/storage/sql"
_ "github.com/foxcpp/maddy/target/queue"
_ "github.com/foxcpp/maddy/target/remote"
)
type modInfo struct {
@ -23,7 +33,7 @@ func Start(cfg []config.Node) error {
globals.String("autogenerated_msg_domain", false, false, "", nil)
globals.String("statedir", false, false, "", nil)
globals.String("libexecdir", false, false, "", nil)
globals.Custom("tls", false, false, nil, tlsDirective, nil)
globals.Custom("tls", false, false, nil, endpoint.TLSDirective, nil)
globals.Bool("auth_perdomain", false, nil)
globals.StringList("auth_domains", false, false, nil, nil)
globals.Custom("log", false, false, defaultLogOutput, logOutput, &log.DefaultLogger.Out)

View file

@ -1,6 +1,7 @@
// Package module contains interfaces implemented by maddy modules.
// Package module contains modules registry and interfaces implemented
// by modules.
//
// They are moved to separate package to prevent circular dependencies.
// Interfaces are placed here to prevent circular dependencies.
//
// Each interface required by maddy for operation is provided by some object
// called "module". This includes authentication, storage backends, DKIM,

735
queue.go
View file

@ -1,735 +0,0 @@
package maddy
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/dsn"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
)
// PartialError describes state of partially successful message delivery.
type PartialError struct {
// Recipients for which delivery permanently failed.
Failed []string
// Recipients for which delivery temporary failed.
TemporaryFailed []string
// Underlying error objects.
Errs map[string]error
}
func (pe PartialError) Error() string {
return fmt.Sprintf("delivery failed for some recipients (permanently: %v, temporary: %v): %v", pe.Failed, pe.TemporaryFailed, pe.Errs)
}
type Queue struct {
name string
location string
hostname string
autogenMsgDomain string
wheel *TimeWheel
dsnDispatcher *Dispatcher
// Retry delay is calculated using the following formula:
// initialRetryTime * retryTimeScale ^ (TriesCount - 1)
initialRetryTime time.Duration
retryTimeScale float64
maxTries int
// If any delivery is scheduled in less than postInitDelay
// after Init, its delay will be increased by postInitDelay.
//
// Say, if postInitDelay is 10 secs.
// Then if some message is scheduled to delivered 5 seconds
// after init, it will be actually delivered 15 seconds
// after start-up.
//
// This delay is added to make that if maddy is killed shortly
// after start-up for whatever reason it will not affect the queue.
postInitDelay time.Duration
Log log.Logger
Target module.DeliveryTarget
workersWg sync.WaitGroup
// Closed from Queue.Close.
workersStop chan struct{}
}
type QueueMetadata struct {
MsgMeta *module.MsgMetadata
From string
// Recipients that should be tried next.
// May or may not be equal to PartialError.TemporaryFailed.
To []string
// Information about previous failures.
// Preserved to be included in a bounce message.
FailedRcpts []string
TemporaryFailedRcpts []string
// All errors are converted to SMTPError we can serialize and
// also it is directly usable for bounce messages.
RcptErrs map[string]*smtp.SMTPError
// Amount of times delivery *already tried*.
TriesCount int
FirstAttempt time.Time
LastAttempt time.Time
// Whether this is a delivery notification.
DSN bool
}
func NewQueue(_, instName string, _ []string) (module.Module, error) {
return &Queue{
name: instName,
initialRetryTime: 15 * time.Minute,
retryTimeScale: 2,
postInitDelay: 10 * time.Second,
Log: log.Logger{Name: "queue"},
workersStop: make(chan struct{}),
}, nil
}
func (q *Queue) Init(cfg *config.Map) error {
var workers int
cfg.Bool("debug", true, &q.Log.Debug)
cfg.Int("max_tries", false, false, 8, &q.maxTries)
cfg.Int("workers", false, false, 16, &workers)
cfg.String("location", false, false, "", &q.location)
cfg.Custom("target", false, true, nil, deliveryDirective, &q.Target)
cfg.String("hostname", true, true, "", &q.hostname)
cfg.String("autogenerated_msg_domain", true, false, "", &q.autogenMsgDomain)
cfg.Custom("bounce", false, false, nil, func(m *config.Map, node *config.Node) (interface{}, error) {
return NewDispatcher(m.Globals, node.Children)
}, &q.dsnDispatcher)
if _, err := cfg.Process(); err != nil {
return err
}
if q.dsnDispatcher != nil {
if q.autogenMsgDomain == "" {
return errors.New("queue: autogenerated_msg_domain is required if bounce {} is specified")
}
q.dsnDispatcher.hostname = q.hostname
q.dsnDispatcher.Log = log.Logger{Name: "queue/dispatcher", Debug: q.Log.Debug}
}
if q.location == "" && q.name == "" {
return errors.New("queue: need explicit location directive or config block name if defined inline")
}
if q.location == "" {
q.location = filepath.Join(StateDirectory(cfg.Globals), q.name)
}
if !filepath.IsAbs(q.location) {
q.location = filepath.Join(StateDirectory(cfg.Globals), q.location)
}
// TODO: Check location write permissions.
if err := os.MkdirAll(q.location, os.ModePerm); err != nil {
return err
}
return q.start(workers)
}
func (q *Queue) start(workers int) error {
q.wheel = NewTimeWheel()
if err := q.readDiskQueue(); err != nil {
return err
}
q.Log.Debugf("delivery target: %T", q.Target)
for i := 0; i < workers; i++ {
q.workersWg.Add(1)
go q.worker()
}
return nil
}
func (q *Queue) Close() error {
// Make Close function idempotent. This makes it more
// convenient to use in certain situations (see queue tests).
if q.workersStop == nil {
return nil
}
close(q.workersStop)
q.workersWg.Wait()
q.workersStop = nil
q.wheel.Close()
return nil
}
func (q *Queue) worker() {
for {
select {
case <-q.workersStop:
q.workersWg.Done()
return
case slot := <-q.wheel.Dispatch():
q.Log.Debugln("worker woke up for", slot.Value)
id := slot.Value.(string)
meta, header, body, err := q.openMessage(id)
if err != nil {
q.Log.Printf("failed to read message: %v", err)
continue
}
q.tryDelivery(meta, header, body)
}
}
}
func (q *Queue) tryDelivery(meta *QueueMetadata, header textproto.Header, body buffer.Buffer) {
dl := deliveryLogger(q.Log, meta.MsgMeta)
dl.Debugf("delivery attempt #%d", meta.TriesCount+1)
partialErr := q.deliver(meta, header, body)
dl.Debugf("failures: permanently: %v, temporary: %v, errors: %v",
partialErr.Failed, partialErr.TemporaryFailed, partialErr.Errs)
// Save permanent errors information for reporting in bounce message.
meta.FailedRcpts = append(meta.FailedRcpts, partialErr.Failed...)
for rcpt, rcptErr := range partialErr.Errs {
var smtpErr *smtp.SMTPError
var ok bool
if smtpErr, ok = rcptErr.(*smtp.SMTPError); !ok {
smtpErr = &smtp.SMTPError{
Code: 554,
EnhancedCode: smtp.EnhancedCode{5, 0, 0},
Message: rcptErr.Error(),
}
if isTemporaryErr(rcptErr) {
smtpErr.Code = 451
smtpErr.EnhancedCode = smtp.EnhancedCode{4, 0, 0}
}
}
meta.RcptErrs[rcpt] = smtpErr
}
meta.To = partialErr.TemporaryFailed
meta.LastAttempt = time.Now()
if meta.TriesCount == q.maxTries || len(partialErr.TemporaryFailed) == 0 {
// Attempt either fully succeeded or completely failed.
if meta.TriesCount == q.maxTries {
dl.Printf("gave up trying to deliver to %v, errors: %v", meta.TemporaryFailedRcpts, meta.RcptErrs)
}
if len(meta.FailedRcpts) != 0 {
dl.Printf("permanently failed to deliver to %v, errors: %v", meta.FailedRcpts, meta.RcptErrs)
}
if !meta.DSN {
q.emitDSN(meta, header)
}
q.removeFromDisk(meta.MsgMeta)
return
}
meta.TriesCount++
if err := q.updateMetadataOnDisk(meta); err != nil {
dl.Printf("failed to update meta-data: %v", err)
}
nextTryTime := time.Now()
nextTryTime = nextTryTime.Add(q.initialRetryTime * time.Duration(math.Pow(q.retryTimeScale, float64(meta.TriesCount-1))))
dl.Printf("%d attempt failed, will retry in %v (at %v)", meta.TriesCount, time.Until(nextTryTime), nextTryTime)
q.wheel.Add(nextTryTime, meta.MsgMeta.ID)
}
func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffer.Buffer) PartialError {
dl := deliveryLogger(q.Log, meta.MsgMeta)
perr := PartialError{
Errs: map[string]error{},
}
target := q.Target
if meta.DSN {
target = q.dsnDispatcher
}
delivery, err := target.Start(meta.MsgMeta, meta.From)
if err != nil {
perr.Failed = append(perr.Failed, meta.To...)
for _, rcpt := range meta.To {
perr.Errs[rcpt] = err
}
return perr
}
var acceptedRcpts []string
for _, rcpt := range meta.To {
if err := delivery.AddRcpt(rcpt); err != nil {
if isTemporaryErr(err) {
perr.TemporaryFailed = append(perr.TemporaryFailed, rcpt)
} else {
perr.Failed = append(perr.Failed, rcpt)
}
perr.Errs[rcpt] = err
} else {
acceptedRcpts = append(acceptedRcpts, rcpt)
}
}
if len(acceptedRcpts) == 0 {
if err := delivery.Abort(); err != nil {
dl.Printf("delivery.Abort failed: %v", err)
}
return perr
}
expandToPartialErr := func(err error) {
if expandedPerr, ok := err.(PartialError); ok {
perr.TemporaryFailed = append(perr.TemporaryFailed, expandedPerr.TemporaryFailed...)
perr.Failed = append(perr.Failed, expandedPerr.Failed...)
for rcpt, rcptErr := range expandedPerr.Errs {
perr.Errs[rcpt] = rcptErr
}
} else {
if isTemporaryErr(err) {
perr.TemporaryFailed = append(perr.TemporaryFailed, acceptedRcpts...)
} else {
perr.Failed = append(perr.Failed, acceptedRcpts...)
}
for _, rcpt := range acceptedRcpts {
perr.Errs[rcpt] = err
}
}
}
if err := delivery.Body(header, body); err != nil {
expandToPartialErr(err)
// No recipients succeeded.
if len(perr.TemporaryFailed)+len(perr.Failed) == len(acceptedRcpts) {
if err := delivery.Abort(); err != nil {
dl.Printf("delivery.Abort failed: %v", err)
}
return perr
}
}
if err := delivery.Commit(); err != nil {
expandToPartialErr(err)
}
return perr
}
type queueDelivery struct {
q *Queue
meta *QueueMetadata
header textproto.Header
body buffer.Buffer
}
func (qd *queueDelivery) AddRcpt(rcptTo string) error {
qd.meta.To = append(qd.meta.To, rcptTo)
return nil
}
func (qd *queueDelivery) Body(header textproto.Header, body buffer.Buffer) error {
// Body buffer initially passed to us may not be valid after "delivery" to queue completes.
// storeNewMessage returns a new buffer object created from message blob stored on disk.
storedBody, err := qd.q.storeNewMessage(qd.meta, header, body)
if err != nil {
return err
}
qd.body = storedBody
qd.header = header
return nil
}
func (qd *queueDelivery) Abort() error {
if qd.body != nil {
qd.q.removeFromDisk(qd.meta.MsgMeta)
}
return nil
}
func (qd *queueDelivery) Commit() error {
// workersWg counter in incremented to make sure there will be no race with Close.
// e.g. it will not close the wheel before we complete first attempt.
// Also the first attempt is not scheduled using time wheel because
// the "normal" code path requires re-reading and re-parsing of header
// which is kinda expensive.
// FIXME: Though this is temporary solution, the correct fix
// would be to ditch "worker goroutines" altogether and enforce
// concurrent deliveries limit using a semaphore.
qd.q.workersWg.Add(1)
go func() {
qd.q.tryDelivery(qd.meta, qd.header, qd.body)
qd.q.workersWg.Done()
}()
return nil
}
func (q *Queue) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
meta := &QueueMetadata{
MsgMeta: msgMeta,
From: mailFrom,
RcptErrs: map[string]*smtp.SMTPError{},
FirstAttempt: time.Now(),
LastAttempt: time.Now(),
}
return &queueDelivery{q: q, meta: meta}, nil
}
func (q *Queue) StartDSN(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
meta := &QueueMetadata{
DSN: true,
MsgMeta: msgMeta,
From: mailFrom,
RcptErrs: map[string]*smtp.SMTPError{},
FirstAttempt: time.Now(),
LastAttempt: time.Now(),
}
return &queueDelivery{q: q, meta: meta}, nil
}
func (q *Queue) removeFromDisk(msgMeta *module.MsgMetadata) {
id := msgMeta.ID
dl := deliveryLogger(q.Log, msgMeta)
// Order is important.
// If we remove header and body but can't remove meta now - readDiskQueue
// will detect and report it.
headerPath := filepath.Join(q.location, id+".header")
if err := os.Remove(headerPath); err != nil {
dl.Printf("failed to remove header from disk: %v", err)
}
bodyPath := filepath.Join(q.location, id+".body")
if err := os.Remove(bodyPath); err != nil {
dl.Printf("failed to remove body from disk: %v", err)
}
metaPath := filepath.Join(q.location, id+".meta")
if err := os.Remove(metaPath); err != nil {
dl.Printf("failed to remove meta-data from disk: %v", err)
}
dl.Debugf("removed message from disk")
}
func (q *Queue) readDiskQueue() error {
dirInfo, err := ioutil.ReadDir(q.location)
if err != nil {
return err
}
// TODO: Rewrite this function to pass all sub-tests in TestQueueDelivery_DeserializationCleanUp/NoMeta.
loadedCount := 0
for _, entry := range dirInfo {
// We start loading from meta-data files and then check whether ID.header and ID.body exist.
// This allows us to properly detect dangling body files.
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta") {
continue
}
id := entry.Name()[:len(entry.Name())-5]
meta, err := q.readMessageMeta(id)
if err != nil {
q.Log.Printf("failed to read meta-data, skipping: %v (msg ID = %s)", err, id)
continue
}
// Check header file existence.
if _, err := os.Stat(filepath.Join(q.location, id+".header")); err != nil {
if os.IsNotExist(err) {
q.Log.Printf("header file doesn't exist for msg ID = %s", id)
q.tryRemoveDanglingFile(id + ".meta")
q.tryRemoveDanglingFile(id + ".body")
} else {
q.Log.Printf("skipping nonstat'able header file: %v (msg ID = %s)", err, id)
}
continue
}
// Check body file existence.
if _, err := os.Stat(filepath.Join(q.location, id+".body")); err != nil {
if os.IsNotExist(err) {
q.Log.Printf("body file doesn't exist for msg ID = %s", id)
q.tryRemoveDanglingFile(id + ".meta")
q.tryRemoveDanglingFile(id + ".header")
} else {
q.Log.Printf("skipping nonstat'able body file: %v (msg ID = %s)", err, id)
}
continue
}
nextTryTime := meta.LastAttempt
nextTryTime = nextTryTime.Add(q.initialRetryTime * time.Duration(math.Pow(q.retryTimeScale, float64(meta.TriesCount-1))))
if time.Until(nextTryTime) < q.postInitDelay {
nextTryTime = time.Now().Add(q.postInitDelay)
}
q.Log.Debugf("will try to deliver (msg ID = %s) in %v (%v)", id, time.Until(nextTryTime), nextTryTime)
q.wheel.Add(nextTryTime, id)
loadedCount++
}
if loadedCount != 0 {
q.Log.Printf("loaded %d saved queue entries", loadedCount)
}
return nil
}
func (q *Queue) storeNewMessage(meta *QueueMetadata, header textproto.Header, body buffer.Buffer) (buffer.Buffer, error) {
id := meta.MsgMeta.ID
headerPath := filepath.Join(q.location, id+".header")
headerFile, err := os.Create(headerPath)
if err != nil {
return nil, err
}
defer headerFile.Close()
if err := textproto.WriteHeader(headerFile, header); err != nil {
q.tryRemoveDanglingFile(id + ".header")
return nil, err
}
bodyReader, err := body.Open()
if err != nil {
q.tryRemoveDanglingFile(id + ".header")
return nil, err
}
defer bodyReader.Close()
bodyPath := filepath.Join(q.location, id+".body")
bodyFile, err := os.Create(bodyPath)
if err != nil {
return nil, err
}
defer bodyFile.Close()
if _, err := io.Copy(bodyFile, bodyReader); err != nil {
q.tryRemoveDanglingFile(id + ".body")
q.tryRemoveDanglingFile(id + ".header")
return nil, err
}
if err := q.updateMetadataOnDisk(meta); err != nil {
q.tryRemoveDanglingFile(id + ".body")
q.tryRemoveDanglingFile(id + ".header")
return nil, err
}
return buffer.FileBuffer{Path: bodyPath}, nil
}
func (q *Queue) updateMetadataOnDisk(meta *QueueMetadata) error {
metaPath := filepath.Join(q.location, meta.MsgMeta.ID+".meta")
file, err := os.Create(metaPath)
if err != nil {
return err
}
defer file.Close()
metaCopy := *meta
metaCopy.MsgMeta = meta.MsgMeta.DeepCopy()
if _, ok := metaCopy.MsgMeta.SrcAddr.(*net.TCPAddr); !ok {
meta.MsgMeta.SrcAddr = nil
}
if err := json.NewEncoder(file).Encode(metaCopy); err != nil {
return err
}
return nil
}
func (q *Queue) readMessageMeta(id string) (*QueueMetadata, error) {
metaPath := filepath.Join(q.location, id+".meta")
file, err := os.Open(metaPath)
if err != nil {
return nil, err
}
defer file.Close()
meta := &QueueMetadata{}
// net.Addr can't be deserialized because we don't know concrete type. For
// this reason we assume that SrcAddr is TCPAddr, if it is not - we drop it
// during serialization (see updateMetadataOnDisk).
meta.MsgMeta = &module.MsgMetadata{}
meta.MsgMeta.SrcAddr = &net.TCPAddr{}
if err := json.NewDecoder(file).Decode(meta); err != nil {
return nil, err
}
return meta, nil
}
type BufferedReadCloser struct {
*bufio.Reader
io.Closer
}
func (q *Queue) tryRemoveDanglingFile(name string) {
if err := os.Remove(filepath.Join(q.location, name)); err != nil {
q.Log.Println(err)
return
}
q.Log.Printf("removed dangling file %s", name)
}
func (q *Queue) openMessage(id string) (*QueueMetadata, textproto.Header, buffer.Buffer, error) {
meta, err := q.readMessageMeta(id)
if err != nil {
return nil, textproto.Header{}, nil, err
}
bodyPath := filepath.Join(q.location, id+".body")
_, err = os.Stat(bodyPath)
if err != nil {
if os.IsNotExist(err) {
q.tryRemoveDanglingFile(id + ".meta")
}
return nil, textproto.Header{}, nil, nil
}
body := buffer.FileBuffer{Path: bodyPath}
headerPath := filepath.Join(q.location, id+".header")
headerFile, err := os.Open(headerPath)
if err != nil {
if os.IsNotExist(err) {
q.tryRemoveDanglingFile(id + ".meta")
q.tryRemoveDanglingFile(id + ".body")
}
return nil, textproto.Header{}, nil, nil
}
bufferedHeader := bufio.NewReader(headerFile)
header, err := textproto.ReadHeader(bufferedHeader)
if err != nil {
return nil, textproto.Header{}, nil, nil
}
return meta, header, body, nil
}
func (q *Queue) InstanceName() string {
return q.name
}
func (q *Queue) Name() string {
return "queue"
}
func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
// If, apparently, we have no DSN dispatcher configured - do nothing.
if q.dsnDispatcher == nil {
return
}
dsnID, err := generateMsgID()
if err != nil {
q.Log.Printf("rand.Rand error: %v", err)
return
}
dsnEnvelope := dsn.Envelope{
MsgID: "<" + dsnID + "@" + q.autogenMsgDomain + ">",
From: "MAILER-DAEMON@" + q.autogenMsgDomain,
To: meta.From,
}
mtaInfo := dsn.ReportingMTAInfo{
ReportingMTA: q.hostname,
XSender: meta.From,
XMessageID: meta.MsgMeta.ID,
ArrivalDate: meta.FirstAttempt,
LastAttemptDate: meta.LastAttempt,
}
if !meta.MsgMeta.DontTraceSender {
mtaInfo.ReceivedFromMTA = meta.MsgMeta.SrcHostname
}
rcptInfo := make([]dsn.RecipientInfo, 0, len(meta.RcptErrs))
for rcpt, err := range meta.RcptErrs {
if meta.MsgMeta.OriginalRcpts != nil {
originalRcpt := meta.MsgMeta.OriginalRcpts[rcpt]
if originalRcpt != "" {
rcpt = originalRcpt
}
}
rcptInfo = append(rcptInfo, dsn.RecipientInfo{
FinalRecipient: rcpt,
Action: dsn.ActionFailed,
Status: err.EnhancedCode,
DiagnosticCode: err,
})
}
var dsnBodyBlob bytes.Buffer
dl := deliveryLogger(q.Log, meta.MsgMeta)
dsnHeader, err := dsn.GenerateDSN(dsnEnvelope, mtaInfo, rcptInfo, header, &dsnBodyBlob)
if err != nil {
dl.Printf("failed to generate fail DSN: %v", err)
return
}
dsnBody := buffer.MemoryBuffer{Slice: dsnBodyBlob.Bytes()}
dsnMeta := &module.MsgMetadata{
ID: dsnID,
SrcProto: "",
SrcHostname: q.hostname,
OurHostname: q.hostname,
}
dl.Printf("generated failed DSN, DSN ID = %s", dsnID)
dsnDelivery, err := q.StartDSN(dsnMeta, "MAILER-DAEMON@"+q.autogenMsgDomain)
if err != nil {
dl.Printf("failed to enqueue DSN: %v", err)
return
}
defer func() {
if err != nil {
dl.Printf("failed to enqueue DSN: %v", err)
dsnDelivery.Abort()
}
}()
if err = dsnDelivery.AddRcpt(meta.From); err != nil {
return
}
if err = dsnDelivery.Body(dsnHeader, dsnBody); err != nil {
return
}
if err = dsnDelivery.Commit(); err != nil {
return
}
}
func init() {
module.Register("queue", NewQueue)
}

View file

@ -1,549 +0,0 @@
package maddy
import (
"errors"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
)
// newTestQueue returns properly initialized Queue object usable for testing.
//
// See newTestQueueDir to create testing queue from an existing directory.
// It is called responsibility to remove queue directory created by this function.
func newTestQueue(t *testing.T, target module.DeliveryTarget) *Queue {
dir, err := ioutil.TempDir("", "maddy-tests-queue")
if err != nil {
t.Fatal("failed to create temporary directory for queue:", err)
}
return newTestQueueDir(t, target, dir)
}
func cleanQueue(t *testing.T, q *Queue) {
if err := q.Close(); err != nil {
t.Fatal("queue.Close:", err)
}
if err := os.RemoveAll(q.location); err != nil {
t.Fatal("os.RemoveAll", err)
}
}
func newTestQueueDir(t *testing.T, target module.DeliveryTarget, dir string) *Queue {
mod, _ := NewQueue("", "queue", nil)
q := mod.(*Queue)
q.Log = testLogger(t, "queue")
q.initialRetryTime = 0
q.retryTimeScale = 1
q.postInitDelay = 0
q.maxTries = 5
q.location = dir
q.Target = target
if !testing.Verbose() {
q.Log = log.Logger{Name: "", Out: log.WriterLog(ioutil.Discard)}
}
q.start(1)
return q
}
// unreliableTarget is a module.DeliveryTarget implementation that stores
// messages to a slice and sometimes fails with the specified error.
type unreliableTarget struct {
committed chan msg
aborted chan msg
// Amount of completed deliveries (both failed and succeeded)
passedMessages int
// To make unreliableTarget fail Commit for N-th delivery, set N-1-th
// element of this slice to wanted error object. If slice is
// nil/empty or N is bigger than its size - delivery will succeed.
bodyFailures []error
rcptFailures []map[string]error
}
type unreliableTargetDelivery struct {
ut *unreliableTarget
msg msg
}
func (utd *unreliableTargetDelivery) AddRcpt(rcptTo string) error {
if len(utd.ut.rcptFailures) > utd.ut.passedMessages {
rcptErrs := utd.ut.rcptFailures[utd.ut.passedMessages]
if err := rcptErrs[rcptTo]; err != nil {
return err
}
}
utd.msg.rcptTo = append(utd.msg.rcptTo, rcptTo)
return nil
}
func (utd *unreliableTargetDelivery) Body(header textproto.Header, body buffer.Buffer) error {
r, _ := body.Open()
utd.msg.body, _ = ioutil.ReadAll(r)
if len(utd.ut.bodyFailures) > utd.ut.passedMessages {
return utd.ut.bodyFailures[utd.ut.passedMessages]
}
return nil
}
func (utd *unreliableTargetDelivery) Abort() error {
utd.ut.passedMessages++
if utd.ut.aborted != nil {
utd.ut.aborted <- utd.msg
}
return nil
}
func (utd *unreliableTargetDelivery) Commit() error {
utd.ut.passedMessages++
if utd.ut.committed != nil {
utd.ut.committed <- utd.msg
}
return nil
}
func (ut *unreliableTarget) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
return &unreliableTargetDelivery{
ut: ut,
msg: msg{
msgMeta: msgMeta,
mailFrom: mailFrom,
},
}, nil
}
func readMsgChanTimeout(t *testing.T, ch <-chan msg, timeout time.Duration) *msg {
t.Helper()
timer := time.NewTimer(timeout)
select {
case msg := <-ch:
return &msg
case <-timer.C:
t.Fatal("chan read timed out")
return nil
}
}
func checkQueueDir(t *testing.T, q *Queue, expectedIDs []string) {
t.Helper()
// We use the map to lookups and also to mark messages we found
// we can report missing entries.
expectedMap := make(map[string]bool, len(expectedIDs))
for _, id := range expectedIDs {
expectedMap[id] = false
}
dir, err := ioutil.ReadDir(q.location)
if err != nil {
t.Fatalf("failed to read queue directory: %v", err)
}
// Queue implementation uses file names in the following format:
// DELIVERY_ID.SOMETHING
for _, file := range dir {
if file.IsDir() {
t.Fatalf("queue should not create subdirectories in the store, but there is %s dir in it", file.Name())
}
nameParts := strings.Split(file.Name(), ".")
if len(nameParts) != 2 {
t.Fatalf("did the queue files name format changed? got %s", file.Name())
}
_, ok := expectedMap[nameParts[0]]
if !ok {
t.Errorf("message with unexpected Msg ID %s is stored in queue store", nameParts[0])
continue
}
expectedMap[nameParts[0]] = true
}
for id, found := range expectedMap {
if !found {
t.Errorf("expected message with Msg ID %s is missing from queue store", id)
}
}
}
func TestQueueDelivery(t *testing.T) {
t.Parallel()
dt := unreliableTarget{committed: make(chan msg, 10)}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
// This is far from being a proper blackbox testing.
// But I can't come up with a better way to inspect the Queue state.
// This probably will be improved when bounce messages will be implemented.
// For now, this is a dirty hack. Close the Queue and inspect serialized state.
// FIXME.
// Wait for the delivery to complete and stop processing.
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
q.Close()
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
// There should be no queued messages.
checkQueueDir(t, q, []string{})
}
func TestQueueDelivery_PermanentFail_NonPartial(t *testing.T) {
t.Parallel()
dt := unreliableTarget{
bodyFailures: []error{
&smtp.SMTPError{
Code: 500,
EnhancedCode: smtp.EnhancedCode{5, 0, 0},
Message: "you shall not pass",
},
},
aborted: make(chan msg, 10),
}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
// Queue will abort a delivery if it fails for all recipients.
readMsgChanTimeout(t, dt.aborted, 5*time.Second)
q.Close()
// Delivery is failed permanently, hence no retry should be rescheduled.
checkQueueDir(t, q, []string{})
}
func TestQueueDelivery_PermanentFail_Partial(t *testing.T) {
t.Parallel()
dt := unreliableTarget{
bodyFailures: []error{
PartialError{
Failed: []string{"tester1@example.org", "tester2@example.org"},
Errs: map[string]error{
"tester1@example.org": errors.New("you shall not pass"),
"tester2@example.org": errors.New("you shall not pass"),
},
},
},
aborted: make(chan msg, 10),
}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
// This this is similar to the previous test, but checks PartialErr processing logic.
// Here delivery fails for recipients too, but this is reported using PartialErr.
readMsgChanTimeout(t, dt.aborted, 5*time.Second)
q.Close()
checkQueueDir(t, q, []string{})
}
func TestQueueDelivery_TemporaryFail(t *testing.T) {
t.Parallel()
dt := unreliableTarget{
bodyFailures: []error{
PartialError{
TemporaryFailed: []string{"tester1@example.org", "tester2@example.org"},
Errs: map[string]error{
"tester1@example.org": errors.New("you shall not pass"),
"tester2@example.org": errors.New("you shall not pass"),
},
},
},
aborted: make(chan msg, 10),
committed: make(chan msg, 10),
}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
// Delivery should be aborted, because it failed for all recipients.
readMsgChanTimeout(t, dt.aborted, 5*time.Second)
// Second retry, should work fine.
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
q.Close()
// No more retries scheduled, queue storage is clear.
defer checkQueueDir(t, q, []string{})
}
func TestQueueDelivery_TemporaryFail_Partial(t *testing.T) {
t.Parallel()
dt := unreliableTarget{
bodyFailures: []error{
PartialError{
TemporaryFailed: []string{"tester2@example.org"},
Errs: map[string]error{
"tester2@example.org": &smtp.SMTPError{
Code: 400,
Message: "go away",
},
},
},
},
aborted: make(chan msg, 10),
committed: make(chan msg, 10),
}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
// Committed, tester1@example.org - ok.
msg := readMsgChanTimeout(t, dt.committed, 5000*time.Second)
// Side note: unreliableTarget adds recipients to the msg object even if they were rejected
// later using a partial error. So slice below is all recipients that were submitted by
// the queue.
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
// committed #2, tester2@example.org - ok
msg = readMsgChanTimeout(t, dt.committed, 5000*time.Second)
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
q.Close()
// No more retries scheduled, queue storage is clear.
checkQueueDir(t, q, []string{})
}
func TestQueueDelivery_MultipleAttempts(t *testing.T) {
t.Parallel()
dt := unreliableTarget{
bodyFailures: []error{
PartialError{
Failed: []string{"tester1@example.org"},
TemporaryFailed: []string{"tester2@example.org"},
Errs: map[string]error{
"tester1@example.org": errors.New("you shall not pass"),
"tester2@example.org": errors.New("you shall not pass"),
},
},
PartialError{
TemporaryFailed: []string{"tester2@example.org"},
Errs: map[string]error{
"tester2@example.org": errors.New("you shall not pass"),
},
},
},
committed: make(chan msg, 10),
}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org", "tester3@example.org"})
// Committed because delivery to tester3@example.org is succeeded.
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
// Side note: This slice contains all recipients submitted by the queue, even if
// they were rejected later using PartialError.
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org", "tester2@example.org", "tester3@example.org"})
// tester1 is failed permanently, should not be retried.
// tester2 is failed temporary, should be retried.
msg = readMsgChanTimeout(t, dt.committed, 5*time.Second)
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
q.Close()
// No more retries should be scheduled.
checkQueueDir(t, q, []string{})
}
func TestQueueDelivery_PermanentRcptReject(t *testing.T) {
t.Parallel()
dt := unreliableTarget{
rcptFailures: []map[string]error{
{
"tester1@example.org": &smtp.SMTPError{
Code: 500,
Message: "go away",
},
},
},
committed: make(chan msg, 10),
}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
doTestDelivery(t, q, "tester@example.org", []string{"tester1@example.org", "tester2@example.org"})
// Committed, tester2@example.org succeeded.
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
checkMsg(t, msg, "tester@example.org", []string{"tester2@example.org"})
q.Close()
// No more retries should be scheduled.
checkQueueDir(t, q, []string{})
}
func TestQueueDelivery_TemporaryRcptReject(t *testing.T) {
t.Parallel()
dt := unreliableTarget{
rcptFailures: []map[string]error{
{
"tester1@example.org": &smtp.SMTPError{
Code: 400,
Message: "go away",
},
},
},
committed: make(chan msg, 10),
}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
// First attempt:
// tester1 - temp. fail
// tester2 - ok
// Second attempt:
// tester1 - ok
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
// Unlike previous tests where unreliableTarget rejected recipients by PartialError, here they are rejected
// by AddRcpt directly, so they are NOT saved by the target.
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
msg = readMsgChanTimeout(t, dt.committed, 5*time.Second)
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org"})
q.Close()
// No more retries should be scheduled.
checkQueueDir(t, q, []string{})
}
func TestQueueDelivery_SerializationRoundtrip(t *testing.T) {
t.Parallel()
dt := unreliableTarget{
rcptFailures: []map[string]error{
{
"tester1@example.org": &smtp.SMTPError{
Code: 400,
Message: "go away",
},
},
},
committed: make(chan msg, 10),
}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
// This is the most tricky test because it is racy and I have no idea what can be done to avoid it.
// It relies on us calling Close before queue dispatcher decides to retry delivery.
// Hence retry delay is increased from 0ms used in other tests to make it reliable.
q.initialRetryTime = 1 * time.Second
// To make sure we will not time out due to post-init delay.
q.postInitDelay = 0
deliveryID := doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
// Standard partial delivery, retry will be scheduled for tester1@example.org.
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
// Then stop it.
q.Close()
// Make sure it is saved.
checkQueueDir(t, q, []string{deliveryID})
// Then reinit it.
q = newTestQueueDir(t, &dt, q.location)
// Wait for retry and check it.
msg = readMsgChanTimeout(t, dt.committed, 5*time.Second)
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org"})
// Close it again.
q.Close()
// No more retries should be scheduled.
checkQueueDir(t, q, []string{})
}
func TestQueueDelivery_DeserlizationCleanUp(t *testing.T) {
t.Parallel()
test := func(t *testing.T, fileSuffix string) {
dt := unreliableTarget{
rcptFailures: []map[string]error{
{
"tester1@example.org": &smtp.SMTPError{
Code: 400,
Message: "go away",
},
},
},
committed: make(chan msg, 10),
}
q := newTestQueue(t, &dt)
defer cleanQueue(t, q)
// This is the most tricky test because it is racy and I have no idea what can be done to avoid it.
// It relies on us calling Close before queue dispatcher decides to retry delivery.
// Hence retry delay is increased from 0ms used in other tests to make it reliable.
q.initialRetryTime = 1 * time.Second
// To make sure we will not time out due to post-init delay.
q.postInitDelay = 0
deliveryID := doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
// Standard partial delivery, retry will be scheduled for tester1@example.org.
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
q.Close()
if err := os.Remove(filepath.Join(q.location, deliveryID+fileSuffix)); err != nil {
t.Fatal(err)
}
// Dangling files should be removing during load.
q = newTestQueueDir(t, &dt, q.location)
q.Close()
// Nothing should be left.
checkQueueDir(t, q, []string{})
}
t.Run("NoMeta", func(t *testing.T) {
t.Skip("Not implemented")
test(t, ".meta")
})
t.Run("NoBody", func(t *testing.T) {
test(t, ".body")
})
t.Run("NoHeader", func(t *testing.T) {
test(t, ".header")
})
}

View file

@ -1,4 +1,11 @@
package maddy
// Package sql implements SQL-based storage module
// using go-imap-sql library (github.com/foxcpp/go-imap-sql).
//
// Interfaces implemented:
// - module.StorageBackend
// - module.AuthProvider
// - module.DeliveryTarget
package sql
import (
"context"
@ -16,8 +23,11 @@ import (
"github.com/emersion/go-smtp"
imapsql "github.com/foxcpp/go-imap-sql"
"github.com/foxcpp/go-imap-sql/fsstore"
"github.com/foxcpp/maddy/address"
"github.com/foxcpp/maddy/auth"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/dns"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
@ -25,7 +35,7 @@ import (
_ "github.com/lib/pq"
)
type SQLStorage struct {
type Storage struct {
back *imapsql.Backend
instName string
Log log.Logger
@ -35,11 +45,11 @@ type SQLStorage struct {
authDomains []string
junkMbox string
resolver Resolver
resolver dns.Resolver
}
type sqlDelivery struct {
sqlm *SQLStorage
type delivery struct {
store *Storage
msgMeta *module.MsgMetadata
d *imapsql.Delivery
mailFrom string
@ -47,7 +57,7 @@ type sqlDelivery struct {
addedRcpts map[string]struct{}
}
func LookupAddr(r Resolver, ip net.IP) (string, error) {
func LookupAddr(r dns.Resolver, ip net.IP) (string, error) {
names, err := r.LookupAddr(context.Background(), ip.String())
if err != nil || len(names) == 0 {
return "", err
@ -55,7 +65,11 @@ func LookupAddr(r Resolver, ip net.IP) (string, error) {
return strings.TrimRight(names[0], "."), nil
}
func generateReceived(r Resolver, msgMeta *module.MsgMetadata, mailFrom, rcptTo string) string {
func sanitizeString(raw string) string {
return strings.Replace(raw, "\n", "", -1)
}
func generateReceived(r dns.Resolver, msgMeta *module.MsgMetadata, mailFrom, rcptTo string) string {
var received string
if !msgMeta.DontTraceSender {
received += "from " + msgMeta.SrcHostname
@ -74,15 +88,15 @@ func generateReceived(r Resolver, msgMeta *module.MsgMetadata, mailFrom, rcptTo
return received
}
func (sd *sqlDelivery) AddRcpt(rcptTo string) error {
func (d *delivery) AddRcpt(rcptTo string) error {
var accountName string
// Side note: <postmaster> address will be always accepted
// and delivered to "postmaster" account for both cases.
if sd.sqlm.storagePerDomain {
if d.store.storagePerDomain {
accountName = rcptTo
} else {
var err error
accountName, _, err = splitAddress(rcptTo)
accountName, _, err = address.Split(rcptTo)
if err != nil {
return &smtp.SMTPError{
Code: 501,
@ -93,7 +107,7 @@ func (sd *sqlDelivery) AddRcpt(rcptTo string) error {
}
accountName = strings.ToLower(accountName)
if _, ok := sd.addedRcpts[accountName]; ok {
if _, ok := d.addedRcpts[accountName]; ok {
return nil
}
@ -102,9 +116,9 @@ func (sd *sqlDelivery) AddRcpt(rcptTo string) error {
// with small amount of per-recipient data in a efficient way.
userHeader := textproto.Header{}
userHeader.Add("Delivered-To", rcptTo)
userHeader.Add("Received", generateReceived(sd.sqlm.resolver, sd.msgMeta, sd.mailFrom, rcptTo))
userHeader.Add("Received", generateReceived(d.store.resolver, d.msgMeta, d.mailFrom, rcptTo))
if err := sd.d.AddRcpt(strings.ToLower(accountName), userHeader); err != nil {
if err := d.d.AddRcpt(strings.ToLower(accountName), userHeader); err != nil {
if err == imapsql.ErrUserDoesntExists || err == backend.ErrNoSuchMailbox {
return &smtp.SMTPError{
Code: 550,
@ -115,37 +129,37 @@ func (sd *sqlDelivery) AddRcpt(rcptTo string) error {
return err
}
sd.addedRcpts[accountName] = struct{}{}
d.addedRcpts[accountName] = struct{}{}
return nil
}
func (sd *sqlDelivery) Body(header textproto.Header, body buffer.Buffer) error {
if sd.msgMeta.Quarantine {
if err := sd.d.SpecialMailbox(specialuse.Junk, sd.sqlm.junkMbox); err != nil {
func (d *delivery) Body(header textproto.Header, body buffer.Buffer) error {
if d.msgMeta.Quarantine {
if err := d.d.SpecialMailbox(specialuse.Junk, d.store.junkMbox); err != nil {
return err
}
}
header = header.Copy()
header.Add("Return-Path", "<"+sanitizeString(sd.mailFrom)+">")
return sd.d.BodyParsed(header, sd.msgMeta.BodyLength, body)
header.Add("Return-Path", "<"+sanitizeString(d.mailFrom)+">")
return d.d.BodyParsed(header, d.msgMeta.BodyLength, body)
}
func (sd *sqlDelivery) Abort() error {
return sd.d.Abort()
func (d *delivery) Abort() error {
return d.d.Abort()
}
func (sd *sqlDelivery) Commit() error {
return sd.d.Commit()
func (d *delivery) Commit() error {
return d.d.Commit()
}
func (sqlm *SQLStorage) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
d, err := sqlm.back.StartDelivery()
func (store *Storage) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
d, err := store.back.StartDelivery()
if err != nil {
return nil, err
}
return &sqlDelivery{
sqlm: sqlm,
return &delivery{
store: store,
msgMeta: msgMeta,
d: d,
mailFrom: mailFrom,
@ -153,23 +167,23 @@ func (sqlm *SQLStorage) Start(msgMeta *module.MsgMetadata, mailFrom string) (mod
}, nil
}
func (sqlm *SQLStorage) Name() string {
func (store *Storage) Name() string {
return "sql"
}
func (sqlm *SQLStorage) InstanceName() string {
return sqlm.instName
func (store *Storage) InstanceName() string {
return store.instName
}
func NewSQLStorage(_, instName string, _ []string) (module.Module, error) {
return &SQLStorage{
func New(_, instName string, _ []string) (module.Module, error) {
return &Storage{
instName: instName,
Log: log.Logger{Name: "sql"},
resolver: net.DefaultResolver,
}, nil
}
func (sqlm *SQLStorage) Init(cfg *config.Map) error {
func (store *Storage) Init(cfg *config.Map) error {
var driver, dsn string
var fsstoreLocation string
appendlimitVal := int64(-1)
@ -182,24 +196,24 @@ func (sqlm *SQLStorage) Init(cfg *config.Map) error {
cfg.String("driver", false, true, "", &driver)
cfg.String("dsn", false, true, "", &dsn)
cfg.Int64("appendlimit", false, false, 32*1024*1024, &appendlimitVal)
cfg.Bool("debug", true, &sqlm.Log.Debug)
cfg.Bool("storage_perdomain", true, &sqlm.storagePerDomain)
cfg.Bool("auth_perdomain", true, &sqlm.authPerDomain)
cfg.StringList("auth_domains", true, false, nil, &sqlm.authDomains)
cfg.Bool("debug", true, &store.Log.Debug)
cfg.Bool("storage_perdomain", true, &store.storagePerDomain)
cfg.Bool("auth_perdomain", true, &store.authPerDomain)
cfg.StringList("auth_domains", true, false, nil, &store.authDomains)
cfg.Int("sqlite3_cache_size", false, false, 0, &opts.CacheSize)
cfg.Int("sqlite3_busy_timeout", false, false, 0, &opts.BusyTimeout)
cfg.Bool("sqlite3_exclusive_lock", false, &opts.ExclusiveLock)
cfg.String("junk_mailbox", false, false, "Junk", &sqlm.junkMbox)
cfg.String("junk_mailbox", false, false, "Junk", &store.junkMbox)
cfg.Custom("fsstore", false, false, func() (interface{}, error) {
return "", nil
}, func(m *config.Map, node *config.Node) (interface{}, error) {
switch len(node.Args) {
case 0:
if sqlm.instName == "" {
if store.instName == "" {
return nil, errors.New("sql: need explicit fsstore location for inline definition")
}
return filepath.Join(StateDirectory(cfg.Globals), "sql-"+sqlm.instName+"-fsstore"), nil
return filepath.Join(config.StateDirectory(cfg.Globals), "sql-"+store.instName+"-fsstore"), nil
case 1:
return node.Args[0], nil
default:
@ -211,15 +225,15 @@ func (sqlm *SQLStorage) Init(cfg *config.Map) error {
return err
}
opts.Log = &sqlm.Log
opts.Log = &store.Log
if sqlm.authPerDomain && sqlm.authDomains == nil {
if store.authPerDomain && store.authDomains == nil {
return errors.New("sql: auth_domains must be set if auth_perdomain is used")
}
if fsstoreLocation != "" {
if !filepath.IsAbs(fsstoreLocation) {
fsstoreLocation = filepath.Join(StateDirectory(cfg.Globals), fsstoreLocation)
fsstoreLocation = filepath.Join(config.StateDirectory(cfg.Globals), fsstoreLocation)
}
if err := os.MkdirAll(fsstoreLocation, os.ModeDir|os.ModePerm); err != nil {
@ -235,40 +249,40 @@ func (sqlm *SQLStorage) Init(cfg *config.Map) error {
*opts.MaxMsgBytes = uint32(appendlimitVal)
}
var err error
sqlm.back, err = imapsql.New(driver, dsn, opts)
store.back, err = imapsql.New(driver, dsn, opts)
if err != nil {
return fmt.Errorf("sql: %s", err)
}
sqlm.Log.Debugln("go-imap-sql version", imapsql.VersionStr)
store.Log.Debugln("go-imap-sql version", imapsql.VersionStr)
return nil
}
func (sqlm *SQLStorage) IMAPExtensions() []string {
func (store *Storage) IMAPExtensions() []string {
return []string{"APPENDLIMIT", "MOVE", "CHILDREN"}
}
func (sqlm *SQLStorage) Updates() <-chan backend.Update {
return sqlm.back.Updates()
func (store *Storage) Updates() <-chan backend.Update {
return store.back.Updates()
}
func (sqlm *SQLStorage) EnableChildrenExt() bool {
return sqlm.back.EnableChildrenExt()
func (store *Storage) EnableChildrenExt() bool {
return store.back.EnableChildrenExt()
}
func (sqlm *SQLStorage) CheckPlain(username, password string) bool {
accountName, ok := checkDomainAuth(username, sqlm.authPerDomain, sqlm.authDomains)
func (store *Storage) CheckPlain(username, password string) bool {
accountName, ok := auth.CheckDomainAuth(username, store.authPerDomain, store.authDomains)
if !ok {
return false
}
return sqlm.back.CheckPlain(accountName, password)
return store.back.CheckPlain(accountName, password)
}
func (sqlm *SQLStorage) GetOrCreateUser(username string) (backend.User, error) {
func (store *Storage) GetOrCreateUser(username string) (backend.User, error) {
var accountName string
if sqlm.storagePerDomain {
if store.storagePerDomain {
if !strings.Contains(username, "@") {
return nil, errors.New("GetOrCreateUser: username@domain required")
}
@ -278,9 +292,9 @@ func (sqlm *SQLStorage) GetOrCreateUser(username string) (backend.User, error) {
accountName = parts[0]
}
return sqlm.back.GetOrCreateUser(accountName)
return store.back.GetOrCreateUser(accountName)
}
func init() {
module.Register("sql", NewSQLStorage)
module.Register("sql", New)
}

View file

@ -1,5 +1,5 @@
// +build !nosqlite3,cgo
package maddy
package sql
import _ "github.com/mattn/go-sqlite3"

23
target/delivery.go Normal file
View file

@ -0,0 +1,23 @@
package target
import (
"time"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
)
func DeliveryLogger(l log.Logger, msgMeta *module.MsgMetadata) log.Logger {
out := l.Out
if out == nil {
out = log.DefaultLogger.Out
}
return log.Logger{
Out: func(t time.Time, debug bool, str string) {
out(t, debug, str+" (msg ID = "+msgMeta.ID+")")
},
Name: l.Name,
Debug: l.Debug,
}
}

View file

@ -1,4 +1,9 @@
package maddy
// Package remote implements module which does outgoing
// message delivery using servers discovered using DNS MX records.
//
// Implemented interfaces:
// - module.DeliveryTarget
package remote
import (
"context"
@ -7,7 +12,6 @@ import (
"fmt"
"io"
"net"
nettextproto "net/textproto"
"os"
"path/filepath"
"sort"
@ -16,21 +20,23 @@ import (
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/address"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/dns"
"github.com/foxcpp/maddy/log"
"github.com/foxcpp/maddy/module"
"github.com/foxcpp/maddy/mtasts"
"github.com/foxcpp/maddy/target"
"github.com/foxcpp/maddy/target/queue"
)
var ErrTLSRequired = errors.New("TLS is required for outgoing connections but target server doesn't support STARTTLS")
type RemoteTarget struct {
type Target struct {
name string
hostname string
requireTLS bool
resolver Resolver
resolver dns.Resolver
mtastsCache mtasts.Cache
stsCacheUpdateTick *time.Ticker
@ -39,10 +45,10 @@ type RemoteTarget struct {
Log log.Logger
}
var _ module.DeliveryTarget = &RemoteTarget{}
var _ module.DeliveryTarget = &Target{}
func NewRemoteTarget(_, instName string, _ []string) (module.Module, error) {
return &RemoteTarget{
func New(_, instName string, _ []string) (module.Module, error) {
return &Target{
name: instName,
resolver: net.DefaultResolver,
mtastsCache: mtasts.Cache{Resolver: net.DefaultResolver},
@ -52,9 +58,9 @@ func NewRemoteTarget(_, instName string, _ []string) (module.Module, error) {
}, nil
}
func (rt *RemoteTarget) Init(cfg *config.Map) error {
func (rt *Target) Init(cfg *config.Map) error {
cfg.String("hostname", true, true, "", &rt.hostname)
cfg.String("mtasts_cache", false, false, filepath.Join(StateDirectory(cfg.Globals), "mtasts-cache"), &rt.mtastsCache.Location)
cfg.String("mtasts_cache", false, false, filepath.Join(config.StateDirectory(cfg.Globals), "mtasts-cache"), &rt.mtastsCache.Location)
cfg.Bool("debug", true, &rt.Log.Debug)
cfg.Bool("require_tls", false, &rt.requireTLS)
if _, err := cfg.Process(); err != nil {
@ -62,7 +68,7 @@ func (rt *RemoteTarget) Init(cfg *config.Map) error {
}
if !filepath.IsAbs(rt.mtastsCache.Location) {
rt.mtastsCache.Location = filepath.Join(StateDirectory(cfg.Globals), rt.mtastsCache.Location)
rt.mtastsCache.Location = filepath.Join(config.StateDirectory(cfg.Globals), rt.mtastsCache.Location)
}
if err := os.MkdirAll(rt.mtastsCache.Location, os.ModePerm); err != nil {
return err
@ -76,17 +82,17 @@ func (rt *RemoteTarget) Init(cfg *config.Map) error {
return nil
}
func (rt *RemoteTarget) Close() error {
func (rt *Target) Close() error {
rt.stsCacheUpdateDone <- struct{}{}
<-rt.stsCacheUpdateDone
return nil
}
func (rt *RemoteTarget) Name() string {
func (rt *Target) Name() string {
return "remote"
}
func (rt *RemoteTarget) InstanceName() string {
func (rt *Target) InstanceName() string {
return rt.name
}
@ -97,7 +103,7 @@ type remoteConnection struct {
}
type remoteDelivery struct {
rt *RemoteTarget
rt *Target
mailFrom string
msgMeta *module.MsgMetadata
Log log.Logger
@ -105,18 +111,18 @@ type remoteDelivery struct {
connections map[string]*remoteConnection
}
func (rt *RemoteTarget) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
func (rt *Target) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
return &remoteDelivery{
rt: rt,
mailFrom: mailFrom,
msgMeta: msgMeta,
Log: deliveryLogger(rt.Log, msgMeta),
Log: target.DeliveryLogger(rt.Log, msgMeta),
connections: map[string]*remoteConnection{},
}, nil
}
func (rd *remoteDelivery) AddRcpt(to string) error {
_, domain, err := splitAddress(to)
_, domain, err := address.Split(to)
if err != nil {
return err
}
@ -187,13 +193,13 @@ func (rd *remoteDelivery) Body(header textproto.Header, b buffer.Buffer) error {
// TODO: Report partial errors early for LMTP. See github.com/emersion/go-smtp/pull/56
partialErr := PartialError{
partialErr := queue.PartialError{
Errs: map[string]error{},
}
for domain, conn := range rd.connections {
err := <-errChans[domain]
if err != nil {
if isTemporaryErr(err) {
if target.IsTemporaryErr(err) {
partialErr.TemporaryFailed = append(partialErr.TemporaryFailed, conn.recipients...)
} else {
partialErr.Failed = append(partialErr.Failed, conn.recipients...)
@ -280,7 +286,7 @@ func (rd *remoteDelivery) connectionForDomain(domain string) (*remoteConnection,
return conn, nil
}
func (rt *RemoteTarget) getSTSPolicy(domain string) (*mtasts.Policy, error) {
func (rt *Target) getSTSPolicy(domain string) (*mtasts.Policy, error) {
stsPolicy, err := rt.mtastsCache.Get(domain)
if err != nil && err != mtasts.ErrNoPolicy {
rt.Log.Printf("failed to fetch MTA-STS policy for %s: %v", domain, err)
@ -298,7 +304,7 @@ func (rt *RemoteTarget) getSTSPolicy(domain string) (*mtasts.Policy, error) {
var ErrNoMXMatchedBySTS = errors.New("remote: no MX record matched MTA-STS policy")
func (rt *RemoteTarget) stsCacheUpdater() {
func (rt *Target) stsCacheUpdater() {
// Always update cache on start-up since we may have been down for some
// time.
rt.Log.Debugln("updating MTA-STS cache...")
@ -339,13 +345,13 @@ func connectToServer(ourHostname, address string, requireTLS bool) (*smtp.Client
return nil, err
}
} else if requireTLS {
return nil, ErrTLSRequired
return nil, target.ErrTLSRequired
}
return cl, nil
}
func (rt *RemoteTarget) lookupTargetServers(domain string) ([]string, error) {
func (rt *Target) lookupTargetServers(domain string) ([]string, error) {
records, err := rt.resolver.LookupMX(context.Background(), domain)
if err != nil {
return nil, err
@ -362,25 +368,6 @@ func (rt *RemoteTarget) lookupTargetServers(domain string) ([]string, error) {
return hosts, nil
}
func isTemporaryErr(err error) bool {
if protoErr, ok := err.(*nettextproto.Error); ok {
return (protoErr.Code / 100) == 4
}
if smtpErr, ok := err.(*smtp.SMTPError); ok {
return (smtpErr.Code / 100) == 4
}
if netErr, ok := err.(net.Error); ok {
return netErr.Temporary()
}
if err == ErrTLSRequired {
return false
}
// Connection error? Assume it is temporary.
return true
}
func init() {
module.Register("remote", NewRemoteTarget)
module.Register("remote", New)
}

30
target/temporaryerr.go Normal file
View file

@ -0,0 +1,30 @@
package target
import (
"errors"
"net"
"net/textproto"
"github.com/emersion/go-smtp"
)
func IsTemporaryErr(err error) bool {
if protoErr, ok := err.(*textproto.Error); ok {
return (protoErr.Code / 100) == 4
}
if smtpErr, ok := err.(*smtp.SMTPError); ok {
return (smtpErr.Code / 100) == 4
}
if netErr, ok := err.(net.Error); ok {
return netErr.Temporary()
}
if err == ErrTLSRequired {
return false
}
// Connection error? Assume it is temporary.
return true
}
var ErrTLSRequired = errors.New("TLS is required for outgoing connections but target server doesn't support STARTTLS")

65
testutils/check.go Normal file
View file

@ -0,0 +1,65 @@
package testutils
import (
"context"
"github.com/emersion/go-message/textproto"
"github.com/foxcpp/maddy/buffer"
"github.com/foxcpp/maddy/config"
"github.com/foxcpp/maddy/module"
)
type Check struct {
ConnRes module.CheckResult
SenderRes module.CheckResult
RcptRes module.CheckResult
BodyRes module.CheckResult
}
func (c *Check) NewMessage(msgMeta *module.MsgMetadata) (module.CheckState, error) {
return &checkState{msgMeta, c}, nil
}
func (c *Check) Init(*config.Map) error {
return nil
}
func (c *Check) Name() string {
return "test_check"
}
func (c *Check) InstanceName() string {
return "test_check"
}
type checkState struct {
msgMeta *module.MsgMetadata
check *Check
}
func (cs *checkState) CheckConnection(ctx context.Context) module.CheckResult {
return cs.check.ConnRes
}
func (cs *checkState) CheckSender(ctx context.Context, from string) module.CheckResult {
return cs.check.SenderRes
}
func (cs *checkState) CheckRcpt(ctx context.Context, to string) module.CheckResult {
return cs.check.RcptRes
}
func (cs *checkState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
return cs.check.BodyRes
}
func (cs *checkState) Close() error {
return nil
}
func init() {
module.Register("test_check", func(modName, instanceName string, aliases []string) (module.Module, error) {
return &Check{}, nil
})
module.RegisterInstance(&Check{}, nil)
}

View file

@ -1,4 +1,4 @@
package maddy
package testutils
import (
"strings"
@ -8,7 +8,7 @@ import (
"github.com/foxcpp/maddy/log"
)
func testLogger(t *testing.T, name string) log.Logger {
func Logger(t *testing.T, name string) log.Logger {
if testing.Verbose() {
return log.Logger{
Out: func(_ time.Time, _ bool, str string) {

View file

@ -1,4 +1,4 @@
package maddy
package testutils
import (
"crypto/sha1"
@ -14,76 +14,76 @@ import (
"github.com/foxcpp/maddy/module"
)
type msg struct {
msgMeta *module.MsgMetadata
mailFrom string
rcptTo []string
body []byte
header textproto.Header
type Msg struct {
MsgMeta *module.MsgMetadata
MailFrom string
RcptTo []string
Body []byte
Header textproto.Header
}
type testTarget struct {
messages []msg
type Target struct {
Messages []Msg
startErr error
rcptErr map[string]error
bodyErr error
abortErr error
commitErr error
StartErr error
RcptErr map[string]error
BodyErr error
AbortErr error
CommitErr error
instName string
InstName string
}
/*
module.Module is implemented with dummy functions for logging done by Dispatcher code.
*/
func (dt testTarget) Init(*config.Map) error {
func (dt Target) Init(*config.Map) error {
return nil
}
func (dt testTarget) InstanceName() string {
if dt.instName != "" {
return dt.instName
func (dt Target) InstanceName() string {
if dt.InstName != "" {
return dt.InstName
}
return "test_instance"
}
func (dt testTarget) Name() string {
func (dt Target) Name() string {
return "test_target"
}
type testTargetDelivery struct {
msg msg
tgt *testTarget
msg Msg
tgt *Target
aborted bool
committed bool
}
func (dt *testTarget) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
func (dt *Target) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
return &testTargetDelivery{
tgt: dt,
msg: msg{msgMeta: msgMeta, mailFrom: mailFrom},
}, dt.startErr
msg: Msg{MsgMeta: msgMeta, MailFrom: mailFrom},
}, dt.StartErr
}
func (dtd *testTargetDelivery) AddRcpt(to string) error {
if dtd.tgt.rcptErr != nil {
if err := dtd.tgt.rcptErr[to]; err != nil {
if dtd.tgt.RcptErr != nil {
if err := dtd.tgt.RcptErr[to]; err != nil {
return err
}
}
dtd.msg.rcptTo = append(dtd.msg.rcptTo, to)
dtd.msg.RcptTo = append(dtd.msg.RcptTo, to)
return nil
}
func (dtd *testTargetDelivery) Body(header textproto.Header, buf buffer.Buffer) error {
if dtd.tgt.bodyErr != nil {
return dtd.tgt.bodyErr
if dtd.tgt.BodyErr != nil {
return dtd.tgt.BodyErr
}
dtd.msg.header = header
dtd.msg.Header = header
body, err := buf.Open()
if err != nil {
@ -91,23 +91,23 @@ func (dtd *testTargetDelivery) Body(header textproto.Header, buf buffer.Buffer)
}
defer body.Close()
dtd.msg.body, err = ioutil.ReadAll(body)
dtd.msg.Body, err = ioutil.ReadAll(body)
return err
}
func (dtd *testTargetDelivery) Abort() error {
return dtd.tgt.abortErr
return dtd.tgt.AbortErr
}
func (dtd *testTargetDelivery) Commit() error {
if dtd.tgt.commitErr != nil {
return dtd.tgt.commitErr
if dtd.tgt.CommitErr != nil {
return dtd.tgt.CommitErr
}
dtd.tgt.messages = append(dtd.tgt.messages, dtd.msg)
dtd.tgt.Messages = append(dtd.tgt.Messages, dtd.msg)
return nil
}
func doTestDelivery(t *testing.T, tgt module.DeliveryTarget, from string, to []string) string {
func DoTestDelivery(t *testing.T, tgt module.DeliveryTarget, from string, to []string) string {
t.Helper()
IDRaw := sha1.Sum([]byte(t.Name()))
@ -137,7 +137,7 @@ func doTestDelivery(t *testing.T, tgt module.DeliveryTarget, from string, to []s
return encodedID
}
func doTestDeliveryErr(t *testing.T, tgt module.DeliveryTarget, from string, to []string) (string, error) {
func DoTestDeliveryErr(t *testing.T, tgt module.DeliveryTarget, from string, to []string) (string, error) {
t.Helper()
IDRaw := sha1.Sum([]byte(t.Name()))
@ -169,37 +169,37 @@ func doTestDeliveryErr(t *testing.T, tgt module.DeliveryTarget, from string, to
return encodedID, err
}
func checkTestMessage(t *testing.T, tgt *testTarget, indx int, sender string, rcpt []string) {
func CheckTestMessage(t *testing.T, tgt *Target, indx int, sender string, rcpt []string) {
t.Helper()
if len(tgt.messages) <= indx {
t.Errorf("wrong amount of messages received, want at least %d, got %d", indx+1, len(tgt.messages))
if len(tgt.Messages) <= indx {
t.Errorf("wrong amount of messages received, want at least %d, got %d", indx+1, len(tgt.Messages))
return
}
msg := tgt.messages[indx]
msg := tgt.Messages[indx]
checkMsg(t, &msg, sender, rcpt)
CheckMsg(t, &msg, sender, rcpt)
}
func checkMsg(t *testing.T, msg *msg, sender string, rcpt []string) {
func CheckMsg(t *testing.T, msg *Msg, sender string, rcpt []string) {
t.Helper()
idRaw := sha1.Sum([]byte(t.Name()))
encodedId := hex.EncodeToString(idRaw[:])
if msg.msgMeta.ID != encodedId {
t.Errorf("empty or wrong delivery context for passed message? %+v", msg.msgMeta)
if msg.MsgMeta.ID != encodedId {
t.Errorf("empty or wrong delivery context for passed message? %+v", msg.MsgMeta)
}
if msg.mailFrom != sender {
t.Errorf("wrong sender, want %s, got %s", sender, msg.mailFrom)
if msg.MailFrom != sender {
t.Errorf("wrong sender, want %s, got %s", sender, msg.MailFrom)
}
sort.Strings(rcpt)
sort.Strings(msg.rcptTo)
if !reflect.DeepEqual(msg.rcptTo, rcpt) {
t.Errorf("wrong recipients, want %v, got %v", rcpt, msg.rcptTo)
sort.Strings(msg.RcptTo)
if !reflect.DeepEqual(msg.RcptTo, rcpt) {
t.Errorf("wrong recipients, want %v, got %v", rcpt, msg.RcptTo)
}
if string(msg.body) != "foobar" {
t.Errorf("wrong body, want '%s', got '%s'", "foobar", string(msg.body))
if string(msg.Body) != "foobar" {
t.Errorf("wrong body, want '%s', got '%s'", "foobar", string(msg.Body))
}
}

View file

@ -1,115 +0,0 @@
package maddy
import (
"container/list"
"sync"
"time"
)
type TimeSlot struct {
Time time.Time
Value interface{}
}
type TimeWheel struct {
slots *list.List
slotsLock sync.Mutex
updateNotify chan time.Time
stopNotify chan struct{}
dispatch chan TimeSlot
}
func NewTimeWheel() *TimeWheel {
tw := &TimeWheel{
slots: list.New(),
stopNotify: make(chan struct{}),
updateNotify: make(chan time.Time),
dispatch: make(chan TimeSlot, 10),
}
go tw.tick()
return tw
}
func (tw *TimeWheel) Add(target time.Time, value interface{}) {
if value == nil {
panic("can't insert nil objects into TimeWheel queue")
}
tw.slotsLock.Lock()
tw.slots.PushBack(TimeSlot{Time: target, Value: value})
tw.slotsLock.Unlock()
tw.updateNotify <- target
}
func (tw *TimeWheel) Close() {
tw.stopNotify <- struct{}{}
<-tw.stopNotify
close(tw.updateNotify)
close(tw.dispatch)
}
func (tw *TimeWheel) tick() {
for {
now := time.Now()
// Look for list element closest to now.
tw.slotsLock.Lock()
var closestSlot TimeSlot
var closestEl *list.Element
for e := tw.slots.Front(); e != nil; e = e.Next() {
slot := e.Value.(TimeSlot)
if slot.Time.Sub(now) < closestSlot.Time.Sub(now) || closestSlot.Value == nil {
closestSlot = slot
closestEl = e
}
}
tw.slotsLock.Unlock()
// Only this goroutine removes elements from TimeWheel so we can be safe using closestSlot.
// Queue is empty. Just wait until update.
if closestEl == nil {
select {
case <-tw.updateNotify:
continue
case <-tw.stopNotify:
tw.stopNotify <- struct{}{}
return
}
}
timer := time.NewTimer(closestSlot.Time.Sub(now))
for {
select {
case <-timer.C:
tw.slotsLock.Lock()
tw.slots.Remove(closestEl)
tw.slotsLock.Unlock()
tw.dispatch <- closestSlot
// break inside of select exits select, not for loop
goto breakinnerloop
case newTarget := <-tw.updateNotify:
// Avoid unnecessary restarts if new target is not going to affect our
// current wait time.
if closestSlot.Time.Sub(now) <= newTarget.Sub(now) {
continue
}
timer.Stop()
// Recalculate new slot time.
case <-tw.stopNotify:
tw.stopNotify <- struct{}{}
return
}
}
breakinnerloop:
}
}
func (tw *TimeWheel) Dispatch() <-chan TimeSlot {
return tw.dispatch
}