Instrument the SMTP code using runtime/trace

runtime/trace together with 'go tool trace' provides extremely powerful
tooling for performance (latency) analysis. Since maddy prides itself on
being "optimized for concurrency", it is a good idea to actually live up
to this promise.

Closes #144. No need to reinvent the wheel. The original issue
proposed a solution to use in production to detect "performance
anomalies", it is possible to use runtime/trace in production too, but
the corresponding flag to enable profiler endpoint is hidden behind the
'debugflags' build tag at the moment.

For SMTP code, the basic latency information can be obtained from
regular logs since they include timestamps with millisecond granularity.
After the issue is apparent, it is possible to deploy the server
executable compiled with tracing support and obtain more information

... Also add missing context.Context arguments to smtpconn.C.
This commit is contained in:
fox.cpp 2019-12-09 23:11:39 +03:00
parent 305fdddf24
commit c4ea9a730f
No known key found for this signature in database
GPG key ID: E76D97CCEDE90B6C
20 changed files with 281 additions and 135 deletions

View file

@ -11,6 +11,7 @@ import (
"os"
"os/exec"
"regexp"
"runtime/trace"
"strconv"
"strings"
@ -300,6 +301,9 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
return module.CheckResult{}
}
// TODO: It is not possible to distinguish different commands.
defer trace.StartRegion(ctx, "command/CheckConnection").End()
cmdName, cmdArgs := s.expandCommand("")
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
}
@ -311,6 +315,8 @@ func (s *state) CheckSender(ctx context.Context, addr string) module.CheckResult
return module.CheckResult{}
}
defer trace.StartRegion(ctx, "command/CheckSender").End()
cmdName, cmdArgs := s.expandCommand(addr)
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
}
@ -321,6 +327,7 @@ func (s *state) CheckRcpt(ctx context.Context, addr string) module.CheckResult {
if s.c.stage != StageRcpt {
return module.CheckResult{}
}
defer trace.StartRegion(ctx, "command/CheckRcpt").End()
cmdName, cmdArgs := s.expandCommand(addr)
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
@ -331,6 +338,8 @@ func (s *state) CheckBody(ctx context.Context, hdr textproto.Header, body buffer
return module.CheckResult{}
}
defer trace.StartRegion(ctx, "command/CheckBody").End()
cmdName, cmdArgs := s.expandCommand("")
var buf bytes.Buffer

View file

@ -6,6 +6,7 @@ import (
"errors"
"io"
nettextproto "net/textproto"
"runtime/trace"
"strings"
"github.com/emersion/go-message/textproto"
@ -96,6 +97,8 @@ func (d *dkimCheckState) CheckRcpt(ctx context.Context, rcptTo string) module.Ch
}
func (d *dkimCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
defer trace.StartRegion(ctx, "verify_dkim/CheckBody").End()
if !header.Has("DKIM-Signature") {
if d.c.noSigAction.Reject || d.c.noSigAction.Quarantine {
d.log.Printf("no signatures present")

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"net"
"runtime/trace"
"strings"
"github.com/emersion/go-message/textproto"
@ -323,6 +324,8 @@ func (bl *DNSBL) CheckConnection(ctx context.Context, state *smtp.ConnectionStat
return nil
}
defer trace.StartRegion(ctx, "dnsbl/CheckConnection (Early)").End()
ip, ok := state.RemoteAddr.(*net.TCPAddr)
if !ok {
bl.log.Msg("non-TCP/IP source",
@ -358,6 +361,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
return module.CheckResult{}
}
defer trace.StartRegion(ctx, "dnsbl/CheckConnection").End()
if s.msgMeta.Conn == nil {
s.log.Msg("locally generated message, ignoring")
return module.CheckResult{}

View file

@ -6,6 +6,7 @@ import (
"fmt"
"net"
"runtime/debug"
"runtime/trace"
"blitiri.com.ar/go/spf"
"github.com/emersion/go-message/textproto"
@ -267,6 +268,8 @@ func prepareMailFrom(from string) (string, error) {
}
func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
defer trace.StartRegion(ctx, "apply_spf/CheckConnection").End()
if s.msgMeta.Conn == nil {
s.log.Println("locally generated message, skipping")
return module.CheckResult{}
@ -306,6 +309,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
}
}()
defer trace.StartRegion(ctx, "apply_spf/CheckConnection (Async)").End()
res, err := spf.CheckHostWithSender(ip.IP, s.msgMeta.Conn.Hostname, mailFrom)
s.log.Debugf("result: %s (%v)", res, err)
s.spfFetch <- spfRes{res, err}
@ -328,6 +333,8 @@ func (s *state) CheckBody(ctx context.Context, header textproto.Header, body buf
return module.CheckResult{}
}
defer trace.StartRegion(ctx, "apply_spf/CheckBody").End()
res, ok := <-s.spfFetch
if !ok {
return module.CheckResult{

View file

@ -3,6 +3,7 @@ package check
import (
"context"
"fmt"
"runtime/trace"
"github.com/emersion/go-message/textproto"
"github.com/foxcpp/maddy/internal/buffer"
@ -57,12 +58,18 @@ type statelessCheckState struct {
msgMeta *module.MsgMetadata
}
func (s *statelessCheckState) String() string {
return s.c.modName + ":" + s.c.instName
}
func (s *statelessCheckState) CheckConnection(ctx context.Context) module.CheckResult {
if s.c.connCheck == nil {
return module.CheckResult{}
}
defer trace.StartRegion(ctx, s.c.modName+"/CheckConnection").End()
originalRes := s.c.connCheck(StatelessCheckContext{
Context: ctx,
Resolver: s.c.resolver,
MsgMeta: s.msgMeta,
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
@ -74,8 +81,10 @@ func (s *statelessCheckState) CheckSender(ctx context.Context, mailFrom string)
if s.c.senderCheck == nil {
return module.CheckResult{}
}
defer trace.StartRegion(ctx, s.c.modName+"/CheckSender").End()
originalRes := s.c.senderCheck(StatelessCheckContext{
Context: ctx,
Resolver: s.c.resolver,
MsgMeta: s.msgMeta,
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
@ -87,8 +96,10 @@ func (s *statelessCheckState) CheckRcpt(ctx context.Context, rcptTo string) modu
if s.c.rcptCheck == nil {
return module.CheckResult{}
}
defer trace.StartRegion(ctx, s.c.modName+"/CheckRcpt").End()
originalRes := s.c.rcptCheck(StatelessCheckContext{
Context: ctx,
Resolver: s.c.resolver,
MsgMeta: s.msgMeta,
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
@ -100,8 +111,10 @@ func (s *statelessCheckState) CheckBody(ctx context.Context, header textproto.He
if s.c.bodyCheck == nil {
return module.CheckResult{}
}
defer trace.StartRegion(ctx, s.c.modName+"/CheckBody").End()
originalRes := s.c.bodyCheck(StatelessCheckContext{
Context: ctx,
Resolver: s.c.resolver,
MsgMeta: s.msgMeta,
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),

View file

@ -4,6 +4,7 @@ import (
"context"
"math/rand"
"net"
"runtime/trace"
"strings"
"github.com/emersion/go-message/textproto"
@ -81,6 +82,8 @@ func (v *Verifier) FetchRecord(ctx context.Context, header textproto.Header) {
}
}()
defer trace.StartRegion(ctx, "DMARC/FetchRecord").End()
policyDomain, record, err := FetchRecord(ctx, v.resolver, fromDomain)
v.fetchCh <- verifyData{
policyDomain: policyDomain,

View file

@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net"
"runtime/trace"
"strings"
"sync"
"time"
@ -46,6 +47,7 @@ type Session struct {
// msgCtx is not used for cancellation or timeouts, only for tracing.
// It is the subcontext of sessionCtx.
msgCtx context.Context
msgTask *trace.Task
mailFrom string
opts smtp.MailOptions
msgMeta *module.MsgMetadata
@ -75,6 +77,7 @@ func (s *Session) abort(ctx context.Context) {
s.delivery = nil
s.deliveryErr = nil
s.msgCtx = nil
s.msgTask.End()
}
func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) {
@ -138,15 +141,20 @@ func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.Mail
)
}
delivery, err := s.endp.pipeline.Start(ctx, msgMeta, cleanFrom)
s.msgCtx, s.msgTask = trace.NewTask(ctx, "Incoming Message")
mailCtx, mailTask := trace.NewTask(s.msgCtx, "MAIL FROM")
defer mailTask.End()
delivery, err := s.endp.pipeline.Start(mailCtx, msgMeta, cleanFrom)
if err != nil {
s.msgCtx = nil
s.msgTask.End()
return msgMeta.ID, err
}
s.msgMeta = msgMeta
s.mailFrom = cleanFrom
s.delivery = delivery
s.msgCtx = s.sessionCtx // TODO: Trace annotations.
return msgMeta.ID, nil
}
@ -171,6 +179,8 @@ func (s *Session) Mail(from string, opts smtp.MailOptions) error {
}
func (s *Session) fetchRDNSName(ctx context.Context) {
defer trace.StartRegion(ctx, "rDNS fetch").End()
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
if !ok {
s.connState.RDNSName.Set(nil)
@ -211,7 +221,8 @@ func (s *Session) Rcpt(to string) error {
}
}
rcptCtx := s.msgCtx
rcptCtx, rcptTask := trace.NewTask(s.msgCtx, "RCPT TO")
defer rcptTask.End()
if err := s.rcpt(rcptCtx, to); err != nil {
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
@ -293,7 +304,8 @@ func (s *Session) prepareBody(ctx context.Context, r io.Reader) (textproto.Heade
}
func (s *Session) Data(r io.Reader) error {
bodyCtx := s.msgCtx
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
defer bodyTask.End()
wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
@ -317,6 +329,9 @@ func (s *Session) Data(r io.Reader) error {
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil
s.msgCtx = nil
s.msgTask.End()
s.msgTask = nil
s.endp.semaphore.Release()
return nil
@ -332,7 +347,8 @@ func (sw statusWrapper) SetStatus(rcpt string, err error) {
}
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
bodyCtx := s.msgCtx
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
defer bodyTask.End()
wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
@ -356,6 +372,9 @@ func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil
s.msgCtx = nil
s.msgTask.End()
s.msgTask = nil
s.endp.semaphore.Release()
return nil
@ -671,7 +690,7 @@ func (endp *Endpoint) newSession(anonymous bool, username, password string, stat
}
if endp.resolver != nil {
rdnsCtx, cancelRDNS := context.WithCancel(context.TODO())
rdnsCtx, cancelRDNS := context.WithCancel(s.sessionCtx)
s.connState.RDNSName = future.New()
s.cancelRDNS = cancelRDNS
go s.fetchRDNSName(rdnsCtx)

View file

@ -7,6 +7,7 @@ import (
"io"
"net/mail"
"path/filepath"
"runtime/trace"
"strings"
"time"
@ -337,6 +338,8 @@ func (s state) RewriteRcpt(ctx context.Context, rcptTo string) (string, error) {
}
func (s state) RewriteBody(ctx context.Context, h *textproto.Header, body buffer.Buffer) error {
defer trace.StartRegion(ctx, "sign_dkim/RewriteBody").End()
var authUser string
if s.meta.Conn != nil {
authUser = s.meta.Conn.AuthUser

View file

@ -11,10 +11,7 @@ import (
func objectName(x interface{}) string {
mod, ok := x.(module.Module)
if ok {
if mod.InstanceName() == "" {
return mod.Name()
}
return mod.InstanceName()
return mod.Name() + ":" + mod.InstanceName()
}
_, pipeline := x.(*MsgPipeline)
@ -22,5 +19,10 @@ func objectName(x interface{}) string {
return "reroute"
}
str, ok := x.(fmt.Stringer)
if ok {
return str.String()
}
return fmt.Sprintf("%T", x)
}

View file

@ -10,6 +10,7 @@ import (
"net/http"
"os"
"path/filepath"
"runtime/trace"
"time"
"github.com/foxcpp/maddy/internal/exterrors"
@ -23,10 +24,14 @@ var httpClient = &http.Client{
Timeout: time.Minute,
}
func downloadPolicy(domain string) (*Policy, error) {
func downloadPolicy(ctx context.Context, domain string) (*Policy, error) {
// TODO: Consult OCSP/CRL to detect revoked certificates?
resp, err := httpClient.Get("https://mta-sts." + domain + "/.well-known/mta-sts.txt")
req, err := http.NewRequestWithContext(ctx, "GET", "https://mta-sts."+domain+"/.well-known/mta-sts.txt", nil)
if err != nil {
return nil, err
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
@ -80,8 +85,8 @@ var ErrIgnorePolicy = errors.New("mtasts: policy ignored due to errors")
// Get reads policy from cache or tries to fetch it from Policy Host.
//
// The domain is assumed to be normalized, as done by dns.ForLookup.
func (c *Cache) Get(domain string) (*Policy, error) {
_, p, err := c.fetch(false, time.Now(), domain)
func (c *Cache) Get(ctx context.Context, domain string) (*Policy, error) {
_, p, err := c.fetch(ctx, false, time.Now(), domain)
return p, err
}
@ -118,6 +123,9 @@ func (c *Cache) load(domain string) (id string, fetchTime time.Time, p *Policy,
}
func (c *Cache) Refresh() error {
refreshCtx, refreshTask := trace.NewTask(context.Background(), "mtasts.Cache/Refresh")
defer refreshTask.End()
dir, err := ioutil.ReadDir(c.Location)
if err != nil {
return err
@ -132,7 +140,7 @@ func (c *Cache) Refresh() error {
// Since otherwise we are going to have expired policy for another 6 hours,
// which makes it useless.
// See https://tools.ietf.org/html/rfc8461#section-10.2.
cacheHit, _, err := c.fetch(true, time.Now().Add(6*time.Hour), ent.Name())
cacheHit, _, err := c.fetch(refreshCtx, false, time.Now().Add(6*time.Hour), ent.Name())
if err != nil && err != ErrIgnorePolicy {
c.Logger.Error("policy update error", err, "domain", ent.Name())
}
@ -147,7 +155,9 @@ func (c *Cache) Refresh() error {
return nil
}
func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) {
func (c *Cache) fetch(ctx context.Context, ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) {
defer trace.StartRegion(ctx, "mtasts.Cache/fetch").End()
validCache := true
cachedId, fetchTime, cachedPolicy, err := c.load(domain)
if err != nil {
@ -163,7 +173,7 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo
var dnsId string
if !ignoreDns {
records, err := c.Resolver.LookupTXT(context.Background(), "_mta-sts."+domain)
records, err := c.Resolver.LookupTXT(ctx, "_mta-sts."+domain)
if err != nil {
if validCache {
if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsNotFound {
@ -218,11 +228,15 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo
}
if !validCache || dnsId != cachedId {
download := downloadPolicy
var (
policy *Policy
err error
)
if c.downloadPolicy != nil {
download = c.downloadPolicy
policy, err = c.downloadPolicy(domain)
} else {
policy, err = downloadPolicy(ctx, domain)
}
policy, err := download(domain)
if err != nil {
if validCache {
c.Logger.Error("failed to fetch new policy, using cache", err, "domain", domain)

View file

@ -1,6 +1,7 @@
package mtasts
import (
"context"
"errors"
"os"
"reflect"
@ -37,7 +38,7 @@ func TestCacheGet(t *testing.T) {
}
defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org")
policy, err := c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -57,7 +58,7 @@ func TestCacheGet_Error_DNS(t *testing.T) {
}
defer os.RemoveAll(c.Location)
_, err := c.Get("example.org")
_, err := c.Get(context.Background(), "example.org")
if err != ErrIgnorePolicy {
t.Fatalf("policy get: %v", err)
}
@ -78,7 +79,7 @@ func TestCacheGet_Error_HTTPS(t *testing.T) {
}
defer os.RemoveAll(c.Location)
_, err := c.Get("example.org")
_, err := c.Get(context.Background(), "example.org")
if err != ErrIgnorePolicy {
t.Fatalf("policy get: %v", err)
}
@ -104,7 +105,7 @@ func TestCacheGet_Cached(t *testing.T) {
}
defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org")
policy, err := c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -115,7 +116,7 @@ func TestCacheGet_Cached(t *testing.T) {
// calling downloadPolicy.
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
policy, err = c.Get("example.org")
policy, err = c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -146,7 +147,7 @@ func TestCacheGet_Expired(t *testing.T) {
}
defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org")
policy, err := c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -160,7 +161,7 @@ func TestCacheGet_Expired(t *testing.T) {
expectedPolicy.MX = []string{"b"}
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
policy, err = c.Get("example.org")
policy, err = c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -190,7 +191,7 @@ func TestCacheGet_IDChange(t *testing.T) {
}
defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org")
policy, err := c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -206,7 +207,7 @@ func TestCacheGet_IDChange(t *testing.T) {
expectedPolicy.MX = []string{"b"}
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
policy, err = c.Get("example.org")
policy, err = c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -236,7 +237,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) {
}
defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org")
policy, err := c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -251,7 +252,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) {
resolver.Zones = nil
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
policy, err = c.Get("example.org")
policy, err = c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -281,7 +282,7 @@ func TestCacheGet_HTTPGet_ErrNoPolicy(t *testing.T) {
// >(for any reason), and there is no valid (non-expired) previously
// >cached policy, senders MUST continue with delivery as though the
// >domain has not implemented MTA-STS.
_, err := c.Get("example.org")
_, err := c.Get(context.Background(), "example.org")
if err != ErrIgnorePolicy {
t.Fatalf("policy get: %v", err)
}
@ -308,7 +309,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) {
}
defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org")
policy, err := c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -325,7 +326,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) {
}
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
policy, err = c.Get("example.org")
policy, err = c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -357,7 +358,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) {
}
defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org")
policy, err := c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -376,7 +377,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) {
}
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
policy, err = c.Get("example.org")
policy, err = c.Get(context.Background(), "example.org")
if err == nil {
t.Fatalf("expected error, got policy %v", policy)
}
@ -405,7 +406,7 @@ func TestCacheRefresh(t *testing.T) {
}
defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org")
policy, err := c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -427,7 +428,7 @@ func TestCacheRefresh(t *testing.T) {
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
// It should return the new record from cache.
policy, err = c.Get("example.org")
policy, err = c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -459,7 +460,7 @@ func TestCacheRefresh_Error(t *testing.T) {
}
defer os.RemoveAll(c.Location)
policy, err := c.Get("example.org")
policy, err := c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}
@ -477,7 +478,7 @@ func TestCacheRefresh_Error(t *testing.T) {
}
// It should return the old record from cache.
policy, err = c.Get("example.org")
policy, err = c.Get(context.Background(), "example.org")
if err != nil {
t.Fatalf("policy get: %v", err)
}

View file

@ -10,9 +10,11 @@
package smtpconn
import (
"context"
"crypto/tls"
"io"
"net"
"runtime/trace"
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-smtp"
@ -27,9 +29,9 @@ import (
//
// Currently, the C object represents one session and cannot be reused.
type C struct {
// Dialer to use to estabilish new network connections. Set to net.Dial by
// New.
Dialer func(network, addr string) (net.Conn, error)
// Dialer to use to estabilish new network connections. Set to net.Dialer
// DialContext by New.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Fail if the connection cannot use TLS.
RequireTLS bool
@ -60,7 +62,7 @@ type C struct {
// with resonable default values.
func New() *C {
return &C{
Dialer: net.Dial,
Dialer: (&net.Dialer{}).DialContext,
TLSConfig: &tls.Config{},
Hostname: "localhost.localdomain",
}
@ -122,9 +124,11 @@ func (c *C) wrapClientErr(err error, serverName string) error {
}
// Connect actually estabilishes the network connection with the remote host.
func (c *C) Connect(endp config.Endpoint) error {
func (c *C) Connect(ctx context.Context, endp config.Endpoint) error {
defer trace.StartRegion(ctx, "smtpconn/Connect+TLS").End()
// TODO: Helper function to try multiple endpoints?
cl, err := c.attemptConnect(endp, c.AttemptTLS)
cl, err := c.attemptConnect(ctx, endp, c.AttemptTLS)
if err != nil {
return c.wrapClientErr(err, endp.Host)
}
@ -134,9 +138,9 @@ func (c *C) Connect(endp config.Endpoint) error {
return nil
}
func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) {
func (c *C) attemptConnect(ctx context.Context, endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) {
var conn net.Conn
conn, err := c.Dialer(endp.Network(), endp.Address())
conn, err := c.Dialer(ctx, endp.Network(), endp.Address())
if err != nil {
return nil, err
}
@ -197,7 +201,7 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client,
// Re-attempt without STARTTLS. It is not possible to reuse connection
// since it is probably in a bad state.
c.Log.Error("TLS error, falling back to plain-text connection", err, "remote_server", endp.Host+endp.Port)
return c.attemptConnect(endp, false)
return c.attemptConnect(ctx, endp, false)
}
return cl, nil
@ -209,7 +213,9 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client,
// SMTPUTF8 is forwarded if supported by the remote server, if it is not
// supported - attempt will be done to convert addresses to the ASCII form, if
// this is not possible, the corresponding method (Mail or Rcpt) will fail.
func (c *C) Mail(from string, opts smtp.MailOptions) error {
func (c *C) Mail(ctx context.Context, from string, opts smtp.MailOptions) error {
defer trace.StartRegion(ctx, "smtpconn/MAIL FROM").End()
outOpts := smtp.MailOptions{
// Future extensions may add additional fields that should not be
// copied blindly. So we copy only fields we know should be handled
@ -268,7 +274,9 @@ func (c *C) Client() *smtp.Client {
//
// If the address is non-ASCII and cannot be converted to ASCII and the remote
// server does not support SMTPUTF8, error will be returned.
func (c *C) Rcpt(to string) error {
func (c *C) Rcpt(ctx context.Context, to string) error {
defer trace.StartRegion(ctx, "smtpconn/RCPT TO").End()
// If necessary, the extension flag is enabled in Start.
if ok, _ := c.cl.Extension("SMTPUTF8"); !address.IsASCII(to) && !ok {
var err error
@ -300,7 +308,9 @@ func (c *C) Rcpt(to string) error {
//
// If the Data command fails, the connection may be in a unclean state (e.g. in
// the middle of message data stream). It is not safe to continue using it.
func (c *C) Data(hdr textproto.Header, body io.Reader) error {
func (c *C) Data(ctx context.Context, hdr textproto.Header, body io.Reader) error {
defer trace.StartRegion(ctx, "smtpconn/DATA").End()
wc, err := c.cl.Data()
if err != nil {
return c.wrapClientErr(err, c.serverName)

View file

@ -1,6 +1,7 @@
package smtpconn
import (
"context"
"strings"
"testing"
@ -14,11 +15,11 @@ import (
func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.MailOptions) error {
t.Helper()
if err := conn.Mail(from, opts); err != nil {
if err := conn.Mail(context.Background(), from, opts); err != nil {
return err
}
for _, rcpt := range to {
if err := conn.Rcpt(rcpt); err != nil {
if err := conn.Rcpt(context.Background(), rcpt); err != nil {
return err
}
}
@ -26,7 +27,7 @@ func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.M
hdr := textproto.Header{}
hdr.Add("B", "2")
hdr.Add("A", "1")
if err := conn.Data(hdr, strings.NewReader("foobar\n")); err != nil {
if err := conn.Data(context.Background(), hdr, strings.NewReader("foobar\n")); err != nil {
return err
}
@ -41,7 +42,7 @@ func TestSMTPUTF8_Sender_UTF8_Punycode(t *testing.T) {
c := New()
c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{
if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
@ -70,7 +71,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Punycode(t *testing.T) {
c := New()
c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{
if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
@ -99,7 +100,7 @@ func TestSMTPUTF8_Sender_UTF8_Reject(t *testing.T) {
c := New()
c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{
if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
@ -122,7 +123,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Reject(t *testing.T) {
c := New()
c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{
if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
@ -144,7 +145,7 @@ func TestSMTPUTF8_Sender_UTF8_Domain(t *testing.T) {
c := New()
c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{
if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
@ -172,7 +173,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Domain(t *testing.T) {
c := New()
c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{
if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
@ -201,7 +202,7 @@ func TestSMTPUTF8_Sender_UTF8_Username(t *testing.T) {
c := New()
c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{
if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
@ -230,7 +231,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Username(t *testing.T) {
c := New()
c.Log = testutils.Logger(t, "smtp_downstream")
if err := c.Connect(config.Endpoint{
if err := c.Connect(context.Background(), config.Endpoint{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,

View file

@ -15,6 +15,7 @@ import (
"fmt"
"os"
"path/filepath"
"runtime/trace"
"strconv"
"strings"
@ -63,7 +64,13 @@ type delivery struct {
addedRcpts map[string]struct{}
}
func (d *delivery) String() string {
return d.store.Name() + ":" + d.store.InstanceName()
}
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
defer trace.StartRegion(ctx, "sql/AddRcpt").End()
accountName, err := prepareUsername(rcptTo)
if err != nil {
return &exterrors.SMTPError{
@ -104,6 +111,8 @@ func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
}
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
defer trace.StartRegion(ctx, "sql/Body").End()
if d.msgMeta.Quarantine {
if err := d.d.SpecialMailbox(specialuse.Junk, d.store.junkMbox); err != nil {
return err
@ -116,14 +125,20 @@ func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffe
}
func (d *delivery) Abort(ctx context.Context) error {
defer trace.StartRegion(ctx, "sql/Abort").End()
return d.d.Abort()
}
func (d *delivery) Commit(ctx context.Context) error {
defer trace.StartRegion(ctx, "sql/Commit").End()
return d.d.Commit()
}
func (store *Storage) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
defer trace.StartRegion(ctx, "sql/Start").End()
return &delivery{
store: store,
msgMeta: msgMeta,
@ -360,6 +375,9 @@ func prepareUsername(username string) (string, error) {
}
func (store *Storage) CheckPlain(username, password string) bool {
// TODO: Pass session context there.
defer trace.StartRegion(context.Background(), "sql/CheckPlain").End()
accountName, err := prepareUsername(username)
if err != nil {
return false

View file

@ -50,6 +50,7 @@ import (
"os"
"path/filepath"
"runtime/debug"
"runtime/trace"
"strconv"
"strings"
"sync"
@ -438,10 +439,12 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
msgMeta.ID = msgMeta.ID + "-" + strconv.Itoa(meta.TriesCount+1)
dl.Debugf("using message ID = %s", msgMeta.ID)
// TODO: Trace annotations?
msgCtx := context.Background()
msgCtx, msgTask := trace.NewTask(context.Background(), "Queue delivery")
defer msgTask.End()
delivery, err := q.Target.Start(msgCtx, msgMeta, meta.From)
mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM")
delivery, err := q.Target.Start(mailCtx, msgMeta, meta.From)
mailTask.End()
if err != nil {
dl.Debugf("target.Start failed: %v", err)
perr.Failed = append(perr.Failed, meta.To...)
@ -454,7 +457,8 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
var acceptedRcpts []string
for _, rcpt := range meta.To {
if err := delivery.AddRcpt(msgCtx, rcpt); err != nil {
rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO")
if err := delivery.AddRcpt(rcptCtx, rcpt); err != nil {
if exterrors.IsTemporaryOrUnspec(err) {
perr.TemporaryFailed = append(perr.TemporaryFailed, rcpt)
} else {
@ -466,6 +470,7 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
dl.Debugf("delivery.AddRcpt %s OK", rcpt)
acceptedRcpts = append(acceptedRcpts, rcpt)
}
rcptTask.End()
}
if len(acceptedRcpts) == 0 {
@ -487,12 +492,15 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
}
}
bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA")
defer bodyTask.End()
partDelivery, ok := delivery.(module.PartialDelivery)
if ok {
dl.Debugf("using delivery.BodyNonAtomic")
partDelivery.BodyNonAtomic(msgCtx, &perr, header, body)
partDelivery.BodyNonAtomic(bodyCtx, &perr, header, body)
} else {
if err := delivery.Body(msgCtx, header, body); err != nil {
if err := delivery.Body(bodyCtx, header, body); err != nil {
dl.Debugf("delivery.Body failed: %v", err)
expandToPartialErr(err)
}
@ -508,13 +516,13 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
if allFailed {
// No recipients succeeded.
dl.Debugf("delivery.Abort (all recipients failed)")
if err := delivery.Abort(msgCtx); err != nil {
if err := delivery.Abort(bodyCtx); err != nil {
dl.Msg("delivery.Abort failed", err)
}
return perr
}
if err := delivery.Commit(msgCtx); err != nil {
if err := delivery.Commit(bodyCtx); err != nil {
dl.Debugf("delivery.Commit failed: %v", err)
expandToPartialErr(err)
}
@ -537,6 +545,8 @@ func (qd *queueDelivery) AddRcpt(ctx context.Context, rcptTo string) error {
}
func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
defer trace.StartRegion(ctx, "queue/Body").End()
// Body buffer initially passed to us may not be valid after "delivery" to queue completes.
// storeNewMessage returns a new buffer object created from message blob stored on disk.
storedBody, err := qd.q.storeNewMessage(qd.meta, header, body)
@ -550,6 +560,8 @@ func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body
}
func (qd *queueDelivery) Abort(ctx context.Context) error {
defer trace.StartRegion(ctx, "queue/Abort").End()
if qd.body != nil {
qd.q.removeFromDisk(qd.meta.MsgMeta)
}
@ -557,6 +569,8 @@ func (qd *queueDelivery) Abort(ctx context.Context) error {
}
func (qd *queueDelivery) Commit(ctx context.Context) error {
defer trace.StartRegion(ctx, "queue/Commit").End()
if qd.meta == nil {
panic("queue: double Commit")
}
@ -899,14 +913,17 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
}
dl.Msg("generated failed DSN", "dsn_id", dsnID)
// TODO: Trace annotations?
msgCtx := context.Background()
msgCtx, msgTask := trace.NewTask(context.Background(), "DSN Delivery")
defer msgTask.End()
dsnDelivery, err := q.dsnPipeline.Start(msgCtx, dsnMeta, "")
mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM")
dsnDelivery, err := q.dsnPipeline.Start(mailCtx, dsnMeta, "")
mailTask.End()
if err != nil {
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
return
}
defer func() {
if err != nil {
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
@ -916,15 +933,23 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
}
}()
if err = dsnDelivery.AddRcpt(msgCtx, meta.From); err != nil {
rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO")
if err = dsnDelivery.AddRcpt(rcptCtx, meta.From); err != nil {
rcptTask.End()
return
}
if err = dsnDelivery.Body(msgCtx, dsnHeader, dsnBody); err != nil {
rcptTask.End()
bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA")
if err = dsnDelivery.Body(bodyCtx, dsnHeader, dsnBody); err != nil {
bodyTask.End()
return
}
if err = dsnDelivery.Commit(msgCtx); err != nil {
if err = dsnDelivery.Commit(bodyCtx); err != nil {
bodyTask.End()
return
}
bodyTask.End()
}
func init() {

View file

@ -1,6 +1,7 @@
package remote
import (
"context"
"errors"
"net"
"strconv"
@ -33,7 +34,7 @@ func TestRemoteDelivery_AuthMX_Fail(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireMXAuth: true,
Log: testutils.Logger(t, "remote"),
@ -63,12 +64,12 @@ func TestRemoteDelivery_AuthMX_MTASTS(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireMXAuth: true,
tlsConfig: clientCfg,
mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) {
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup")
}
@ -114,12 +115,12 @@ func TestRemoteDelivery_AuthMX_PreferAuth(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireMXAuth: true,
tlsConfig: clientCfg,
mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) {
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup")
}
@ -166,12 +167,12 @@ func TestRemoteDelivery_MTASTS_SkipNonMatching(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireMXAuth: true,
tlsConfig: clientCfg,
mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) {
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup")
}
@ -208,11 +209,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_Fail(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireMXAuth: true,
mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) {
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup")
}
@ -251,11 +252,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_NoPolicy(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireMXAuth: true,
mxAuth: map[string]struct{}{AuthMTASTS: {}},
mtastsGet: func(domain string) (*mtasts.Policy, error) {
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
if domain != "example.invalid" {
return nil, errors.New("Wrong domain in lookup")
}
@ -290,7 +291,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireMXAuth: true,
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
@ -321,7 +322,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_Fail(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireMXAuth: true,
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
@ -354,7 +355,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_NotETLDp1(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireMXAuth: true,
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
@ -405,7 +406,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: extResolver,
requireMXAuth: true,
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
@ -452,7 +453,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC_Fail(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: extResolver,
requireMXAuth: true,
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
@ -506,7 +507,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &resolver,
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: extResolver,
requireMXAuth: true,
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
@ -556,7 +557,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral_Fail(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &resolver,
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: extResolver,
requireMXAuth: true,
mxAuth: map[string]struct{}{AuthDNSSEC: {}},

View file

@ -13,6 +13,7 @@ import (
"net"
"os"
"path/filepath"
"runtime/trace"
"sort"
"strings"
"sync"
@ -57,12 +58,12 @@ type Target struct {
mxAuth map[string]struct{}
resolver dns.Resolver
dialer func(network, addr string) (net.Conn, error)
dialer func(ctx context.Context, network, addr string) (net.Conn, error)
extResolver *dns.ExtResolver
// This is the callback that is usually mtastsCache.Get,
// but replaced by tests to mock mtasts.Cache.
mtastsGet func(domain string) (*mtasts.Policy, error)
mtastsGet func(ctx context.Context, domain string) (*mtasts.Policy, error)
mtastsCache mtasts.Cache
stsCacheUpdateTick *time.Ticker
@ -80,7 +81,7 @@ func New(_, instName string, _, inlineArgs []string) (module.Module, error) {
return &Target{
name: instName,
resolver: dns.DefaultResolver(),
dialer: net.Dial,
dialer: (&net.Dialer{}).DialContext,
mtastsCache: mtasts.Cache{Resolver: dns.DefaultResolver()},
Log: log.Logger{Name: "remote"},
@ -195,6 +196,8 @@ func (rt *Target) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFr
}
func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error {
defer trace.StartRegion(ctx, "remote/AddRcpt").End()
if rd.msgMeta.Quarantine {
return &exterrors.SMTPError{
Code: 550,
@ -232,7 +235,7 @@ func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error {
return err
}
if err := conn.Rcpt(to); err != nil {
if err := conn.Rcpt(ctx, to); err != nil {
return moduleError(err)
}
@ -379,6 +382,8 @@ func (m *multipleErrs) SetStatus(rcptTo string, err error) {
}
func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buffer buffer.Buffer) error {
defer trace.StartRegion(ctx, "remote/Body").End()
merr := multipleErrs{
errs: make(map[string]error),
}
@ -396,6 +401,8 @@ func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buf
}
func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusCollector, header textproto.Header, b buffer.Buffer) {
defer trace.StartRegion(ctx, "remote/BodyNonAtomic").End()
if rd.msgMeta.Quarantine {
for _, rcpt := range rd.recipients {
c.SetStatus(rcpt, &exterrors.SMTPError{
@ -425,7 +432,7 @@ func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusColl
}
defer bodyR.Close()
err = conn.Data(header, bodyR)
err = conn.Data(ctx, header, bodyR)
for _, rcpt := range conn.Rcpts() {
c.SetStatus(rcpt, err)
}
@ -486,7 +493,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
addrs = append(addrs, nonAuthMXs...)
rd.Log.DebugMsg("considering", "mxs", addrs)
for i, addr := range addrs {
err = conn.Connect(config.Endpoint{
err = conn.Connect(ctx, config.Endpoint{
Scheme: "tcp",
Host: addr,
Port: smtpPort,
@ -517,7 +524,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
rd.Log.DebugMsg("connected", "remote_server", conn.ServerName())
if err := conn.Mail(rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil {
if err := conn.Mail(ctx, rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil {
conn.Close()
return nil, moduleError(err)
}
@ -526,8 +533,8 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
return conn, nil
}
func (rt *Target) getSTSPolicy(domain string) (*mtasts.Policy, error) {
stsPolicy, err := rt.mtastsGet(domain)
func (rt *Target) getSTSPolicy(ctx context.Context, domain string) (*mtasts.Policy, error) {
stsPolicy, err := rt.mtastsGet(ctx, domain)
if err != nil && !mtasts.IsNoPolicy(err) {
return nil, &exterrors.SMTPError{
Code: exterrors.SMTPCode(err, 450, 554),
@ -568,9 +575,11 @@ func (rt *Target) stsCacheUpdater() {
}
func (rd *remoteDelivery) lookupAndFilter(ctx context.Context, domain string) (authMXs, nonAuthMXs []string, requireTLS bool, err error) {
defer trace.StartRegion(ctx, "remote/LookupAndFilterMX").End()
var policy *mtasts.Policy
if _, use := rd.rt.mxAuth[AuthMTASTS]; use {
policy, err = rd.rt.getSTSPolicy(domain)
policy, err = rd.rt.getSTSPolicy(ctx, domain)
if err != nil {
return nil, nil, false, err
}

View file

@ -42,7 +42,7 @@ func TestRemoteDelivery(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -74,7 +74,7 @@ func TestRemoteDelivery_IPLiteral(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &resolver,
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -100,7 +100,7 @@ func TestRemoteDelivery_FallbackMX(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: resolver,
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -127,7 +127,7 @@ func TestRemoteDelivery_BodyNonAtomic(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: resolver,
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -162,7 +162,7 @@ func TestRemoteDelivery_Abort(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -199,7 +199,7 @@ func TestRemoteDelivery_CommitWithoutBody(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -243,7 +243,7 @@ func TestRemoteDelivery_MAILFROMErr(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -276,7 +276,7 @@ func TestRemoteDelivery_NoMX(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: resolver,
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -313,7 +313,7 @@ func TestRemoteDelivery_NullMX(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -349,7 +349,7 @@ func TestRemoteDelivery_Quarantined(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -404,7 +404,7 @@ func TestRemoteDelivery_MAILFROMErr_Repeated(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -451,7 +451,7 @@ func TestRemoteDelivery_RcptErr(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -509,7 +509,7 @@ func TestRemoteDelivery_DownMX(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -539,7 +539,7 @@ func TestRemoteDelivery_AllMXDown(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -577,7 +577,7 @@ func TestRemoteDelivery_Split(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -623,7 +623,7 @@ func TestRemoteDelivery_Split_Fail(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -683,7 +683,7 @@ func TestRemoteDelivery_BodyErr(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -744,7 +744,7 @@ func TestRemoteDelivery_Split_BodyErr(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -807,7 +807,7 @@ func TestRemoteDelivery_Split_BodyErr_NonAtomic(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
Log: testutils.Logger(t, "remote"),
}
@ -867,7 +867,7 @@ func TestRemoteDelivery_TLSErrFallback(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
tlsConfig: &tls.Config{},
Log: testutils.Logger(t, "remote"),
@ -895,7 +895,7 @@ func TestRemoteDelivery_RequireTLS_Missing(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireTLS: true,
Log: testutils.Logger(t, "remote"),
@ -925,7 +925,7 @@ func TestRemoteDelivery_RequireTLS_Present(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
requireTLS: true,
tlsConfig: clientCfg,
@ -954,7 +954,7 @@ func TestRemoteDelivery_RequireTLS_NoErrFallback(t *testing.T) {
name: "remote",
hostname: "mx.example.com",
resolver: &mockdns.Resolver{Zones: zones},
dialer: resolver.Dial,
dialer: resolver.DialContext,
extResolver: nil,
tlsConfig: &tls.Config{},
requireTLS: true,

View file

@ -44,7 +44,7 @@ func TestRemoteDelivery_EHLO_ALabel(t *testing.T) {
tgt := mod.(*Target)
tgt.resolver = &mockdns.Resolver{Zones: zones}
tgt.dialer = resolver.Dial
tgt.dialer = resolver.DialContext
tgt.extResolver = nil
tgt.Log = testutils.Logger(t, "remote")

View file

@ -14,6 +14,7 @@ import (
"fmt"
"io"
"net"
"runtime/trace"
"github.com/emersion/go-message/textproto"
"github.com/foxcpp/maddy/internal/buffer"
@ -126,24 +127,26 @@ type delivery struct {
}
func (u *Downstream) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
defer trace.StartRegion(ctx, "smtp_downstream/Start").End()
d := &delivery{
u: u,
log: target.DeliveryLogger(u.log, msgMeta),
msgMeta: msgMeta,
mailFrom: mailFrom,
}
if err := d.connect(); err != nil {
if err := d.connect(ctx); err != nil {
return nil, err
}
if err := d.conn.Mail(mailFrom, msgMeta.SMTPOpts); err != nil {
if err := d.conn.Mail(ctx, mailFrom, msgMeta.SMTPOpts); err != nil {
d.conn.Close()
return nil, err
}
return d, nil
}
func (d *delivery) connect() error {
func (d *delivery) connect(ctx context.Context) error {
// TODO: Review possibility of connection pooling here.
var lastErr error
@ -156,7 +159,7 @@ func (d *delivery) connect() error {
conn.AddrInSMTPMsg = false
for _, endp := range d.u.endpoints {
err := conn.Connect(endp)
err := conn.Connect(ctx, endp)
if err == nil {
d.log.DebugMsg("connected", "downstream_server", conn.ServerName())
lastErr = nil
@ -191,7 +194,7 @@ func (d *delivery) connect() error {
}
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
return moduleError(d.conn.Rcpt(rcptTo))
return moduleError(d.conn.Rcpt(ctx, rcptTo))
}
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
@ -217,7 +220,7 @@ func (d *delivery) Commit(ctx context.Context) error {
defer d.conn.Close()
defer d.body.Close()
return moduleError(d.conn.Data(d.hdr, d.body))
return moduleError(d.conn.Data(ctx, d.hdr, d.body))
}
func init() {