diff --git a/internal/check/command/command.go b/internal/check/command/command.go index c5229a3..e73d9cc 100644 --- a/internal/check/command/command.go +++ b/internal/check/command/command.go @@ -11,6 +11,7 @@ import ( "os" "os/exec" "regexp" + "runtime/trace" "strconv" "strings" @@ -300,6 +301,9 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult { return module.CheckResult{} } + // TODO: It is not possible to distinguish different commands. + defer trace.StartRegion(ctx, "command/CheckConnection").End() + cmdName, cmdArgs := s.expandCommand("") return s.run(cmdName, cmdArgs, bytes.NewReader(nil)) } @@ -311,6 +315,8 @@ func (s *state) CheckSender(ctx context.Context, addr string) module.CheckResult return module.CheckResult{} } + defer trace.StartRegion(ctx, "command/CheckSender").End() + cmdName, cmdArgs := s.expandCommand(addr) return s.run(cmdName, cmdArgs, bytes.NewReader(nil)) } @@ -321,6 +327,7 @@ func (s *state) CheckRcpt(ctx context.Context, addr string) module.CheckResult { if s.c.stage != StageRcpt { return module.CheckResult{} } + defer trace.StartRegion(ctx, "command/CheckRcpt").End() cmdName, cmdArgs := s.expandCommand(addr) return s.run(cmdName, cmdArgs, bytes.NewReader(nil)) @@ -331,6 +338,8 @@ func (s *state) CheckBody(ctx context.Context, hdr textproto.Header, body buffer return module.CheckResult{} } + defer trace.StartRegion(ctx, "command/CheckBody").End() + cmdName, cmdArgs := s.expandCommand("") var buf bytes.Buffer diff --git a/internal/check/dkim/dkim.go b/internal/check/dkim/dkim.go index 5283e41..c98982a 100644 --- a/internal/check/dkim/dkim.go +++ b/internal/check/dkim/dkim.go @@ -6,6 +6,7 @@ import ( "errors" "io" nettextproto "net/textproto" + "runtime/trace" "strings" "github.com/emersion/go-message/textproto" @@ -96,6 +97,8 @@ func (d *dkimCheckState) CheckRcpt(ctx context.Context, rcptTo string) module.Ch } func (d *dkimCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult { + defer trace.StartRegion(ctx, "verify_dkim/CheckBody").End() + if !header.Has("DKIM-Signature") { if d.c.noSigAction.Reject || d.c.noSigAction.Quarantine { d.log.Printf("no signatures present") diff --git a/internal/check/dnsbl/dnsbl.go b/internal/check/dnsbl/dnsbl.go index 1e14caf..e189360 100644 --- a/internal/check/dnsbl/dnsbl.go +++ b/internal/check/dnsbl/dnsbl.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "runtime/trace" "strings" "github.com/emersion/go-message/textproto" @@ -323,6 +324,8 @@ func (bl *DNSBL) CheckConnection(ctx context.Context, state *smtp.ConnectionStat return nil } + defer trace.StartRegion(ctx, "dnsbl/CheckConnection (Early)").End() + ip, ok := state.RemoteAddr.(*net.TCPAddr) if !ok { bl.log.Msg("non-TCP/IP source", @@ -358,6 +361,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult { return module.CheckResult{} } + defer trace.StartRegion(ctx, "dnsbl/CheckConnection").End() + if s.msgMeta.Conn == nil { s.log.Msg("locally generated message, ignoring") return module.CheckResult{} diff --git a/internal/check/spf/spf.go b/internal/check/spf/spf.go index c4ca19d..f976847 100644 --- a/internal/check/spf/spf.go +++ b/internal/check/spf/spf.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "runtime/debug" + "runtime/trace" "blitiri.com.ar/go/spf" "github.com/emersion/go-message/textproto" @@ -267,6 +268,8 @@ func prepareMailFrom(from string) (string, error) { } func (s *state) CheckConnection(ctx context.Context) module.CheckResult { + defer trace.StartRegion(ctx, "apply_spf/CheckConnection").End() + if s.msgMeta.Conn == nil { s.log.Println("locally generated message, skipping") return module.CheckResult{} @@ -306,6 +309,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult { } }() + defer trace.StartRegion(ctx, "apply_spf/CheckConnection (Async)").End() + res, err := spf.CheckHostWithSender(ip.IP, s.msgMeta.Conn.Hostname, mailFrom) s.log.Debugf("result: %s (%v)", res, err) s.spfFetch <- spfRes{res, err} @@ -328,6 +333,8 @@ func (s *state) CheckBody(ctx context.Context, header textproto.Header, body buf return module.CheckResult{} } + defer trace.StartRegion(ctx, "apply_spf/CheckBody").End() + res, ok := <-s.spfFetch if !ok { return module.CheckResult{ diff --git a/internal/check/stateless_check.go b/internal/check/stateless_check.go index 7561bd3..012c504 100644 --- a/internal/check/stateless_check.go +++ b/internal/check/stateless_check.go @@ -3,6 +3,7 @@ package check import ( "context" "fmt" + "runtime/trace" "github.com/emersion/go-message/textproto" "github.com/foxcpp/maddy/internal/buffer" @@ -57,12 +58,18 @@ type statelessCheckState struct { msgMeta *module.MsgMetadata } +func (s *statelessCheckState) String() string { + return s.c.modName + ":" + s.c.instName +} + func (s *statelessCheckState) CheckConnection(ctx context.Context) module.CheckResult { if s.c.connCheck == nil { return module.CheckResult{} } + defer trace.StartRegion(ctx, s.c.modName+"/CheckConnection").End() originalRes := s.c.connCheck(StatelessCheckContext{ + Context: ctx, Resolver: s.c.resolver, MsgMeta: s.msgMeta, Logger: target.DeliveryLogger(s.c.logger, s.msgMeta), @@ -74,8 +81,10 @@ func (s *statelessCheckState) CheckSender(ctx context.Context, mailFrom string) if s.c.senderCheck == nil { return module.CheckResult{} } + defer trace.StartRegion(ctx, s.c.modName+"/CheckSender").End() originalRes := s.c.senderCheck(StatelessCheckContext{ + Context: ctx, Resolver: s.c.resolver, MsgMeta: s.msgMeta, Logger: target.DeliveryLogger(s.c.logger, s.msgMeta), @@ -87,8 +96,10 @@ func (s *statelessCheckState) CheckRcpt(ctx context.Context, rcptTo string) modu if s.c.rcptCheck == nil { return module.CheckResult{} } + defer trace.StartRegion(ctx, s.c.modName+"/CheckRcpt").End() originalRes := s.c.rcptCheck(StatelessCheckContext{ + Context: ctx, Resolver: s.c.resolver, MsgMeta: s.msgMeta, Logger: target.DeliveryLogger(s.c.logger, s.msgMeta), @@ -100,8 +111,10 @@ func (s *statelessCheckState) CheckBody(ctx context.Context, header textproto.He if s.c.bodyCheck == nil { return module.CheckResult{} } + defer trace.StartRegion(ctx, s.c.modName+"/CheckBody").End() originalRes := s.c.bodyCheck(StatelessCheckContext{ + Context: ctx, Resolver: s.c.resolver, MsgMeta: s.msgMeta, Logger: target.DeliveryLogger(s.c.logger, s.msgMeta), diff --git a/internal/dmarc/verifier.go b/internal/dmarc/verifier.go index bd9b9de..cdbafbd 100644 --- a/internal/dmarc/verifier.go +++ b/internal/dmarc/verifier.go @@ -4,6 +4,7 @@ import ( "context" "math/rand" "net" + "runtime/trace" "strings" "github.com/emersion/go-message/textproto" @@ -81,6 +82,8 @@ func (v *Verifier) FetchRecord(ctx context.Context, header textproto.Header) { } }() + defer trace.StartRegion(ctx, "DMARC/FetchRecord").End() + policyDomain, record, err := FetchRecord(ctx, v.resolver, fromDomain) v.fetchCh <- verifyData{ policyDomain: policyDomain, diff --git a/internal/endpoint/smtp/smtp.go b/internal/endpoint/smtp/smtp.go index 4512ba3..670bb7f 100644 --- a/internal/endpoint/smtp/smtp.go +++ b/internal/endpoint/smtp/smtp.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "runtime/trace" "strings" "sync" "time" @@ -46,6 +47,7 @@ type Session struct { // msgCtx is not used for cancellation or timeouts, only for tracing. // It is the subcontext of sessionCtx. msgCtx context.Context + msgTask *trace.Task mailFrom string opts smtp.MailOptions msgMeta *module.MsgMetadata @@ -75,6 +77,7 @@ func (s *Session) abort(ctx context.Context) { s.delivery = nil s.deliveryErr = nil s.msgCtx = nil + s.msgTask.End() } func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) { @@ -138,15 +141,20 @@ func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.Mail ) } - delivery, err := s.endp.pipeline.Start(ctx, msgMeta, cleanFrom) + s.msgCtx, s.msgTask = trace.NewTask(ctx, "Incoming Message") + mailCtx, mailTask := trace.NewTask(s.msgCtx, "MAIL FROM") + defer mailTask.End() + + delivery, err := s.endp.pipeline.Start(mailCtx, msgMeta, cleanFrom) if err != nil { + s.msgCtx = nil + s.msgTask.End() return msgMeta.ID, err } s.msgMeta = msgMeta s.mailFrom = cleanFrom s.delivery = delivery - s.msgCtx = s.sessionCtx // TODO: Trace annotations. return msgMeta.ID, nil } @@ -171,6 +179,8 @@ func (s *Session) Mail(from string, opts smtp.MailOptions) error { } func (s *Session) fetchRDNSName(ctx context.Context) { + defer trace.StartRegion(ctx, "rDNS fetch").End() + tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr) if !ok { s.connState.RDNSName.Set(nil) @@ -211,7 +221,8 @@ func (s *Session) Rcpt(to string) error { } } - rcptCtx := s.msgCtx + rcptCtx, rcptTask := trace.NewTask(s.msgCtx, "RCPT TO") + defer rcptTask.End() if err := s.rcpt(rcptCtx, to); err != nil { if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors { @@ -293,7 +304,8 @@ func (s *Session) prepareBody(ctx context.Context, r io.Reader) (textproto.Heade } func (s *Session) Data(r io.Reader) error { - bodyCtx := s.msgCtx + bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA") + defer bodyTask.End() wrapErr := func(err error) error { s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID) @@ -317,6 +329,9 @@ func (s *Session) Data(r io.Reader) error { // go-smtp will call Reset, but it will call Abort if delivery is non-nil. s.delivery = nil + s.msgCtx = nil + s.msgTask.End() + s.msgTask = nil s.endp.semaphore.Release() return nil @@ -332,7 +347,8 @@ func (sw statusWrapper) SetStatus(rcpt string, err error) { } func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error { - bodyCtx := s.msgCtx + bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA") + defer bodyTask.End() wrapErr := func(err error) error { s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID) @@ -356,6 +372,9 @@ func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error { // go-smtp will call Reset, but it will call Abort if delivery is non-nil. s.delivery = nil + s.msgCtx = nil + s.msgTask.End() + s.msgTask = nil s.endp.semaphore.Release() return nil @@ -671,7 +690,7 @@ func (endp *Endpoint) newSession(anonymous bool, username, password string, stat } if endp.resolver != nil { - rdnsCtx, cancelRDNS := context.WithCancel(context.TODO()) + rdnsCtx, cancelRDNS := context.WithCancel(s.sessionCtx) s.connState.RDNSName = future.New() s.cancelRDNS = cancelRDNS go s.fetchRDNSName(rdnsCtx) diff --git a/internal/modify/dkim/dkim.go b/internal/modify/dkim/dkim.go index fc603c1..006d889 100644 --- a/internal/modify/dkim/dkim.go +++ b/internal/modify/dkim/dkim.go @@ -7,6 +7,7 @@ import ( "io" "net/mail" "path/filepath" + "runtime/trace" "strings" "time" @@ -337,6 +338,8 @@ func (s state) RewriteRcpt(ctx context.Context, rcptTo string) (string, error) { } func (s state) RewriteBody(ctx context.Context, h *textproto.Header, body buffer.Buffer) error { + defer trace.StartRegion(ctx, "sign_dkim/RewriteBody").End() + var authUser string if s.meta.Conn != nil { authUser = s.meta.Conn.AuthUser diff --git a/internal/msgpipeline/objname.go b/internal/msgpipeline/objname.go index 089d42b..8c35e90 100644 --- a/internal/msgpipeline/objname.go +++ b/internal/msgpipeline/objname.go @@ -11,10 +11,7 @@ import ( func objectName(x interface{}) string { mod, ok := x.(module.Module) if ok { - if mod.InstanceName() == "" { - return mod.Name() - } - return mod.InstanceName() + return mod.Name() + ":" + mod.InstanceName() } _, pipeline := x.(*MsgPipeline) @@ -22,5 +19,10 @@ func objectName(x interface{}) string { return "reroute" } + str, ok := x.(fmt.Stringer) + if ok { + return str.String() + } + return fmt.Sprintf("%T", x) } diff --git a/internal/mtasts/cache.go b/internal/mtasts/cache.go index 19e9900..687601a 100644 --- a/internal/mtasts/cache.go +++ b/internal/mtasts/cache.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "path/filepath" + "runtime/trace" "time" "github.com/foxcpp/maddy/internal/exterrors" @@ -23,10 +24,14 @@ var httpClient = &http.Client{ Timeout: time.Minute, } -func downloadPolicy(domain string) (*Policy, error) { +func downloadPolicy(ctx context.Context, domain string) (*Policy, error) { // TODO: Consult OCSP/CRL to detect revoked certificates? - resp, err := httpClient.Get("https://mta-sts." + domain + "/.well-known/mta-sts.txt") + req, err := http.NewRequestWithContext(ctx, "GET", "https://mta-sts."+domain+"/.well-known/mta-sts.txt", nil) + if err != nil { + return nil, err + } + resp, err := httpClient.Do(req) if err != nil { return nil, err } @@ -80,8 +85,8 @@ var ErrIgnorePolicy = errors.New("mtasts: policy ignored due to errors") // Get reads policy from cache or tries to fetch it from Policy Host. // // The domain is assumed to be normalized, as done by dns.ForLookup. -func (c *Cache) Get(domain string) (*Policy, error) { - _, p, err := c.fetch(false, time.Now(), domain) +func (c *Cache) Get(ctx context.Context, domain string) (*Policy, error) { + _, p, err := c.fetch(ctx, false, time.Now(), domain) return p, err } @@ -118,6 +123,9 @@ func (c *Cache) load(domain string) (id string, fetchTime time.Time, p *Policy, } func (c *Cache) Refresh() error { + refreshCtx, refreshTask := trace.NewTask(context.Background(), "mtasts.Cache/Refresh") + defer refreshTask.End() + dir, err := ioutil.ReadDir(c.Location) if err != nil { return err @@ -132,7 +140,7 @@ func (c *Cache) Refresh() error { // Since otherwise we are going to have expired policy for another 6 hours, // which makes it useless. // See https://tools.ietf.org/html/rfc8461#section-10.2. - cacheHit, _, err := c.fetch(true, time.Now().Add(6*time.Hour), ent.Name()) + cacheHit, _, err := c.fetch(refreshCtx, false, time.Now().Add(6*time.Hour), ent.Name()) if err != nil && err != ErrIgnorePolicy { c.Logger.Error("policy update error", err, "domain", ent.Name()) } @@ -147,7 +155,9 @@ func (c *Cache) Refresh() error { return nil } -func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) { +func (c *Cache) fetch(ctx context.Context, ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) { + defer trace.StartRegion(ctx, "mtasts.Cache/fetch").End() + validCache := true cachedId, fetchTime, cachedPolicy, err := c.load(domain) if err != nil { @@ -163,7 +173,7 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo var dnsId string if !ignoreDns { - records, err := c.Resolver.LookupTXT(context.Background(), "_mta-sts."+domain) + records, err := c.Resolver.LookupTXT(ctx, "_mta-sts."+domain) if err != nil { if validCache { if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsNotFound { @@ -218,11 +228,15 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo } if !validCache || dnsId != cachedId { - download := downloadPolicy + var ( + policy *Policy + err error + ) if c.downloadPolicy != nil { - download = c.downloadPolicy + policy, err = c.downloadPolicy(domain) + } else { + policy, err = downloadPolicy(ctx, domain) } - policy, err := download(domain) if err != nil { if validCache { c.Logger.Error("failed to fetch new policy, using cache", err, "domain", domain) diff --git a/internal/mtasts/cache_test.go b/internal/mtasts/cache_test.go index aea66b9..593d5ea 100644 --- a/internal/mtasts/cache_test.go +++ b/internal/mtasts/cache_test.go @@ -1,6 +1,7 @@ package mtasts import ( + "context" "errors" "os" "reflect" @@ -37,7 +38,7 @@ func TestCacheGet(t *testing.T) { } defer os.RemoveAll(c.Location) - policy, err := c.Get("example.org") + policy, err := c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -57,7 +58,7 @@ func TestCacheGet_Error_DNS(t *testing.T) { } defer os.RemoveAll(c.Location) - _, err := c.Get("example.org") + _, err := c.Get(context.Background(), "example.org") if err != ErrIgnorePolicy { t.Fatalf("policy get: %v", err) } @@ -78,7 +79,7 @@ func TestCacheGet_Error_HTTPS(t *testing.T) { } defer os.RemoveAll(c.Location) - _, err := c.Get("example.org") + _, err := c.Get(context.Background(), "example.org") if err != ErrIgnorePolicy { t.Fatalf("policy get: %v", err) } @@ -104,7 +105,7 @@ func TestCacheGet_Cached(t *testing.T) { } defer os.RemoveAll(c.Location) - policy, err := c.Get("example.org") + policy, err := c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -115,7 +116,7 @@ func TestCacheGet_Cached(t *testing.T) { // calling downloadPolicy. c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) - policy, err = c.Get("example.org") + policy, err = c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -146,7 +147,7 @@ func TestCacheGet_Expired(t *testing.T) { } defer os.RemoveAll(c.Location) - policy, err := c.Get("example.org") + policy, err := c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -160,7 +161,7 @@ func TestCacheGet_Expired(t *testing.T) { expectedPolicy.MX = []string{"b"} c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil) - policy, err = c.Get("example.org") + policy, err = c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -190,7 +191,7 @@ func TestCacheGet_IDChange(t *testing.T) { } defer os.RemoveAll(c.Location) - policy, err := c.Get("example.org") + policy, err := c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -206,7 +207,7 @@ func TestCacheGet_IDChange(t *testing.T) { expectedPolicy.MX = []string{"b"} c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil) - policy, err = c.Get("example.org") + policy, err = c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -236,7 +237,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) { } defer os.RemoveAll(c.Location) - policy, err := c.Get("example.org") + policy, err := c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -251,7 +252,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) { resolver.Zones = nil c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) - policy, err = c.Get("example.org") + policy, err = c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -281,7 +282,7 @@ func TestCacheGet_HTTPGet_ErrNoPolicy(t *testing.T) { // >(for any reason), and there is no valid (non-expired) previously // >cached policy, senders MUST continue with delivery as though the // >domain has not implemented MTA-STS. - _, err := c.Get("example.org") + _, err := c.Get(context.Background(), "example.org") if err != ErrIgnorePolicy { t.Fatalf("policy get: %v", err) } @@ -308,7 +309,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) { } defer os.RemoveAll(c.Location) - policy, err := c.Get("example.org") + policy, err := c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -325,7 +326,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) { } c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) - policy, err = c.Get("example.org") + policy, err = c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -357,7 +358,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) { } defer os.RemoveAll(c.Location) - policy, err := c.Get("example.org") + policy, err := c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -376,7 +377,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) { } c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) - policy, err = c.Get("example.org") + policy, err = c.Get(context.Background(), "example.org") if err == nil { t.Fatalf("expected error, got policy %v", policy) } @@ -405,7 +406,7 @@ func TestCacheRefresh(t *testing.T) { } defer os.RemoveAll(c.Location) - policy, err := c.Get("example.org") + policy, err := c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -427,7 +428,7 @@ func TestCacheRefresh(t *testing.T) { c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken")) // It should return the new record from cache. - policy, err = c.Get("example.org") + policy, err = c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -459,7 +460,7 @@ func TestCacheRefresh_Error(t *testing.T) { } defer os.RemoveAll(c.Location) - policy, err := c.Get("example.org") + policy, err := c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } @@ -477,7 +478,7 @@ func TestCacheRefresh_Error(t *testing.T) { } // It should return the old record from cache. - policy, err = c.Get("example.org") + policy, err = c.Get(context.Background(), "example.org") if err != nil { t.Fatalf("policy get: %v", err) } diff --git a/internal/smtpconn/smtpconn.go b/internal/smtpconn/smtpconn.go index 9f139bb..d8f13ea 100644 --- a/internal/smtpconn/smtpconn.go +++ b/internal/smtpconn/smtpconn.go @@ -10,9 +10,11 @@ package smtpconn import ( + "context" "crypto/tls" "io" "net" + "runtime/trace" "github.com/emersion/go-message/textproto" "github.com/emersion/go-smtp" @@ -27,9 +29,9 @@ import ( // // Currently, the C object represents one session and cannot be reused. type C struct { - // Dialer to use to estabilish new network connections. Set to net.Dial by - // New. - Dialer func(network, addr string) (net.Conn, error) + // Dialer to use to estabilish new network connections. Set to net.Dialer + // DialContext by New. + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) // Fail if the connection cannot use TLS. RequireTLS bool @@ -60,7 +62,7 @@ type C struct { // with resonable default values. func New() *C { return &C{ - Dialer: net.Dial, + Dialer: (&net.Dialer{}).DialContext, TLSConfig: &tls.Config{}, Hostname: "localhost.localdomain", } @@ -122,9 +124,11 @@ func (c *C) wrapClientErr(err error, serverName string) error { } // Connect actually estabilishes the network connection with the remote host. -func (c *C) Connect(endp config.Endpoint) error { +func (c *C) Connect(ctx context.Context, endp config.Endpoint) error { + defer trace.StartRegion(ctx, "smtpconn/Connect+TLS").End() + // TODO: Helper function to try multiple endpoints? - cl, err := c.attemptConnect(endp, c.AttemptTLS) + cl, err := c.attemptConnect(ctx, endp, c.AttemptTLS) if err != nil { return c.wrapClientErr(err, endp.Host) } @@ -134,9 +138,9 @@ func (c *C) Connect(endp config.Endpoint) error { return nil } -func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) { +func (c *C) attemptConnect(ctx context.Context, endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) { var conn net.Conn - conn, err := c.Dialer(endp.Network(), endp.Address()) + conn, err := c.Dialer(ctx, endp.Network(), endp.Address()) if err != nil { return nil, err } @@ -197,7 +201,7 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client, // Re-attempt without STARTTLS. It is not possible to reuse connection // since it is probably in a bad state. c.Log.Error("TLS error, falling back to plain-text connection", err, "remote_server", endp.Host+endp.Port) - return c.attemptConnect(endp, false) + return c.attemptConnect(ctx, endp, false) } return cl, nil @@ -209,7 +213,9 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client, // SMTPUTF8 is forwarded if supported by the remote server, if it is not // supported - attempt will be done to convert addresses to the ASCII form, if // this is not possible, the corresponding method (Mail or Rcpt) will fail. -func (c *C) Mail(from string, opts smtp.MailOptions) error { +func (c *C) Mail(ctx context.Context, from string, opts smtp.MailOptions) error { + defer trace.StartRegion(ctx, "smtpconn/MAIL FROM").End() + outOpts := smtp.MailOptions{ // Future extensions may add additional fields that should not be // copied blindly. So we copy only fields we know should be handled @@ -268,7 +274,9 @@ func (c *C) Client() *smtp.Client { // // If the address is non-ASCII and cannot be converted to ASCII and the remote // server does not support SMTPUTF8, error will be returned. -func (c *C) Rcpt(to string) error { +func (c *C) Rcpt(ctx context.Context, to string) error { + defer trace.StartRegion(ctx, "smtpconn/RCPT TO").End() + // If necessary, the extension flag is enabled in Start. if ok, _ := c.cl.Extension("SMTPUTF8"); !address.IsASCII(to) && !ok { var err error @@ -300,7 +308,9 @@ func (c *C) Rcpt(to string) error { // // If the Data command fails, the connection may be in a unclean state (e.g. in // the middle of message data stream). It is not safe to continue using it. -func (c *C) Data(hdr textproto.Header, body io.Reader) error { +func (c *C) Data(ctx context.Context, hdr textproto.Header, body io.Reader) error { + defer trace.StartRegion(ctx, "smtpconn/DATA").End() + wc, err := c.cl.Data() if err != nil { return c.wrapClientErr(err, c.serverName) diff --git a/internal/smtpconn/smtputf8_test.go b/internal/smtpconn/smtputf8_test.go index 6e3adfc..9acf6b2 100644 --- a/internal/smtpconn/smtputf8_test.go +++ b/internal/smtpconn/smtputf8_test.go @@ -1,6 +1,7 @@ package smtpconn import ( + "context" "strings" "testing" @@ -14,11 +15,11 @@ import ( func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.MailOptions) error { t.Helper() - if err := conn.Mail(from, opts); err != nil { + if err := conn.Mail(context.Background(), from, opts); err != nil { return err } for _, rcpt := range to { - if err := conn.Rcpt(rcpt); err != nil { + if err := conn.Rcpt(context.Background(), rcpt); err != nil { return err } } @@ -26,7 +27,7 @@ func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.M hdr := textproto.Header{} hdr.Add("B", "2") hdr.Add("A", "1") - if err := conn.Data(hdr, strings.NewReader("foobar\n")); err != nil { + if err := conn.Data(context.Background(), hdr, strings.NewReader("foobar\n")); err != nil { return err } @@ -41,7 +42,7 @@ func TestSMTPUTF8_Sender_UTF8_Punycode(t *testing.T) { c := New() c.Log = testutils.Logger(t, "smtp_downstream") - if err := c.Connect(config.Endpoint{ + if err := c.Connect(context.Background(), config.Endpoint{ Scheme: "tcp", Host: "127.0.0.1", Port: testPort, @@ -70,7 +71,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Punycode(t *testing.T) { c := New() c.Log = testutils.Logger(t, "smtp_downstream") - if err := c.Connect(config.Endpoint{ + if err := c.Connect(context.Background(), config.Endpoint{ Scheme: "tcp", Host: "127.0.0.1", Port: testPort, @@ -99,7 +100,7 @@ func TestSMTPUTF8_Sender_UTF8_Reject(t *testing.T) { c := New() c.Log = testutils.Logger(t, "smtp_downstream") - if err := c.Connect(config.Endpoint{ + if err := c.Connect(context.Background(), config.Endpoint{ Scheme: "tcp", Host: "127.0.0.1", Port: testPort, @@ -122,7 +123,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Reject(t *testing.T) { c := New() c.Log = testutils.Logger(t, "smtp_downstream") - if err := c.Connect(config.Endpoint{ + if err := c.Connect(context.Background(), config.Endpoint{ Scheme: "tcp", Host: "127.0.0.1", Port: testPort, @@ -144,7 +145,7 @@ func TestSMTPUTF8_Sender_UTF8_Domain(t *testing.T) { c := New() c.Log = testutils.Logger(t, "smtp_downstream") - if err := c.Connect(config.Endpoint{ + if err := c.Connect(context.Background(), config.Endpoint{ Scheme: "tcp", Host: "127.0.0.1", Port: testPort, @@ -172,7 +173,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Domain(t *testing.T) { c := New() c.Log = testutils.Logger(t, "smtp_downstream") - if err := c.Connect(config.Endpoint{ + if err := c.Connect(context.Background(), config.Endpoint{ Scheme: "tcp", Host: "127.0.0.1", Port: testPort, @@ -201,7 +202,7 @@ func TestSMTPUTF8_Sender_UTF8_Username(t *testing.T) { c := New() c.Log = testutils.Logger(t, "smtp_downstream") - if err := c.Connect(config.Endpoint{ + if err := c.Connect(context.Background(), config.Endpoint{ Scheme: "tcp", Host: "127.0.0.1", Port: testPort, @@ -230,7 +231,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Username(t *testing.T) { c := New() c.Log = testutils.Logger(t, "smtp_downstream") - if err := c.Connect(config.Endpoint{ + if err := c.Connect(context.Background(), config.Endpoint{ Scheme: "tcp", Host: "127.0.0.1", Port: testPort, diff --git a/internal/storage/sql/sql.go b/internal/storage/sql/sql.go index bad3e26..7d7b71b 100644 --- a/internal/storage/sql/sql.go +++ b/internal/storage/sql/sql.go @@ -15,6 +15,7 @@ import ( "fmt" "os" "path/filepath" + "runtime/trace" "strconv" "strings" @@ -63,7 +64,13 @@ type delivery struct { addedRcpts map[string]struct{} } +func (d *delivery) String() string { + return d.store.Name() + ":" + d.store.InstanceName() +} + func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error { + defer trace.StartRegion(ctx, "sql/AddRcpt").End() + accountName, err := prepareUsername(rcptTo) if err != nil { return &exterrors.SMTPError{ @@ -104,6 +111,8 @@ func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error { } func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error { + defer trace.StartRegion(ctx, "sql/Body").End() + if d.msgMeta.Quarantine { if err := d.d.SpecialMailbox(specialuse.Junk, d.store.junkMbox); err != nil { return err @@ -116,14 +125,20 @@ func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffe } func (d *delivery) Abort(ctx context.Context) error { + defer trace.StartRegion(ctx, "sql/Abort").End() + return d.d.Abort() } func (d *delivery) Commit(ctx context.Context) error { + defer trace.StartRegion(ctx, "sql/Commit").End() + return d.d.Commit() } func (store *Storage) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) { + defer trace.StartRegion(ctx, "sql/Start").End() + return &delivery{ store: store, msgMeta: msgMeta, @@ -360,6 +375,9 @@ func prepareUsername(username string) (string, error) { } func (store *Storage) CheckPlain(username, password string) bool { + // TODO: Pass session context there. + defer trace.StartRegion(context.Background(), "sql/CheckPlain").End() + accountName, err := prepareUsername(username) if err != nil { return false diff --git a/internal/target/queue/queue.go b/internal/target/queue/queue.go index 1bf3f83..fbf1389 100644 --- a/internal/target/queue/queue.go +++ b/internal/target/queue/queue.go @@ -50,6 +50,7 @@ import ( "os" "path/filepath" "runtime/debug" + "runtime/trace" "strconv" "strings" "sync" @@ -438,10 +439,12 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe msgMeta.ID = msgMeta.ID + "-" + strconv.Itoa(meta.TriesCount+1) dl.Debugf("using message ID = %s", msgMeta.ID) - // TODO: Trace annotations? - msgCtx := context.Background() + msgCtx, msgTask := trace.NewTask(context.Background(), "Queue delivery") + defer msgTask.End() - delivery, err := q.Target.Start(msgCtx, msgMeta, meta.From) + mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM") + delivery, err := q.Target.Start(mailCtx, msgMeta, meta.From) + mailTask.End() if err != nil { dl.Debugf("target.Start failed: %v", err) perr.Failed = append(perr.Failed, meta.To...) @@ -454,7 +457,8 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe var acceptedRcpts []string for _, rcpt := range meta.To { - if err := delivery.AddRcpt(msgCtx, rcpt); err != nil { + rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO") + if err := delivery.AddRcpt(rcptCtx, rcpt); err != nil { if exterrors.IsTemporaryOrUnspec(err) { perr.TemporaryFailed = append(perr.TemporaryFailed, rcpt) } else { @@ -466,6 +470,7 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe dl.Debugf("delivery.AddRcpt %s OK", rcpt) acceptedRcpts = append(acceptedRcpts, rcpt) } + rcptTask.End() } if len(acceptedRcpts) == 0 { @@ -487,12 +492,15 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe } } + bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA") + defer bodyTask.End() + partDelivery, ok := delivery.(module.PartialDelivery) if ok { dl.Debugf("using delivery.BodyNonAtomic") - partDelivery.BodyNonAtomic(msgCtx, &perr, header, body) + partDelivery.BodyNonAtomic(bodyCtx, &perr, header, body) } else { - if err := delivery.Body(msgCtx, header, body); err != nil { + if err := delivery.Body(bodyCtx, header, body); err != nil { dl.Debugf("delivery.Body failed: %v", err) expandToPartialErr(err) } @@ -508,13 +516,13 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe if allFailed { // No recipients succeeded. dl.Debugf("delivery.Abort (all recipients failed)") - if err := delivery.Abort(msgCtx); err != nil { + if err := delivery.Abort(bodyCtx); err != nil { dl.Msg("delivery.Abort failed", err) } return perr } - if err := delivery.Commit(msgCtx); err != nil { + if err := delivery.Commit(bodyCtx); err != nil { dl.Debugf("delivery.Commit failed: %v", err) expandToPartialErr(err) } @@ -537,6 +545,8 @@ func (qd *queueDelivery) AddRcpt(ctx context.Context, rcptTo string) error { } func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error { + defer trace.StartRegion(ctx, "queue/Body").End() + // Body buffer initially passed to us may not be valid after "delivery" to queue completes. // storeNewMessage returns a new buffer object created from message blob stored on disk. storedBody, err := qd.q.storeNewMessage(qd.meta, header, body) @@ -550,6 +560,8 @@ func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body } func (qd *queueDelivery) Abort(ctx context.Context) error { + defer trace.StartRegion(ctx, "queue/Abort").End() + if qd.body != nil { qd.q.removeFromDisk(qd.meta.MsgMeta) } @@ -557,6 +569,8 @@ func (qd *queueDelivery) Abort(ctx context.Context) error { } func (qd *queueDelivery) Commit(ctx context.Context) error { + defer trace.StartRegion(ctx, "queue/Commit").End() + if qd.meta == nil { panic("queue: double Commit") } @@ -899,14 +913,17 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) { } dl.Msg("generated failed DSN", "dsn_id", dsnID) - // TODO: Trace annotations? - msgCtx := context.Background() + msgCtx, msgTask := trace.NewTask(context.Background(), "DSN Delivery") + defer msgTask.End() - dsnDelivery, err := q.dsnPipeline.Start(msgCtx, dsnMeta, "") + mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM") + dsnDelivery, err := q.dsnPipeline.Start(mailCtx, dsnMeta, "") + mailTask.End() if err != nil { dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID) return } + defer func() { if err != nil { dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID) @@ -916,15 +933,23 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) { } }() - if err = dsnDelivery.AddRcpt(msgCtx, meta.From); err != nil { + rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO") + if err = dsnDelivery.AddRcpt(rcptCtx, meta.From); err != nil { + rcptTask.End() return } - if err = dsnDelivery.Body(msgCtx, dsnHeader, dsnBody); err != nil { + rcptTask.End() + + bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA") + if err = dsnDelivery.Body(bodyCtx, dsnHeader, dsnBody); err != nil { + bodyTask.End() return } - if err = dsnDelivery.Commit(msgCtx); err != nil { + if err = dsnDelivery.Commit(bodyCtx); err != nil { + bodyTask.End() return } + bodyTask.End() } func init() { diff --git a/internal/target/remote/mxauth_test.go b/internal/target/remote/mxauth_test.go index 97f8e4a..2ce32d8 100644 --- a/internal/target/remote/mxauth_test.go +++ b/internal/target/remote/mxauth_test.go @@ -1,6 +1,7 @@ package remote import ( + "context" "errors" "net" "strconv" @@ -33,7 +34,7 @@ func TestRemoteDelivery_AuthMX_Fail(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireMXAuth: true, Log: testutils.Logger(t, "remote"), @@ -63,12 +64,12 @@ func TestRemoteDelivery_AuthMX_MTASTS(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireMXAuth: true, tlsConfig: clientCfg, mxAuth: map[string]struct{}{AuthMTASTS: {}}, - mtastsGet: func(domain string) (*mtasts.Policy, error) { + mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) { if domain != "example.invalid" { return nil, errors.New("Wrong domain in lookup") } @@ -114,12 +115,12 @@ func TestRemoteDelivery_AuthMX_PreferAuth(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireMXAuth: true, tlsConfig: clientCfg, mxAuth: map[string]struct{}{AuthMTASTS: {}}, - mtastsGet: func(domain string) (*mtasts.Policy, error) { + mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) { if domain != "example.invalid" { return nil, errors.New("Wrong domain in lookup") } @@ -166,12 +167,12 @@ func TestRemoteDelivery_MTASTS_SkipNonMatching(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireMXAuth: true, tlsConfig: clientCfg, mxAuth: map[string]struct{}{AuthMTASTS: {}}, - mtastsGet: func(domain string) (*mtasts.Policy, error) { + mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) { if domain != "example.invalid" { return nil, errors.New("Wrong domain in lookup") } @@ -208,11 +209,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_Fail(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireMXAuth: true, mxAuth: map[string]struct{}{AuthMTASTS: {}}, - mtastsGet: func(domain string) (*mtasts.Policy, error) { + mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) { if domain != "example.invalid" { return nil, errors.New("Wrong domain in lookup") } @@ -251,11 +252,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_NoPolicy(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireMXAuth: true, mxAuth: map[string]struct{}{AuthMTASTS: {}}, - mtastsGet: func(domain string) (*mtasts.Policy, error) { + mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) { if domain != "example.invalid" { return nil, errors.New("Wrong domain in lookup") } @@ -290,7 +291,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireMXAuth: true, mxAuth: map[string]struct{}{AuthCommonDomain: {}}, @@ -321,7 +322,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_Fail(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireMXAuth: true, mxAuth: map[string]struct{}{AuthCommonDomain: {}}, @@ -354,7 +355,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_NotETLDp1(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireMXAuth: true, mxAuth: map[string]struct{}{AuthCommonDomain: {}}, @@ -405,7 +406,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: extResolver, requireMXAuth: true, mxAuth: map[string]struct{}{AuthDNSSEC: {}}, @@ -452,7 +453,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC_Fail(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: extResolver, requireMXAuth: true, mxAuth: map[string]struct{}{AuthDNSSEC: {}}, @@ -506,7 +507,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &resolver, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: extResolver, requireMXAuth: true, mxAuth: map[string]struct{}{AuthDNSSEC: {}}, @@ -556,7 +557,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral_Fail(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &resolver, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: extResolver, requireMXAuth: true, mxAuth: map[string]struct{}{AuthDNSSEC: {}}, diff --git a/internal/target/remote/remote.go b/internal/target/remote/remote.go index ebc078d..5bed6ad 100644 --- a/internal/target/remote/remote.go +++ b/internal/target/remote/remote.go @@ -13,6 +13,7 @@ import ( "net" "os" "path/filepath" + "runtime/trace" "sort" "strings" "sync" @@ -57,12 +58,12 @@ type Target struct { mxAuth map[string]struct{} resolver dns.Resolver - dialer func(network, addr string) (net.Conn, error) + dialer func(ctx context.Context, network, addr string) (net.Conn, error) extResolver *dns.ExtResolver // This is the callback that is usually mtastsCache.Get, // but replaced by tests to mock mtasts.Cache. - mtastsGet func(domain string) (*mtasts.Policy, error) + mtastsGet func(ctx context.Context, domain string) (*mtasts.Policy, error) mtastsCache mtasts.Cache stsCacheUpdateTick *time.Ticker @@ -80,7 +81,7 @@ func New(_, instName string, _, inlineArgs []string) (module.Module, error) { return &Target{ name: instName, resolver: dns.DefaultResolver(), - dialer: net.Dial, + dialer: (&net.Dialer{}).DialContext, mtastsCache: mtasts.Cache{Resolver: dns.DefaultResolver()}, Log: log.Logger{Name: "remote"}, @@ -195,6 +196,8 @@ func (rt *Target) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFr } func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error { + defer trace.StartRegion(ctx, "remote/AddRcpt").End() + if rd.msgMeta.Quarantine { return &exterrors.SMTPError{ Code: 550, @@ -232,7 +235,7 @@ func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error { return err } - if err := conn.Rcpt(to); err != nil { + if err := conn.Rcpt(ctx, to); err != nil { return moduleError(err) } @@ -379,6 +382,8 @@ func (m *multipleErrs) SetStatus(rcptTo string, err error) { } func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buffer buffer.Buffer) error { + defer trace.StartRegion(ctx, "remote/Body").End() + merr := multipleErrs{ errs: make(map[string]error), } @@ -396,6 +401,8 @@ func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buf } func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusCollector, header textproto.Header, b buffer.Buffer) { + defer trace.StartRegion(ctx, "remote/BodyNonAtomic").End() + if rd.msgMeta.Quarantine { for _, rcpt := range rd.recipients { c.SetStatus(rcpt, &exterrors.SMTPError{ @@ -425,7 +432,7 @@ func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusColl } defer bodyR.Close() - err = conn.Data(header, bodyR) + err = conn.Data(ctx, header, bodyR) for _, rcpt := range conn.Rcpts() { c.SetStatus(rcpt, err) } @@ -486,7 +493,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string addrs = append(addrs, nonAuthMXs...) rd.Log.DebugMsg("considering", "mxs", addrs) for i, addr := range addrs { - err = conn.Connect(config.Endpoint{ + err = conn.Connect(ctx, config.Endpoint{ Scheme: "tcp", Host: addr, Port: smtpPort, @@ -517,7 +524,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string rd.Log.DebugMsg("connected", "remote_server", conn.ServerName()) - if err := conn.Mail(rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil { + if err := conn.Mail(ctx, rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil { conn.Close() return nil, moduleError(err) } @@ -526,8 +533,8 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string return conn, nil } -func (rt *Target) getSTSPolicy(domain string) (*mtasts.Policy, error) { - stsPolicy, err := rt.mtastsGet(domain) +func (rt *Target) getSTSPolicy(ctx context.Context, domain string) (*mtasts.Policy, error) { + stsPolicy, err := rt.mtastsGet(ctx, domain) if err != nil && !mtasts.IsNoPolicy(err) { return nil, &exterrors.SMTPError{ Code: exterrors.SMTPCode(err, 450, 554), @@ -568,9 +575,11 @@ func (rt *Target) stsCacheUpdater() { } func (rd *remoteDelivery) lookupAndFilter(ctx context.Context, domain string) (authMXs, nonAuthMXs []string, requireTLS bool, err error) { + defer trace.StartRegion(ctx, "remote/LookupAndFilterMX").End() + var policy *mtasts.Policy if _, use := rd.rt.mxAuth[AuthMTASTS]; use { - policy, err = rd.rt.getSTSPolicy(domain) + policy, err = rd.rt.getSTSPolicy(ctx, domain) if err != nil { return nil, nil, false, err } diff --git a/internal/target/remote/remote_test.go b/internal/target/remote/remote_test.go index 976c1a3..1362c68 100644 --- a/internal/target/remote/remote_test.go +++ b/internal/target/remote/remote_test.go @@ -42,7 +42,7 @@ func TestRemoteDelivery(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -74,7 +74,7 @@ func TestRemoteDelivery_IPLiteral(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &resolver, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -100,7 +100,7 @@ func TestRemoteDelivery_FallbackMX(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: resolver, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -127,7 +127,7 @@ func TestRemoteDelivery_BodyNonAtomic(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: resolver, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -162,7 +162,7 @@ func TestRemoteDelivery_Abort(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -199,7 +199,7 @@ func TestRemoteDelivery_CommitWithoutBody(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -243,7 +243,7 @@ func TestRemoteDelivery_MAILFROMErr(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -276,7 +276,7 @@ func TestRemoteDelivery_NoMX(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: resolver, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -313,7 +313,7 @@ func TestRemoteDelivery_NullMX(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -349,7 +349,7 @@ func TestRemoteDelivery_Quarantined(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -404,7 +404,7 @@ func TestRemoteDelivery_MAILFROMErr_Repeated(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -451,7 +451,7 @@ func TestRemoteDelivery_RcptErr(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -509,7 +509,7 @@ func TestRemoteDelivery_DownMX(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -539,7 +539,7 @@ func TestRemoteDelivery_AllMXDown(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -577,7 +577,7 @@ func TestRemoteDelivery_Split(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -623,7 +623,7 @@ func TestRemoteDelivery_Split_Fail(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -683,7 +683,7 @@ func TestRemoteDelivery_BodyErr(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -744,7 +744,7 @@ func TestRemoteDelivery_Split_BodyErr(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -807,7 +807,7 @@ func TestRemoteDelivery_Split_BodyErr_NonAtomic(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, Log: testutils.Logger(t, "remote"), } @@ -867,7 +867,7 @@ func TestRemoteDelivery_TLSErrFallback(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, tlsConfig: &tls.Config{}, Log: testutils.Logger(t, "remote"), @@ -895,7 +895,7 @@ func TestRemoteDelivery_RequireTLS_Missing(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireTLS: true, Log: testutils.Logger(t, "remote"), @@ -925,7 +925,7 @@ func TestRemoteDelivery_RequireTLS_Present(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, requireTLS: true, tlsConfig: clientCfg, @@ -954,7 +954,7 @@ func TestRemoteDelivery_RequireTLS_NoErrFallback(t *testing.T) { name: "remote", hostname: "mx.example.com", resolver: &mockdns.Resolver{Zones: zones}, - dialer: resolver.Dial, + dialer: resolver.DialContext, extResolver: nil, tlsConfig: &tls.Config{}, requireTLS: true, diff --git a/internal/target/remote/smtputf8_test.go b/internal/target/remote/smtputf8_test.go index fa3369a..7a7fe29 100644 --- a/internal/target/remote/smtputf8_test.go +++ b/internal/target/remote/smtputf8_test.go @@ -44,7 +44,7 @@ func TestRemoteDelivery_EHLO_ALabel(t *testing.T) { tgt := mod.(*Target) tgt.resolver = &mockdns.Resolver{Zones: zones} - tgt.dialer = resolver.Dial + tgt.dialer = resolver.DialContext tgt.extResolver = nil tgt.Log = testutils.Logger(t, "remote") diff --git a/internal/target/smtp_downstream/smtp_downstream.go b/internal/target/smtp_downstream/smtp_downstream.go index d192614..67e1293 100644 --- a/internal/target/smtp_downstream/smtp_downstream.go +++ b/internal/target/smtp_downstream/smtp_downstream.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "net" + "runtime/trace" "github.com/emersion/go-message/textproto" "github.com/foxcpp/maddy/internal/buffer" @@ -126,24 +127,26 @@ type delivery struct { } func (u *Downstream) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) { + defer trace.StartRegion(ctx, "smtp_downstream/Start").End() + d := &delivery{ u: u, log: target.DeliveryLogger(u.log, msgMeta), msgMeta: msgMeta, mailFrom: mailFrom, } - if err := d.connect(); err != nil { + if err := d.connect(ctx); err != nil { return nil, err } - if err := d.conn.Mail(mailFrom, msgMeta.SMTPOpts); err != nil { + if err := d.conn.Mail(ctx, mailFrom, msgMeta.SMTPOpts); err != nil { d.conn.Close() return nil, err } return d, nil } -func (d *delivery) connect() error { +func (d *delivery) connect(ctx context.Context) error { // TODO: Review possibility of connection pooling here. var lastErr error @@ -156,7 +159,7 @@ func (d *delivery) connect() error { conn.AddrInSMTPMsg = false for _, endp := range d.u.endpoints { - err := conn.Connect(endp) + err := conn.Connect(ctx, endp) if err == nil { d.log.DebugMsg("connected", "downstream_server", conn.ServerName()) lastErr = nil @@ -191,7 +194,7 @@ func (d *delivery) connect() error { } func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error { - return moduleError(d.conn.Rcpt(rcptTo)) + return moduleError(d.conn.Rcpt(ctx, rcptTo)) } func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error { @@ -217,7 +220,7 @@ func (d *delivery) Commit(ctx context.Context) error { defer d.conn.Close() defer d.body.Close() - return moduleError(d.conn.Data(d.hdr, d.body)) + return moduleError(d.conn.Data(ctx, d.hdr, d.body)) } func init() {