mirror of
https://github.com/foxcpp/maddy.git
synced 2025-04-04 21:47:40 +03:00
Restructure code tree
Root package now contains only initialization code and 'dummy' module. Each module now got its own package. Module packages are grouped by their main purpose (storage/, target/, auth/, etc). Shared code is placed in these "group" packages. Parser for module references in config is moved into config/module. Code shared by tests (mock modules, etc) is placed in testutils.
This commit is contained in:
parent
d4d807d6c7
commit
35c3b1c792
51 changed files with 961 additions and 2223 deletions
3
address/doc.go
Normal file
3
address/doc.go
Normal file
|
@ -0,0 +1,3 @@
|
|||
// Package address provides utilities for parsing
|
||||
// and validation of RFC 2821 addresses.
|
||||
package address
|
24
address/split.go
Normal file
24
address/split.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package address
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Split(addr string) (mailbox, domain string, err error) {
|
||||
parts := strings.Split(addr, "@")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
if strings.EqualFold(parts[0], "postmaster") {
|
||||
return parts[0], "", nil
|
||||
}
|
||||
return "", "", fmt.Errorf("malformed address")
|
||||
case 2:
|
||||
if len(parts[0]) == 0 || len(parts[1]) == 0 {
|
||||
return "", "", fmt.Errorf("malformed address")
|
||||
}
|
||||
return parts[0], parts[1], nil
|
||||
default:
|
||||
return "", "", fmt.Errorf("malformed address")
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package address
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
@ -9,7 +9,8 @@ Rules for validation are subset of rules listed here:
|
|||
https://emailregex.com/email-validation-summary/
|
||||
*/
|
||||
|
||||
func validAddress(addr string) bool {
|
||||
// Valid checks whether ths string is valid as a email address.
|
||||
func Valid(addr string) bool {
|
||||
if len(addr) > 320 { // RFC 3696 says it's 320, not 255.
|
||||
return false
|
||||
}
|
||||
|
@ -18,19 +19,20 @@ func validAddress(addr string) bool {
|
|||
case 1:
|
||||
return strings.EqualFold(addr, "postmaster")
|
||||
case 2:
|
||||
return validMailboxName(parts[0]) && validDomain(parts[1])
|
||||
return ValidMailboxName(parts[0]) && ValidDomain(parts[1])
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// validMailboxName checks whether the specified string is a valid mailbox-name
|
||||
// ValidMailboxName checks whether the specified string is a valid mailbox-name
|
||||
// element of e-mail address (left part of it, before at-sign).
|
||||
func validMailboxName(mbox string) bool {
|
||||
func ValidMailboxName(mbox string) bool {
|
||||
return true // TODO
|
||||
}
|
||||
|
||||
func validDomain(domain string) bool {
|
||||
// ValidDomain checks whether the specified string is a valid DNS domain.
|
||||
func ValidDomain(domain string) bool {
|
||||
if len(domain) > 255 {
|
||||
return false
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
package maddy
|
||||
package auth
|
||||
|
||||
import "strings"
|
||||
|
||||
func checkDomainAuth(username string, perDomain bool, allowedDomains []string) (loginName string, allowed bool) {
|
||||
func CheckDomainAuth(username string, perDomain bool, allowedDomains []string) (loginName string, allowed bool) {
|
||||
var accountName, domain string
|
||||
if perDomain {
|
||||
parts := strings.Split(username, "@")
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -59,7 +59,7 @@ func TestCheckDomainAuth(t *testing.T) {
|
|||
for _, case_ := range cases {
|
||||
case_ := case_
|
||||
t.Run(fmt.Sprintf("%+v", case_), func(t *testing.T) {
|
||||
loginName, allowed := checkDomainAuth(case_.rawUsername, case_.perDomain, case_.allowedDomains)
|
||||
loginName, allowed := CheckDomainAuth(case_.rawUsername, case_.perDomain, case_.allowedDomains)
|
||||
if case_.loginName != "" && !allowed {
|
||||
t.Fatalf("Unexpected authentication fail")
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package external
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -9,6 +9,7 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/foxcpp/maddy/auth"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
|
@ -75,9 +76,9 @@ func (ea *ExternalAuth) Init(cfg *config.Map) error {
|
|||
}
|
||||
}
|
||||
|
||||
ea.helperPath = filepath.Join(LibexecDirectory(cfg.Globals), helperName)
|
||||
ea.helperPath = filepath.Join(config.LibexecDirectory(cfg.Globals), helperName)
|
||||
if _, err := os.Stat(ea.helperPath); err != nil {
|
||||
return fmt.Errorf("no %s authentication support, %s is not found in %s and no custom path is set", ea.modName, LibexecDirectory(cfg.Globals), helperName)
|
||||
return fmt.Errorf("no %s authentication support, %s is not found in %s and no custom path is set", ea.modName, config.LibexecDirectory(cfg.Globals), helperName)
|
||||
}
|
||||
|
||||
ea.Log.Debugln("using helper:", ea.helperPath)
|
||||
|
@ -86,7 +87,7 @@ func (ea *ExternalAuth) Init(cfg *config.Map) error {
|
|||
}
|
||||
|
||||
func (ea *ExternalAuth) CheckPlain(username, password string) bool {
|
||||
accountName, ok := checkDomainAuth(username, ea.perDomain, ea.domains)
|
||||
accountName, ok := auth.CheckDomainAuth(username, ea.perDomain, ea.domains)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
62
check/action.go
Normal file
62
check/action.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package check
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
type FailAction struct {
|
||||
Quarantine bool
|
||||
Reject bool
|
||||
ScoreAdjust int
|
||||
}
|
||||
|
||||
func failActionDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
if len(node.Args) == 0 {
|
||||
return nil, m.MatchErr("expected at least 1 argument")
|
||||
}
|
||||
if len(node.Children) != 0 {
|
||||
return nil, m.MatchErr("can't declare block here")
|
||||
}
|
||||
|
||||
switch node.Args[0] {
|
||||
case "Reject", "Quarantine", "ignore":
|
||||
if len(node.Args) > 1 {
|
||||
return nil, m.MatchErr("too many arguments")
|
||||
}
|
||||
return FailAction{
|
||||
Reject: node.Args[0] == "Reject",
|
||||
Quarantine: node.Args[0] == "Quarantine",
|
||||
}, nil
|
||||
case "score":
|
||||
if len(node.Args) != 2 {
|
||||
return nil, m.MatchErr("expected 2 arguments")
|
||||
}
|
||||
scoreAdj, err := strconv.Atoi(node.Args[1])
|
||||
if err != nil {
|
||||
return nil, m.MatchErr("%v", err)
|
||||
}
|
||||
return FailAction{
|
||||
ScoreAdjust: scoreAdj,
|
||||
}, nil
|
||||
default:
|
||||
return nil, m.MatchErr("invalid action")
|
||||
}
|
||||
}
|
||||
|
||||
// Apply merges the result of check execution with action configuration specified
|
||||
// in the check configuration.
|
||||
func (cfa FailAction) Apply(originalRes module.CheckResult) module.CheckResult {
|
||||
if originalRes.RejectErr == nil {
|
||||
return originalRes
|
||||
}
|
||||
|
||||
originalRes.Quarantine = cfa.Quarantine
|
||||
originalRes.ScoreAdjust = int32(cfa.ScoreAdjust)
|
||||
if !cfa.Reject {
|
||||
originalRes.RejectErr = nil
|
||||
}
|
||||
return originalRes
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package check
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -13,31 +13,31 @@ import (
|
|||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// CheckGroup is a type that wraps a group of checks and runs them in parallel.
|
||||
// Group is a type that wraps a group of checks and runs them in parallel.
|
||||
//
|
||||
// It implements module.Check interface.
|
||||
type CheckGroup struct {
|
||||
checks []module.Check
|
||||
type Group struct {
|
||||
Checks []module.Check
|
||||
}
|
||||
|
||||
func (cg *CheckGroup) NewMessage(msgMeta *module.MsgMetadata) (module.CheckState, error) {
|
||||
states := make([]module.CheckState, 0, len(cg.checks))
|
||||
for _, check := range cg.checks {
|
||||
func (g *Group) NewMessage(msgMeta *module.MsgMetadata) (module.CheckState, error) {
|
||||
states := make([]module.CheckState, 0, len(g.Checks))
|
||||
for _, check := range g.Checks {
|
||||
state, err := check.NewMessage(msgMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
states = append(states, state)
|
||||
}
|
||||
return &checkGroupState{msgMeta, states}, nil
|
||||
return &groupState{msgMeta, states}, nil
|
||||
}
|
||||
|
||||
type checkGroupState struct {
|
||||
type groupState struct {
|
||||
msgMeta *module.MsgMetadata
|
||||
states []module.CheckState
|
||||
}
|
||||
|
||||
func (cgs *checkGroupState) runAndMergeResults(ctx context.Context, runner func(context.Context, module.CheckState) module.CheckResult) module.CheckResult {
|
||||
func (gs *groupState) runAndMergeResults(ctx context.Context, runner func(context.Context, module.CheckState) module.CheckResult) module.CheckResult {
|
||||
var (
|
||||
checkScore int32
|
||||
quarantineFlag atomicbool.AtomicBool
|
||||
|
@ -48,7 +48,7 @@ func (cgs *checkGroupState) runAndMergeResults(ctx context.Context, runner func(
|
|||
)
|
||||
|
||||
syncGroup, childCtx := errgroup.WithContext(ctx)
|
||||
for _, state := range cgs.states {
|
||||
for _, state := range gs.states {
|
||||
state := state
|
||||
syncGroup.Go(func() error {
|
||||
subCheckRes := runner(childCtx, state)
|
||||
|
@ -90,32 +90,32 @@ func (cgs *checkGroupState) runAndMergeResults(ctx context.Context, runner func(
|
|||
}
|
||||
}
|
||||
|
||||
func (cgs *checkGroupState) CheckConnection(ctx context.Context) module.CheckResult {
|
||||
return cgs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
|
||||
func (gs *groupState) CheckConnection(ctx context.Context) module.CheckResult {
|
||||
return gs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
|
||||
return state.CheckConnection(childCtx)
|
||||
})
|
||||
}
|
||||
|
||||
func (cgs *checkGroupState) CheckSender(ctx context.Context, from string) module.CheckResult {
|
||||
return cgs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
|
||||
func (gs *groupState) CheckSender(ctx context.Context, from string) module.CheckResult {
|
||||
return gs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
|
||||
return state.CheckSender(childCtx, from)
|
||||
})
|
||||
}
|
||||
|
||||
func (cgs *checkGroupState) CheckRcpt(ctx context.Context, to string) module.CheckResult {
|
||||
return cgs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
|
||||
func (gs *groupState) CheckRcpt(ctx context.Context, to string) module.CheckResult {
|
||||
return gs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
|
||||
return state.CheckRcpt(childCtx, to)
|
||||
})
|
||||
}
|
||||
|
||||
func (cgs *checkGroupState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
|
||||
return cgs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
|
||||
func (gs *groupState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
|
||||
return gs.runAndMergeResults(ctx, func(childCtx context.Context, state module.CheckState) module.CheckResult {
|
||||
return state.CheckBody(childCtx, header, body)
|
||||
})
|
||||
}
|
||||
|
||||
func (cgs *checkGroupState) Close() error {
|
||||
for _, state := range cgs.states {
|
||||
func (gs *groupState) Close() error {
|
||||
for _, state := range gs.states {
|
||||
state.Close()
|
||||
}
|
||||
return nil
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package dkim
|
||||
|
||||
import (
|
||||
"io"
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -6,11 +6,13 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/maddy/address"
|
||||
"github.com/foxcpp/maddy/check"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
func requireMatchingRDNS(ctx StatelessCheckContext) module.CheckResult {
|
||||
func requireMatchingRDNS(ctx check.StatelessCheckContext) module.CheckResult {
|
||||
tcpAddr, ok := ctx.MsgMeta.SrcAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
log.Debugf("non TCP/IP source (%v), skipped", ctx.MsgMeta.SrcAddr)
|
||||
|
@ -47,8 +49,8 @@ func requireMatchingRDNS(ctx StatelessCheckContext) module.CheckResult {
|
|||
}
|
||||
}
|
||||
|
||||
func requireMXRecord(ctx StatelessCheckContext, mailFrom string) module.CheckResult {
|
||||
_, domain, err := splitAddress(mailFrom)
|
||||
func requireMXRecord(ctx check.StatelessCheckContext, mailFrom string) module.CheckResult {
|
||||
_, domain, err := address.Split(mailFrom)
|
||||
if err != nil {
|
||||
return module.CheckResult{RejectErr: err}
|
||||
}
|
||||
|
@ -95,7 +97,7 @@ func requireMXRecord(ctx StatelessCheckContext, mailFrom string) module.CheckRes
|
|||
return module.CheckResult{}
|
||||
}
|
||||
|
||||
func requireMatchingEHLO(ctx StatelessCheckContext) module.CheckResult {
|
||||
func requireMatchingEHLO(ctx check.StatelessCheckContext) module.CheckResult {
|
||||
tcpAddr, ok := ctx.MsgMeta.SrcAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
ctx.Logger.Debugf("not TCP/IP source (%v), skipped", ctx.MsgMeta.SrcAddr)
|
||||
|
@ -132,10 +134,10 @@ func requireMatchingEHLO(ctx StatelessCheckContext) module.CheckResult {
|
|||
}
|
||||
|
||||
func init() {
|
||||
RegisterStatelessCheck("require_matching_rdns", checkFailAction{quarantine: true},
|
||||
check.RegisterStatelessCheck("require_matching_rdns", check.FailAction{Quarantine: true},
|
||||
requireMatchingRDNS, nil, nil, nil)
|
||||
RegisterStatelessCheck("require_mx_record", checkFailAction{quarantine: true},
|
||||
check.RegisterStatelessCheck("require_mx_record", check.FailAction{Quarantine: true},
|
||||
nil, requireMXRecord, nil, nil)
|
||||
RegisterStatelessCheck("require_matching_ehlo", checkFailAction{quarantine: true},
|
||||
check.RegisterStatelessCheck("require_matching_ehlo", check.FailAction{Quarantine: true},
|
||||
requireMatchingEHLO, nil, nil, nil)
|
||||
}
|
|
@ -1,21 +1,22 @@
|
|||
package maddy
|
||||
package check
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/dns"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
"github.com/foxcpp/maddy/target"
|
||||
)
|
||||
|
||||
type (
|
||||
StatelessCheckContext struct {
|
||||
// Resolver that should be used by the check for DNS queries.
|
||||
Resolver Resolver
|
||||
Resolver dns.Resolver
|
||||
|
||||
MsgMeta *module.MsgMetadata
|
||||
|
||||
|
@ -37,13 +38,13 @@ type (
|
|||
type statelessCheck struct {
|
||||
modName string
|
||||
instName string
|
||||
resolver Resolver
|
||||
resolver dns.Resolver
|
||||
logger log.Logger
|
||||
|
||||
// One used by Init if config option is not passed by a user.
|
||||
defaultFailAction checkFailAction
|
||||
defaultFailAction FailAction
|
||||
// The actual fail action that should be applied.
|
||||
failAction checkFailAction
|
||||
failAction FailAction
|
||||
okScore int
|
||||
|
||||
connCheck FuncConnCheck
|
||||
|
@ -57,21 +58,6 @@ type statelessCheckState struct {
|
|||
msgMeta *module.MsgMetadata
|
||||
}
|
||||
|
||||
func deliveryLogger(l log.Logger, msgMeta *module.MsgMetadata) log.Logger {
|
||||
out := l.Out
|
||||
if out == nil {
|
||||
out = log.DefaultLogger.Out
|
||||
}
|
||||
|
||||
return log.Logger{
|
||||
Out: func(t time.Time, debug bool, str string) {
|
||||
out(t, debug, str+" (msg ID = "+msgMeta.ID+")")
|
||||
},
|
||||
Name: l.Name,
|
||||
Debug: l.Debug,
|
||||
}
|
||||
}
|
||||
|
||||
func (s statelessCheckState) CheckConnection(ctx context.Context) module.CheckResult {
|
||||
if s.c.connCheck == nil {
|
||||
return module.CheckResult{}
|
||||
|
@ -81,9 +67,9 @@ func (s statelessCheckState) CheckConnection(ctx context.Context) module.CheckRe
|
|||
Resolver: s.c.resolver,
|
||||
MsgMeta: s.msgMeta,
|
||||
CancelCtx: ctx,
|
||||
Logger: deliveryLogger(s.c.logger, s.msgMeta),
|
||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||
})
|
||||
return s.c.failAction.apply(originalRes)
|
||||
return s.c.failAction.Apply(originalRes)
|
||||
}
|
||||
|
||||
func (s statelessCheckState) CheckSender(ctx context.Context, mailFrom string) module.CheckResult {
|
||||
|
@ -95,9 +81,9 @@ func (s statelessCheckState) CheckSender(ctx context.Context, mailFrom string) m
|
|||
Resolver: s.c.resolver,
|
||||
MsgMeta: s.msgMeta,
|
||||
CancelCtx: ctx,
|
||||
Logger: deliveryLogger(s.c.logger, s.msgMeta),
|
||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||
}, mailFrom)
|
||||
return s.c.failAction.apply(originalRes)
|
||||
return s.c.failAction.Apply(originalRes)
|
||||
}
|
||||
|
||||
func (s statelessCheckState) CheckRcpt(ctx context.Context, rcptTo string) module.CheckResult {
|
||||
|
@ -109,9 +95,9 @@ func (s statelessCheckState) CheckRcpt(ctx context.Context, rcptTo string) modul
|
|||
Resolver: s.c.resolver,
|
||||
MsgMeta: s.msgMeta,
|
||||
CancelCtx: ctx,
|
||||
Logger: deliveryLogger(s.c.logger, s.msgMeta),
|
||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||
}, rcptTo)
|
||||
return s.c.failAction.apply(originalRes)
|
||||
return s.c.failAction.Apply(originalRes)
|
||||
}
|
||||
|
||||
func (s statelessCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
|
||||
|
@ -123,9 +109,9 @@ func (s statelessCheckState) CheckBody(ctx context.Context, header textproto.Hea
|
|||
Resolver: s.c.resolver,
|
||||
MsgMeta: s.msgMeta,
|
||||
CancelCtx: ctx,
|
||||
Logger: deliveryLogger(s.c.logger, s.msgMeta),
|
||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||
}, header, body)
|
||||
return s.c.failAction.apply(originalRes)
|
||||
return s.c.failAction.Apply(originalRes)
|
||||
}
|
||||
|
||||
func (s statelessCheckState) Close() error {
|
||||
|
@ -144,7 +130,7 @@ func (c *statelessCheck) Init(m *config.Map) error {
|
|||
m.Custom("fail_action", false, false,
|
||||
func() (interface{}, error) {
|
||||
return c.defaultFailAction, nil
|
||||
}, checkFailActionDirective, &c.failAction)
|
||||
}, failActionDirective, &c.failAction)
|
||||
_, err := m.Process()
|
||||
return err
|
||||
}
|
||||
|
@ -165,10 +151,10 @@ func (c *statelessCheck) InstanceName() string {
|
|||
//
|
||||
// Note about CheckResult that is returned by the functions:
|
||||
// StatelessCheck supports different action types based on the user configuration, but the particular check
|
||||
// code doesn't need to know about it. It should assume that it is always "reject" and hence it should
|
||||
// code doesn't need to know about it. It should assume that it is always "Reject" and hence it should
|
||||
// populate RejectErr field of the result object with the relevant error description. Fields ScoreAdjust and
|
||||
// Quarantine will be ignored.
|
||||
func RegisterStatelessCheck(name string, defaultFailAction checkFailAction, connCheck FuncConnCheck, senderCheck FuncSenderCheck, rcptCheck FuncRcptCheck, bodyCheck FuncBodyCheck) {
|
||||
func RegisterStatelessCheck(name string, defaultFailAction FailAction, connCheck FuncConnCheck, senderCheck FuncSenderCheck, rcptCheck FuncRcptCheck, bodyCheck FuncBodyCheck) {
|
||||
module.Register(name, func(modName, instName string, aliases []string) (module.Module, error) {
|
||||
return &statelessCheck{
|
||||
modName: modName,
|
|
@ -1,65 +0,0 @@
|
|||
package maddy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
type testCheck struct {
|
||||
connRes module.CheckResult
|
||||
senderRes module.CheckResult
|
||||
rcptRes module.CheckResult
|
||||
bodyRes module.CheckResult
|
||||
}
|
||||
|
||||
func (tc *testCheck) NewMessage(msgMeta *module.MsgMetadata) (module.CheckState, error) {
|
||||
return &testCheckState{msgMeta, tc}, nil
|
||||
}
|
||||
|
||||
func (tc *testCheck) Init(*config.Map) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *testCheck) Name() string {
|
||||
return "test_check"
|
||||
}
|
||||
|
||||
func (tc *testCheck) InstanceName() string {
|
||||
return "test_check"
|
||||
}
|
||||
|
||||
type testCheckState struct {
|
||||
msgMeta *module.MsgMetadata
|
||||
check *testCheck
|
||||
}
|
||||
|
||||
func (tcs *testCheckState) CheckConnection(ctx context.Context) module.CheckResult {
|
||||
return tcs.check.connRes
|
||||
}
|
||||
|
||||
func (tcs *testCheckState) CheckSender(ctx context.Context, from string) module.CheckResult {
|
||||
return tcs.check.senderRes
|
||||
}
|
||||
|
||||
func (tcs *testCheckState) CheckRcpt(ctx context.Context, to string) module.CheckResult {
|
||||
return tcs.check.rcptRes
|
||||
}
|
||||
|
||||
func (tcs *testCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
|
||||
return tcs.check.bodyRes
|
||||
}
|
||||
|
||||
func (tcs *testCheckState) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
module.Register("test_check", func(modName, instanceName string, aliases []string) (module.Module, error) {
|
||||
return &testCheck{}, nil
|
||||
})
|
||||
module.RegisterInstance(&testCheck{}, nil)
|
||||
}
|
|
@ -5,7 +5,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/foxcpp/maddy/shadow"
|
||||
"github.com/foxcpp/maddy/auth/shadow"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("config", filepath.Join(maddy.ConfigDirectory(), "maddy.conf"), "path to configuration file")
|
||||
configPath := flag.String("config", filepath.Join(config.ConfigDirectory(), "maddy.conf"), "path to configuration file")
|
||||
debugFlag := flag.Bool("debug", false, "enable debug logging early")
|
||||
profileEndpoint := flag.String("debug.pprof", "", "enable live profiler HTTP endpoint and listen on the specified endpoint")
|
||||
blockProfileRate := flag.Int("debug.blockprofrate", 0, "set blocking profile rate")
|
||||
|
|
193
config.go
193
config.go
|
@ -4,208 +4,15 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
/*
|
||||
Config matchers for module interfaces.
|
||||
*/
|
||||
|
||||
// createInlineModule is a helper function for config matchers that can create inline modules.
|
||||
func createInlineModule(modName, instName string, aliases []string) (module.Module, error) {
|
||||
newMod := module.Get(modName)
|
||||
if newMod == nil {
|
||||
return nil, fmt.Errorf("unknown module: %s", modName)
|
||||
}
|
||||
|
||||
log.Debugln("module create", modName, instName, "(inline)")
|
||||
|
||||
return newMod(modName, instName, aliases)
|
||||
}
|
||||
|
||||
// initInlineModule constructs "faked" config tree and passes it to module
|
||||
// Init function to make it look like it is defined at top-level.
|
||||
//
|
||||
// args must contain at least one argument, otherwise initInlineModule panics.
|
||||
func initInlineModule(modObj module.Module, globals map[string]interface{}, block *config.Node) error {
|
||||
log.Debugln("module init", modObj.Name(), modObj.InstanceName(), "(inline)")
|
||||
return modObj.Init(config.NewMap(globals, block))
|
||||
}
|
||||
|
||||
// moduleFromNode does all work to create or get existing module object with a certain type.
|
||||
// It is not used by top-level module definitions, only for references from other
|
||||
// modules configuration blocks.
|
||||
//
|
||||
// inlineCfg should contain configuration directives for inline declarations.
|
||||
// args should contain values that are used to create module.
|
||||
// It should be either module name + instance name or just module name. Further extensions
|
||||
// may add other string arguments (currently, they can be accessed by module instances
|
||||
// as aliases argument to constructor).
|
||||
//
|
||||
// It checks using reflection whether it is possible to store a module object into modObj
|
||||
// pointer (e.g. it implements all necessary interfaces) and stores it if everything is fine.
|
||||
// If module object doesn't implement necessary module interfaces - error is returned.
|
||||
// If modObj is not a pointer, moduleFromNode panics.
|
||||
func moduleFromNode(args []string, inlineCfg *config.Node, globals map[string]interface{}, moduleIface interface{}) error {
|
||||
// single argument
|
||||
// - instance name of an existing module
|
||||
// single argument + block
|
||||
// - module name, inline definition
|
||||
// two+ arguments + block
|
||||
// - module name and instance name, inline definition
|
||||
// two+ arguments, no block
|
||||
// - module name and instance name, inline definition, empty config block
|
||||
|
||||
if len(args) == 0 {
|
||||
return config.NodeErr(inlineCfg, "at least one argument is required")
|
||||
}
|
||||
|
||||
var modObj module.Module
|
||||
var err error
|
||||
if inlineCfg.Children != nil || len(args) > 1 {
|
||||
modName := args[0]
|
||||
|
||||
modAliases := args[1:]
|
||||
instName := ""
|
||||
if len(args) >= 2 {
|
||||
modAliases = args[2:]
|
||||
instName = args[1]
|
||||
}
|
||||
|
||||
modObj, err = createInlineModule(modName, instName, modAliases)
|
||||
} else {
|
||||
if len(args) != 1 {
|
||||
return config.NodeErr(inlineCfg, "exactly one argument is to use existing config block")
|
||||
}
|
||||
modObj, err = module.GetInstance(args[0])
|
||||
}
|
||||
if err != nil {
|
||||
return config.NodeErr(inlineCfg, "%v", err)
|
||||
}
|
||||
|
||||
// NOTE: This will panic if moduleIface is not a pointer.
|
||||
modIfaceType := reflect.TypeOf(moduleIface).Elem()
|
||||
modObjType := reflect.TypeOf(modObj)
|
||||
if !modObjType.Implements(modIfaceType) && !modObjType.AssignableTo(modIfaceType) {
|
||||
return config.NodeErr(inlineCfg, "module %s (%s) doesn't implement %v interface", modObj.Name(), modObj.InstanceName(), modIfaceType)
|
||||
}
|
||||
|
||||
reflect.ValueOf(moduleIface).Elem().Set(reflect.ValueOf(modObj))
|
||||
|
||||
if inlineCfg.Children != nil {
|
||||
if err := initInlineModule(modObj, globals, inlineCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deliveryDirective is a callback for use in config.Map.Custom.
|
||||
//
|
||||
// It does all work necessary to create a module instance from the config
|
||||
// directive with the following structure:
|
||||
// directive_name mod_name [inst_name] [{
|
||||
// inline_mod_config
|
||||
// }]
|
||||
//
|
||||
// Note that if used configuration structure lacks directive_name before mod_name - this function
|
||||
// should not be used (call deliveryTarget directly).
|
||||
func deliveryDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
return deliveryTarget(m.Globals, node.Args, node)
|
||||
}
|
||||
|
||||
func deliveryTarget(globals map[string]interface{}, args []string, block *config.Node) (module.DeliveryTarget, error) {
|
||||
var target module.DeliveryTarget
|
||||
if err := moduleFromNode(args, block, globals, &target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func messageCheck(globals map[string]interface{}, args []string, block *config.Node) (module.Check, error) {
|
||||
var check module.Check
|
||||
if err := moduleFromNode(args, block, globals, &check); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return check, nil
|
||||
}
|
||||
|
||||
func authDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
var provider module.AuthProvider
|
||||
if err := moduleFromNode(node.Args, node, m.Globals, &provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func storageDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
var backend module.Storage
|
||||
if err := moduleFromNode(node.Args, node, m.Globals, &backend); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
type checkFailAction struct {
|
||||
quarantine bool
|
||||
reject bool
|
||||
scoreAdjust int
|
||||
}
|
||||
|
||||
func checkFailActionDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
if len(node.Args) == 0 {
|
||||
return nil, m.MatchErr("expected at least 1 argument")
|
||||
}
|
||||
if len(node.Children) != 0 {
|
||||
return nil, m.MatchErr("can't declare block here")
|
||||
}
|
||||
|
||||
switch node.Args[0] {
|
||||
case "reject", "quarantine", "ignore":
|
||||
if len(node.Args) > 1 {
|
||||
return nil, m.MatchErr("too many arguments")
|
||||
}
|
||||
return checkFailAction{
|
||||
reject: node.Args[0] == "reject",
|
||||
quarantine: node.Args[0] == "quarantine",
|
||||
}, nil
|
||||
case "score":
|
||||
if len(node.Args) != 2 {
|
||||
return nil, m.MatchErr("expected 2 arguments")
|
||||
}
|
||||
scoreAdj, err := strconv.Atoi(node.Args[1])
|
||||
if err != nil {
|
||||
return nil, m.MatchErr("%v", err)
|
||||
}
|
||||
return checkFailAction{
|
||||
scoreAdjust: scoreAdj,
|
||||
}, nil
|
||||
default:
|
||||
return nil, m.MatchErr("invalid action")
|
||||
}
|
||||
}
|
||||
|
||||
// apply merges the result of check execution with action configuration specified
|
||||
// in the check configuration.
|
||||
func (cfa checkFailAction) apply(originalRes module.CheckResult) module.CheckResult {
|
||||
if originalRes.RejectErr == nil {
|
||||
return originalRes
|
||||
}
|
||||
|
||||
originalRes.Quarantine = cfa.quarantine
|
||||
originalRes.ScoreAdjust = int32(cfa.scoreAdjust)
|
||||
if !cfa.reject {
|
||||
originalRes.RejectErr = nil
|
||||
}
|
||||
return originalRes
|
||||
}
|
||||
|
||||
func logOutput(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
if len(node.Args) == 0 {
|
||||
return nil, m.MatchErr("expected at least 1 argument")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -86,9 +86,9 @@ func (a Address) IsTLS() bool {
|
|||
return a.Scheme == "imaps" || a.Scheme == "smtps"
|
||||
}
|
||||
|
||||
// standardizeAddress parses an address string into a structured format with separate
|
||||
// StandardizeAddress parses an address string into a structured format with separate
|
||||
// scheme, host, port, and path portions, as well as the original input string.
|
||||
func standardizeAddress(str string) (Address, error) {
|
||||
func StandardizeAddress(str string) (Address, error) {
|
||||
input := str
|
||||
|
||||
// Split input into components (prepend with // to assert host by default)
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package config
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
@ -13,7 +13,7 @@ func TestStandardizeAddress(t *testing.T) {
|
|||
{Original: "smtp://0.0.0.0:10025", Scheme: "smtp", Host: "0.0.0.0", Port: "10025"},
|
||||
{Original: "smtp://[::]:10025", Scheme: "smtp", Host: "::", Port: "10025"},
|
||||
} {
|
||||
actual, err := standardizeAddress(expected.Original)
|
||||
actual, err := StandardizeAddress(expected.Original)
|
||||
if err != nil {
|
||||
t.Error("Unexpected failure:", err)
|
||||
return
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
14
config/module/auth.go
Normal file
14
config/module/auth.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package modconfig
|
||||
|
||||
import (
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
func AuthDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
var provider module.AuthProvider
|
||||
if err := ModuleFromNode(node.Args, node, m.Globals, &provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
14
config/module/check.go
Normal file
14
config/module/check.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package modconfig
|
||||
|
||||
import (
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
func MessageCheck(globals map[string]interface{}, args []string, block *config.Node) (module.Check, error) {
|
||||
var check module.Check
|
||||
if err := ModuleFromNode(args, block, globals, &check); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return check, nil
|
||||
}
|
28
config/module/delivery.go
Normal file
28
config/module/delivery.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package modconfig
|
||||
|
||||
import (
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
// deliveryDirective is a callback for use in config.Map.Custom.
|
||||
//
|
||||
// It does all work necessary to create a module instance from the config
|
||||
// directive with the following structure:
|
||||
// directive_name mod_name [inst_name] [{
|
||||
// inline_mod_config
|
||||
// }]
|
||||
//
|
||||
// Note that if used configuration structure lacks directive_name before mod_name - this function
|
||||
// should not be used (call DeliveryTarget directly).
|
||||
func DeliveryDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
return DeliveryTarget(m.Globals, node.Args, node)
|
||||
}
|
||||
|
||||
func DeliveryTarget(globals map[string]interface{}, args []string, block *config.Node) (module.DeliveryTarget, error) {
|
||||
var target module.DeliveryTarget
|
||||
if err := ModuleFromNode(args, block, globals, &target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return target, nil
|
||||
}
|
108
config/module/modconfig.go
Normal file
108
config/module/modconfig.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
// Package modconfig provides matchers for config.Map that query
|
||||
// modules registry and parse inline module definitions.
|
||||
//
|
||||
// They should be used instead of manual querying when there is need to
|
||||
// reference a module instance in the configuration.
|
||||
//
|
||||
// See ModuleFromNode documentation for explanation of what is 'args'
|
||||
// for some functions (DeliveryTarget).
|
||||
package modconfig
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
// createInlineModule is a helper function for config matchers that can create inline modules.
|
||||
func createInlineModule(modName, instName string, aliases []string) (module.Module, error) {
|
||||
newMod := module.Get(modName)
|
||||
if newMod == nil {
|
||||
return nil, fmt.Errorf("unknown module: %s", modName)
|
||||
}
|
||||
|
||||
log.Debugln("module create", modName, instName, "(inline)")
|
||||
|
||||
return newMod(modName, instName, aliases)
|
||||
}
|
||||
|
||||
// initInlineModule constructs "faked" config tree and passes it to module
|
||||
// Init function to make it look like it is defined at top-level.
|
||||
//
|
||||
// args must contain at least one argument, otherwise initInlineModule panics.
|
||||
func initInlineModule(modObj module.Module, globals map[string]interface{}, block *config.Node) error {
|
||||
log.Debugln("module init", modObj.Name(), modObj.InstanceName(), "(inline)")
|
||||
return modObj.Init(config.NewMap(globals, block))
|
||||
}
|
||||
|
||||
// ModuleFromNode does all work to create or get existing module object with a certain type.
|
||||
// It is not used by top-level module definitions, only for references from other
|
||||
// modules configuration blocks.
|
||||
//
|
||||
// inlineCfg should contain configuration directives for inline declarations.
|
||||
// args should contain values that are used to create module.
|
||||
// It should be either module name + instance name or just module name. Further extensions
|
||||
// may add other string arguments (currently, they can be accessed by module instances
|
||||
// as aliases argument to constructor).
|
||||
//
|
||||
// It checks using reflection whether it is possible to store a module object into modObj
|
||||
// pointer (e.g. it implements all necessary interfaces) and stores it if everything is fine.
|
||||
// If module object doesn't implement necessary module interfaces - error is returned.
|
||||
// If modObj is not a pointer, ModuleFromNode panics.
|
||||
func ModuleFromNode(args []string, inlineCfg *config.Node, globals map[string]interface{}, moduleIface interface{}) error {
|
||||
// single argument
|
||||
// - instance name of an existing module
|
||||
// single argument + block
|
||||
// - module name, inline definition
|
||||
// two+ arguments + block
|
||||
// - module name and instance name, inline definition
|
||||
// two+ arguments, no block
|
||||
// - module name and instance name, inline definition, empty config block
|
||||
|
||||
if len(args) == 0 {
|
||||
return config.NodeErr(inlineCfg, "at least one argument is required")
|
||||
}
|
||||
|
||||
var modObj module.Module
|
||||
var err error
|
||||
if inlineCfg.Children != nil || len(args) > 1 {
|
||||
modName := args[0]
|
||||
|
||||
modAliases := args[1:]
|
||||
instName := ""
|
||||
if len(args) >= 2 {
|
||||
modAliases = args[2:]
|
||||
instName = args[1]
|
||||
}
|
||||
|
||||
modObj, err = createInlineModule(modName, instName, modAliases)
|
||||
} else {
|
||||
if len(args) != 1 {
|
||||
return config.NodeErr(inlineCfg, "exactly one argument is to use existing config block")
|
||||
}
|
||||
modObj, err = module.GetInstance(args[0])
|
||||
}
|
||||
if err != nil {
|
||||
return config.NodeErr(inlineCfg, "%v", err)
|
||||
}
|
||||
|
||||
// NOTE: This will panic if moduleIface is not a pointer.
|
||||
modIfaceType := reflect.TypeOf(moduleIface).Elem()
|
||||
modObjType := reflect.TypeOf(modObj)
|
||||
if !modObjType.Implements(modIfaceType) && !modObjType.AssignableTo(modIfaceType) {
|
||||
return config.NodeErr(inlineCfg, "module %s (%s) doesn't implement %v interface", modObj.Name(), modObj.InstanceName(), modIfaceType)
|
||||
}
|
||||
|
||||
reflect.ValueOf(moduleIface).Elem().Set(reflect.ValueOf(modObj))
|
||||
|
||||
if inlineCfg.Children != nil {
|
||||
if err := initInlineModule(modObj, globals, inlineCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
14
config/module/storage.go
Normal file
14
config/module/storage.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package modconfig
|
||||
|
||||
import (
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
func StorageDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
var backend module.Storage
|
||||
if err := ModuleFromNode(node.Args, node, m.Globals, &backend); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return backend, nil
|
||||
}
|
|
@ -1,23 +1,25 @@
|
|||
package maddy
|
||||
package dispatcher
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-msgauth/authres"
|
||||
"github.com/foxcpp/maddy/check"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
"github.com/foxcpp/maddy/testutils"
|
||||
)
|
||||
|
||||
func TestDispatcher_NoScoresChecked(t *testing.T) {
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -26,31 +28,31 @@ func TestDispatcher_NoScoresChecked(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
// No rejectScore or quarantineScore.
|
||||
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
if target.messages[0].msgMeta.Quarantine {
|
||||
if target.Messages[0].MsgMeta.Quarantine {
|
||||
t.Fatalf("message is quarantined when it shouldn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_RejectScore(t *testing.T) {
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}
|
||||
rejectScore := 10
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -60,30 +62,30 @@ func TestDispatcher_RejectScore(t *testing.T) {
|
|||
},
|
||||
rejectScore: &rejectScore,
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
// Should be rejected.
|
||||
if _, err := doTestDeliveryErr(t, &d, "whatever@whatever", []string{"whatever@whatever"}); err == nil {
|
||||
if _, err := testutils.DoTestDeliveryErr(t, &d, "whatever@whatever", []string{"whatever@whatever"}); err == nil {
|
||||
t.Fatalf("expected an error")
|
||||
}
|
||||
|
||||
if len(target.messages) != 0 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 0, len(target.messages))
|
||||
if len(target.Messages) != 0 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 0, len(target.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_RejectScore_notEnough(t *testing.T) {
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}
|
||||
rejectScore := 15
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -93,29 +95,29 @@ func TestDispatcher_RejectScore_notEnough(t *testing.T) {
|
|||
},
|
||||
rejectScore: &rejectScore,
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
if target.messages[0].msgMeta.Quarantine {
|
||||
if target.Messages[0].MsgMeta.Quarantine {
|
||||
t.Fatalf("message is quarantined when it shouldn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_Quarantine(t *testing.T) {
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{Quarantine: true},
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{},
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{Quarantine: true},
|
||||
}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -124,31 +126,31 @@ func TestDispatcher_Quarantine(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
// Should be quarantined.
|
||||
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
if !target.messages[0].msgMeta.Quarantine {
|
||||
if !target.Messages[0].MsgMeta.Quarantine {
|
||||
t.Fatalf("message is not quarantined when it should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_QuarantineScore(t *testing.T) {
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}
|
||||
quarantineScore := 10
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -158,31 +160,31 @@ func TestDispatcher_QuarantineScore(t *testing.T) {
|
|||
},
|
||||
quarantineScore: &quarantineScore,
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
// Should be quarantined.
|
||||
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
if !target.messages[0].msgMeta.Quarantine {
|
||||
if !target.Messages[0].MsgMeta.Quarantine {
|
||||
t.Fatalf("message is not quarantined when it should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_QuarantineScore_notEnough(t *testing.T) {
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}
|
||||
quarantineScore := 15
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -192,32 +194,32 @@ func TestDispatcher_QuarantineScore_notEnough(t *testing.T) {
|
|||
},
|
||||
quarantineScore: &quarantineScore,
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
// Should be quarantined.
|
||||
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
if target.messages[0].msgMeta.Quarantine {
|
||||
if target.Messages[0].MsgMeta.Quarantine {
|
||||
t.Fatalf("message is quarantined when it shouldn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_BothScores_Quarantined(t *testing.T) {
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}
|
||||
quarantineScore := 10
|
||||
rejectScore := 15
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -228,32 +230,32 @@ func TestDispatcher_BothScores_Quarantined(t *testing.T) {
|
|||
quarantineScore: &quarantineScore,
|
||||
rejectScore: &rejectScore,
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
// Should be quarantined.
|
||||
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
if !target.messages[0].msgMeta.Quarantine {
|
||||
if !target.Messages[0].MsgMeta.Quarantine {
|
||||
t.Fatalf("message is not quarantined when it should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_BothScores_Rejected(t *testing.T) {
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{ScoreAdjust: 5},
|
||||
}
|
||||
quarantineScore := 5
|
||||
rejectScore := 10
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -264,23 +266,23 @@ func TestDispatcher_BothScores_Rejected(t *testing.T) {
|
|||
quarantineScore: &quarantineScore,
|
||||
rejectScore: &rejectScore,
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
// Should be quarantined.
|
||||
if _, err := doTestDeliveryErr(t, &d, "whatever@whatever", []string{"whatever@whatever"}); err == nil {
|
||||
if _, err := testutils.DoTestDeliveryErr(t, &d, "whatever@whatever", []string{"whatever@whatever"}); err == nil {
|
||||
t.Fatalf("message not rejected")
|
||||
}
|
||||
|
||||
if len(target.messages) != 0 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 0, len(target.messages))
|
||||
if len(target.Messages) != 0 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 0, len(target.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_AuthResults(t *testing.T) {
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{
|
||||
AuthResult: []authres.Result{
|
||||
&authres.SPFResult{
|
||||
Value: authres.ResultFail,
|
||||
|
@ -289,8 +291,8 @@ func TestDispatcher_AuthResults(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{
|
||||
AuthResult: []authres.Result{
|
||||
&authres.SPFResult{
|
||||
Value: authres.ResultFail,
|
||||
|
@ -302,7 +304,7 @@ func TestDispatcher_AuthResults(t *testing.T) {
|
|||
}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -311,17 +313,17 @@ func TestDispatcher_AuthResults(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
hostname: "TEST-HOST",
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Hostname: "TEST-HOST",
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
|
||||
authRes := target.messages[0].header.Get("Authentication-Results")
|
||||
authRes := target.Messages[0].Header.Get("Authentication-Results")
|
||||
id, parsed, err := authres.Parse(authRes)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse results")
|
||||
|
@ -362,19 +364,19 @@ func TestDispatcher_Headers(t *testing.T) {
|
|||
hdr2 := textproto.Header{}
|
||||
hdr2.Add("HDR2", "2")
|
||||
|
||||
target := testTarget{}
|
||||
check1, check2 := testCheck{
|
||||
bodyRes: module.CheckResult{
|
||||
target := testutils.Target{}
|
||||
check1, check2 := testutils.Check{
|
||||
BodyRes: module.CheckResult{
|
||||
Header: hdr1,
|
||||
},
|
||||
}, testCheck{
|
||||
bodyRes: module.CheckResult{
|
||||
}, testutils.Check{
|
||||
BodyRes: module.CheckResult{
|
||||
Header: hdr2,
|
||||
},
|
||||
}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
globalChecks: CheckGroup{checks: []module.Check{&check1, &check2}},
|
||||
globalChecks: check.Group{Checks: []module.Check{&check1, &check2}},
|
||||
perSource: map[string]sourceBlock{},
|
||||
defaultSource: sourceBlock{
|
||||
perRcpt: map[string]*rcptBlock{},
|
||||
|
@ -383,20 +385,20 @@ func TestDispatcher_Headers(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
hostname: "TEST-HOST",
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Hostname: "TEST-HOST",
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
testutils.DoTestDelivery(t, &d, "whatever@whatever", []string{"whatever@whatever"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
|
||||
if target.messages[0].header.Get("HDR1") != "1" {
|
||||
t.Fatalf("wrong HDR1 value, want %s, got %s", "1", target.messages[0].header.Get("HDR1"))
|
||||
if target.Messages[0].Header.Get("HDR1") != "1" {
|
||||
t.Fatalf("wrong HDR1 value, want %s, got %s", "1", target.Messages[0].Header.Get("HDR1"))
|
||||
}
|
||||
if target.messages[0].header.Get("HDR2") != "2" {
|
||||
t.Fatalf("wrong HDR2 value, want %s, got %s", "1", target.messages[0].header.Get("HDR2"))
|
||||
if target.Messages[0].Header.Get("HDR2") != "2" {
|
||||
t.Fatalf("wrong HDR2 value, want %s, got %s", "1", target.Messages[0].Header.Get("HDR2"))
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package dispatcher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -6,11 +6,14 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/maddy/address"
|
||||
"github.com/foxcpp/maddy/check"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
modconfig "github.com/foxcpp/maddy/config/module"
|
||||
)
|
||||
|
||||
type dispatcherCfg struct {
|
||||
globalChecks CheckGroup
|
||||
globalChecks check.Group
|
||||
perSource map[string]sourceBlock
|
||||
defaultSource sourceBlock
|
||||
|
||||
|
@ -197,7 +200,7 @@ func parseDispatcherRcptCfg(globals map[string]interface{}, nodes []config.Node)
|
|||
if len(node.Args) == 0 {
|
||||
return nil, config.NodeErr(&node, "required at least one argument")
|
||||
}
|
||||
mod, err := deliveryTarget(globals, node.Args, &node)
|
||||
mod, err := modconfig.DeliveryTarget(globals, node.Args, &node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -276,19 +279,19 @@ func parseEnhancedCode(s string) (smtp.EnhancedCode, error) {
|
|||
return code, nil
|
||||
}
|
||||
|
||||
func parseChecksGroup(globals map[string]interface{}, nodes []config.Node) (CheckGroup, error) {
|
||||
checks := CheckGroup{}
|
||||
func parseChecksGroup(globals map[string]interface{}, nodes []config.Node) (check.Group, error) {
|
||||
checks := check.Group{}
|
||||
for _, child := range nodes {
|
||||
check, err := messageCheck(globals, append([]string{child.Name}, child.Args...), &child)
|
||||
msgCheck, err := modconfig.MessageCheck(globals, append([]string{child.Name}, child.Args...), &child)
|
||||
if err != nil {
|
||||
return CheckGroup{}, err
|
||||
return check.Group{}, err
|
||||
}
|
||||
|
||||
checks.checks = append(checks.checks, check)
|
||||
checks.Checks = append(checks.Checks, msgCheck)
|
||||
}
|
||||
return checks, nil
|
||||
}
|
||||
|
||||
func validDispatchRule(rule string) bool {
|
||||
return validDomain(rule) || validAddress(rule)
|
||||
return address.ValidDomain(rule) || address.Valid(rule)
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package dispatcher
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
@ -244,7 +244,7 @@ func TestDispatcherCfg_GlobalChecks(t *testing.T) {
|
|||
t.Fatalf("unexpected parse error: %v", err)
|
||||
}
|
||||
|
||||
if len(parsed.globalChecks.checks) == 0 {
|
||||
if len(parsed.globalChecks.Checks) == 0 {
|
||||
t.Fatalf("missing test_check in globalChecks")
|
||||
}
|
||||
}
|
||||
|
@ -269,7 +269,7 @@ func TestDispatcherCfg_SourceChecks(t *testing.T) {
|
|||
t.Fatalf("unexpected parse error: %v", err)
|
||||
}
|
||||
|
||||
if len(parsed.perSource["example.org"].checks.checks) == 0 {
|
||||
if len(parsed.perSource["example.org"].checks.Checks) == 0 {
|
||||
t.Fatalf("missing test_check in source checks")
|
||||
}
|
||||
}
|
||||
|
@ -294,7 +294,7 @@ func TestDispatcherCfg_RcptChecks(t *testing.T) {
|
|||
t.Fatalf("unexpected parse error: %v", err)
|
||||
}
|
||||
|
||||
if len(parsed.defaultSource.perRcpt["example.org"].checks.checks) == 0 {
|
||||
if len(parsed.defaultSource.perRcpt["example.org"].checks.Checks) == 0 {
|
||||
t.Fatalf("missing test_check in rcpt checks")
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package dispatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -8,10 +8,13 @@ import (
|
|||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-msgauth/authres"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/maddy/address"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/check"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
"github.com/foxcpp/maddy/target"
|
||||
)
|
||||
|
||||
// Dispatcher is a object that is responsible for selecting delivery targets
|
||||
|
@ -23,20 +26,20 @@ import (
|
|||
// source (Submission, SMTP, JMAP modules) implementation.
|
||||
type Dispatcher struct {
|
||||
dispatcherCfg
|
||||
hostname string
|
||||
Hostname string
|
||||
|
||||
Log log.Logger
|
||||
}
|
||||
|
||||
type sourceBlock struct {
|
||||
checks CheckGroup
|
||||
checks check.Group
|
||||
rejectErr error
|
||||
perRcpt map[string]*rcptBlock
|
||||
defaultRcpt *rcptBlock
|
||||
}
|
||||
|
||||
type rcptBlock struct {
|
||||
checks CheckGroup
|
||||
checks check.Group
|
||||
rejectErr error
|
||||
targets []module.DeliveryTarget
|
||||
}
|
||||
|
@ -48,26 +51,8 @@ func NewDispatcher(globals map[string]interface{}, cfg []config.Node) (*Dispatch
|
|||
}, err
|
||||
}
|
||||
|
||||
func splitAddress(addr string) (mailbox, domain string, err error) {
|
||||
parts := strings.Split(addr, "@")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
if strings.EqualFold(parts[0], "postmaster") {
|
||||
return parts[0], "", nil
|
||||
}
|
||||
return "", "", fmt.Errorf("malformed address")
|
||||
case 2:
|
||||
if len(parts[0]) == 0 || len(parts[1]) == 0 {
|
||||
return "", "", fmt.Errorf("malformed address")
|
||||
}
|
||||
return parts[0], parts[1], nil
|
||||
default:
|
||||
return "", "", fmt.Errorf("malformed address")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
dl := deliveryLogger(d.Log, msgMeta)
|
||||
dl := target.DeliveryLogger(d.Log, msgMeta)
|
||||
dd := dispatcherDelivery{
|
||||
d: d,
|
||||
rcptChecksState: make(map[*rcptBlock]module.CheckState),
|
||||
|
@ -101,7 +86,7 @@ func (d *Dispatcher) Start(msgMeta *module.MsgMetadata, mailFrom string) (module
|
|||
sourceBlock, ok := d.perSource[strings.ToLower(mailFrom)]
|
||||
if !ok {
|
||||
// Then try domain-only.
|
||||
_, domain, err := splitAddress(mailFrom)
|
||||
_, domain, err := address.Split(mailFrom)
|
||||
if err != nil {
|
||||
return nil, &smtp.SMTPError{
|
||||
Code: 501,
|
||||
|
@ -176,7 +161,7 @@ func (dd *dispatcherDelivery) AddRcpt(to string) error {
|
|||
rcptBlock, ok := dd.sourceBlock.perRcpt[strings.ToLower(to)]
|
||||
if !ok {
|
||||
// Then try domain-only.
|
||||
_, domain, err := splitAddress(to)
|
||||
_, domain, err := address.Split(to)
|
||||
if err != nil {
|
||||
return &smtp.SMTPError{
|
||||
Code: 501,
|
||||
|
@ -268,7 +253,7 @@ func (dd *dispatcherDelivery) Body(header textproto.Header, body buffer.Buffer)
|
|||
// After results for all checks are checked, authRes will be populated with values
|
||||
// we should put into Authentication-Results header.
|
||||
if len(dd.authRes) != 0 {
|
||||
header.Add("Authentication-Results", authres.Format(dd.d.hostname, dd.authRes))
|
||||
header.Add("Authentication-Results", authres.Format(dd.d.Hostname, dd.authRes))
|
||||
}
|
||||
for field := dd.header.Fields(); field.Next(); {
|
||||
header.Add(field.Key(), field.Value())
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package dispatcher
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -7,10 +7,11 @@ import (
|
|||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
"github.com/foxcpp/maddy/testutils"
|
||||
)
|
||||
|
||||
func TestDispatcher_AllToTarget(t *testing.T) {
|
||||
target := testTarget{}
|
||||
target := testutils.Target{}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{},
|
||||
|
@ -21,20 +22,20 @@ func TestDispatcher_AllToTarget(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
|
||||
checkTestMessage(t, &target, 0, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
testutils.CheckTestMessage(t, &target, 0, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
}
|
||||
|
||||
func TestDispatcher_PerSourceDomainSplit(t *testing.T) {
|
||||
orgTarget, comTarget := testTarget{instName: "orgTarget"}, testTarget{instName: "comTarget"}
|
||||
orgTarget, comTarget := testutils.Target{InstName: "orgTarget"}, testutils.Target{InstName: "comTarget"}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{
|
||||
|
@ -53,25 +54,25 @@ func TestDispatcher_PerSourceDomainSplit(t *testing.T) {
|
|||
},
|
||||
defaultSource: sourceBlock{rejectErr: errors.New("default src block used")},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
doTestDelivery(t, &d, "sender@example.org", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender@example.org", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
|
||||
if len(comTarget.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for comTarget, want %d, got %d", 1, len(comTarget.messages))
|
||||
if len(comTarget.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for comTarget, want %d, got %d", 1, len(comTarget.Messages))
|
||||
}
|
||||
checkTestMessage(t, &comTarget, 0, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
testutils.CheckTestMessage(t, &comTarget, 0, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
|
||||
if len(orgTarget.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for orgTarget, want %d, got %d", 1, len(orgTarget.messages))
|
||||
if len(orgTarget.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for orgTarget, want %d, got %d", 1, len(orgTarget.Messages))
|
||||
}
|
||||
checkTestMessage(t, &orgTarget, 0, "sender@example.org", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
testutils.CheckTestMessage(t, &orgTarget, 0, "sender@example.org", []string{"rcpt1@example.com", "rcpt2@example.com"})
|
||||
}
|
||||
|
||||
func TestDispatcher_PerRcptAddrSplit(t *testing.T) {
|
||||
target1, target2 := testTarget{instName: "target1"}, testTarget{instName: "target2"}
|
||||
target1, target2 := testutils.Target{InstName: "target1"}, testutils.Target{InstName: "target2"}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{
|
||||
|
@ -90,25 +91,25 @@ func TestDispatcher_PerRcptAddrSplit(t *testing.T) {
|
|||
},
|
||||
defaultSource: sourceBlock{rejectErr: errors.New("default src block used")},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
doTestDelivery(t, &d, "sender2@example.com", []string{"rcpt@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender2@example.com", []string{"rcpt@example.com"})
|
||||
|
||||
if len(target1.messages) != 1 {
|
||||
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 1, len(target1.messages))
|
||||
if len(target1.Messages) != 1 {
|
||||
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 1, len(target1.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target1, 0, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
testutils.CheckTestMessage(t, &target1, 0, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
|
||||
if len(target2.messages) != 1 {
|
||||
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 1, len(target2.messages))
|
||||
if len(target2.Messages) != 1 {
|
||||
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 1, len(target2.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target2, 0, "sender2@example.com", []string{"rcpt@example.com"})
|
||||
testutils.CheckTestMessage(t, &target2, 0, "sender2@example.com", []string{"rcpt@example.com"})
|
||||
}
|
||||
|
||||
func TestDispatcher_PerRcptDomainSplit(t *testing.T) {
|
||||
target1, target2 := testTarget{instName: "target1"}, testTarget{instName: "target2"}
|
||||
target1, target2 := testutils.Target{InstName: "target1"}, testutils.Target{InstName: "target2"}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{},
|
||||
|
@ -126,27 +127,27 @@ func TestDispatcher_PerRcptDomainSplit(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.org"})
|
||||
doTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.org", "rcpt2@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.com", "rcpt2@example.org"})
|
||||
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"rcpt1@example.org", "rcpt2@example.com"})
|
||||
|
||||
if len(target1.messages) != 2 {
|
||||
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 2, len(target1.messages))
|
||||
if len(target1.Messages) != 2 {
|
||||
t.Errorf("wrong amount of messages received for target1, want %d, got %d", 2, len(target1.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target1, 0, "sender@example.com", []string{"rcpt1@example.com"})
|
||||
checkTestMessage(t, &target1, 1, "sender@example.com", []string{"rcpt2@example.com"})
|
||||
testutils.CheckTestMessage(t, &target1, 0, "sender@example.com", []string{"rcpt1@example.com"})
|
||||
testutils.CheckTestMessage(t, &target1, 1, "sender@example.com", []string{"rcpt2@example.com"})
|
||||
|
||||
if len(target2.messages) != 2 {
|
||||
t.Errorf("wrong amount of messages received for target2, want %d, got %d", 2, len(target2.messages))
|
||||
if len(target2.Messages) != 2 {
|
||||
t.Errorf("wrong amount of messages received for target2, want %d, got %d", 2, len(target2.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target2, 0, "sender@example.com", []string{"rcpt2@example.org"})
|
||||
checkTestMessage(t, &target2, 1, "sender@example.com", []string{"rcpt1@example.org"})
|
||||
testutils.CheckTestMessage(t, &target2, 0, "sender@example.com", []string{"rcpt2@example.org"})
|
||||
testutils.CheckTestMessage(t, &target2, 1, "sender@example.com", []string{"rcpt1@example.org"})
|
||||
}
|
||||
|
||||
func TestDispatcher_PerSourceAddrAndDomainSplit(t *testing.T) {
|
||||
target1, target2 := testTarget{instName: "target1"}, testTarget{instName: "target2"}
|
||||
target1, target2 := testutils.Target{InstName: "target1"}, testutils.Target{InstName: "target2"}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{
|
||||
|
@ -164,25 +165,25 @@ func TestDispatcher_PerSourceAddrAndDomainSplit(t *testing.T) {
|
|||
},
|
||||
defaultSource: sourceBlock{rejectErr: errors.New("default src block used")},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
doTestDelivery(t, &d, "sender2@example.com", []string{"rcpt@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender2@example.com", []string{"rcpt@example.com"})
|
||||
|
||||
if len(target1.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target1, want %d, got %d", 1, len(target1.messages))
|
||||
if len(target1.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target1, want %d, got %d", 1, len(target1.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target1, 0, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
testutils.CheckTestMessage(t, &target1, 0, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
|
||||
if len(target2.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target2, want %d, got %d", 1, len(target2.messages))
|
||||
if len(target2.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target2, want %d, got %d", 1, len(target2.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target2, 0, "sender2@example.com", []string{"rcpt@example.com"})
|
||||
testutils.CheckTestMessage(t, &target2, 0, "sender2@example.com", []string{"rcpt@example.com"})
|
||||
}
|
||||
|
||||
func TestDispatcher_PerSourceReject(t *testing.T) {
|
||||
target := testTarget{}
|
||||
target := testutils.Target{}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{
|
||||
|
@ -198,10 +199,10 @@ func TestDispatcher_PerSourceReject(t *testing.T) {
|
|||
},
|
||||
defaultSource: sourceBlock{rejectErr: errors.New("go away")},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender1@example.com", []string{"rcpt@example.com"})
|
||||
|
||||
_, err := d.Start(&module.MsgMetadata{ID: "testing"}, "sender2@example.com")
|
||||
if err == nil {
|
||||
|
@ -215,7 +216,7 @@ func TestDispatcher_PerSourceReject(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDispatcher_PerRcptReject(t *testing.T) {
|
||||
target := testTarget{}
|
||||
target := testutils.Target{}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{},
|
||||
|
@ -233,7 +234,7 @@ func TestDispatcher_PerRcptReject(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
delivery, err := d.Start(&module.MsgMetadata{ID: "testing"}, "sender@example.com")
|
||||
|
@ -257,7 +258,7 @@ func TestDispatcher_PerRcptReject(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDispatcher_PostmasterRcpt(t *testing.T) {
|
||||
target := testTarget{}
|
||||
target := testutils.Target{}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{},
|
||||
|
@ -275,18 +276,18 @@ func TestDispatcher_PostmasterRcpt(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "disappointed-user@example.com", []string{"postmaster"})
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.messages))
|
||||
testutils.DoTestDelivery(t, &d, "disappointed-user@example.com", []string{"postmaster"})
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target, 0, "disappointed-user@example.com", []string{"postmaster"})
|
||||
testutils.CheckTestMessage(t, &target, 0, "disappointed-user@example.com", []string{"postmaster"})
|
||||
}
|
||||
|
||||
func TestDispatcher_PostmasterSrc(t *testing.T) {
|
||||
target := testTarget{}
|
||||
target := testutils.Target{}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{
|
||||
|
@ -304,18 +305,18 @@ func TestDispatcher_PostmasterSrc(t *testing.T) {
|
|||
rejectErr: errors.New("go away"),
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "postmaster", []string{"disappointed-user@example.com"})
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.messages))
|
||||
testutils.DoTestDelivery(t, &d, "postmaster", []string{"disappointed-user@example.com"})
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target, 0, "postmaster", []string{"disappointed-user@example.com"})
|
||||
testutils.CheckTestMessage(t, &target, 0, "postmaster", []string{"disappointed-user@example.com"})
|
||||
}
|
||||
|
||||
func TestDispatcher_CaseInsensetiveMatch_Src(t *testing.T) {
|
||||
target := testTarget{}
|
||||
target := testutils.Target{}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{
|
||||
|
@ -342,22 +343,22 @@ func TestDispatcher_CaseInsensetiveMatch_Src(t *testing.T) {
|
|||
rejectErr: errors.New("go away"),
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "POSTMastER", []string{"disappointed-user@example.com"})
|
||||
doTestDelivery(t, &d, "SenDeR@EXAMPLE.com", []string{"disappointed-user@example.com"})
|
||||
doTestDelivery(t, &d, "sender@exAMPle.com", []string{"disappointed-user@example.com"})
|
||||
if len(target.messages) != 3 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 3, len(target.messages))
|
||||
testutils.DoTestDelivery(t, &d, "POSTMastER", []string{"disappointed-user@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "SenDeR@EXAMPLE.com", []string{"disappointed-user@example.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender@exAMPle.com", []string{"disappointed-user@example.com"})
|
||||
if len(target.Messages) != 3 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 3, len(target.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target, 0, "POSTMastER", []string{"disappointed-user@example.com"})
|
||||
checkTestMessage(t, &target, 1, "SenDeR@EXAMPLE.com", []string{"disappointed-user@example.com"})
|
||||
checkTestMessage(t, &target, 2, "sender@exAMPle.com", []string{"disappointed-user@example.com"})
|
||||
testutils.CheckTestMessage(t, &target, 0, "POSTMastER", []string{"disappointed-user@example.com"})
|
||||
testutils.CheckTestMessage(t, &target, 1, "SenDeR@EXAMPLE.com", []string{"disappointed-user@example.com"})
|
||||
testutils.CheckTestMessage(t, &target, 2, "sender@exAMPle.com", []string{"disappointed-user@example.com"})
|
||||
}
|
||||
|
||||
func TestDispatcher_CaseInsensetiveMatch_Rcpt(t *testing.T) {
|
||||
target := testTarget{}
|
||||
target := testutils.Target{}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{},
|
||||
|
@ -375,22 +376,22 @@ func TestDispatcher_CaseInsensetiveMatch_Rcpt(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "sender@example.com", []string{"POSTMastER"})
|
||||
doTestDelivery(t, &d, "sender@example.com", []string{"SenDeR@EXAMPLE.com"})
|
||||
doTestDelivery(t, &d, "sender@example.com", []string{"sender@exAMPle.com"})
|
||||
if len(target.messages) != 3 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 3, len(target.messages))
|
||||
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"POSTMastER"})
|
||||
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"SenDeR@EXAMPLE.com"})
|
||||
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"sender@exAMPle.com"})
|
||||
if len(target.Messages) != 3 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 3, len(target.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target, 0, "sender@example.com", []string{"POSTMastER"})
|
||||
checkTestMessage(t, &target, 1, "sender@example.com", []string{"SenDeR@EXAMPLE.com"})
|
||||
checkTestMessage(t, &target, 2, "sender@example.com", []string{"sender@exAMPle.com"})
|
||||
testutils.CheckTestMessage(t, &target, 0, "sender@example.com", []string{"POSTMastER"})
|
||||
testutils.CheckTestMessage(t, &target, 1, "sender@example.com", []string{"SenDeR@EXAMPLE.com"})
|
||||
testutils.CheckTestMessage(t, &target, 2, "sender@example.com", []string{"sender@exAMPle.com"})
|
||||
}
|
||||
|
||||
func TestDispatcher_MalformedSource(t *testing.T) {
|
||||
target := testTarget{}
|
||||
target := testutils.Target{}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{},
|
||||
|
@ -408,7 +409,7 @@ func TestDispatcher_MalformedSource(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
// Simple checks for violations that can make dispatcher misbehave.
|
||||
|
@ -421,7 +422,7 @@ func TestDispatcher_MalformedSource(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDispatcher_TwoRcptToOneTarget(t *testing.T) {
|
||||
target := testTarget{}
|
||||
target := testutils.Target{}
|
||||
d := Dispatcher{
|
||||
dispatcherCfg: dispatcherCfg{
|
||||
perSource: map[string]sourceBlock{},
|
||||
|
@ -436,13 +437,13 @@ func TestDispatcher_TwoRcptToOneTarget(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Log: testLogger(t, "dispatcher"),
|
||||
Log: testutils.Logger(t, "dispatcher"),
|
||||
}
|
||||
|
||||
doTestDelivery(t, &d, "sender@example.com", []string{"recipient@example.com", "recipient@example.org"})
|
||||
testutils.DoTestDelivery(t, &d, "sender@example.com", []string{"recipient@example.com", "recipient@example.org"})
|
||||
|
||||
if len(target.messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.messages))
|
||||
if len(target.Messages) != 1 {
|
||||
t.Fatalf("wrong amount of messages received for target, want %d, got %d", 1, len(target.Messages))
|
||||
}
|
||||
checkTestMessage(t, &target, 0, "sender@example.com", []string{"recipient@example.com", "recipient@example.org"})
|
||||
testutils.CheckTestMessage(t, &target, 0, "sender@example.com", []string{"recipient@example.com", "recipient@example.org"})
|
||||
}
|
16
dispatcher/msgid.go
Normal file
16
dispatcher/msgid.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package dispatcher
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
// GenerateMsgID generates a string usable as MsgID field in module.MsgMeta.
|
||||
//
|
||||
// TODO: Find a better place for this function. 'dispatcher' package seems
|
||||
// irrelevant.
|
||||
func GenerateMsgID() (string, error) {
|
||||
rawID := make([]byte, 32)
|
||||
_, err := rand.Read(rawID)
|
||||
return hex.EncodeToString(rawID), err
|
||||
}
|
|
@ -1,4 +1,10 @@
|
|||
package maddy
|
||||
// Package dns defines interfaces used by maddy modules to perform DNS
|
||||
// lookups.
|
||||
//
|
||||
// Currently, there is only Resolver interface which is implemented
|
||||
// by net.DefaultResolver. In the future, DNSSEC-enabled stub resolver
|
||||
// implementation will be added here.
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package imap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
@ -9,23 +9,24 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/emersion/go-imap"
|
||||
imapbackend "github.com/emersion/go-imap/backend"
|
||||
imapserver "github.com/emersion/go-imap/server"
|
||||
"github.com/emersion/go-message"
|
||||
_ "github.com/emersion/go-message/charset"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
|
||||
appendlimit "github.com/emersion/go-imap-appendlimit"
|
||||
compress "github.com/emersion/go-imap-compress"
|
||||
idle "github.com/emersion/go-imap-idle"
|
||||
move "github.com/emersion/go-imap-move"
|
||||
unselect "github.com/emersion/go-imap-unselect"
|
||||
imapbackend "github.com/emersion/go-imap/backend"
|
||||
imapserver "github.com/emersion/go-imap/server"
|
||||
"github.com/emersion/go-message"
|
||||
_ "github.com/emersion/go-message/charset"
|
||||
"github.com/foxcpp/go-imap-sql/children"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
modconfig "github.com/foxcpp/maddy/config/module"
|
||||
"github.com/foxcpp/maddy/endpoint"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
type IMAPEndpoint struct {
|
||||
type Endpoint struct {
|
||||
name string
|
||||
aliases []string
|
||||
serv *imapserver.Server
|
||||
|
@ -40,8 +41,8 @@ type IMAPEndpoint struct {
|
|||
Log log.Logger
|
||||
}
|
||||
|
||||
func NewIMAPEndpoint(_, instName string, aliases []string) (module.Module, error) {
|
||||
endp := &IMAPEndpoint{
|
||||
func New(_, instName string, aliases []string) (module.Module, error) {
|
||||
endp := &Endpoint{
|
||||
name: instName,
|
||||
aliases: aliases,
|
||||
Log: log.Logger{Name: "imap"},
|
||||
|
@ -51,15 +52,15 @@ func NewIMAPEndpoint(_, instName string, aliases []string) (module.Module, error
|
|||
return endp, nil
|
||||
}
|
||||
|
||||
func (endp *IMAPEndpoint) Init(cfg *config.Map) error {
|
||||
func (endp *Endpoint) Init(cfg *config.Map) error {
|
||||
var (
|
||||
insecureAuth bool
|
||||
ioDebug bool
|
||||
)
|
||||
|
||||
cfg.Custom("auth", false, true, nil, authDirective, &endp.Auth)
|
||||
cfg.Custom("storage", false, true, nil, storageDirective, &endp.Store)
|
||||
cfg.Custom("tls", true, true, nil, tlsDirective, &endp.tlsConfig)
|
||||
cfg.Custom("auth", false, true, nil, modconfig.AuthDirective, &endp.Auth)
|
||||
cfg.Custom("storage", false, true, nil, modconfig.StorageDirective, &endp.Store)
|
||||
cfg.Custom("tls", true, true, nil, endpoint.TLSDirective, &endp.tlsConfig)
|
||||
cfg.Bool("insecure_auth", false, &insecureAuth)
|
||||
cfg.Bool("io_debug", false, &ioDebug)
|
||||
cfg.Bool("debug", true, &endp.Log.Debug)
|
||||
|
@ -80,9 +81,9 @@ func (endp *IMAPEndpoint) Init(cfg *config.Map) error {
|
|||
}
|
||||
|
||||
args := append([]string{endp.name}, endp.aliases...)
|
||||
addresses := make([]Address, 0, len(args))
|
||||
addresses := make([]config.Address, 0, len(args))
|
||||
for _, addr := range args {
|
||||
saddr, err := standardizeAddress(addr)
|
||||
saddr, err := config.StandardizeAddress(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("imap: invalid address: %s", endp.name)
|
||||
}
|
||||
|
@ -144,19 +145,19 @@ func (endp *IMAPEndpoint) Init(cfg *config.Map) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (endp *IMAPEndpoint) Updates() <-chan imapbackend.Update {
|
||||
func (endp *Endpoint) Updates() <-chan imapbackend.Update {
|
||||
return endp.updater.Updates()
|
||||
}
|
||||
|
||||
func (endp *IMAPEndpoint) Name() string {
|
||||
func (endp *Endpoint) Name() string {
|
||||
return "imap"
|
||||
}
|
||||
|
||||
func (endp *IMAPEndpoint) InstanceName() string {
|
||||
func (endp *Endpoint) InstanceName() string {
|
||||
return endp.name
|
||||
}
|
||||
|
||||
func (endp *IMAPEndpoint) Close() error {
|
||||
func (endp *Endpoint) Close() error {
|
||||
for _, l := range endp.listeners {
|
||||
l.Close()
|
||||
}
|
||||
|
@ -167,7 +168,7 @@ func (endp *IMAPEndpoint) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (endp *IMAPEndpoint) Login(connInfo *imap.ConnInfo, username, password string) (imapbackend.User, error) {
|
||||
func (endp *Endpoint) Login(connInfo *imap.ConnInfo, username, password string) (imapbackend.User, error) {
|
||||
if !endp.Auth.CheckPlain(username, password) {
|
||||
endp.Log.Printf("authentication failed for %s (from %v)", username, connInfo.RemoteAddr)
|
||||
return nil, imapbackend.ErrInvalidCredentials
|
||||
|
@ -176,11 +177,11 @@ func (endp *IMAPEndpoint) Login(connInfo *imap.ConnInfo, username, password stri
|
|||
return endp.Store.GetOrCreateUser(username)
|
||||
}
|
||||
|
||||
func (endp *IMAPEndpoint) EnableChildrenExt() bool {
|
||||
func (endp *Endpoint) EnableChildrenExt() bool {
|
||||
return endp.Store.(children.Backend).EnableChildrenExt()
|
||||
}
|
||||
|
||||
func (endp *IMAPEndpoint) enableExtensions() error {
|
||||
func (endp *Endpoint) enableExtensions() error {
|
||||
exts := endp.Store.IMAPExtensions()
|
||||
for _, ext := range exts {
|
||||
switch ext {
|
||||
|
@ -201,7 +202,7 @@ func (endp *IMAPEndpoint) enableExtensions() error {
|
|||
}
|
||||
|
||||
func init() {
|
||||
module.Register("imap", NewIMAPEndpoint)
|
||||
module.Register("imap", New)
|
||||
|
||||
imap.CharsetReader = message.CharsetReader
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
|
@ -11,13 +11,15 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
modconfig "github.com/foxcpp/maddy/config/module"
|
||||
"github.com/foxcpp/maddy/dispatcher"
|
||||
"github.com/foxcpp/maddy/endpoint"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
@ -38,8 +40,8 @@ func MsgMetaLog(l log.Logger, msgMeta *module.MsgMetadata) log.Logger {
|
|||
}
|
||||
}
|
||||
|
||||
type SMTPSession struct {
|
||||
endp *SMTPEndpoint
|
||||
type Session struct {
|
||||
endp *Endpoint
|
||||
delivery module.Delivery
|
||||
msgMeta *module.MsgMetadata
|
||||
log log.Logger
|
||||
|
@ -51,7 +53,7 @@ var errInternal = &smtp.SMTPError{
|
|||
Message: "Internal server error",
|
||||
}
|
||||
|
||||
func (s *SMTPSession) Reset() {
|
||||
func (s *Session) Reset() {
|
||||
if s.delivery != nil {
|
||||
if err := s.delivery.Abort(); err != nil {
|
||||
s.endp.Log.Printf("failed to abort delivery: %v", err)
|
||||
|
@ -60,15 +62,9 @@ func (s *SMTPSession) Reset() {
|
|||
}
|
||||
}
|
||||
|
||||
func generateMsgID() (string, error) {
|
||||
rawID := make([]byte, 32)
|
||||
_, err := rand.Read(rawID)
|
||||
return hex.EncodeToString(rawID), err
|
||||
}
|
||||
|
||||
func (s *SMTPSession) Mail(from string) error {
|
||||
func (s *Session) Mail(from string) error {
|
||||
var err error
|
||||
s.msgMeta.ID, err = generateMsgID()
|
||||
s.msgMeta.ID, err = dispatcher.GenerateMsgID()
|
||||
if err != nil {
|
||||
s.endp.Log.Printf("rand.Rand error: %v", err)
|
||||
return s.wrapErr(errInternal)
|
||||
|
@ -86,7 +82,7 @@ func (s *SMTPSession) Mail(from string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *SMTPSession) Rcpt(to string) error {
|
||||
func (s *Session) Rcpt(to string) error {
|
||||
err := s.delivery.AddRcpt(to)
|
||||
if err != nil {
|
||||
s.log.Printf("recipient rejected: %v, RCPT TO = %s", err, to)
|
||||
|
@ -94,7 +90,7 @@ func (s *SMTPSession) Rcpt(to string) error {
|
|||
return s.wrapErr(err)
|
||||
}
|
||||
|
||||
func (s *SMTPSession) Logout() error {
|
||||
func (s *Session) Logout() error {
|
||||
if s.delivery != nil {
|
||||
if err := s.delivery.Abort(); err != nil {
|
||||
s.endp.Log.Printf("failed to abort delivery: %v", err)
|
||||
|
@ -104,7 +100,7 @@ func (s *SMTPSession) Logout() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *SMTPSession) Data(r io.Reader) error {
|
||||
func (s *Session) Data(r io.Reader) error {
|
||||
bufr := bufio.NewReader(r)
|
||||
header, err := textproto.ReadHeader(bufr)
|
||||
if err != nil {
|
||||
|
@ -143,7 +139,7 @@ func (s *SMTPSession) Data(r io.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *SMTPSession) wrapErr(err error) error {
|
||||
func (s *Session) wrapErr(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -158,13 +154,13 @@ func (s *SMTPSession) wrapErr(err error) error {
|
|||
return fmt.Errorf("%v (msg ID = %s)", err, s.msgMeta.ID)
|
||||
}
|
||||
|
||||
type SMTPEndpoint struct {
|
||||
type Endpoint struct {
|
||||
Auth module.AuthProvider
|
||||
serv *smtp.Server
|
||||
name string
|
||||
aliases []string
|
||||
listeners []net.Listener
|
||||
dispatcher *Dispatcher
|
||||
dispatcher *dispatcher.Dispatcher
|
||||
|
||||
authAlwaysRequired bool
|
||||
|
||||
|
@ -175,16 +171,16 @@ type SMTPEndpoint struct {
|
|||
Log log.Logger
|
||||
}
|
||||
|
||||
func (endp *SMTPEndpoint) Name() string {
|
||||
func (endp *Endpoint) Name() string {
|
||||
return "smtp"
|
||||
}
|
||||
|
||||
func (endp *SMTPEndpoint) InstanceName() string {
|
||||
func (endp *Endpoint) InstanceName() string {
|
||||
return endp.name
|
||||
}
|
||||
|
||||
func NewSMTPEndpoint(modName, instName string, aliases []string) (module.Module, error) {
|
||||
endp := &SMTPEndpoint{
|
||||
func New(modName, instName string, aliases []string) (module.Module, error) {
|
||||
endp := &Endpoint{
|
||||
name: instName,
|
||||
aliases: aliases,
|
||||
submission: modName == "submission",
|
||||
|
@ -193,7 +189,7 @@ func NewSMTPEndpoint(modName, instName string, aliases []string) (module.Module,
|
|||
return endp, nil
|
||||
}
|
||||
|
||||
func (endp *SMTPEndpoint) Init(cfg *config.Map) error {
|
||||
func (endp *Endpoint) Init(cfg *config.Map) error {
|
||||
endp.serv = smtp.NewServer(endp)
|
||||
if err := endp.setConfig(cfg); err != nil {
|
||||
return err
|
||||
|
@ -204,9 +200,9 @@ func (endp *SMTPEndpoint) Init(cfg *config.Map) error {
|
|||
}
|
||||
|
||||
args := append([]string{endp.name}, endp.aliases...)
|
||||
addresses := make([]Address, 0, len(args))
|
||||
addresses := make([]config.Address, 0, len(args))
|
||||
for _, addr := range args {
|
||||
saddr, err := standardizeAddress(addr)
|
||||
saddr, err := config.StandardizeAddress(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp: invalid address: %s", addr)
|
||||
}
|
||||
|
@ -232,7 +228,7 @@ func (endp *SMTPEndpoint) Init(cfg *config.Map) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (endp *SMTPEndpoint) setConfig(cfg *config.Map) error {
|
||||
func (endp *Endpoint) setConfig(cfg *config.Map) error {
|
||||
var (
|
||||
err error
|
||||
ioDebug bool
|
||||
|
@ -242,14 +238,14 @@ func (endp *SMTPEndpoint) setConfig(cfg *config.Map) error {
|
|||
readTimeoutSecs uint
|
||||
)
|
||||
|
||||
cfg.Custom("auth", false, false, nil, authDirective, &endp.Auth)
|
||||
cfg.Custom("auth", false, false, nil, modconfig.AuthDirective, &endp.Auth)
|
||||
cfg.String("hostname", true, true, "", &endp.serv.Domain)
|
||||
// TODO: Parse human-readable duration values.
|
||||
cfg.UInt("write_timeout", false, false, 60, &writeTimeoutSecs)
|
||||
cfg.UInt("read_timeout", false, false, 600, &readTimeoutSecs)
|
||||
cfg.Int("max_message_size", false, false, 32*1024*1024, &endp.serv.MaxMessageBytes)
|
||||
cfg.Int("max_recipients", false, false, 20000, &endp.serv.MaxRecipients)
|
||||
cfg.Custom("tls", true, true, nil, tlsDirective, &endp.serv.TLSConfig)
|
||||
cfg.Custom("tls", true, true, nil, endpoint.TLSDirective, &endp.serv.TLSConfig)
|
||||
cfg.Bool("insecure_auth", false, &endp.serv.AllowInsecureAuth)
|
||||
cfg.Bool("io_debug", false, &ioDebug)
|
||||
cfg.Bool("debug", true, &endp.Log.Debug)
|
||||
|
@ -259,14 +255,14 @@ func (endp *SMTPEndpoint) setConfig(cfg *config.Map) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
endp.dispatcher, err = NewDispatcher(cfg.Globals, unmatched)
|
||||
endp.dispatcher, err = dispatcher.NewDispatcher(cfg.Globals, unmatched)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
endp.dispatcher.hostname = endp.serv.Domain
|
||||
endp.dispatcher.Hostname = endp.serv.Domain
|
||||
endp.dispatcher.Log = log.Logger{Name: "smtp/dispatcher", Debug: endp.Log.Debug}
|
||||
|
||||
// endp.submission can be set to true by NewSMTPEndpoint, leave it on
|
||||
// endp.submission can be set to true by New, leave it on
|
||||
// even if directive is missing.
|
||||
if submission {
|
||||
endp.submission = true
|
||||
|
@ -291,7 +287,7 @@ func (endp *SMTPEndpoint) setConfig(cfg *config.Map) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (endp *SMTPEndpoint) setupListeners(addresses []Address) error {
|
||||
func (endp *Endpoint) setupListeners(addresses []config.Address) error {
|
||||
var smtpUsed, lmtpUsed bool
|
||||
for _, addr := range addresses {
|
||||
if addr.Scheme == "smtp" || addr.Scheme == "smtps" {
|
||||
|
@ -341,7 +337,7 @@ func (endp *SMTPEndpoint) setupListeners(addresses []Address) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (endp *SMTPEndpoint) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
|
||||
func (endp *Endpoint) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
|
||||
if endp.Auth == nil {
|
||||
return nil, smtp.ErrAuthUnsupported
|
||||
}
|
||||
|
@ -354,7 +350,7 @@ func (endp *SMTPEndpoint) Login(state *smtp.ConnectionState, username, password
|
|||
return endp.newSession(false, username, password, state), nil
|
||||
}
|
||||
|
||||
func (endp *SMTPEndpoint) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
|
||||
func (endp *Endpoint) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
|
||||
if endp.authAlwaysRequired {
|
||||
return nil, smtp.ErrAuthRequired
|
||||
}
|
||||
|
@ -362,7 +358,7 @@ func (endp *SMTPEndpoint) AnonymousLogin(state *smtp.ConnectionState) (smtp.Sess
|
|||
return endp.newSession(true, "", "", state), nil
|
||||
}
|
||||
|
||||
func (endp *SMTPEndpoint) newSession(anonymous bool, username, password string, state *smtp.ConnectionState) smtp.Session {
|
||||
func (endp *Endpoint) newSession(anonymous bool, username, password string, state *smtp.ConnectionState) smtp.Session {
|
||||
ctx := &module.MsgMetadata{
|
||||
Anonymous: anonymous,
|
||||
AuthUser: username,
|
||||
|
@ -385,18 +381,14 @@ func (endp *SMTPEndpoint) newSession(anonymous bool, username, password string,
|
|||
}
|
||||
}
|
||||
|
||||
return &SMTPSession{
|
||||
return &Session{
|
||||
endp: endp,
|
||||
msgMeta: ctx,
|
||||
log: MsgMetaLog(endp.Log, ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeString(raw string) string {
|
||||
return strings.Replace(raw, "\n", "", -1)
|
||||
}
|
||||
|
||||
func (endp *SMTPEndpoint) Close() error {
|
||||
func (endp *Endpoint) Close() error {
|
||||
for _, l := range endp.listeners {
|
||||
l.Close()
|
||||
}
|
||||
|
@ -406,8 +398,8 @@ func (endp *SMTPEndpoint) Close() error {
|
|||
}
|
||||
|
||||
func init() {
|
||||
module.Register("smtp", NewSMTPEndpoint)
|
||||
module.Register("submission", NewSMTPEndpoint)
|
||||
module.Register("smtp", New)
|
||||
module.Register("submission", New)
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"errors"
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package endpoint
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
|
@ -16,7 +16,7 @@ import (
|
|||
"github.com/foxcpp/maddy/log"
|
||||
)
|
||||
|
||||
func tlsDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
func TLSDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
switch len(node.Args) {
|
||||
case 1:
|
||||
switch node.Args[0] {
|
|
@ -1,3 +1,4 @@
|
|||
// Package log implements minimalistic logging library.
|
||||
package log
|
||||
|
||||
import (
|
||||
|
|
12
maddy.go
12
maddy.go
|
@ -7,8 +7,18 @@ import (
|
|||
"syscall"
|
||||
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/endpoint"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
|
||||
// Import packages for side-effect of module registration.
|
||||
_ "github.com/foxcpp/maddy/auth/external"
|
||||
_ "github.com/foxcpp/maddy/check/dns"
|
||||
_ "github.com/foxcpp/maddy/endpoint/imap"
|
||||
_ "github.com/foxcpp/maddy/endpoint/smtp"
|
||||
_ "github.com/foxcpp/maddy/storage/sql"
|
||||
_ "github.com/foxcpp/maddy/target/queue"
|
||||
_ "github.com/foxcpp/maddy/target/remote"
|
||||
)
|
||||
|
||||
type modInfo struct {
|
||||
|
@ -23,7 +33,7 @@ func Start(cfg []config.Node) error {
|
|||
globals.String("autogenerated_msg_domain", false, false, "", nil)
|
||||
globals.String("statedir", false, false, "", nil)
|
||||
globals.String("libexecdir", false, false, "", nil)
|
||||
globals.Custom("tls", false, false, nil, tlsDirective, nil)
|
||||
globals.Custom("tls", false, false, nil, endpoint.TLSDirective, nil)
|
||||
globals.Bool("auth_perdomain", false, nil)
|
||||
globals.StringList("auth_domains", false, false, nil, nil)
|
||||
globals.Custom("log", false, false, defaultLogOutput, logOutput, &log.DefaultLogger.Out)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// Package module contains interfaces implemented by maddy modules.
|
||||
// Package module contains modules registry and interfaces implemented
|
||||
// by modules.
|
||||
//
|
||||
// They are moved to separate package to prevent circular dependencies.
|
||||
// Interfaces are placed here to prevent circular dependencies.
|
||||
//
|
||||
// Each interface required by maddy for operation is provided by some object
|
||||
// called "module". This includes authentication, storage backends, DKIM,
|
||||
|
|
735
queue.go
735
queue.go
|
@ -1,735 +0,0 @@
|
|||
package maddy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/dsn"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
// PartialError describes state of partially successful message delivery.
|
||||
type PartialError struct {
|
||||
// Recipients for which delivery permanently failed.
|
||||
Failed []string
|
||||
// Recipients for which delivery temporary failed.
|
||||
TemporaryFailed []string
|
||||
|
||||
// Underlying error objects.
|
||||
Errs map[string]error
|
||||
}
|
||||
|
||||
func (pe PartialError) Error() string {
|
||||
return fmt.Sprintf("delivery failed for some recipients (permanently: %v, temporary: %v): %v", pe.Failed, pe.TemporaryFailed, pe.Errs)
|
||||
}
|
||||
|
||||
type Queue struct {
|
||||
name string
|
||||
location string
|
||||
hostname string
|
||||
autogenMsgDomain string
|
||||
wheel *TimeWheel
|
||||
|
||||
dsnDispatcher *Dispatcher
|
||||
|
||||
// Retry delay is calculated using the following formula:
|
||||
// initialRetryTime * retryTimeScale ^ (TriesCount - 1)
|
||||
|
||||
initialRetryTime time.Duration
|
||||
retryTimeScale float64
|
||||
maxTries int
|
||||
|
||||
// If any delivery is scheduled in less than postInitDelay
|
||||
// after Init, its delay will be increased by postInitDelay.
|
||||
//
|
||||
// Say, if postInitDelay is 10 secs.
|
||||
// Then if some message is scheduled to delivered 5 seconds
|
||||
// after init, it will be actually delivered 15 seconds
|
||||
// after start-up.
|
||||
//
|
||||
// This delay is added to make that if maddy is killed shortly
|
||||
// after start-up for whatever reason it will not affect the queue.
|
||||
postInitDelay time.Duration
|
||||
|
||||
Log log.Logger
|
||||
Target module.DeliveryTarget
|
||||
|
||||
workersWg sync.WaitGroup
|
||||
// Closed from Queue.Close.
|
||||
workersStop chan struct{}
|
||||
}
|
||||
|
||||
type QueueMetadata struct {
|
||||
MsgMeta *module.MsgMetadata
|
||||
From string
|
||||
|
||||
// Recipients that should be tried next.
|
||||
// May or may not be equal to PartialError.TemporaryFailed.
|
||||
To []string
|
||||
|
||||
// Information about previous failures.
|
||||
// Preserved to be included in a bounce message.
|
||||
FailedRcpts []string
|
||||
TemporaryFailedRcpts []string
|
||||
// All errors are converted to SMTPError we can serialize and
|
||||
// also it is directly usable for bounce messages.
|
||||
RcptErrs map[string]*smtp.SMTPError
|
||||
|
||||
// Amount of times delivery *already tried*.
|
||||
TriesCount int
|
||||
|
||||
FirstAttempt time.Time
|
||||
LastAttempt time.Time
|
||||
|
||||
// Whether this is a delivery notification.
|
||||
DSN bool
|
||||
}
|
||||
|
||||
func NewQueue(_, instName string, _ []string) (module.Module, error) {
|
||||
return &Queue{
|
||||
name: instName,
|
||||
initialRetryTime: 15 * time.Minute,
|
||||
retryTimeScale: 2,
|
||||
postInitDelay: 10 * time.Second,
|
||||
Log: log.Logger{Name: "queue"},
|
||||
workersStop: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (q *Queue) Init(cfg *config.Map) error {
|
||||
var workers int
|
||||
cfg.Bool("debug", true, &q.Log.Debug)
|
||||
cfg.Int("max_tries", false, false, 8, &q.maxTries)
|
||||
cfg.Int("workers", false, false, 16, &workers)
|
||||
cfg.String("location", false, false, "", &q.location)
|
||||
cfg.Custom("target", false, true, nil, deliveryDirective, &q.Target)
|
||||
cfg.String("hostname", true, true, "", &q.hostname)
|
||||
cfg.String("autogenerated_msg_domain", true, false, "", &q.autogenMsgDomain)
|
||||
cfg.Custom("bounce", false, false, nil, func(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
return NewDispatcher(m.Globals, node.Children)
|
||||
}, &q.dsnDispatcher)
|
||||
if _, err := cfg.Process(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if q.dsnDispatcher != nil {
|
||||
if q.autogenMsgDomain == "" {
|
||||
return errors.New("queue: autogenerated_msg_domain is required if bounce {} is specified")
|
||||
}
|
||||
|
||||
q.dsnDispatcher.hostname = q.hostname
|
||||
q.dsnDispatcher.Log = log.Logger{Name: "queue/dispatcher", Debug: q.Log.Debug}
|
||||
}
|
||||
if q.location == "" && q.name == "" {
|
||||
return errors.New("queue: need explicit location directive or config block name if defined inline")
|
||||
}
|
||||
if q.location == "" {
|
||||
q.location = filepath.Join(StateDirectory(cfg.Globals), q.name)
|
||||
}
|
||||
if !filepath.IsAbs(q.location) {
|
||||
q.location = filepath.Join(StateDirectory(cfg.Globals), q.location)
|
||||
}
|
||||
|
||||
// TODO: Check location write permissions.
|
||||
if err := os.MkdirAll(q.location, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return q.start(workers)
|
||||
}
|
||||
|
||||
func (q *Queue) start(workers int) error {
|
||||
q.wheel = NewTimeWheel()
|
||||
|
||||
if err := q.readDiskQueue(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q.Log.Debugf("delivery target: %T", q.Target)
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
q.workersWg.Add(1)
|
||||
go q.worker()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Queue) Close() error {
|
||||
// Make Close function idempotent. This makes it more
|
||||
// convenient to use in certain situations (see queue tests).
|
||||
if q.workersStop == nil {
|
||||
return nil
|
||||
}
|
||||
close(q.workersStop)
|
||||
q.workersWg.Wait()
|
||||
q.workersStop = nil
|
||||
q.wheel.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Queue) worker() {
|
||||
for {
|
||||
select {
|
||||
case <-q.workersStop:
|
||||
q.workersWg.Done()
|
||||
return
|
||||
case slot := <-q.wheel.Dispatch():
|
||||
q.Log.Debugln("worker woke up for", slot.Value)
|
||||
id := slot.Value.(string)
|
||||
|
||||
meta, header, body, err := q.openMessage(id)
|
||||
if err != nil {
|
||||
q.Log.Printf("failed to read message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
q.tryDelivery(meta, header, body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Queue) tryDelivery(meta *QueueMetadata, header textproto.Header, body buffer.Buffer) {
|
||||
dl := deliveryLogger(q.Log, meta.MsgMeta)
|
||||
dl.Debugf("delivery attempt #%d", meta.TriesCount+1)
|
||||
|
||||
partialErr := q.deliver(meta, header, body)
|
||||
dl.Debugf("failures: permanently: %v, temporary: %v, errors: %v",
|
||||
partialErr.Failed, partialErr.TemporaryFailed, partialErr.Errs)
|
||||
|
||||
// Save permanent errors information for reporting in bounce message.
|
||||
meta.FailedRcpts = append(meta.FailedRcpts, partialErr.Failed...)
|
||||
for rcpt, rcptErr := range partialErr.Errs {
|
||||
var smtpErr *smtp.SMTPError
|
||||
var ok bool
|
||||
if smtpErr, ok = rcptErr.(*smtp.SMTPError); !ok {
|
||||
smtpErr = &smtp.SMTPError{
|
||||
Code: 554,
|
||||
EnhancedCode: smtp.EnhancedCode{5, 0, 0},
|
||||
Message: rcptErr.Error(),
|
||||
}
|
||||
if isTemporaryErr(rcptErr) {
|
||||
smtpErr.Code = 451
|
||||
smtpErr.EnhancedCode = smtp.EnhancedCode{4, 0, 0}
|
||||
}
|
||||
}
|
||||
|
||||
meta.RcptErrs[rcpt] = smtpErr
|
||||
}
|
||||
meta.To = partialErr.TemporaryFailed
|
||||
|
||||
meta.LastAttempt = time.Now()
|
||||
if meta.TriesCount == q.maxTries || len(partialErr.TemporaryFailed) == 0 {
|
||||
// Attempt either fully succeeded or completely failed.
|
||||
if meta.TriesCount == q.maxTries {
|
||||
dl.Printf("gave up trying to deliver to %v, errors: %v", meta.TemporaryFailedRcpts, meta.RcptErrs)
|
||||
}
|
||||
if len(meta.FailedRcpts) != 0 {
|
||||
dl.Printf("permanently failed to deliver to %v, errors: %v", meta.FailedRcpts, meta.RcptErrs)
|
||||
}
|
||||
if !meta.DSN {
|
||||
q.emitDSN(meta, header)
|
||||
}
|
||||
q.removeFromDisk(meta.MsgMeta)
|
||||
return
|
||||
}
|
||||
|
||||
meta.TriesCount++
|
||||
|
||||
if err := q.updateMetadataOnDisk(meta); err != nil {
|
||||
dl.Printf("failed to update meta-data: %v", err)
|
||||
}
|
||||
|
||||
nextTryTime := time.Now()
|
||||
nextTryTime = nextTryTime.Add(q.initialRetryTime * time.Duration(math.Pow(q.retryTimeScale, float64(meta.TriesCount-1))))
|
||||
dl.Printf("%d attempt failed, will retry in %v (at %v)", meta.TriesCount, time.Until(nextTryTime), nextTryTime)
|
||||
|
||||
q.wheel.Add(nextTryTime, meta.MsgMeta.ID)
|
||||
}
|
||||
|
||||
func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffer.Buffer) PartialError {
|
||||
dl := deliveryLogger(q.Log, meta.MsgMeta)
|
||||
perr := PartialError{
|
||||
Errs: map[string]error{},
|
||||
}
|
||||
|
||||
target := q.Target
|
||||
if meta.DSN {
|
||||
target = q.dsnDispatcher
|
||||
}
|
||||
|
||||
delivery, err := target.Start(meta.MsgMeta, meta.From)
|
||||
if err != nil {
|
||||
perr.Failed = append(perr.Failed, meta.To...)
|
||||
for _, rcpt := range meta.To {
|
||||
perr.Errs[rcpt] = err
|
||||
}
|
||||
return perr
|
||||
}
|
||||
|
||||
var acceptedRcpts []string
|
||||
for _, rcpt := range meta.To {
|
||||
if err := delivery.AddRcpt(rcpt); err != nil {
|
||||
if isTemporaryErr(err) {
|
||||
perr.TemporaryFailed = append(perr.TemporaryFailed, rcpt)
|
||||
} else {
|
||||
perr.Failed = append(perr.Failed, rcpt)
|
||||
}
|
||||
perr.Errs[rcpt] = err
|
||||
} else {
|
||||
acceptedRcpts = append(acceptedRcpts, rcpt)
|
||||
}
|
||||
}
|
||||
|
||||
if len(acceptedRcpts) == 0 {
|
||||
if err := delivery.Abort(); err != nil {
|
||||
dl.Printf("delivery.Abort failed: %v", err)
|
||||
}
|
||||
return perr
|
||||
}
|
||||
|
||||
expandToPartialErr := func(err error) {
|
||||
if expandedPerr, ok := err.(PartialError); ok {
|
||||
perr.TemporaryFailed = append(perr.TemporaryFailed, expandedPerr.TemporaryFailed...)
|
||||
perr.Failed = append(perr.Failed, expandedPerr.Failed...)
|
||||
for rcpt, rcptErr := range expandedPerr.Errs {
|
||||
perr.Errs[rcpt] = rcptErr
|
||||
}
|
||||
} else {
|
||||
if isTemporaryErr(err) {
|
||||
perr.TemporaryFailed = append(perr.TemporaryFailed, acceptedRcpts...)
|
||||
} else {
|
||||
perr.Failed = append(perr.Failed, acceptedRcpts...)
|
||||
}
|
||||
for _, rcpt := range acceptedRcpts {
|
||||
perr.Errs[rcpt] = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := delivery.Body(header, body); err != nil {
|
||||
expandToPartialErr(err)
|
||||
// No recipients succeeded.
|
||||
if len(perr.TemporaryFailed)+len(perr.Failed) == len(acceptedRcpts) {
|
||||
if err := delivery.Abort(); err != nil {
|
||||
dl.Printf("delivery.Abort failed: %v", err)
|
||||
}
|
||||
return perr
|
||||
}
|
||||
}
|
||||
if err := delivery.Commit(); err != nil {
|
||||
expandToPartialErr(err)
|
||||
}
|
||||
|
||||
return perr
|
||||
}
|
||||
|
||||
type queueDelivery struct {
|
||||
q *Queue
|
||||
meta *QueueMetadata
|
||||
|
||||
header textproto.Header
|
||||
body buffer.Buffer
|
||||
}
|
||||
|
||||
func (qd *queueDelivery) AddRcpt(rcptTo string) error {
|
||||
qd.meta.To = append(qd.meta.To, rcptTo)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qd *queueDelivery) Body(header textproto.Header, body buffer.Buffer) error {
|
||||
// Body buffer initially passed to us may not be valid after "delivery" to queue completes.
|
||||
// storeNewMessage returns a new buffer object created from message blob stored on disk.
|
||||
storedBody, err := qd.q.storeNewMessage(qd.meta, header, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
qd.body = storedBody
|
||||
qd.header = header
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qd *queueDelivery) Abort() error {
|
||||
if qd.body != nil {
|
||||
qd.q.removeFromDisk(qd.meta.MsgMeta)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qd *queueDelivery) Commit() error {
|
||||
// workersWg counter in incremented to make sure there will be no race with Close.
|
||||
// e.g. it will not close the wheel before we complete first attempt.
|
||||
// Also the first attempt is not scheduled using time wheel because
|
||||
// the "normal" code path requires re-reading and re-parsing of header
|
||||
// which is kinda expensive.
|
||||
// FIXME: Though this is temporary solution, the correct fix
|
||||
// would be to ditch "worker goroutines" altogether and enforce
|
||||
// concurrent deliveries limit using a semaphore.
|
||||
qd.q.workersWg.Add(1)
|
||||
go func() {
|
||||
qd.q.tryDelivery(qd.meta, qd.header, qd.body)
|
||||
qd.q.workersWg.Done()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Queue) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
meta := &QueueMetadata{
|
||||
MsgMeta: msgMeta,
|
||||
From: mailFrom,
|
||||
RcptErrs: map[string]*smtp.SMTPError{},
|
||||
FirstAttempt: time.Now(),
|
||||
LastAttempt: time.Now(),
|
||||
}
|
||||
return &queueDelivery{q: q, meta: meta}, nil
|
||||
}
|
||||
|
||||
func (q *Queue) StartDSN(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
meta := &QueueMetadata{
|
||||
DSN: true,
|
||||
MsgMeta: msgMeta,
|
||||
From: mailFrom,
|
||||
RcptErrs: map[string]*smtp.SMTPError{},
|
||||
FirstAttempt: time.Now(),
|
||||
LastAttempt: time.Now(),
|
||||
}
|
||||
return &queueDelivery{q: q, meta: meta}, nil
|
||||
}
|
||||
|
||||
func (q *Queue) removeFromDisk(msgMeta *module.MsgMetadata) {
|
||||
id := msgMeta.ID
|
||||
dl := deliveryLogger(q.Log, msgMeta)
|
||||
|
||||
// Order is important.
|
||||
// If we remove header and body but can't remove meta now - readDiskQueue
|
||||
// will detect and report it.
|
||||
headerPath := filepath.Join(q.location, id+".header")
|
||||
if err := os.Remove(headerPath); err != nil {
|
||||
dl.Printf("failed to remove header from disk: %v", err)
|
||||
}
|
||||
bodyPath := filepath.Join(q.location, id+".body")
|
||||
if err := os.Remove(bodyPath); err != nil {
|
||||
dl.Printf("failed to remove body from disk: %v", err)
|
||||
}
|
||||
metaPath := filepath.Join(q.location, id+".meta")
|
||||
if err := os.Remove(metaPath); err != nil {
|
||||
dl.Printf("failed to remove meta-data from disk: %v", err)
|
||||
}
|
||||
dl.Debugf("removed message from disk")
|
||||
}
|
||||
|
||||
func (q *Queue) readDiskQueue() error {
|
||||
dirInfo, err := ioutil.ReadDir(q.location)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Rewrite this function to pass all sub-tests in TestQueueDelivery_DeserializationCleanUp/NoMeta.
|
||||
|
||||
loadedCount := 0
|
||||
for _, entry := range dirInfo {
|
||||
// We start loading from meta-data files and then check whether ID.header and ID.body exist.
|
||||
// This allows us to properly detect dangling body files.
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta") {
|
||||
continue
|
||||
}
|
||||
id := entry.Name()[:len(entry.Name())-5]
|
||||
|
||||
meta, err := q.readMessageMeta(id)
|
||||
if err != nil {
|
||||
q.Log.Printf("failed to read meta-data, skipping: %v (msg ID = %s)", err, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check header file existence.
|
||||
if _, err := os.Stat(filepath.Join(q.location, id+".header")); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
q.Log.Printf("header file doesn't exist for msg ID = %s", id)
|
||||
q.tryRemoveDanglingFile(id + ".meta")
|
||||
q.tryRemoveDanglingFile(id + ".body")
|
||||
} else {
|
||||
q.Log.Printf("skipping nonstat'able header file: %v (msg ID = %s)", err, id)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check body file existence.
|
||||
if _, err := os.Stat(filepath.Join(q.location, id+".body")); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
q.Log.Printf("body file doesn't exist for msg ID = %s", id)
|
||||
q.tryRemoveDanglingFile(id + ".meta")
|
||||
q.tryRemoveDanglingFile(id + ".header")
|
||||
} else {
|
||||
q.Log.Printf("skipping nonstat'able body file: %v (msg ID = %s)", err, id)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
nextTryTime := meta.LastAttempt
|
||||
nextTryTime = nextTryTime.Add(q.initialRetryTime * time.Duration(math.Pow(q.retryTimeScale, float64(meta.TriesCount-1))))
|
||||
|
||||
if time.Until(nextTryTime) < q.postInitDelay {
|
||||
nextTryTime = time.Now().Add(q.postInitDelay)
|
||||
}
|
||||
|
||||
q.Log.Debugf("will try to deliver (msg ID = %s) in %v (%v)", id, time.Until(nextTryTime), nextTryTime)
|
||||
q.wheel.Add(nextTryTime, id)
|
||||
loadedCount++
|
||||
}
|
||||
|
||||
if loadedCount != 0 {
|
||||
q.Log.Printf("loaded %d saved queue entries", loadedCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Queue) storeNewMessage(meta *QueueMetadata, header textproto.Header, body buffer.Buffer) (buffer.Buffer, error) {
|
||||
id := meta.MsgMeta.ID
|
||||
|
||||
headerPath := filepath.Join(q.location, id+".header")
|
||||
headerFile, err := os.Create(headerPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer headerFile.Close()
|
||||
|
||||
if err := textproto.WriteHeader(headerFile, header); err != nil {
|
||||
q.tryRemoveDanglingFile(id + ".header")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyReader, err := body.Open()
|
||||
if err != nil {
|
||||
q.tryRemoveDanglingFile(id + ".header")
|
||||
return nil, err
|
||||
}
|
||||
defer bodyReader.Close()
|
||||
|
||||
bodyPath := filepath.Join(q.location, id+".body")
|
||||
bodyFile, err := os.Create(bodyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer bodyFile.Close()
|
||||
|
||||
if _, err := io.Copy(bodyFile, bodyReader); err != nil {
|
||||
q.tryRemoveDanglingFile(id + ".body")
|
||||
q.tryRemoveDanglingFile(id + ".header")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := q.updateMetadataOnDisk(meta); err != nil {
|
||||
q.tryRemoveDanglingFile(id + ".body")
|
||||
q.tryRemoveDanglingFile(id + ".header")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buffer.FileBuffer{Path: bodyPath}, nil
|
||||
}
|
||||
|
||||
func (q *Queue) updateMetadataOnDisk(meta *QueueMetadata) error {
|
||||
metaPath := filepath.Join(q.location, meta.MsgMeta.ID+".meta")
|
||||
file, err := os.Create(metaPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
metaCopy := *meta
|
||||
metaCopy.MsgMeta = meta.MsgMeta.DeepCopy()
|
||||
|
||||
if _, ok := metaCopy.MsgMeta.SrcAddr.(*net.TCPAddr); !ok {
|
||||
meta.MsgMeta.SrcAddr = nil
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(file).Encode(metaCopy); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Queue) readMessageMeta(id string) (*QueueMetadata, error) {
|
||||
metaPath := filepath.Join(q.location, id+".meta")
|
||||
file, err := os.Open(metaPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
meta := &QueueMetadata{}
|
||||
|
||||
// net.Addr can't be deserialized because we don't know concrete type. For
|
||||
// this reason we assume that SrcAddr is TCPAddr, if it is not - we drop it
|
||||
// during serialization (see updateMetadataOnDisk).
|
||||
meta.MsgMeta = &module.MsgMetadata{}
|
||||
meta.MsgMeta.SrcAddr = &net.TCPAddr{}
|
||||
|
||||
if err := json.NewDecoder(file).Decode(meta); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
type BufferedReadCloser struct {
|
||||
*bufio.Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
func (q *Queue) tryRemoveDanglingFile(name string) {
|
||||
if err := os.Remove(filepath.Join(q.location, name)); err != nil {
|
||||
q.Log.Println(err)
|
||||
return
|
||||
}
|
||||
q.Log.Printf("removed dangling file %s", name)
|
||||
}
|
||||
|
||||
func (q *Queue) openMessage(id string) (*QueueMetadata, textproto.Header, buffer.Buffer, error) {
|
||||
meta, err := q.readMessageMeta(id)
|
||||
if err != nil {
|
||||
return nil, textproto.Header{}, nil, err
|
||||
}
|
||||
|
||||
bodyPath := filepath.Join(q.location, id+".body")
|
||||
_, err = os.Stat(bodyPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
q.tryRemoveDanglingFile(id + ".meta")
|
||||
}
|
||||
return nil, textproto.Header{}, nil, nil
|
||||
}
|
||||
body := buffer.FileBuffer{Path: bodyPath}
|
||||
|
||||
headerPath := filepath.Join(q.location, id+".header")
|
||||
headerFile, err := os.Open(headerPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
q.tryRemoveDanglingFile(id + ".meta")
|
||||
q.tryRemoveDanglingFile(id + ".body")
|
||||
}
|
||||
return nil, textproto.Header{}, nil, nil
|
||||
}
|
||||
|
||||
bufferedHeader := bufio.NewReader(headerFile)
|
||||
header, err := textproto.ReadHeader(bufferedHeader)
|
||||
if err != nil {
|
||||
return nil, textproto.Header{}, nil, nil
|
||||
}
|
||||
|
||||
return meta, header, body, nil
|
||||
}
|
||||
|
||||
func (q *Queue) InstanceName() string {
|
||||
return q.name
|
||||
}
|
||||
|
||||
func (q *Queue) Name() string {
|
||||
return "queue"
|
||||
}
|
||||
|
||||
func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
|
||||
// If, apparently, we have no DSN dispatcher configured - do nothing.
|
||||
if q.dsnDispatcher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
dsnID, err := generateMsgID()
|
||||
if err != nil {
|
||||
q.Log.Printf("rand.Rand error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dsnEnvelope := dsn.Envelope{
|
||||
MsgID: "<" + dsnID + "@" + q.autogenMsgDomain + ">",
|
||||
From: "MAILER-DAEMON@" + q.autogenMsgDomain,
|
||||
To: meta.From,
|
||||
}
|
||||
mtaInfo := dsn.ReportingMTAInfo{
|
||||
ReportingMTA: q.hostname,
|
||||
XSender: meta.From,
|
||||
XMessageID: meta.MsgMeta.ID,
|
||||
ArrivalDate: meta.FirstAttempt,
|
||||
LastAttemptDate: meta.LastAttempt,
|
||||
}
|
||||
if !meta.MsgMeta.DontTraceSender {
|
||||
mtaInfo.ReceivedFromMTA = meta.MsgMeta.SrcHostname
|
||||
}
|
||||
|
||||
rcptInfo := make([]dsn.RecipientInfo, 0, len(meta.RcptErrs))
|
||||
for rcpt, err := range meta.RcptErrs {
|
||||
if meta.MsgMeta.OriginalRcpts != nil {
|
||||
originalRcpt := meta.MsgMeta.OriginalRcpts[rcpt]
|
||||
if originalRcpt != "" {
|
||||
rcpt = originalRcpt
|
||||
}
|
||||
}
|
||||
|
||||
rcptInfo = append(rcptInfo, dsn.RecipientInfo{
|
||||
FinalRecipient: rcpt,
|
||||
Action: dsn.ActionFailed,
|
||||
Status: err.EnhancedCode,
|
||||
DiagnosticCode: err,
|
||||
})
|
||||
}
|
||||
|
||||
var dsnBodyBlob bytes.Buffer
|
||||
dl := deliveryLogger(q.Log, meta.MsgMeta)
|
||||
dsnHeader, err := dsn.GenerateDSN(dsnEnvelope, mtaInfo, rcptInfo, header, &dsnBodyBlob)
|
||||
if err != nil {
|
||||
dl.Printf("failed to generate fail DSN: %v", err)
|
||||
return
|
||||
}
|
||||
dsnBody := buffer.MemoryBuffer{Slice: dsnBodyBlob.Bytes()}
|
||||
|
||||
dsnMeta := &module.MsgMetadata{
|
||||
ID: dsnID,
|
||||
SrcProto: "",
|
||||
SrcHostname: q.hostname,
|
||||
OurHostname: q.hostname,
|
||||
}
|
||||
dl.Printf("generated failed DSN, DSN ID = %s", dsnID)
|
||||
|
||||
dsnDelivery, err := q.StartDSN(dsnMeta, "MAILER-DAEMON@"+q.autogenMsgDomain)
|
||||
if err != nil {
|
||||
dl.Printf("failed to enqueue DSN: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
dl.Printf("failed to enqueue DSN: %v", err)
|
||||
dsnDelivery.Abort()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = dsnDelivery.AddRcpt(meta.From); err != nil {
|
||||
return
|
||||
}
|
||||
if err = dsnDelivery.Body(dsnHeader, dsnBody); err != nil {
|
||||
return
|
||||
}
|
||||
if err = dsnDelivery.Commit(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
module.Register("queue", NewQueue)
|
||||
}
|
549
queue_test.go
549
queue_test.go
|
@ -1,549 +0,0 @@
|
|||
package maddy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
// newTestQueue returns properly initialized Queue object usable for testing.
|
||||
//
|
||||
// See newTestQueueDir to create testing queue from an existing directory.
|
||||
// It is called responsibility to remove queue directory created by this function.
|
||||
func newTestQueue(t *testing.T, target module.DeliveryTarget) *Queue {
|
||||
dir, err := ioutil.TempDir("", "maddy-tests-queue")
|
||||
if err != nil {
|
||||
t.Fatal("failed to create temporary directory for queue:", err)
|
||||
}
|
||||
return newTestQueueDir(t, target, dir)
|
||||
}
|
||||
|
||||
func cleanQueue(t *testing.T, q *Queue) {
|
||||
if err := q.Close(); err != nil {
|
||||
t.Fatal("queue.Close:", err)
|
||||
}
|
||||
if err := os.RemoveAll(q.location); err != nil {
|
||||
t.Fatal("os.RemoveAll", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestQueueDir(t *testing.T, target module.DeliveryTarget, dir string) *Queue {
|
||||
mod, _ := NewQueue("", "queue", nil)
|
||||
q := mod.(*Queue)
|
||||
q.Log = testLogger(t, "queue")
|
||||
q.initialRetryTime = 0
|
||||
q.retryTimeScale = 1
|
||||
q.postInitDelay = 0
|
||||
q.maxTries = 5
|
||||
q.location = dir
|
||||
q.Target = target
|
||||
|
||||
if !testing.Verbose() {
|
||||
q.Log = log.Logger{Name: "", Out: log.WriterLog(ioutil.Discard)}
|
||||
}
|
||||
|
||||
q.start(1)
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// unreliableTarget is a module.DeliveryTarget implementation that stores
|
||||
// messages to a slice and sometimes fails with the specified error.
|
||||
type unreliableTarget struct {
|
||||
committed chan msg
|
||||
aborted chan msg
|
||||
|
||||
// Amount of completed deliveries (both failed and succeeded)
|
||||
passedMessages int
|
||||
|
||||
// To make unreliableTarget fail Commit for N-th delivery, set N-1-th
|
||||
// element of this slice to wanted error object. If slice is
|
||||
// nil/empty or N is bigger than its size - delivery will succeed.
|
||||
bodyFailures []error
|
||||
rcptFailures []map[string]error
|
||||
}
|
||||
|
||||
type unreliableTargetDelivery struct {
|
||||
ut *unreliableTarget
|
||||
msg msg
|
||||
}
|
||||
|
||||
func (utd *unreliableTargetDelivery) AddRcpt(rcptTo string) error {
|
||||
if len(utd.ut.rcptFailures) > utd.ut.passedMessages {
|
||||
rcptErrs := utd.ut.rcptFailures[utd.ut.passedMessages]
|
||||
if err := rcptErrs[rcptTo]; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
utd.msg.rcptTo = append(utd.msg.rcptTo, rcptTo)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (utd *unreliableTargetDelivery) Body(header textproto.Header, body buffer.Buffer) error {
|
||||
r, _ := body.Open()
|
||||
utd.msg.body, _ = ioutil.ReadAll(r)
|
||||
|
||||
if len(utd.ut.bodyFailures) > utd.ut.passedMessages {
|
||||
return utd.ut.bodyFailures[utd.ut.passedMessages]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (utd *unreliableTargetDelivery) Abort() error {
|
||||
utd.ut.passedMessages++
|
||||
if utd.ut.aborted != nil {
|
||||
utd.ut.aborted <- utd.msg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (utd *unreliableTargetDelivery) Commit() error {
|
||||
utd.ut.passedMessages++
|
||||
if utd.ut.committed != nil {
|
||||
utd.ut.committed <- utd.msg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ut *unreliableTarget) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
return &unreliableTargetDelivery{
|
||||
ut: ut,
|
||||
msg: msg{
|
||||
msgMeta: msgMeta,
|
||||
mailFrom: mailFrom,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readMsgChanTimeout(t *testing.T, ch <-chan msg, timeout time.Duration) *msg {
|
||||
t.Helper()
|
||||
timer := time.NewTimer(timeout)
|
||||
select {
|
||||
case msg := <-ch:
|
||||
return &msg
|
||||
case <-timer.C:
|
||||
t.Fatal("chan read timed out")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkQueueDir(t *testing.T, q *Queue, expectedIDs []string) {
|
||||
t.Helper()
|
||||
// We use the map to lookups and also to mark messages we found
|
||||
// we can report missing entries.
|
||||
expectedMap := make(map[string]bool, len(expectedIDs))
|
||||
for _, id := range expectedIDs {
|
||||
expectedMap[id] = false
|
||||
}
|
||||
|
||||
dir, err := ioutil.ReadDir(q.location)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read queue directory: %v", err)
|
||||
}
|
||||
|
||||
// Queue implementation uses file names in the following format:
|
||||
// DELIVERY_ID.SOMETHING
|
||||
for _, file := range dir {
|
||||
if file.IsDir() {
|
||||
t.Fatalf("queue should not create subdirectories in the store, but there is %s dir in it", file.Name())
|
||||
}
|
||||
|
||||
nameParts := strings.Split(file.Name(), ".")
|
||||
if len(nameParts) != 2 {
|
||||
t.Fatalf("did the queue files name format changed? got %s", file.Name())
|
||||
}
|
||||
|
||||
_, ok := expectedMap[nameParts[0]]
|
||||
if !ok {
|
||||
t.Errorf("message with unexpected Msg ID %s is stored in queue store", nameParts[0])
|
||||
continue
|
||||
}
|
||||
|
||||
expectedMap[nameParts[0]] = true
|
||||
}
|
||||
|
||||
for id, found := range expectedMap {
|
||||
if !found {
|
||||
t.Errorf("expected message with Msg ID %s is missing from queue store", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueDelivery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dt := unreliableTarget{committed: make(chan msg, 10)}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// This is far from being a proper blackbox testing.
|
||||
// But I can't come up with a better way to inspect the Queue state.
|
||||
// This probably will be improved when bounce messages will be implemented.
|
||||
// For now, this is a dirty hack. Close the Queue and inspect serialized state.
|
||||
// FIXME.
|
||||
|
||||
// Wait for the delivery to complete and stop processing.
|
||||
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
q.Close()
|
||||
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// There should be no queued messages.
|
||||
checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
func TestQueueDelivery_PermanentFail_NonPartial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dt := unreliableTarget{
|
||||
bodyFailures: []error{
|
||||
&smtp.SMTPError{
|
||||
Code: 500,
|
||||
EnhancedCode: smtp.EnhancedCode{5, 0, 0},
|
||||
Message: "you shall not pass",
|
||||
},
|
||||
},
|
||||
aborted: make(chan msg, 10),
|
||||
}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// Queue will abort a delivery if it fails for all recipients.
|
||||
readMsgChanTimeout(t, dt.aborted, 5*time.Second)
|
||||
q.Close()
|
||||
|
||||
// Delivery is failed permanently, hence no retry should be rescheduled.
|
||||
checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
func TestQueueDelivery_PermanentFail_Partial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dt := unreliableTarget{
|
||||
bodyFailures: []error{
|
||||
PartialError{
|
||||
Failed: []string{"tester1@example.org", "tester2@example.org"},
|
||||
Errs: map[string]error{
|
||||
"tester1@example.org": errors.New("you shall not pass"),
|
||||
"tester2@example.org": errors.New("you shall not pass"),
|
||||
},
|
||||
},
|
||||
},
|
||||
aborted: make(chan msg, 10),
|
||||
}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// This this is similar to the previous test, but checks PartialErr processing logic.
|
||||
// Here delivery fails for recipients too, but this is reported using PartialErr.
|
||||
|
||||
readMsgChanTimeout(t, dt.aborted, 5*time.Second)
|
||||
q.Close()
|
||||
checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
func TestQueueDelivery_TemporaryFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dt := unreliableTarget{
|
||||
bodyFailures: []error{
|
||||
PartialError{
|
||||
TemporaryFailed: []string{"tester1@example.org", "tester2@example.org"},
|
||||
Errs: map[string]error{
|
||||
"tester1@example.org": errors.New("you shall not pass"),
|
||||
"tester2@example.org": errors.New("you shall not pass"),
|
||||
},
|
||||
},
|
||||
},
|
||||
aborted: make(chan msg, 10),
|
||||
committed: make(chan msg, 10),
|
||||
}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// Delivery should be aborted, because it failed for all recipients.
|
||||
readMsgChanTimeout(t, dt.aborted, 5*time.Second)
|
||||
|
||||
// Second retry, should work fine.
|
||||
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
q.Close()
|
||||
// No more retries scheduled, queue storage is clear.
|
||||
defer checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
func TestQueueDelivery_TemporaryFail_Partial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dt := unreliableTarget{
|
||||
bodyFailures: []error{
|
||||
PartialError{
|
||||
TemporaryFailed: []string{"tester2@example.org"},
|
||||
Errs: map[string]error{
|
||||
"tester2@example.org": &smtp.SMTPError{
|
||||
Code: 400,
|
||||
Message: "go away",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
aborted: make(chan msg, 10),
|
||||
committed: make(chan msg, 10),
|
||||
}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// Committed, tester1@example.org - ok.
|
||||
msg := readMsgChanTimeout(t, dt.committed, 5000*time.Second)
|
||||
// Side note: unreliableTarget adds recipients to the msg object even if they were rejected
|
||||
// later using a partial error. So slice below is all recipients that were submitted by
|
||||
// the queue.
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// committed #2, tester2@example.org - ok
|
||||
msg = readMsgChanTimeout(t, dt.committed, 5000*time.Second)
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
|
||||
|
||||
q.Close()
|
||||
// No more retries scheduled, queue storage is clear.
|
||||
checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
func TestQueueDelivery_MultipleAttempts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dt := unreliableTarget{
|
||||
bodyFailures: []error{
|
||||
PartialError{
|
||||
Failed: []string{"tester1@example.org"},
|
||||
TemporaryFailed: []string{"tester2@example.org"},
|
||||
Errs: map[string]error{
|
||||
"tester1@example.org": errors.New("you shall not pass"),
|
||||
"tester2@example.org": errors.New("you shall not pass"),
|
||||
},
|
||||
},
|
||||
PartialError{
|
||||
TemporaryFailed: []string{"tester2@example.org"},
|
||||
Errs: map[string]error{
|
||||
"tester2@example.org": errors.New("you shall not pass"),
|
||||
},
|
||||
},
|
||||
},
|
||||
committed: make(chan msg, 10),
|
||||
}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org", "tester3@example.org"})
|
||||
|
||||
// Committed because delivery to tester3@example.org is succeeded.
|
||||
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
// Side note: This slice contains all recipients submitted by the queue, even if
|
||||
// they were rejected later using PartialError.
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org", "tester2@example.org", "tester3@example.org"})
|
||||
|
||||
// tester1 is failed permanently, should not be retried.
|
||||
// tester2 is failed temporary, should be retried.
|
||||
msg = readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
|
||||
|
||||
q.Close()
|
||||
// No more retries should be scheduled.
|
||||
checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
func TestQueueDelivery_PermanentRcptReject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dt := unreliableTarget{
|
||||
rcptFailures: []map[string]error{
|
||||
{
|
||||
"tester1@example.org": &smtp.SMTPError{
|
||||
Code: 500,
|
||||
Message: "go away",
|
||||
},
|
||||
},
|
||||
},
|
||||
committed: make(chan msg, 10),
|
||||
}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
doTestDelivery(t, q, "tester@example.org", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// Committed, tester2@example.org succeeded.
|
||||
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
checkMsg(t, msg, "tester@example.org", []string{"tester2@example.org"})
|
||||
|
||||
q.Close()
|
||||
// No more retries should be scheduled.
|
||||
checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
func TestQueueDelivery_TemporaryRcptReject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dt := unreliableTarget{
|
||||
rcptFailures: []map[string]error{
|
||||
{
|
||||
"tester1@example.org": &smtp.SMTPError{
|
||||
Code: 400,
|
||||
Message: "go away",
|
||||
},
|
||||
},
|
||||
},
|
||||
committed: make(chan msg, 10),
|
||||
}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
// First attempt:
|
||||
// tester1 - temp. fail
|
||||
// tester2 - ok
|
||||
// Second attempt:
|
||||
// tester1 - ok
|
||||
doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
// Unlike previous tests where unreliableTarget rejected recipients by PartialError, here they are rejected
|
||||
// by AddRcpt directly, so they are NOT saved by the target.
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
|
||||
|
||||
msg = readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org"})
|
||||
|
||||
q.Close()
|
||||
// No more retries should be scheduled.
|
||||
checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
func TestQueueDelivery_SerializationRoundtrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dt := unreliableTarget{
|
||||
rcptFailures: []map[string]error{
|
||||
{
|
||||
"tester1@example.org": &smtp.SMTPError{
|
||||
Code: 400,
|
||||
Message: "go away",
|
||||
},
|
||||
},
|
||||
},
|
||||
committed: make(chan msg, 10),
|
||||
}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
// This is the most tricky test because it is racy and I have no idea what can be done to avoid it.
|
||||
// It relies on us calling Close before queue dispatcher decides to retry delivery.
|
||||
// Hence retry delay is increased from 0ms used in other tests to make it reliable.
|
||||
q.initialRetryTime = 1 * time.Second
|
||||
|
||||
// To make sure we will not time out due to post-init delay.
|
||||
q.postInitDelay = 0
|
||||
|
||||
deliveryID := doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// Standard partial delivery, retry will be scheduled for tester1@example.org.
|
||||
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
|
||||
|
||||
// Then stop it.
|
||||
q.Close()
|
||||
|
||||
// Make sure it is saved.
|
||||
checkQueueDir(t, q, []string{deliveryID})
|
||||
|
||||
// Then reinit it.
|
||||
q = newTestQueueDir(t, &dt, q.location)
|
||||
|
||||
// Wait for retry and check it.
|
||||
msg = readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester1@example.org"})
|
||||
|
||||
// Close it again.
|
||||
q.Close()
|
||||
// No more retries should be scheduled.
|
||||
checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
func TestQueueDelivery_DeserlizationCleanUp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
test := func(t *testing.T, fileSuffix string) {
|
||||
dt := unreliableTarget{
|
||||
rcptFailures: []map[string]error{
|
||||
{
|
||||
"tester1@example.org": &smtp.SMTPError{
|
||||
Code: 400,
|
||||
Message: "go away",
|
||||
},
|
||||
},
|
||||
},
|
||||
committed: make(chan msg, 10),
|
||||
}
|
||||
q := newTestQueue(t, &dt)
|
||||
defer cleanQueue(t, q)
|
||||
|
||||
// This is the most tricky test because it is racy and I have no idea what can be done to avoid it.
|
||||
// It relies on us calling Close before queue dispatcher decides to retry delivery.
|
||||
// Hence retry delay is increased from 0ms used in other tests to make it reliable.
|
||||
q.initialRetryTime = 1 * time.Second
|
||||
|
||||
// To make sure we will not time out due to post-init delay.
|
||||
q.postInitDelay = 0
|
||||
|
||||
deliveryID := doTestDelivery(t, q, "tester@example.com", []string{"tester1@example.org", "tester2@example.org"})
|
||||
|
||||
// Standard partial delivery, retry will be scheduled for tester1@example.org.
|
||||
msg := readMsgChanTimeout(t, dt.committed, 5*time.Second)
|
||||
checkMsg(t, msg, "tester@example.com", []string{"tester2@example.org"})
|
||||
|
||||
q.Close()
|
||||
|
||||
if err := os.Remove(filepath.Join(q.location, deliveryID+fileSuffix)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Dangling files should be removing during load.
|
||||
q = newTestQueueDir(t, &dt, q.location)
|
||||
q.Close()
|
||||
|
||||
// Nothing should be left.
|
||||
checkQueueDir(t, q, []string{})
|
||||
}
|
||||
|
||||
t.Run("NoMeta", func(t *testing.T) {
|
||||
t.Skip("Not implemented")
|
||||
test(t, ".meta")
|
||||
})
|
||||
t.Run("NoBody", func(t *testing.T) {
|
||||
test(t, ".body")
|
||||
})
|
||||
t.Run("NoHeader", func(t *testing.T) {
|
||||
test(t, ".header")
|
||||
})
|
||||
}
|
|
@ -1,4 +1,11 @@
|
|||
package maddy
|
||||
// Package sql implements SQL-based storage module
|
||||
// using go-imap-sql library (github.com/foxcpp/go-imap-sql).
|
||||
//
|
||||
// Interfaces implemented:
|
||||
// - module.StorageBackend
|
||||
// - module.AuthProvider
|
||||
// - module.DeliveryTarget
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -16,8 +23,11 @@ import (
|
|||
"github.com/emersion/go-smtp"
|
||||
imapsql "github.com/foxcpp/go-imap-sql"
|
||||
"github.com/foxcpp/go-imap-sql/fsstore"
|
||||
"github.com/foxcpp/maddy/address"
|
||||
"github.com/foxcpp/maddy/auth"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/dns"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
|
||||
|
@ -25,7 +35,7 @@ import (
|
|||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type SQLStorage struct {
|
||||
type Storage struct {
|
||||
back *imapsql.Backend
|
||||
instName string
|
||||
Log log.Logger
|
||||
|
@ -35,11 +45,11 @@ type SQLStorage struct {
|
|||
authDomains []string
|
||||
junkMbox string
|
||||
|
||||
resolver Resolver
|
||||
resolver dns.Resolver
|
||||
}
|
||||
|
||||
type sqlDelivery struct {
|
||||
sqlm *SQLStorage
|
||||
type delivery struct {
|
||||
store *Storage
|
||||
msgMeta *module.MsgMetadata
|
||||
d *imapsql.Delivery
|
||||
mailFrom string
|
||||
|
@ -47,7 +57,7 @@ type sqlDelivery struct {
|
|||
addedRcpts map[string]struct{}
|
||||
}
|
||||
|
||||
func LookupAddr(r Resolver, ip net.IP) (string, error) {
|
||||
func LookupAddr(r dns.Resolver, ip net.IP) (string, error) {
|
||||
names, err := r.LookupAddr(context.Background(), ip.String())
|
||||
if err != nil || len(names) == 0 {
|
||||
return "", err
|
||||
|
@ -55,7 +65,11 @@ func LookupAddr(r Resolver, ip net.IP) (string, error) {
|
|||
return strings.TrimRight(names[0], "."), nil
|
||||
}
|
||||
|
||||
func generateReceived(r Resolver, msgMeta *module.MsgMetadata, mailFrom, rcptTo string) string {
|
||||
func sanitizeString(raw string) string {
|
||||
return strings.Replace(raw, "\n", "", -1)
|
||||
}
|
||||
|
||||
func generateReceived(r dns.Resolver, msgMeta *module.MsgMetadata, mailFrom, rcptTo string) string {
|
||||
var received string
|
||||
if !msgMeta.DontTraceSender {
|
||||
received += "from " + msgMeta.SrcHostname
|
||||
|
@ -74,15 +88,15 @@ func generateReceived(r Resolver, msgMeta *module.MsgMetadata, mailFrom, rcptTo
|
|||
return received
|
||||
}
|
||||
|
||||
func (sd *sqlDelivery) AddRcpt(rcptTo string) error {
|
||||
func (d *delivery) AddRcpt(rcptTo string) error {
|
||||
var accountName string
|
||||
// Side note: <postmaster> address will be always accepted
|
||||
// and delivered to "postmaster" account for both cases.
|
||||
if sd.sqlm.storagePerDomain {
|
||||
if d.store.storagePerDomain {
|
||||
accountName = rcptTo
|
||||
} else {
|
||||
var err error
|
||||
accountName, _, err = splitAddress(rcptTo)
|
||||
accountName, _, err = address.Split(rcptTo)
|
||||
if err != nil {
|
||||
return &smtp.SMTPError{
|
||||
Code: 501,
|
||||
|
@ -93,7 +107,7 @@ func (sd *sqlDelivery) AddRcpt(rcptTo string) error {
|
|||
}
|
||||
|
||||
accountName = strings.ToLower(accountName)
|
||||
if _, ok := sd.addedRcpts[accountName]; ok {
|
||||
if _, ok := d.addedRcpts[accountName]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -102,9 +116,9 @@ func (sd *sqlDelivery) AddRcpt(rcptTo string) error {
|
|||
// with small amount of per-recipient data in a efficient way.
|
||||
userHeader := textproto.Header{}
|
||||
userHeader.Add("Delivered-To", rcptTo)
|
||||
userHeader.Add("Received", generateReceived(sd.sqlm.resolver, sd.msgMeta, sd.mailFrom, rcptTo))
|
||||
userHeader.Add("Received", generateReceived(d.store.resolver, d.msgMeta, d.mailFrom, rcptTo))
|
||||
|
||||
if err := sd.d.AddRcpt(strings.ToLower(accountName), userHeader); err != nil {
|
||||
if err := d.d.AddRcpt(strings.ToLower(accountName), userHeader); err != nil {
|
||||
if err == imapsql.ErrUserDoesntExists || err == backend.ErrNoSuchMailbox {
|
||||
return &smtp.SMTPError{
|
||||
Code: 550,
|
||||
|
@ -115,37 +129,37 @@ func (sd *sqlDelivery) AddRcpt(rcptTo string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
sd.addedRcpts[accountName] = struct{}{}
|
||||
d.addedRcpts[accountName] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sd *sqlDelivery) Body(header textproto.Header, body buffer.Buffer) error {
|
||||
if sd.msgMeta.Quarantine {
|
||||
if err := sd.d.SpecialMailbox(specialuse.Junk, sd.sqlm.junkMbox); err != nil {
|
||||
func (d *delivery) Body(header textproto.Header, body buffer.Buffer) error {
|
||||
if d.msgMeta.Quarantine {
|
||||
if err := d.d.SpecialMailbox(specialuse.Junk, d.store.junkMbox); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
header = header.Copy()
|
||||
header.Add("Return-Path", "<"+sanitizeString(sd.mailFrom)+">")
|
||||
return sd.d.BodyParsed(header, sd.msgMeta.BodyLength, body)
|
||||
header.Add("Return-Path", "<"+sanitizeString(d.mailFrom)+">")
|
||||
return d.d.BodyParsed(header, d.msgMeta.BodyLength, body)
|
||||
}
|
||||
|
||||
func (sd *sqlDelivery) Abort() error {
|
||||
return sd.d.Abort()
|
||||
func (d *delivery) Abort() error {
|
||||
return d.d.Abort()
|
||||
}
|
||||
|
||||
func (sd *sqlDelivery) Commit() error {
|
||||
return sd.d.Commit()
|
||||
func (d *delivery) Commit() error {
|
||||
return d.d.Commit()
|
||||
}
|
||||
|
||||
func (sqlm *SQLStorage) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
d, err := sqlm.back.StartDelivery()
|
||||
func (store *Storage) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
d, err := store.back.StartDelivery()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sqlDelivery{
|
||||
sqlm: sqlm,
|
||||
return &delivery{
|
||||
store: store,
|
||||
msgMeta: msgMeta,
|
||||
d: d,
|
||||
mailFrom: mailFrom,
|
||||
|
@ -153,23 +167,23 @@ func (sqlm *SQLStorage) Start(msgMeta *module.MsgMetadata, mailFrom string) (mod
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (sqlm *SQLStorage) Name() string {
|
||||
func (store *Storage) Name() string {
|
||||
return "sql"
|
||||
}
|
||||
|
||||
func (sqlm *SQLStorage) InstanceName() string {
|
||||
return sqlm.instName
|
||||
func (store *Storage) InstanceName() string {
|
||||
return store.instName
|
||||
}
|
||||
|
||||
func NewSQLStorage(_, instName string, _ []string) (module.Module, error) {
|
||||
return &SQLStorage{
|
||||
func New(_, instName string, _ []string) (module.Module, error) {
|
||||
return &Storage{
|
||||
instName: instName,
|
||||
Log: log.Logger{Name: "sql"},
|
||||
resolver: net.DefaultResolver,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (sqlm *SQLStorage) Init(cfg *config.Map) error {
|
||||
func (store *Storage) Init(cfg *config.Map) error {
|
||||
var driver, dsn string
|
||||
var fsstoreLocation string
|
||||
appendlimitVal := int64(-1)
|
||||
|
@ -182,24 +196,24 @@ func (sqlm *SQLStorage) Init(cfg *config.Map) error {
|
|||
cfg.String("driver", false, true, "", &driver)
|
||||
cfg.String("dsn", false, true, "", &dsn)
|
||||
cfg.Int64("appendlimit", false, false, 32*1024*1024, &appendlimitVal)
|
||||
cfg.Bool("debug", true, &sqlm.Log.Debug)
|
||||
cfg.Bool("storage_perdomain", true, &sqlm.storagePerDomain)
|
||||
cfg.Bool("auth_perdomain", true, &sqlm.authPerDomain)
|
||||
cfg.StringList("auth_domains", true, false, nil, &sqlm.authDomains)
|
||||
cfg.Bool("debug", true, &store.Log.Debug)
|
||||
cfg.Bool("storage_perdomain", true, &store.storagePerDomain)
|
||||
cfg.Bool("auth_perdomain", true, &store.authPerDomain)
|
||||
cfg.StringList("auth_domains", true, false, nil, &store.authDomains)
|
||||
cfg.Int("sqlite3_cache_size", false, false, 0, &opts.CacheSize)
|
||||
cfg.Int("sqlite3_busy_timeout", false, false, 0, &opts.BusyTimeout)
|
||||
cfg.Bool("sqlite3_exclusive_lock", false, &opts.ExclusiveLock)
|
||||
cfg.String("junk_mailbox", false, false, "Junk", &sqlm.junkMbox)
|
||||
cfg.String("junk_mailbox", false, false, "Junk", &store.junkMbox)
|
||||
|
||||
cfg.Custom("fsstore", false, false, func() (interface{}, error) {
|
||||
return "", nil
|
||||
}, func(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
switch len(node.Args) {
|
||||
case 0:
|
||||
if sqlm.instName == "" {
|
||||
if store.instName == "" {
|
||||
return nil, errors.New("sql: need explicit fsstore location for inline definition")
|
||||
}
|
||||
return filepath.Join(StateDirectory(cfg.Globals), "sql-"+sqlm.instName+"-fsstore"), nil
|
||||
return filepath.Join(config.StateDirectory(cfg.Globals), "sql-"+store.instName+"-fsstore"), nil
|
||||
case 1:
|
||||
return node.Args[0], nil
|
||||
default:
|
||||
|
@ -211,15 +225,15 @@ func (sqlm *SQLStorage) Init(cfg *config.Map) error {
|
|||
return err
|
||||
}
|
||||
|
||||
opts.Log = &sqlm.Log
|
||||
opts.Log = &store.Log
|
||||
|
||||
if sqlm.authPerDomain && sqlm.authDomains == nil {
|
||||
if store.authPerDomain && store.authDomains == nil {
|
||||
return errors.New("sql: auth_domains must be set if auth_perdomain is used")
|
||||
}
|
||||
|
||||
if fsstoreLocation != "" {
|
||||
if !filepath.IsAbs(fsstoreLocation) {
|
||||
fsstoreLocation = filepath.Join(StateDirectory(cfg.Globals), fsstoreLocation)
|
||||
fsstoreLocation = filepath.Join(config.StateDirectory(cfg.Globals), fsstoreLocation)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(fsstoreLocation, os.ModeDir|os.ModePerm); err != nil {
|
||||
|
@ -235,40 +249,40 @@ func (sqlm *SQLStorage) Init(cfg *config.Map) error {
|
|||
*opts.MaxMsgBytes = uint32(appendlimitVal)
|
||||
}
|
||||
var err error
|
||||
sqlm.back, err = imapsql.New(driver, dsn, opts)
|
||||
store.back, err = imapsql.New(driver, dsn, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sql: %s", err)
|
||||
}
|
||||
|
||||
sqlm.Log.Debugln("go-imap-sql version", imapsql.VersionStr)
|
||||
store.Log.Debugln("go-imap-sql version", imapsql.VersionStr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sqlm *SQLStorage) IMAPExtensions() []string {
|
||||
func (store *Storage) IMAPExtensions() []string {
|
||||
return []string{"APPENDLIMIT", "MOVE", "CHILDREN"}
|
||||
}
|
||||
|
||||
func (sqlm *SQLStorage) Updates() <-chan backend.Update {
|
||||
return sqlm.back.Updates()
|
||||
func (store *Storage) Updates() <-chan backend.Update {
|
||||
return store.back.Updates()
|
||||
}
|
||||
|
||||
func (sqlm *SQLStorage) EnableChildrenExt() bool {
|
||||
return sqlm.back.EnableChildrenExt()
|
||||
func (store *Storage) EnableChildrenExt() bool {
|
||||
return store.back.EnableChildrenExt()
|
||||
}
|
||||
|
||||
func (sqlm *SQLStorage) CheckPlain(username, password string) bool {
|
||||
accountName, ok := checkDomainAuth(username, sqlm.authPerDomain, sqlm.authDomains)
|
||||
func (store *Storage) CheckPlain(username, password string) bool {
|
||||
accountName, ok := auth.CheckDomainAuth(username, store.authPerDomain, store.authDomains)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return sqlm.back.CheckPlain(accountName, password)
|
||||
return store.back.CheckPlain(accountName, password)
|
||||
}
|
||||
|
||||
func (sqlm *SQLStorage) GetOrCreateUser(username string) (backend.User, error) {
|
||||
func (store *Storage) GetOrCreateUser(username string) (backend.User, error) {
|
||||
var accountName string
|
||||
if sqlm.storagePerDomain {
|
||||
if store.storagePerDomain {
|
||||
if !strings.Contains(username, "@") {
|
||||
return nil, errors.New("GetOrCreateUser: username@domain required")
|
||||
}
|
||||
|
@ -278,9 +292,9 @@ func (sqlm *SQLStorage) GetOrCreateUser(username string) (backend.User, error) {
|
|||
accountName = parts[0]
|
||||
}
|
||||
|
||||
return sqlm.back.GetOrCreateUser(accountName)
|
||||
return store.back.GetOrCreateUser(accountName)
|
||||
}
|
||||
|
||||
func init() {
|
||||
module.Register("sql", NewSQLStorage)
|
||||
module.Register("sql", New)
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
// +build !nosqlite3,cgo
|
||||
|
||||
package maddy
|
||||
package sql
|
||||
|
||||
import _ "github.com/mattn/go-sqlite3"
|
23
target/delivery.go
Normal file
23
target/delivery.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package target
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
func DeliveryLogger(l log.Logger, msgMeta *module.MsgMetadata) log.Logger {
|
||||
out := l.Out
|
||||
if out == nil {
|
||||
out = log.DefaultLogger.Out
|
||||
}
|
||||
|
||||
return log.Logger{
|
||||
Out: func(t time.Time, debug bool, str string) {
|
||||
out(t, debug, str+" (msg ID = "+msgMeta.ID+")")
|
||||
},
|
||||
Name: l.Name,
|
||||
Debug: l.Debug,
|
||||
}
|
||||
}
|
|
@ -1,4 +1,9 @@
|
|||
package maddy
|
||||
// Package remote implements module which does outgoing
|
||||
// message delivery using servers discovered using DNS MX records.
|
||||
//
|
||||
// Implemented interfaces:
|
||||
// - module.DeliveryTarget
|
||||
package remote
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -7,7 +12,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
nettextproto "net/textproto"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
@ -16,21 +20,23 @@ import (
|
|||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/maddy/address"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/dns"
|
||||
"github.com/foxcpp/maddy/log"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
"github.com/foxcpp/maddy/mtasts"
|
||||
"github.com/foxcpp/maddy/target"
|
||||
"github.com/foxcpp/maddy/target/queue"
|
||||
)
|
||||
|
||||
var ErrTLSRequired = errors.New("TLS is required for outgoing connections but target server doesn't support STARTTLS")
|
||||
|
||||
type RemoteTarget struct {
|
||||
type Target struct {
|
||||
name string
|
||||
hostname string
|
||||
requireTLS bool
|
||||
|
||||
resolver Resolver
|
||||
resolver dns.Resolver
|
||||
|
||||
mtastsCache mtasts.Cache
|
||||
stsCacheUpdateTick *time.Ticker
|
||||
|
@ -39,10 +45,10 @@ type RemoteTarget struct {
|
|||
Log log.Logger
|
||||
}
|
||||
|
||||
var _ module.DeliveryTarget = &RemoteTarget{}
|
||||
var _ module.DeliveryTarget = &Target{}
|
||||
|
||||
func NewRemoteTarget(_, instName string, _ []string) (module.Module, error) {
|
||||
return &RemoteTarget{
|
||||
func New(_, instName string, _ []string) (module.Module, error) {
|
||||
return &Target{
|
||||
name: instName,
|
||||
resolver: net.DefaultResolver,
|
||||
mtastsCache: mtasts.Cache{Resolver: net.DefaultResolver},
|
||||
|
@ -52,9 +58,9 @@ func NewRemoteTarget(_, instName string, _ []string) (module.Module, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (rt *RemoteTarget) Init(cfg *config.Map) error {
|
||||
func (rt *Target) Init(cfg *config.Map) error {
|
||||
cfg.String("hostname", true, true, "", &rt.hostname)
|
||||
cfg.String("mtasts_cache", false, false, filepath.Join(StateDirectory(cfg.Globals), "mtasts-cache"), &rt.mtastsCache.Location)
|
||||
cfg.String("mtasts_cache", false, false, filepath.Join(config.StateDirectory(cfg.Globals), "mtasts-cache"), &rt.mtastsCache.Location)
|
||||
cfg.Bool("debug", true, &rt.Log.Debug)
|
||||
cfg.Bool("require_tls", false, &rt.requireTLS)
|
||||
if _, err := cfg.Process(); err != nil {
|
||||
|
@ -62,7 +68,7 @@ func (rt *RemoteTarget) Init(cfg *config.Map) error {
|
|||
}
|
||||
|
||||
if !filepath.IsAbs(rt.mtastsCache.Location) {
|
||||
rt.mtastsCache.Location = filepath.Join(StateDirectory(cfg.Globals), rt.mtastsCache.Location)
|
||||
rt.mtastsCache.Location = filepath.Join(config.StateDirectory(cfg.Globals), rt.mtastsCache.Location)
|
||||
}
|
||||
if err := os.MkdirAll(rt.mtastsCache.Location, os.ModePerm); err != nil {
|
||||
return err
|
||||
|
@ -76,17 +82,17 @@ func (rt *RemoteTarget) Init(cfg *config.Map) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (rt *RemoteTarget) Close() error {
|
||||
func (rt *Target) Close() error {
|
||||
rt.stsCacheUpdateDone <- struct{}{}
|
||||
<-rt.stsCacheUpdateDone
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rt *RemoteTarget) Name() string {
|
||||
func (rt *Target) Name() string {
|
||||
return "remote"
|
||||
}
|
||||
|
||||
func (rt *RemoteTarget) InstanceName() string {
|
||||
func (rt *Target) InstanceName() string {
|
||||
return rt.name
|
||||
}
|
||||
|
||||
|
@ -97,7 +103,7 @@ type remoteConnection struct {
|
|||
}
|
||||
|
||||
type remoteDelivery struct {
|
||||
rt *RemoteTarget
|
||||
rt *Target
|
||||
mailFrom string
|
||||
msgMeta *module.MsgMetadata
|
||||
Log log.Logger
|
||||
|
@ -105,18 +111,18 @@ type remoteDelivery struct {
|
|||
connections map[string]*remoteConnection
|
||||
}
|
||||
|
||||
func (rt *RemoteTarget) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
func (rt *Target) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
return &remoteDelivery{
|
||||
rt: rt,
|
||||
mailFrom: mailFrom,
|
||||
msgMeta: msgMeta,
|
||||
Log: deliveryLogger(rt.Log, msgMeta),
|
||||
Log: target.DeliveryLogger(rt.Log, msgMeta),
|
||||
connections: map[string]*remoteConnection{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rd *remoteDelivery) AddRcpt(to string) error {
|
||||
_, domain, err := splitAddress(to)
|
||||
_, domain, err := address.Split(to)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -187,13 +193,13 @@ func (rd *remoteDelivery) Body(header textproto.Header, b buffer.Buffer) error {
|
|||
|
||||
// TODO: Report partial errors early for LMTP. See github.com/emersion/go-smtp/pull/56
|
||||
|
||||
partialErr := PartialError{
|
||||
partialErr := queue.PartialError{
|
||||
Errs: map[string]error{},
|
||||
}
|
||||
for domain, conn := range rd.connections {
|
||||
err := <-errChans[domain]
|
||||
if err != nil {
|
||||
if isTemporaryErr(err) {
|
||||
if target.IsTemporaryErr(err) {
|
||||
partialErr.TemporaryFailed = append(partialErr.TemporaryFailed, conn.recipients...)
|
||||
} else {
|
||||
partialErr.Failed = append(partialErr.Failed, conn.recipients...)
|
||||
|
@ -280,7 +286,7 @@ func (rd *remoteDelivery) connectionForDomain(domain string) (*remoteConnection,
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
func (rt *RemoteTarget) getSTSPolicy(domain string) (*mtasts.Policy, error) {
|
||||
func (rt *Target) getSTSPolicy(domain string) (*mtasts.Policy, error) {
|
||||
stsPolicy, err := rt.mtastsCache.Get(domain)
|
||||
if err != nil && err != mtasts.ErrNoPolicy {
|
||||
rt.Log.Printf("failed to fetch MTA-STS policy for %s: %v", domain, err)
|
||||
|
@ -298,7 +304,7 @@ func (rt *RemoteTarget) getSTSPolicy(domain string) (*mtasts.Policy, error) {
|
|||
|
||||
var ErrNoMXMatchedBySTS = errors.New("remote: no MX record matched MTA-STS policy")
|
||||
|
||||
func (rt *RemoteTarget) stsCacheUpdater() {
|
||||
func (rt *Target) stsCacheUpdater() {
|
||||
// Always update cache on start-up since we may have been down for some
|
||||
// time.
|
||||
rt.Log.Debugln("updating MTA-STS cache...")
|
||||
|
@ -339,13 +345,13 @@ func connectToServer(ourHostname, address string, requireTLS bool) (*smtp.Client
|
|||
return nil, err
|
||||
}
|
||||
} else if requireTLS {
|
||||
return nil, ErrTLSRequired
|
||||
return nil, target.ErrTLSRequired
|
||||
}
|
||||
|
||||
return cl, nil
|
||||
}
|
||||
|
||||
func (rt *RemoteTarget) lookupTargetServers(domain string) ([]string, error) {
|
||||
func (rt *Target) lookupTargetServers(domain string) ([]string, error) {
|
||||
records, err := rt.resolver.LookupMX(context.Background(), domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -362,25 +368,6 @@ func (rt *RemoteTarget) lookupTargetServers(domain string) ([]string, error) {
|
|||
return hosts, nil
|
||||
}
|
||||
|
||||
func isTemporaryErr(err error) bool {
|
||||
if protoErr, ok := err.(*nettextproto.Error); ok {
|
||||
return (protoErr.Code / 100) == 4
|
||||
}
|
||||
if smtpErr, ok := err.(*smtp.SMTPError); ok {
|
||||
return (smtpErr.Code / 100) == 4
|
||||
}
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
return netErr.Temporary()
|
||||
}
|
||||
|
||||
if err == ErrTLSRequired {
|
||||
return false
|
||||
}
|
||||
|
||||
// Connection error? Assume it is temporary.
|
||||
return true
|
||||
}
|
||||
|
||||
func init() {
|
||||
module.Register("remote", NewRemoteTarget)
|
||||
module.Register("remote", New)
|
||||
}
|
30
target/temporaryerr.go
Normal file
30
target/temporaryerr.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package target
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/textproto"
|
||||
|
||||
"github.com/emersion/go-smtp"
|
||||
)
|
||||
|
||||
func IsTemporaryErr(err error) bool {
|
||||
if protoErr, ok := err.(*textproto.Error); ok {
|
||||
return (protoErr.Code / 100) == 4
|
||||
}
|
||||
if smtpErr, ok := err.(*smtp.SMTPError); ok {
|
||||
return (smtpErr.Code / 100) == 4
|
||||
}
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
return netErr.Temporary()
|
||||
}
|
||||
|
||||
if err == ErrTLSRequired {
|
||||
return false
|
||||
}
|
||||
|
||||
// Connection error? Assume it is temporary.
|
||||
return true
|
||||
}
|
||||
|
||||
var ErrTLSRequired = errors.New("TLS is required for outgoing connections but target server doesn't support STARTTLS")
|
65
testutils/check.go
Normal file
65
testutils/check.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package testutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/foxcpp/maddy/buffer"
|
||||
"github.com/foxcpp/maddy/config"
|
||||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
type Check struct {
|
||||
ConnRes module.CheckResult
|
||||
SenderRes module.CheckResult
|
||||
RcptRes module.CheckResult
|
||||
BodyRes module.CheckResult
|
||||
}
|
||||
|
||||
func (c *Check) NewMessage(msgMeta *module.MsgMetadata) (module.CheckState, error) {
|
||||
return &checkState{msgMeta, c}, nil
|
||||
}
|
||||
|
||||
func (c *Check) Init(*config.Map) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Check) Name() string {
|
||||
return "test_check"
|
||||
}
|
||||
|
||||
func (c *Check) InstanceName() string {
|
||||
return "test_check"
|
||||
}
|
||||
|
||||
type checkState struct {
|
||||
msgMeta *module.MsgMetadata
|
||||
check *Check
|
||||
}
|
||||
|
||||
func (cs *checkState) CheckConnection(ctx context.Context) module.CheckResult {
|
||||
return cs.check.ConnRes
|
||||
}
|
||||
|
||||
func (cs *checkState) CheckSender(ctx context.Context, from string) module.CheckResult {
|
||||
return cs.check.SenderRes
|
||||
}
|
||||
|
||||
func (cs *checkState) CheckRcpt(ctx context.Context, to string) module.CheckResult {
|
||||
return cs.check.RcptRes
|
||||
}
|
||||
|
||||
func (cs *checkState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
|
||||
return cs.check.BodyRes
|
||||
}
|
||||
|
||||
func (cs *checkState) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
module.Register("test_check", func(modName, instanceName string, aliases []string) (module.Module, error) {
|
||||
return &Check{}, nil
|
||||
})
|
||||
module.RegisterInstance(&Check{}, nil)
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
@ -8,7 +8,7 @@ import (
|
|||
"github.com/foxcpp/maddy/log"
|
||||
)
|
||||
|
||||
func testLogger(t *testing.T, name string) log.Logger {
|
||||
func Logger(t *testing.T, name string) log.Logger {
|
||||
if testing.Verbose() {
|
||||
return log.Logger{
|
||||
Out: func(_ time.Time, _ bool, str string) {
|
|
@ -1,4 +1,4 @@
|
|||
package maddy
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
|
@ -14,76 +14,76 @@ import (
|
|||
"github.com/foxcpp/maddy/module"
|
||||
)
|
||||
|
||||
type msg struct {
|
||||
msgMeta *module.MsgMetadata
|
||||
mailFrom string
|
||||
rcptTo []string
|
||||
body []byte
|
||||
header textproto.Header
|
||||
type Msg struct {
|
||||
MsgMeta *module.MsgMetadata
|
||||
MailFrom string
|
||||
RcptTo []string
|
||||
Body []byte
|
||||
Header textproto.Header
|
||||
}
|
||||
|
||||
type testTarget struct {
|
||||
messages []msg
|
||||
type Target struct {
|
||||
Messages []Msg
|
||||
|
||||
startErr error
|
||||
rcptErr map[string]error
|
||||
bodyErr error
|
||||
abortErr error
|
||||
commitErr error
|
||||
StartErr error
|
||||
RcptErr map[string]error
|
||||
BodyErr error
|
||||
AbortErr error
|
||||
CommitErr error
|
||||
|
||||
instName string
|
||||
InstName string
|
||||
}
|
||||
|
||||
/*
|
||||
module.Module is implemented with dummy functions for logging done by Dispatcher code.
|
||||
*/
|
||||
|
||||
func (dt testTarget) Init(*config.Map) error {
|
||||
func (dt Target) Init(*config.Map) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dt testTarget) InstanceName() string {
|
||||
if dt.instName != "" {
|
||||
return dt.instName
|
||||
func (dt Target) InstanceName() string {
|
||||
if dt.InstName != "" {
|
||||
return dt.InstName
|
||||
}
|
||||
return "test_instance"
|
||||
}
|
||||
|
||||
func (dt testTarget) Name() string {
|
||||
func (dt Target) Name() string {
|
||||
return "test_target"
|
||||
}
|
||||
|
||||
type testTargetDelivery struct {
|
||||
msg msg
|
||||
tgt *testTarget
|
||||
msg Msg
|
||||
tgt *Target
|
||||
aborted bool
|
||||
committed bool
|
||||
}
|
||||
|
||||
func (dt *testTarget) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
func (dt *Target) Start(msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
return &testTargetDelivery{
|
||||
tgt: dt,
|
||||
msg: msg{msgMeta: msgMeta, mailFrom: mailFrom},
|
||||
}, dt.startErr
|
||||
msg: Msg{MsgMeta: msgMeta, MailFrom: mailFrom},
|
||||
}, dt.StartErr
|
||||
}
|
||||
|
||||
func (dtd *testTargetDelivery) AddRcpt(to string) error {
|
||||
if dtd.tgt.rcptErr != nil {
|
||||
if err := dtd.tgt.rcptErr[to]; err != nil {
|
||||
if dtd.tgt.RcptErr != nil {
|
||||
if err := dtd.tgt.RcptErr[to]; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dtd.msg.rcptTo = append(dtd.msg.rcptTo, to)
|
||||
dtd.msg.RcptTo = append(dtd.msg.RcptTo, to)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dtd *testTargetDelivery) Body(header textproto.Header, buf buffer.Buffer) error {
|
||||
if dtd.tgt.bodyErr != nil {
|
||||
return dtd.tgt.bodyErr
|
||||
if dtd.tgt.BodyErr != nil {
|
||||
return dtd.tgt.BodyErr
|
||||
}
|
||||
|
||||
dtd.msg.header = header
|
||||
dtd.msg.Header = header
|
||||
|
||||
body, err := buf.Open()
|
||||
if err != nil {
|
||||
|
@ -91,23 +91,23 @@ func (dtd *testTargetDelivery) Body(header textproto.Header, buf buffer.Buffer)
|
|||
}
|
||||
defer body.Close()
|
||||
|
||||
dtd.msg.body, err = ioutil.ReadAll(body)
|
||||
dtd.msg.Body, err = ioutil.ReadAll(body)
|
||||
return err
|
||||
}
|
||||
|
||||
func (dtd *testTargetDelivery) Abort() error {
|
||||
return dtd.tgt.abortErr
|
||||
return dtd.tgt.AbortErr
|
||||
}
|
||||
|
||||
func (dtd *testTargetDelivery) Commit() error {
|
||||
if dtd.tgt.commitErr != nil {
|
||||
return dtd.tgt.commitErr
|
||||
if dtd.tgt.CommitErr != nil {
|
||||
return dtd.tgt.CommitErr
|
||||
}
|
||||
dtd.tgt.messages = append(dtd.tgt.messages, dtd.msg)
|
||||
dtd.tgt.Messages = append(dtd.tgt.Messages, dtd.msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func doTestDelivery(t *testing.T, tgt module.DeliveryTarget, from string, to []string) string {
|
||||
func DoTestDelivery(t *testing.T, tgt module.DeliveryTarget, from string, to []string) string {
|
||||
t.Helper()
|
||||
|
||||
IDRaw := sha1.Sum([]byte(t.Name()))
|
||||
|
@ -137,7 +137,7 @@ func doTestDelivery(t *testing.T, tgt module.DeliveryTarget, from string, to []s
|
|||
return encodedID
|
||||
}
|
||||
|
||||
func doTestDeliveryErr(t *testing.T, tgt module.DeliveryTarget, from string, to []string) (string, error) {
|
||||
func DoTestDeliveryErr(t *testing.T, tgt module.DeliveryTarget, from string, to []string) (string, error) {
|
||||
t.Helper()
|
||||
|
||||
IDRaw := sha1.Sum([]byte(t.Name()))
|
||||
|
@ -169,37 +169,37 @@ func doTestDeliveryErr(t *testing.T, tgt module.DeliveryTarget, from string, to
|
|||
return encodedID, err
|
||||
}
|
||||
|
||||
func checkTestMessage(t *testing.T, tgt *testTarget, indx int, sender string, rcpt []string) {
|
||||
func CheckTestMessage(t *testing.T, tgt *Target, indx int, sender string, rcpt []string) {
|
||||
t.Helper()
|
||||
|
||||
if len(tgt.messages) <= indx {
|
||||
t.Errorf("wrong amount of messages received, want at least %d, got %d", indx+1, len(tgt.messages))
|
||||
if len(tgt.Messages) <= indx {
|
||||
t.Errorf("wrong amount of messages received, want at least %d, got %d", indx+1, len(tgt.Messages))
|
||||
return
|
||||
}
|
||||
msg := tgt.messages[indx]
|
||||
msg := tgt.Messages[indx]
|
||||
|
||||
checkMsg(t, &msg, sender, rcpt)
|
||||
CheckMsg(t, &msg, sender, rcpt)
|
||||
}
|
||||
|
||||
func checkMsg(t *testing.T, msg *msg, sender string, rcpt []string) {
|
||||
func CheckMsg(t *testing.T, msg *Msg, sender string, rcpt []string) {
|
||||
t.Helper()
|
||||
|
||||
idRaw := sha1.Sum([]byte(t.Name()))
|
||||
encodedId := hex.EncodeToString(idRaw[:])
|
||||
|
||||
if msg.msgMeta.ID != encodedId {
|
||||
t.Errorf("empty or wrong delivery context for passed message? %+v", msg.msgMeta)
|
||||
if msg.MsgMeta.ID != encodedId {
|
||||
t.Errorf("empty or wrong delivery context for passed message? %+v", msg.MsgMeta)
|
||||
}
|
||||
if msg.mailFrom != sender {
|
||||
t.Errorf("wrong sender, want %s, got %s", sender, msg.mailFrom)
|
||||
if msg.MailFrom != sender {
|
||||
t.Errorf("wrong sender, want %s, got %s", sender, msg.MailFrom)
|
||||
}
|
||||
|
||||
sort.Strings(rcpt)
|
||||
sort.Strings(msg.rcptTo)
|
||||
if !reflect.DeepEqual(msg.rcptTo, rcpt) {
|
||||
t.Errorf("wrong recipients, want %v, got %v", rcpt, msg.rcptTo)
|
||||
sort.Strings(msg.RcptTo)
|
||||
if !reflect.DeepEqual(msg.RcptTo, rcpt) {
|
||||
t.Errorf("wrong recipients, want %v, got %v", rcpt, msg.RcptTo)
|
||||
}
|
||||
if string(msg.body) != "foobar" {
|
||||
t.Errorf("wrong body, want '%s', got '%s'", "foobar", string(msg.body))
|
||||
if string(msg.Body) != "foobar" {
|
||||
t.Errorf("wrong body, want '%s', got '%s'", "foobar", string(msg.Body))
|
||||
}
|
||||
}
|
115
timewheel.go
115
timewheel.go
|
@ -1,115 +0,0 @@
|
|||
package maddy
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TimeSlot struct {
|
||||
Time time.Time
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
type TimeWheel struct {
|
||||
slots *list.List
|
||||
slotsLock sync.Mutex
|
||||
|
||||
updateNotify chan time.Time
|
||||
stopNotify chan struct{}
|
||||
|
||||
dispatch chan TimeSlot
|
||||
}
|
||||
|
||||
func NewTimeWheel() *TimeWheel {
|
||||
tw := &TimeWheel{
|
||||
slots: list.New(),
|
||||
stopNotify: make(chan struct{}),
|
||||
updateNotify: make(chan time.Time),
|
||||
dispatch: make(chan TimeSlot, 10),
|
||||
}
|
||||
go tw.tick()
|
||||
return tw
|
||||
}
|
||||
|
||||
func (tw *TimeWheel) Add(target time.Time, value interface{}) {
|
||||
if value == nil {
|
||||
panic("can't insert nil objects into TimeWheel queue")
|
||||
}
|
||||
|
||||
tw.slotsLock.Lock()
|
||||
tw.slots.PushBack(TimeSlot{Time: target, Value: value})
|
||||
tw.slotsLock.Unlock()
|
||||
|
||||
tw.updateNotify <- target
|
||||
}
|
||||
|
||||
func (tw *TimeWheel) Close() {
|
||||
tw.stopNotify <- struct{}{}
|
||||
<-tw.stopNotify
|
||||
|
||||
close(tw.updateNotify)
|
||||
close(tw.dispatch)
|
||||
}
|
||||
|
||||
func (tw *TimeWheel) tick() {
|
||||
for {
|
||||
now := time.Now()
|
||||
// Look for list element closest to now.
|
||||
tw.slotsLock.Lock()
|
||||
var closestSlot TimeSlot
|
||||
var closestEl *list.Element
|
||||
for e := tw.slots.Front(); e != nil; e = e.Next() {
|
||||
slot := e.Value.(TimeSlot)
|
||||
if slot.Time.Sub(now) < closestSlot.Time.Sub(now) || closestSlot.Value == nil {
|
||||
closestSlot = slot
|
||||
closestEl = e
|
||||
}
|
||||
}
|
||||
tw.slotsLock.Unlock()
|
||||
// Only this goroutine removes elements from TimeWheel so we can be safe using closestSlot.
|
||||
|
||||
// Queue is empty. Just wait until update.
|
||||
if closestEl == nil {
|
||||
select {
|
||||
case <-tw.updateNotify:
|
||||
continue
|
||||
case <-tw.stopNotify:
|
||||
tw.stopNotify <- struct{}{}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
timer := time.NewTimer(closestSlot.Time.Sub(now))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
tw.slotsLock.Lock()
|
||||
tw.slots.Remove(closestEl)
|
||||
tw.slotsLock.Unlock()
|
||||
tw.dispatch <- closestSlot
|
||||
|
||||
// break inside of select exits select, not for loop
|
||||
goto breakinnerloop
|
||||
case newTarget := <-tw.updateNotify:
|
||||
// Avoid unnecessary restarts if new target is not going to affect our
|
||||
// current wait time.
|
||||
if closestSlot.Time.Sub(now) <= newTarget.Sub(now) {
|
||||
continue
|
||||
}
|
||||
|
||||
timer.Stop()
|
||||
// Recalculate new slot time.
|
||||
case <-tw.stopNotify:
|
||||
tw.stopNotify <- struct{}{}
|
||||
return
|
||||
}
|
||||
}
|
||||
breakinnerloop:
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TimeWheel) Dispatch() <-chan TimeSlot {
|
||||
return tw.dispatch
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue