Instrument the SMTP code using runtime/trace

runtime/trace together with 'go tool trace' provides extremely powerful
tooling for performance (latency) analysis. Since maddy prides itself on
being "optimized for concurrency", it is a good idea to actually live up
to this promise.

Closes #144. No need to reinvent the wheel. The original issue
proposed a solution to use in production to detect "performance
anomalies", it is possible to use runtime/trace in production too, but
the corresponding flag to enable profiler endpoint is hidden behind the
'debugflags' build tag at the moment.

For SMTP code, the basic latency information can be obtained from
regular logs since they include timestamps with millisecond granularity.
After the issue is apparent, it is possible to deploy the server
executable compiled with tracing support and obtain more information

... Also add missing context.Context arguments to smtpconn.C.
This commit is contained in:
fox.cpp 2019-12-09 23:11:39 +03:00
parent 305fdddf24
commit c4ea9a730f
No known key found for this signature in database
GPG key ID: E76D97CCEDE90B6C
20 changed files with 281 additions and 135 deletions

View file

@ -11,6 +11,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"regexp" "regexp"
"runtime/trace"
"strconv" "strconv"
"strings" "strings"
@ -300,6 +301,9 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
return module.CheckResult{} return module.CheckResult{}
} }
// TODO: It is not possible to distinguish different commands.
defer trace.StartRegion(ctx, "command/CheckConnection").End()
cmdName, cmdArgs := s.expandCommand("") cmdName, cmdArgs := s.expandCommand("")
return s.run(cmdName, cmdArgs, bytes.NewReader(nil)) return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
} }
@ -311,6 +315,8 @@ func (s *state) CheckSender(ctx context.Context, addr string) module.CheckResult
return module.CheckResult{} return module.CheckResult{}
} }
defer trace.StartRegion(ctx, "command/CheckSender").End()
cmdName, cmdArgs := s.expandCommand(addr) cmdName, cmdArgs := s.expandCommand(addr)
return s.run(cmdName, cmdArgs, bytes.NewReader(nil)) return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
} }
@ -321,6 +327,7 @@ func (s *state) CheckRcpt(ctx context.Context, addr string) module.CheckResult {
if s.c.stage != StageRcpt { if s.c.stage != StageRcpt {
return module.CheckResult{} return module.CheckResult{}
} }
defer trace.StartRegion(ctx, "command/CheckRcpt").End()
cmdName, cmdArgs := s.expandCommand(addr) cmdName, cmdArgs := s.expandCommand(addr)
return s.run(cmdName, cmdArgs, bytes.NewReader(nil)) return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
@ -331,6 +338,8 @@ func (s *state) CheckBody(ctx context.Context, hdr textproto.Header, body buffer
return module.CheckResult{} return module.CheckResult{}
} }
defer trace.StartRegion(ctx, "command/CheckBody").End()
cmdName, cmdArgs := s.expandCommand("") cmdName, cmdArgs := s.expandCommand("")
var buf bytes.Buffer var buf bytes.Buffer

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"io" "io"
nettextproto "net/textproto" nettextproto "net/textproto"
"runtime/trace"
"strings" "strings"
"github.com/emersion/go-message/textproto" "github.com/emersion/go-message/textproto"
@ -96,6 +97,8 @@ func (d *dkimCheckState) CheckRcpt(ctx context.Context, rcptTo string) module.Ch
} }
func (d *dkimCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult { func (d *dkimCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
defer trace.StartRegion(ctx, "verify_dkim/CheckBody").End()
if !header.Has("DKIM-Signature") { if !header.Has("DKIM-Signature") {
if d.c.noSigAction.Reject || d.c.noSigAction.Quarantine { if d.c.noSigAction.Reject || d.c.noSigAction.Quarantine {
d.log.Printf("no signatures present") d.log.Printf("no signatures present")

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net" "net"
"runtime/trace"
"strings" "strings"
"github.com/emersion/go-message/textproto" "github.com/emersion/go-message/textproto"
@ -323,6 +324,8 @@ func (bl *DNSBL) CheckConnection(ctx context.Context, state *smtp.ConnectionStat
return nil return nil
} }
defer trace.StartRegion(ctx, "dnsbl/CheckConnection (Early)").End()
ip, ok := state.RemoteAddr.(*net.TCPAddr) ip, ok := state.RemoteAddr.(*net.TCPAddr)
if !ok { if !ok {
bl.log.Msg("non-TCP/IP source", bl.log.Msg("non-TCP/IP source",
@ -358,6 +361,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
return module.CheckResult{} return module.CheckResult{}
} }
defer trace.StartRegion(ctx, "dnsbl/CheckConnection").End()
if s.msgMeta.Conn == nil { if s.msgMeta.Conn == nil {
s.log.Msg("locally generated message, ignoring") s.log.Msg("locally generated message, ignoring")
return module.CheckResult{} return module.CheckResult{}

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"runtime/debug" "runtime/debug"
"runtime/trace"
"blitiri.com.ar/go/spf" "blitiri.com.ar/go/spf"
"github.com/emersion/go-message/textproto" "github.com/emersion/go-message/textproto"
@ -267,6 +268,8 @@ func prepareMailFrom(from string) (string, error) {
} }
func (s *state) CheckConnection(ctx context.Context) module.CheckResult { func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
defer trace.StartRegion(ctx, "apply_spf/CheckConnection").End()
if s.msgMeta.Conn == nil { if s.msgMeta.Conn == nil {
s.log.Println("locally generated message, skipping") s.log.Println("locally generated message, skipping")
return module.CheckResult{} return module.CheckResult{}
@ -306,6 +309,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
} }
}() }()
defer trace.StartRegion(ctx, "apply_spf/CheckConnection (Async)").End()
res, err := spf.CheckHostWithSender(ip.IP, s.msgMeta.Conn.Hostname, mailFrom) res, err := spf.CheckHostWithSender(ip.IP, s.msgMeta.Conn.Hostname, mailFrom)
s.log.Debugf("result: %s (%v)", res, err) s.log.Debugf("result: %s (%v)", res, err)
s.spfFetch <- spfRes{res, err} s.spfFetch <- spfRes{res, err}
@ -328,6 +333,8 @@ func (s *state) CheckBody(ctx context.Context, header textproto.Header, body buf
return module.CheckResult{} return module.CheckResult{}
} }
defer trace.StartRegion(ctx, "apply_spf/CheckBody").End()
res, ok := <-s.spfFetch res, ok := <-s.spfFetch
if !ok { if !ok {
return module.CheckResult{ return module.CheckResult{

View file

@ -3,6 +3,7 @@ package check
import ( import (
"context" "context"
"fmt" "fmt"
"runtime/trace"
"github.com/emersion/go-message/textproto" "github.com/emersion/go-message/textproto"
"github.com/foxcpp/maddy/internal/buffer" "github.com/foxcpp/maddy/internal/buffer"
@ -57,12 +58,18 @@ type statelessCheckState struct {
msgMeta *module.MsgMetadata msgMeta *module.MsgMetadata
} }
func (s *statelessCheckState) String() string {
return s.c.modName + ":" + s.c.instName
}
func (s *statelessCheckState) CheckConnection(ctx context.Context) module.CheckResult { func (s *statelessCheckState) CheckConnection(ctx context.Context) module.CheckResult {
if s.c.connCheck == nil { if s.c.connCheck == nil {
return module.CheckResult{} return module.CheckResult{}
} }
defer trace.StartRegion(ctx, s.c.modName+"/CheckConnection").End()
originalRes := s.c.connCheck(StatelessCheckContext{ originalRes := s.c.connCheck(StatelessCheckContext{
Context: ctx,
Resolver: s.c.resolver, Resolver: s.c.resolver,
MsgMeta: s.msgMeta, MsgMeta: s.msgMeta,
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta), Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
@ -74,8 +81,10 @@ func (s *statelessCheckState) CheckSender(ctx context.Context, mailFrom string)
if s.c.senderCheck == nil { if s.c.senderCheck == nil {
return module.CheckResult{} return module.CheckResult{}
} }
defer trace.StartRegion(ctx, s.c.modName+"/CheckSender").End()
originalRes := s.c.senderCheck(StatelessCheckContext{ originalRes := s.c.senderCheck(StatelessCheckContext{
Context: ctx,
Resolver: s.c.resolver, Resolver: s.c.resolver,
MsgMeta: s.msgMeta, MsgMeta: s.msgMeta,
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta), Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
@ -87,8 +96,10 @@ func (s *statelessCheckState) CheckRcpt(ctx context.Context, rcptTo string) modu
if s.c.rcptCheck == nil { if s.c.rcptCheck == nil {
return module.CheckResult{} return module.CheckResult{}
} }
defer trace.StartRegion(ctx, s.c.modName+"/CheckRcpt").End()
originalRes := s.c.rcptCheck(StatelessCheckContext{ originalRes := s.c.rcptCheck(StatelessCheckContext{
Context: ctx,
Resolver: s.c.resolver, Resolver: s.c.resolver,
MsgMeta: s.msgMeta, MsgMeta: s.msgMeta,
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta), Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
@ -100,8 +111,10 @@ func (s *statelessCheckState) CheckBody(ctx context.Context, header textproto.He
if s.c.bodyCheck == nil { if s.c.bodyCheck == nil {
return module.CheckResult{} return module.CheckResult{}
} }
defer trace.StartRegion(ctx, s.c.modName+"/CheckBody").End()
originalRes := s.c.bodyCheck(StatelessCheckContext{ originalRes := s.c.bodyCheck(StatelessCheckContext{
Context: ctx,
Resolver: s.c.resolver, Resolver: s.c.resolver,
MsgMeta: s.msgMeta, MsgMeta: s.msgMeta,
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta), Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"math/rand" "math/rand"
"net" "net"
"runtime/trace"
"strings" "strings"
"github.com/emersion/go-message/textproto" "github.com/emersion/go-message/textproto"
@ -81,6 +82,8 @@ func (v *Verifier) FetchRecord(ctx context.Context, header textproto.Header) {
} }
}() }()
defer trace.StartRegion(ctx, "DMARC/FetchRecord").End()
policyDomain, record, err := FetchRecord(ctx, v.resolver, fromDomain) policyDomain, record, err := FetchRecord(ctx, v.resolver, fromDomain)
v.fetchCh <- verifyData{ v.fetchCh <- verifyData{
policyDomain: policyDomain, policyDomain: policyDomain,

View file

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"runtime/trace"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -46,6 +47,7 @@ type Session struct {
// msgCtx is not used for cancellation or timeouts, only for tracing. // msgCtx is not used for cancellation or timeouts, only for tracing.
// It is the subcontext of sessionCtx. // It is the subcontext of sessionCtx.
msgCtx context.Context msgCtx context.Context
msgTask *trace.Task
mailFrom string mailFrom string
opts smtp.MailOptions opts smtp.MailOptions
msgMeta *module.MsgMetadata msgMeta *module.MsgMetadata
@ -75,6 +77,7 @@ func (s *Session) abort(ctx context.Context) {
s.delivery = nil s.delivery = nil
s.deliveryErr = nil s.deliveryErr = nil
s.msgCtx = nil s.msgCtx = nil
s.msgTask.End()
} }
func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) { func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) {
@ -138,15 +141,20 @@ func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.Mail
) )
} }
delivery, err := s.endp.pipeline.Start(ctx, msgMeta, cleanFrom) s.msgCtx, s.msgTask = trace.NewTask(ctx, "Incoming Message")
mailCtx, mailTask := trace.NewTask(s.msgCtx, "MAIL FROM")
defer mailTask.End()
delivery, err := s.endp.pipeline.Start(mailCtx, msgMeta, cleanFrom)
if err != nil { if err != nil {
s.msgCtx = nil
s.msgTask.End()
return msgMeta.ID, err return msgMeta.ID, err
} }
s.msgMeta = msgMeta s.msgMeta = msgMeta
s.mailFrom = cleanFrom s.mailFrom = cleanFrom
s.delivery = delivery s.delivery = delivery
s.msgCtx = s.sessionCtx // TODO: Trace annotations.
return msgMeta.ID, nil return msgMeta.ID, nil
} }
@ -171,6 +179,8 @@ func (s *Session) Mail(from string, opts smtp.MailOptions) error {
} }
func (s *Session) fetchRDNSName(ctx context.Context) { func (s *Session) fetchRDNSName(ctx context.Context) {
defer trace.StartRegion(ctx, "rDNS fetch").End()
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr) tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
if !ok { if !ok {
s.connState.RDNSName.Set(nil) s.connState.RDNSName.Set(nil)
@ -211,7 +221,8 @@ func (s *Session) Rcpt(to string) error {
} }
} }
rcptCtx := s.msgCtx rcptCtx, rcptTask := trace.NewTask(s.msgCtx, "RCPT TO")
defer rcptTask.End()
if err := s.rcpt(rcptCtx, to); err != nil { if err := s.rcpt(rcptCtx, to); err != nil {
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors { if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
@ -293,7 +304,8 @@ func (s *Session) prepareBody(ctx context.Context, r io.Reader) (textproto.Heade
} }
func (s *Session) Data(r io.Reader) error { func (s *Session) Data(r io.Reader) error {
bodyCtx := s.msgCtx bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
defer bodyTask.End()
wrapErr := func(err error) error { wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID) s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
@ -317,6 +329,9 @@ func (s *Session) Data(r io.Reader) error {
// go-smtp will call Reset, but it will call Abort if delivery is non-nil. // go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil s.delivery = nil
s.msgCtx = nil
s.msgTask.End()
s.msgTask = nil
s.endp.semaphore.Release() s.endp.semaphore.Release()
return nil return nil
@ -332,7 +347,8 @@ func (sw statusWrapper) SetStatus(rcpt string, err error) {
} }
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error { func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
bodyCtx := s.msgCtx bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
defer bodyTask.End()
wrapErr := func(err error) error { wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID) s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
@ -356,6 +372,9 @@ func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
// go-smtp will call Reset, but it will call Abort if delivery is non-nil. // go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil s.delivery = nil
s.msgCtx = nil
s.msgTask.End()
s.msgTask = nil
s.endp.semaphore.Release() s.endp.semaphore.Release()
return nil return nil
@ -671,7 +690,7 @@ func (endp *Endpoint) newSession(anonymous bool, username, password string, stat
} }
if endp.resolver != nil { if endp.resolver != nil {
rdnsCtx, cancelRDNS := context.WithCancel(context.TODO()) rdnsCtx, cancelRDNS := context.WithCancel(s.sessionCtx)
s.connState.RDNSName = future.New() s.connState.RDNSName = future.New()
s.cancelRDNS = cancelRDNS s.cancelRDNS = cancelRDNS
go s.fetchRDNSName(rdnsCtx) go s.fetchRDNSName(rdnsCtx)

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"net/mail" "net/mail"
"path/filepath" "path/filepath"
"runtime/trace"
"strings" "strings"
"time" "time"
@ -337,6 +338,8 @@ func (s state) RewriteRcpt(ctx context.Context, rcptTo string) (string, error) {
} }
func (s state) RewriteBody(ctx context.Context, h *textproto.Header, body buffer.Buffer) error { func (s state) RewriteBody(ctx context.Context, h *textproto.Header, body buffer.Buffer) error {
defer trace.StartRegion(ctx, "sign_dkim/RewriteBody").End()
var authUser string var authUser string
if s.meta.Conn != nil { if s.meta.Conn != nil {
authUser = s.meta.Conn.AuthUser authUser = s.meta.Conn.AuthUser

View file

@ -11,10 +11,7 @@ import (
func objectName(x interface{}) string { func objectName(x interface{}) string {
mod, ok := x.(module.Module) mod, ok := x.(module.Module)
if ok { if ok {
if mod.InstanceName() == "" { return mod.Name() + ":" + mod.InstanceName()
return mod.Name()
}
return mod.InstanceName()
} }
_, pipeline := x.(*MsgPipeline) _, pipeline := x.(*MsgPipeline)
@ -22,5 +19,10 @@ func objectName(x interface{}) string {
return "reroute" return "reroute"
} }
str, ok := x.(fmt.Stringer)
if ok {
return str.String()
}
return fmt.Sprintf("%T", x) return fmt.Sprintf("%T", x)
} }

View file

@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"runtime/trace"
"time" "time"
"github.com/foxcpp/maddy/internal/exterrors" "github.com/foxcpp/maddy/internal/exterrors"
@ -23,10 +24,14 @@ var httpClient = &http.Client{
Timeout: time.Minute, Timeout: time.Minute,
} }
func downloadPolicy(domain string) (*Policy, error) { func downloadPolicy(ctx context.Context, domain string) (*Policy, error) {
// TODO: Consult OCSP/CRL to detect revoked certificates? // TODO: Consult OCSP/CRL to detect revoked certificates?
resp, err := httpClient.Get("https://mta-sts." + domain + "/.well-known/mta-sts.txt") req, err := http.NewRequestWithContext(ctx, "GET", "https://mta-sts."+domain+"/.well-known/mta-sts.txt", nil)
if err != nil {
return nil, err
}
resp, err := httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -80,8 +85,8 @@ var ErrIgnorePolicy = errors.New("mtasts: policy ignored due to errors")
// Get reads policy from cache or tries to fetch it from Policy Host. // Get reads policy from cache or tries to fetch it from Policy Host.
// //
// The domain is assumed to be normalized, as done by dns.ForLookup. // The domain is assumed to be normalized, as done by dns.ForLookup.
func (c *Cache) Get(domain string) (*Policy, error) { func (c *Cache) Get(ctx context.Context, domain string) (*Policy, error) {
_, p, err := c.fetch(false, time.Now(), domain) _, p, err := c.fetch(ctx, false, time.Now(), domain)
return p, err return p, err
} }
@ -118,6 +123,9 @@ func (c *Cache) load(domain string) (id string, fetchTime time.Time, p *Policy,
} }
func (c *Cache) Refresh() error { func (c *Cache) Refresh() error {
refreshCtx, refreshTask := trace.NewTask(context.Background(), "mtasts.Cache/Refresh")
defer refreshTask.End()
dir, err := ioutil.ReadDir(c.Location) dir, err := ioutil.ReadDir(c.Location)
if err != nil { if err != nil {
return err return err
@ -132,7 +140,7 @@ func (c *Cache) Refresh() error {
// Since otherwise we are going to have expired policy for another 6 hours, // Since otherwise we are going to have expired policy for another 6 hours,
// which makes it useless. // which makes it useless.
// See https://tools.ietf.org/html/rfc8461#section-10.2. // See https://tools.ietf.org/html/rfc8461#section-10.2.
cacheHit, _, err := c.fetch(true, time.Now().Add(6*time.Hour), ent.Name()) cacheHit, _, err := c.fetch(refreshCtx, false, time.Now().Add(6*time.Hour), ent.Name())
if err != nil && err != ErrIgnorePolicy { if err != nil && err != ErrIgnorePolicy {
c.Logger.Error("policy update error", err, "domain", ent.Name()) c.Logger.Error("policy update error", err, "domain", ent.Name())
} }
@ -147,7 +155,9 @@ func (c *Cache) Refresh() error {
return nil return nil
} }
func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) { func (c *Cache) fetch(ctx context.Context, ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) {
defer trace.StartRegion(ctx, "mtasts.Cache/fetch").End()
validCache := true validCache := true
cachedId, fetchTime, cachedPolicy, err := c.load(domain) cachedId, fetchTime, cachedPolicy, err := c.load(domain)
if err != nil { if err != nil {
@ -163,7 +173,7 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo
var dnsId string var dnsId string
if !ignoreDns { if !ignoreDns {
records, err := c.Resolver.LookupTXT(context.Background(), "_mta-sts."+domain) records, err := c.Resolver.LookupTXT(ctx, "_mta-sts."+domain)
if err != nil { if err != nil {
if validCache { if validCache {
if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsNotFound { if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsNotFound {
@ -218,11 +228,15 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo
} }
if !validCache || dnsId != cachedId { if !validCache || dnsId != cachedId {
download := downloadPolicy var (
policy *Policy
err error
)
if c.downloadPolicy != nil { if c.downloadPolicy != nil {
download = c.downloadPolicy policy, err = c.downloadPolicy(domain)
} else {
policy, err = downloadPolicy(ctx, domain)
} }
policy, err := download(domain)
if err != nil { if err != nil {
if validCache { if validCache {
c.Logger.Error("failed to fetch new policy, using cache", err, "domain", domain) c.Logger.Error("failed to fetch new policy, using cache", err, "domain", domain)

View file

@ -1,6 +1,7 @@
package mtasts package mtasts
import ( import (
"context"
"errors" "errors"
"os" "os"
"reflect" "reflect"
@ -37,7 +38,7 @@ func TestCacheGet(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org") policy, err := c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -57,7 +58,7 @@ func TestCacheGet_Error_DNS(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
_, err := c.Get("example.org") _, err := c.Get(context.Background(), "example.org")
if err != ErrIgnorePolicy { if err != ErrIgnorePolicy {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -78,7 +79,7 @@ func TestCacheGet_Error_HTTPS(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
_, err := c.Get("example.org") _, err := c.Get(context.Background(), "example.org")
if err != ErrIgnorePolicy { if err != ErrIgnorePolicy {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -104,7 +105,7 @@ func TestCacheGet_Cached(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org") policy, err := c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -115,7 +116,7 @@ func TestCacheGet_Cached(t *testing.T) {
// calling downloadPolicy. // calling downloadPolicy.
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
policy, err = c.Get("example.org") policy, err = c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -146,7 +147,7 @@ func TestCacheGet_Expired(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org") policy, err := c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -160,7 +161,7 @@ func TestCacheGet_Expired(t *testing.T) {
expectedPolicy.MX = []string{"b"} expectedPolicy.MX = []string{"b"}
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil) c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
policy, err = c.Get("example.org") policy, err = c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -190,7 +191,7 @@ func TestCacheGet_IDChange(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org") policy, err := c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -206,7 +207,7 @@ func TestCacheGet_IDChange(t *testing.T) {
expectedPolicy.MX = []string{"b"} expectedPolicy.MX = []string{"b"}
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil) c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
policy, err = c.Get("example.org") policy, err = c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -236,7 +237,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org") policy, err := c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -251,7 +252,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) {
resolver.Zones = nil resolver.Zones = nil
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
policy, err = c.Get("example.org") policy, err = c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -281,7 +282,7 @@ func TestCacheGet_HTTPGet_ErrNoPolicy(t *testing.T) {
// >(for any reason), and there is no valid (non-expired) previously // >(for any reason), and there is no valid (non-expired) previously
// >cached policy, senders MUST continue with delivery as though the // >cached policy, senders MUST continue with delivery as though the
// >domain has not implemented MTA-STS. // >domain has not implemented MTA-STS.
_, err := c.Get("example.org") _, err := c.Get(context.Background(), "example.org")
if err != ErrIgnorePolicy { if err != ErrIgnorePolicy {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -308,7 +309,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org") policy, err := c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -325,7 +326,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) {
} }
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
policy, err = c.Get("example.org") policy, err = c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -357,7 +358,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org") policy, err := c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -376,7 +377,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) {
} }
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
policy, err = c.Get("example.org") policy, err = c.Get(context.Background(), "example.org")
if err == nil { if err == nil {
t.Fatalf("expected error, got policy %v", policy) t.Fatalf("expected error, got policy %v", policy)
} }
@ -405,7 +406,7 @@ func TestCacheRefresh(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org") policy, err := c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -427,7 +428,7 @@ func TestCacheRefresh(t *testing.T) {
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
// It should return the new record from cache. // It should return the new record from cache.
policy, err = c.Get("example.org") policy, err = c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -459,7 +460,7 @@ func TestCacheRefresh_Error(t *testing.T) {
} }
defer os.RemoveAll(c.Location) defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org") policy, err := c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }
@ -477,7 +478,7 @@ func TestCacheRefresh_Error(t *testing.T) {
} }
// It should return the old record from cache. // It should return the old record from cache.
policy, err = c.Get("example.org") policy, err = c.Get(context.Background(), "example.org")
if err != nil { if err != nil {
t.Fatalf("policy get: %v", err) t.Fatalf("policy get: %v", err)
} }

View file

@ -10,9 +10,11 @@
package smtpconn package smtpconn
import ( import (
"context"
"crypto/tls" "crypto/tls"
"io" "io"
"net" "net"
"runtime/trace"
"github.com/emersion/go-message/textproto" "github.com/emersion/go-message/textproto"
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
@ -27,9 +29,9 @@ import (
// //
// Currently, the C object represents one session and cannot be reused. // Currently, the C object represents one session and cannot be reused.
type C struct { type C struct {
// Dialer to use to estabilish new network connections. Set to net.Dial by // Dialer to use to estabilish new network connections. Set to net.Dialer
// New. // DialContext by New.
Dialer func(network, addr string) (net.Conn, error) Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Fail if the connection cannot use TLS. // Fail if the connection cannot use TLS.
RequireTLS bool RequireTLS bool
@ -60,7 +62,7 @@ type C struct {
// with resonable default values. // with resonable default values.
func New() *C { func New() *C {
return &C{ return &C{
Dialer: net.Dial, Dialer: (&net.Dialer{}).DialContext,
TLSConfig: &tls.Config{}, TLSConfig: &tls.Config{},
Hostname: "localhost.localdomain", Hostname: "localhost.localdomain",
} }
@ -122,9 +124,11 @@ func (c *C) wrapClientErr(err error, serverName string) error {
} }
// Connect actually estabilishes the network connection with the remote host. // Connect actually estabilishes the network connection with the remote host.
func (c *C) Connect(endp config.Endpoint) error { func (c *C) Connect(ctx context.Context, endp config.Endpoint) error {
defer trace.StartRegion(ctx, "smtpconn/Connect+TLS").End()
// TODO: Helper function to try multiple endpoints? // TODO: Helper function to try multiple endpoints?
cl, err := c.attemptConnect(endp, c.AttemptTLS) cl, err := c.attemptConnect(ctx, endp, c.AttemptTLS)
if err != nil { if err != nil {
return c.wrapClientErr(err, endp.Host) return c.wrapClientErr(err, endp.Host)
} }
@ -134,9 +138,9 @@ func (c *C) Connect(endp config.Endpoint) error {
return nil return nil
} }
func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) { func (c *C) attemptConnect(ctx context.Context, endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) {
var conn net.Conn var conn net.Conn
conn, err := c.Dialer(endp.Network(), endp.Address()) conn, err := c.Dialer(ctx, endp.Network(), endp.Address())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -197,7 +201,7 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client,
// Re-attempt without STARTTLS. It is not possible to reuse connection // Re-attempt without STARTTLS. It is not possible to reuse connection
// since it is probably in a bad state. // since it is probably in a bad state.
c.Log.Error("TLS error, falling back to plain-text connection", err, "remote_server", endp.Host+endp.Port) c.Log.Error("TLS error, falling back to plain-text connection", err, "remote_server", endp.Host+endp.Port)
return c.attemptConnect(endp, false) return c.attemptConnect(ctx, endp, false)
} }
return cl, nil return cl, nil
@ -209,7 +213,9 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client,
// SMTPUTF8 is forwarded if supported by the remote server, if it is not // SMTPUTF8 is forwarded if supported by the remote server, if it is not
// supported - attempt will be done to convert addresses to the ASCII form, if // supported - attempt will be done to convert addresses to the ASCII form, if
// this is not possible, the corresponding method (Mail or Rcpt) will fail. // this is not possible, the corresponding method (Mail or Rcpt) will fail.
func (c *C) Mail(from string, opts smtp.MailOptions) error { func (c *C) Mail(ctx context.Context, from string, opts smtp.MailOptions) error {
defer trace.StartRegion(ctx, "smtpconn/MAIL FROM").End()
outOpts := smtp.MailOptions{ outOpts := smtp.MailOptions{
// Future extensions may add additional fields that should not be // Future extensions may add additional fields that should not be
// copied blindly. So we copy only fields we know should be handled // copied blindly. So we copy only fields we know should be handled
@ -268,7 +274,9 @@ func (c *C) Client() *smtp.Client {
// //
// If the address is non-ASCII and cannot be converted to ASCII and the remote // If the address is non-ASCII and cannot be converted to ASCII and the remote
// server does not support SMTPUTF8, error will be returned. // server does not support SMTPUTF8, error will be returned.
func (c *C) Rcpt(to string) error { func (c *C) Rcpt(ctx context.Context, to string) error {
defer trace.StartRegion(ctx, "smtpconn/RCPT TO").End()
// If necessary, the extension flag is enabled in Start. // If necessary, the extension flag is enabled in Start.
if ok, _ := c.cl.Extension("SMTPUTF8"); !address.IsASCII(to) && !ok { if ok, _ := c.cl.Extension("SMTPUTF8"); !address.IsASCII(to) && !ok {
var err error var err error
@ -300,7 +308,9 @@ func (c *C) Rcpt(to string) error {
// //
// If the Data command fails, the connection may be in a unclean state (e.g. in // If the Data command fails, the connection may be in a unclean state (e.g. in
// the middle of message data stream). It is not safe to continue using it. // the middle of message data stream). It is not safe to continue using it.
func (c *C) Data(hdr textproto.Header, body io.Reader) error { func (c *C) Data(ctx context.Context, hdr textproto.Header, body io.Reader) error {
defer trace.StartRegion(ctx, "smtpconn/DATA").End()
wc, err := c.cl.Data() wc, err := c.cl.Data()
if err != nil { if err != nil {
return c.wrapClientErr(err, c.serverName) return c.wrapClientErr(err, c.serverName)

View file

@ -1,6 +1,7 @@
package smtpconn package smtpconn
import ( import (
"context"
"strings" "strings"
"testing" "testing"
@ -14,11 +15,11 @@ import (
func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.MailOptions) error { func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.MailOptions) error {
t.Helper() t.Helper()
if err := conn.Mail(from, opts); err != nil { if err := conn.Mail(context.Background(), from, opts); err != nil {
return err return err
} }
for _, rcpt := range to { for _, rcpt := range to {
if err := conn.Rcpt(rcpt); err != nil { if err := conn.Rcpt(context.Background(), rcpt); err != nil {
return err return err
} }
} }
@ -26,7 +27,7 @@ func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.M
hdr := textproto.Header{} hdr := textproto.Header{}
hdr.Add("B", "2") hdr.Add("B", "2")
hdr.Add("A", "1") hdr.Add("A", "1")
if err := conn.Data(hdr, strings.NewReader("foobar\n")); err != nil { if err := conn.Data(context.Background(), hdr, strings.NewReader("foobar\n")); err != nil {
return err return err
} }
@ -41,7 +42,7 @@ func TestSMTPUTF8_Sender_UTF8_Punycode(t *testing.T) {
c := New() c := New()
c.Log = testutils.Logger(t, "smtp_downstream") c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{ if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp", Scheme: "tcp",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: testPort, Port: testPort,
@ -70,7 +71,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Punycode(t *testing.T) {
c := New() c := New()
c.Log = testutils.Logger(t, "smtp_downstream") c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{ if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp", Scheme: "tcp",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: testPort, Port: testPort,
@ -99,7 +100,7 @@ func TestSMTPUTF8_Sender_UTF8_Reject(t *testing.T) {
c := New() c := New()
c.Log = testutils.Logger(t, "smtp_downstream") c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{ if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp", Scheme: "tcp",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: testPort, Port: testPort,
@ -122,7 +123,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Reject(t *testing.T) {
c := New() c := New()
c.Log = testutils.Logger(t, "smtp_downstream") c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{ if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp", Scheme: "tcp",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: testPort, Port: testPort,
@ -144,7 +145,7 @@ func TestSMTPUTF8_Sender_UTF8_Domain(t *testing.T) {
c := New() c := New()
c.Log = testutils.Logger(t, "smtp_downstream") c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{ if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp", Scheme: "tcp",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: testPort, Port: testPort,
@ -172,7 +173,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Domain(t *testing.T) {
c := New() c := New()
c.Log = testutils.Logger(t, "smtp_downstream") c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{ if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp", Scheme: "tcp",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: testPort, Port: testPort,
@ -201,7 +202,7 @@ func TestSMTPUTF8_Sender_UTF8_Username(t *testing.T) {
c := New() c := New()
c.Log = testutils.Logger(t, "smtp_downstream") c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{ if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp", Scheme: "tcp",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: testPort, Port: testPort,
@ -230,7 +231,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Username(t *testing.T) {
c := New() c := New()
c.Log = testutils.Logger(t, "smtp_downstream") c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{ if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp", Scheme: "tcp",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: testPort, Port: testPort,

View file

@ -15,6 +15,7 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"runtime/trace"
"strconv" "strconv"
"strings" "strings"
@ -63,7 +64,13 @@ type delivery struct {
addedRcpts map[string]struct{} addedRcpts map[string]struct{}
} }
func (d *delivery) String() string {
return d.store.Name() + ":" + d.store.InstanceName()
}
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error { func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
defer trace.StartRegion(ctx, "sql/AddRcpt").End()
accountName, err := prepareUsername(rcptTo) accountName, err := prepareUsername(rcptTo)
if err != nil { if err != nil {
return &exterrors.SMTPError{ return &exterrors.SMTPError{
@ -104,6 +111,8 @@ func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
} }
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error { func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
defer trace.StartRegion(ctx, "sql/Body").End()
if d.msgMeta.Quarantine { if d.msgMeta.Quarantine {
if err := d.d.SpecialMailbox(specialuse.Junk, d.store.junkMbox); err != nil { if err := d.d.SpecialMailbox(specialuse.Junk, d.store.junkMbox); err != nil {
return err return err
@ -116,14 +125,20 @@ func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffe
} }
func (d *delivery) Abort(ctx context.Context) error { func (d *delivery) Abort(ctx context.Context) error {
defer trace.StartRegion(ctx, "sql/Abort").End()
return d.d.Abort() return d.d.Abort()
} }
func (d *delivery) Commit(ctx context.Context) error { func (d *delivery) Commit(ctx context.Context) error {
defer trace.StartRegion(ctx, "sql/Commit").End()
return d.d.Commit() return d.d.Commit()
} }
func (store *Storage) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) { func (store *Storage) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
defer trace.StartRegion(ctx, "sql/Start").End()
return &delivery{ return &delivery{
store: store, store: store,
msgMeta: msgMeta, msgMeta: msgMeta,
@ -360,6 +375,9 @@ func prepareUsername(username string) (string, error) {
} }
func (store *Storage) CheckPlain(username, password string) bool { func (store *Storage) CheckPlain(username, password string) bool {
// TODO: Pass session context there.
defer trace.StartRegion(context.Background(), "sql/CheckPlain").End()
accountName, err := prepareUsername(username) accountName, err := prepareUsername(username)
if err != nil { if err != nil {
return false return false

View file

@ -50,6 +50,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime/debug" "runtime/debug"
"runtime/trace"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -438,10 +439,12 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
msgMeta.ID = msgMeta.ID + "-" + strconv.Itoa(meta.TriesCount+1) msgMeta.ID = msgMeta.ID + "-" + strconv.Itoa(meta.TriesCount+1)
dl.Debugf("using message ID = %s", msgMeta.ID) dl.Debugf("using message ID = %s", msgMeta.ID)
// TODO: Trace annotations? msgCtx, msgTask := trace.NewTask(context.Background(), "Queue delivery")
msgCtx := context.Background() defer msgTask.End()
delivery, err := q.Target.Start(msgCtx, msgMeta, meta.From) mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM")
delivery, err := q.Target.Start(mailCtx, msgMeta, meta.From)
mailTask.End()
if err != nil { if err != nil {
dl.Debugf("target.Start failed: %v", err) dl.Debugf("target.Start failed: %v", err)
perr.Failed = append(perr.Failed, meta.To...) perr.Failed = append(perr.Failed, meta.To...)
@ -454,7 +457,8 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
var acceptedRcpts []string var acceptedRcpts []string
for _, rcpt := range meta.To { for _, rcpt := range meta.To {
if err := delivery.AddRcpt(msgCtx, rcpt); err != nil { rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO")
if err := delivery.AddRcpt(rcptCtx, rcpt); err != nil {
if exterrors.IsTemporaryOrUnspec(err) { if exterrors.IsTemporaryOrUnspec(err) {
perr.TemporaryFailed = append(perr.TemporaryFailed, rcpt) perr.TemporaryFailed = append(perr.TemporaryFailed, rcpt)
} else { } else {
@ -466,6 +470,7 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
dl.Debugf("delivery.AddRcpt %s OK", rcpt) dl.Debugf("delivery.AddRcpt %s OK", rcpt)
acceptedRcpts = append(acceptedRcpts, rcpt) acceptedRcpts = append(acceptedRcpts, rcpt)
} }
rcptTask.End()
} }
if len(acceptedRcpts) == 0 { if len(acceptedRcpts) == 0 {
@ -487,12 +492,15 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
} }
} }
bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA")
defer bodyTask.End()
partDelivery, ok := delivery.(module.PartialDelivery) partDelivery, ok := delivery.(module.PartialDelivery)
if ok { if ok {
dl.Debugf("using delivery.BodyNonAtomic") dl.Debugf("using delivery.BodyNonAtomic")
partDelivery.BodyNonAtomic(msgCtx, &perr, header, body) partDelivery.BodyNonAtomic(bodyCtx, &perr, header, body)
} else { } else {
if err := delivery.Body(msgCtx, header, body); err != nil { if err := delivery.Body(bodyCtx, header, body); err != nil {
dl.Debugf("delivery.Body failed: %v", err) dl.Debugf("delivery.Body failed: %v", err)
expandToPartialErr(err) expandToPartialErr(err)
} }
@ -508,13 +516,13 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
if allFailed { if allFailed {
// No recipients succeeded. // No recipients succeeded.
dl.Debugf("delivery.Abort (all recipients failed)") dl.Debugf("delivery.Abort (all recipients failed)")
if err := delivery.Abort(msgCtx); err != nil { if err := delivery.Abort(bodyCtx); err != nil {
dl.Msg("delivery.Abort failed", err) dl.Msg("delivery.Abort failed", err)
} }
return perr return perr
} }
if err := delivery.Commit(msgCtx); err != nil { if err := delivery.Commit(bodyCtx); err != nil {
dl.Debugf("delivery.Commit failed: %v", err) dl.Debugf("delivery.Commit failed: %v", err)
expandToPartialErr(err) expandToPartialErr(err)
} }
@ -537,6 +545,8 @@ func (qd *queueDelivery) AddRcpt(ctx context.Context, rcptTo string) error {
} }
func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error { func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
defer trace.StartRegion(ctx, "queue/Body").End()
// Body buffer initially passed to us may not be valid after "delivery" to queue completes. // Body buffer initially passed to us may not be valid after "delivery" to queue completes.
// storeNewMessage returns a new buffer object created from message blob stored on disk. // storeNewMessage returns a new buffer object created from message blob stored on disk.
storedBody, err := qd.q.storeNewMessage(qd.meta, header, body) storedBody, err := qd.q.storeNewMessage(qd.meta, header, body)
@ -550,6 +560,8 @@ func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body
} }
func (qd *queueDelivery) Abort(ctx context.Context) error { func (qd *queueDelivery) Abort(ctx context.Context) error {
defer trace.StartRegion(ctx, "queue/Abort").End()
if qd.body != nil { if qd.body != nil {
qd.q.removeFromDisk(qd.meta.MsgMeta) qd.q.removeFromDisk(qd.meta.MsgMeta)
} }
@ -557,6 +569,8 @@ func (qd *queueDelivery) Abort(ctx context.Context) error {
} }
func (qd *queueDelivery) Commit(ctx context.Context) error { func (qd *queueDelivery) Commit(ctx context.Context) error {
defer trace.StartRegion(ctx, "queue/Commit").End()
if qd.meta == nil { if qd.meta == nil {
panic("queue: double Commit") panic("queue: double Commit")
} }
@ -899,14 +913,17 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
} }
dl.Msg("generated failed DSN", "dsn_id", dsnID) dl.Msg("generated failed DSN", "dsn_id", dsnID)
// TODO: Trace annotations? msgCtx, msgTask := trace.NewTask(context.Background(), "DSN Delivery")
msgCtx := context.Background() defer msgTask.End()
dsnDelivery, err := q.dsnPipeline.Start(msgCtx, dsnMeta, "") mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM")
dsnDelivery, err := q.dsnPipeline.Start(mailCtx, dsnMeta, "")
mailTask.End()
if err != nil { if err != nil {
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID) dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
return return
} }
defer func() { defer func() {
if err != nil { if err != nil {
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID) dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
@ -916,15 +933,23 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
} }
}() }()
if err = dsnDelivery.AddRcpt(msgCtx, meta.From); err != nil { rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO")
if err = dsnDelivery.AddRcpt(rcptCtx, meta.From); err != nil {
rcptTask.End()
return return
} }
if err = dsnDelivery.Body(msgCtx, dsnHeader, dsnBody); err != nil { rcptTask.End()
bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA")
if err = dsnDelivery.Body(bodyCtx, dsnHeader, dsnBody); err != nil {
bodyTask.End()
return return
} }
if err = dsnDelivery.Commit(msgCtx); err != nil { if err = dsnDelivery.Commit(bodyCtx); err != nil {
bodyTask.End()
return return
} }
bodyTask.End()
} }
func init() { func init() {

View file

@ -1,6 +1,7 @@
package remote package remote
import ( import (
"context"
"errors" "errors"
"net" "net"
"strconv" "strconv"
@ -33,7 +34,7 @@ func TestRemoteDelivery_AuthMX_Fail(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireMXAuth: true, requireMXAuth: true,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
@ -63,12 +64,12 @@ func TestRemoteDelivery_AuthMX_MTASTS(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireMXAuth: true, requireMXAuth: true,
tlsConfig: clientCfg, tlsConfig: clientCfg,
mxAuth: map[string]struct{}{AuthMTASTS: {}}, mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) { mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" { if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup") return nil, errors.New("Wrong domain in lookup")
} }
@ -114,12 +115,12 @@ func TestRemoteDelivery_AuthMX_PreferAuth(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireMXAuth: true, requireMXAuth: true,
tlsConfig: clientCfg, tlsConfig: clientCfg,
mxAuth: map[string]struct{}{AuthMTASTS: {}}, mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) { mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" { if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup") return nil, errors.New("Wrong domain in lookup")
} }
@ -166,12 +167,12 @@ func TestRemoteDelivery_MTASTS_SkipNonMatching(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireMXAuth: true, requireMXAuth: true,
tlsConfig: clientCfg, tlsConfig: clientCfg,
mxAuth: map[string]struct{}{AuthMTASTS: {}}, mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) { mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" { if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup") return nil, errors.New("Wrong domain in lookup")
} }
@ -208,11 +209,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_Fail(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireMXAuth: true, requireMXAuth: true,
mxAuth: map[string]struct{}{AuthMTASTS: {}}, mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) { mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" { if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup") return nil, errors.New("Wrong domain in lookup")
} }
@ -251,11 +252,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_NoPolicy(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireMXAuth: true, requireMXAuth: true,
mxAuth: map[string]struct{}{AuthMTASTS: {}}, mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) { mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" { if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup") return nil, errors.New("Wrong domain in lookup")
} }
@ -290,7 +291,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireMXAuth: true, requireMXAuth: true,
mxAuth: map[string]struct{}{AuthCommonDomain: {}}, mxAuth: map[string]struct{}{AuthCommonDomain: {}},
@ -321,7 +322,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_Fail(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireMXAuth: true, requireMXAuth: true,
mxAuth: map[string]struct{}{AuthCommonDomain: {}}, mxAuth: map[string]struct{}{AuthCommonDomain: {}},
@ -354,7 +355,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_NotETLDp1(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireMXAuth: true, requireMXAuth: true,
mxAuth: map[string]struct{}{AuthCommonDomain: {}}, mxAuth: map[string]struct{}{AuthCommonDomain: {}},
@ -405,7 +406,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: extResolver, extResolver: extResolver,
requireMXAuth: true, requireMXAuth: true,
mxAuth: map[string]struct{}{AuthDNSSEC: {}}, mxAuth: map[string]struct{}{AuthDNSSEC: {}},
@ -452,7 +453,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC_Fail(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: extResolver, extResolver: extResolver,
requireMXAuth: true, requireMXAuth: true,
mxAuth: map[string]struct{}{AuthDNSSEC: {}}, mxAuth: map[string]struct{}{AuthDNSSEC: {}},
@ -506,7 +507,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &resolver, resolver: &resolver,
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: extResolver, extResolver: extResolver,
requireMXAuth: true, requireMXAuth: true,
mxAuth: map[string]struct{}{AuthDNSSEC: {}}, mxAuth: map[string]struct{}{AuthDNSSEC: {}},
@ -556,7 +557,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral_Fail(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &resolver, resolver: &resolver,
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: extResolver, extResolver: extResolver,
requireMXAuth: true, requireMXAuth: true,
mxAuth: map[string]struct{}{AuthDNSSEC: {}}, mxAuth: map[string]struct{}{AuthDNSSEC: {}},

View file

@ -13,6 +13,7 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"runtime/trace"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -57,12 +58,12 @@ type Target struct {
mxAuth map[string]struct{} mxAuth map[string]struct{}
resolver dns.Resolver resolver dns.Resolver
dialer func(network, addr string) (net.Conn, error) dialer func(ctx context.Context, network, addr string) (net.Conn, error)
extResolver *dns.ExtResolver extResolver *dns.ExtResolver
// This is the callback that is usually mtastsCache.Get, // This is the callback that is usually mtastsCache.Get,
// but replaced by tests to mock mtasts.Cache. // but replaced by tests to mock mtasts.Cache.
mtastsGet func(domain string) (*mtasts.Policy, error) mtastsGet func(ctx context.Context, domain string) (*mtasts.Policy, error)
mtastsCache mtasts.Cache mtastsCache mtasts.Cache
stsCacheUpdateTick *time.Ticker stsCacheUpdateTick *time.Ticker
@ -80,7 +81,7 @@ func New(_, instName string, _, inlineArgs []string) (module.Module, error) {
return &Target{ return &Target{
name: instName, name: instName,
resolver: dns.DefaultResolver(), resolver: dns.DefaultResolver(),
dialer: net.Dial, dialer: (&net.Dialer{}).DialContext,
mtastsCache: mtasts.Cache{Resolver: dns.DefaultResolver()}, mtastsCache: mtasts.Cache{Resolver: dns.DefaultResolver()},
Log: log.Logger{Name: "remote"}, Log: log.Logger{Name: "remote"},
@ -195,6 +196,8 @@ func (rt *Target) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFr
} }
func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error { func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error {
defer trace.StartRegion(ctx, "remote/AddRcpt").End()
if rd.msgMeta.Quarantine { if rd.msgMeta.Quarantine {
return &exterrors.SMTPError{ return &exterrors.SMTPError{
Code: 550, Code: 550,
@ -232,7 +235,7 @@ func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error {
return err return err
} }
if err := conn.Rcpt(to); err != nil { if err := conn.Rcpt(ctx, to); err != nil {
return moduleError(err) return moduleError(err)
} }
@ -379,6 +382,8 @@ func (m *multipleErrs) SetStatus(rcptTo string, err error) {
} }
func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buffer buffer.Buffer) error { func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buffer buffer.Buffer) error {
defer trace.StartRegion(ctx, "remote/Body").End()
merr := multipleErrs{ merr := multipleErrs{
errs: make(map[string]error), errs: make(map[string]error),
} }
@ -396,6 +401,8 @@ func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buf
} }
func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusCollector, header textproto.Header, b buffer.Buffer) { func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusCollector, header textproto.Header, b buffer.Buffer) {
defer trace.StartRegion(ctx, "remote/BodyNonAtomic").End()
if rd.msgMeta.Quarantine { if rd.msgMeta.Quarantine {
for _, rcpt := range rd.recipients { for _, rcpt := range rd.recipients {
c.SetStatus(rcpt, &exterrors.SMTPError{ c.SetStatus(rcpt, &exterrors.SMTPError{
@ -425,7 +432,7 @@ func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusColl
} }
defer bodyR.Close() defer bodyR.Close()
err = conn.Data(header, bodyR) err = conn.Data(ctx, header, bodyR)
for _, rcpt := range conn.Rcpts() { for _, rcpt := range conn.Rcpts() {
c.SetStatus(rcpt, err) c.SetStatus(rcpt, err)
} }
@ -486,7 +493,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
addrs = append(addrs, nonAuthMXs...) addrs = append(addrs, nonAuthMXs...)
rd.Log.DebugMsg("considering", "mxs", addrs) rd.Log.DebugMsg("considering", "mxs", addrs)
for i, addr := range addrs { for i, addr := range addrs {
err = conn.Connect(config.Endpoint{ err = conn.Connect(ctx, config.Endpoint{
Scheme: "tcp", Scheme: "tcp",
Host: addr, Host: addr,
Port: smtpPort, Port: smtpPort,
@ -517,7 +524,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
rd.Log.DebugMsg("connected", "remote_server", conn.ServerName()) rd.Log.DebugMsg("connected", "remote_server", conn.ServerName())
if err := conn.Mail(rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil { if err := conn.Mail(ctx, rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil {
conn.Close() conn.Close()
return nil, moduleError(err) return nil, moduleError(err)
} }
@ -526,8 +533,8 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
return conn, nil return conn, nil
} }
func (rt *Target) getSTSPolicy(domain string) (*mtasts.Policy, error) { func (rt *Target) getSTSPolicy(ctx context.Context, domain string) (*mtasts.Policy, error) {
stsPolicy, err := rt.mtastsGet(domain) stsPolicy, err := rt.mtastsGet(ctx, domain)
if err != nil && !mtasts.IsNoPolicy(err) { if err != nil && !mtasts.IsNoPolicy(err) {
return nil, &exterrors.SMTPError{ return nil, &exterrors.SMTPError{
Code: exterrors.SMTPCode(err, 450, 554), Code: exterrors.SMTPCode(err, 450, 554),
@ -568,9 +575,11 @@ func (rt *Target) stsCacheUpdater() {
} }
func (rd *remoteDelivery) lookupAndFilter(ctx context.Context, domain string) (authMXs, nonAuthMXs []string, requireTLS bool, err error) { func (rd *remoteDelivery) lookupAndFilter(ctx context.Context, domain string) (authMXs, nonAuthMXs []string, requireTLS bool, err error) {
defer trace.StartRegion(ctx, "remote/LookupAndFilterMX").End()
var policy *mtasts.Policy var policy *mtasts.Policy
if _, use := rd.rt.mxAuth[AuthMTASTS]; use { if _, use := rd.rt.mxAuth[AuthMTASTS]; use {
policy, err = rd.rt.getSTSPolicy(domain) policy, err = rd.rt.getSTSPolicy(ctx, domain)
if err != nil { if err != nil {
return nil, nil, false, err return nil, nil, false, err
} }

View file

@ -42,7 +42,7 @@ func TestRemoteDelivery(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -74,7 +74,7 @@ func TestRemoteDelivery_IPLiteral(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &resolver, resolver: &resolver,
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -100,7 +100,7 @@ func TestRemoteDelivery_FallbackMX(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: resolver, resolver: resolver,
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -127,7 +127,7 @@ func TestRemoteDelivery_BodyNonAtomic(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: resolver, resolver: resolver,
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -162,7 +162,7 @@ func TestRemoteDelivery_Abort(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -199,7 +199,7 @@ func TestRemoteDelivery_CommitWithoutBody(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -243,7 +243,7 @@ func TestRemoteDelivery_MAILFROMErr(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -276,7 +276,7 @@ func TestRemoteDelivery_NoMX(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: resolver, resolver: resolver,
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -313,7 +313,7 @@ func TestRemoteDelivery_NullMX(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -349,7 +349,7 @@ func TestRemoteDelivery_Quarantined(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -404,7 +404,7 @@ func TestRemoteDelivery_MAILFROMErr_Repeated(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -451,7 +451,7 @@ func TestRemoteDelivery_RcptErr(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -509,7 +509,7 @@ func TestRemoteDelivery_DownMX(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -539,7 +539,7 @@ func TestRemoteDelivery_AllMXDown(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -577,7 +577,7 @@ func TestRemoteDelivery_Split(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -623,7 +623,7 @@ func TestRemoteDelivery_Split_Fail(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -683,7 +683,7 @@ func TestRemoteDelivery_BodyErr(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -744,7 +744,7 @@ func TestRemoteDelivery_Split_BodyErr(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -807,7 +807,7 @@ func TestRemoteDelivery_Split_BodyErr_NonAtomic(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
} }
@ -867,7 +867,7 @@ func TestRemoteDelivery_TLSErrFallback(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
tlsConfig: &tls.Config{}, tlsConfig: &tls.Config{},
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
@ -895,7 +895,7 @@ func TestRemoteDelivery_RequireTLS_Missing(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireTLS: true, requireTLS: true,
Log: testutils.Logger(t, "remote"), Log: testutils.Logger(t, "remote"),
@ -925,7 +925,7 @@ func TestRemoteDelivery_RequireTLS_Present(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
requireTLS: true, requireTLS: true,
tlsConfig: clientCfg, tlsConfig: clientCfg,
@ -954,7 +954,7 @@ func TestRemoteDelivery_RequireTLS_NoErrFallback(t *testing.T) {
name: "remote", name: "remote",
hostname: "mx.example.com", hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones}, resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial, dialer: resolver.DialContext,
extResolver: nil, extResolver: nil,
tlsConfig: &tls.Config{}, tlsConfig: &tls.Config{},
requireTLS: true, requireTLS: true,

View file

@ -44,7 +44,7 @@ func TestRemoteDelivery_EHLO_ALabel(t *testing.T) {
tgt := mod.(*Target) tgt := mod.(*Target)
tgt.resolver = &mockdns.Resolver{Zones: zones} tgt.resolver = &mockdns.Resolver{Zones: zones}
tgt.dialer = resolver.Dial tgt.dialer = resolver.DialContext
tgt.extResolver = nil tgt.extResolver = nil
tgt.Log = testutils.Logger(t, "remote") tgt.Log = testutils.Logger(t, "remote")

View file

@ -14,6 +14,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"runtime/trace"
"github.com/emersion/go-message/textproto" "github.com/emersion/go-message/textproto"
"github.com/foxcpp/maddy/internal/buffer" "github.com/foxcpp/maddy/internal/buffer"
@ -126,24 +127,26 @@ type delivery struct {
} }
func (u *Downstream) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) { func (u *Downstream) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
defer trace.StartRegion(ctx, "smtp_downstream/Start").End()
d := &delivery{ d := &delivery{
u: u, u: u,
log: target.DeliveryLogger(u.log, msgMeta), log: target.DeliveryLogger(u.log, msgMeta),
msgMeta: msgMeta, msgMeta: msgMeta,
mailFrom: mailFrom, mailFrom: mailFrom,
} }
if err := d.connect(); err != nil { if err := d.connect(ctx); err != nil {
return nil, err return nil, err
} }
if err := d.conn.Mail(mailFrom, msgMeta.SMTPOpts); err != nil { if err := d.conn.Mail(ctx, mailFrom, msgMeta.SMTPOpts); err != nil {
d.conn.Close() d.conn.Close()
return nil, err return nil, err
} }
return d, nil return d, nil
} }
func (d *delivery) connect() error { func (d *delivery) connect(ctx context.Context) error {
// TODO: Review possibility of connection pooling here. // TODO: Review possibility of connection pooling here.
var lastErr error var lastErr error
@ -156,7 +159,7 @@ func (d *delivery) connect() error {
conn.AddrInSMTPMsg = false conn.AddrInSMTPMsg = false
for _, endp := range d.u.endpoints { for _, endp := range d.u.endpoints {
err := conn.Connect(endp) err := conn.Connect(ctx, endp)
if err == nil { if err == nil {
d.log.DebugMsg("connected", "downstream_server", conn.ServerName()) d.log.DebugMsg("connected", "downstream_server", conn.ServerName())
lastErr = nil lastErr = nil
@ -191,7 +194,7 @@ func (d *delivery) connect() error {
} }
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error { func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
return moduleError(d.conn.Rcpt(rcptTo)) return moduleError(d.conn.Rcpt(ctx, rcptTo))
} }
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error { func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
@ -217,7 +220,7 @@ func (d *delivery) Commit(ctx context.Context) error {
defer d.conn.Close() defer d.conn.Close()
defer d.body.Close() defer d.body.Close()
return moduleError(d.conn.Data(d.hdr, d.body)) return moduleError(d.conn.Data(ctx, d.hdr, d.body))
} }
func init() { func init() {