mirror of
https://github.com/foxcpp/maddy.git
synced 2025-04-06 14:37:37 +03:00
Instrument the SMTP code using runtime/trace
runtime/trace together with 'go tool trace' provides extremely powerful tooling for performance (latency) analysis. Since maddy prides itself on being "optimized for concurrency", it is a good idea to actually live up to this promise. Closes #144. No need to reinvent the wheel. The original issue proposed a solution to use in production to detect "performance anomalies", it is possible to use runtime/trace in production too, but the corresponding flag to enable profiler endpoint is hidden behind the 'debugflags' build tag at the moment. For SMTP code, the basic latency information can be obtained from regular logs since they include timestamps with millisecond granularity. After the issue is apparent, it is possible to deploy the server executable compiled with tracing support and obtain more information ... Also add missing context.Context arguments to smtpconn.C.
This commit is contained in:
parent
305fdddf24
commit
c4ea9a730f
20 changed files with 281 additions and 135 deletions
|
@ -11,6 +11,7 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime/trace"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -300,6 +301,9 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
|
|||
return module.CheckResult{}
|
||||
}
|
||||
|
||||
// TODO: It is not possible to distinguish different commands.
|
||||
defer trace.StartRegion(ctx, "command/CheckConnection").End()
|
||||
|
||||
cmdName, cmdArgs := s.expandCommand("")
|
||||
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
|
||||
}
|
||||
|
@ -311,6 +315,8 @@ func (s *state) CheckSender(ctx context.Context, addr string) module.CheckResult
|
|||
return module.CheckResult{}
|
||||
}
|
||||
|
||||
defer trace.StartRegion(ctx, "command/CheckSender").End()
|
||||
|
||||
cmdName, cmdArgs := s.expandCommand(addr)
|
||||
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
|
||||
}
|
||||
|
@ -321,6 +327,7 @@ func (s *state) CheckRcpt(ctx context.Context, addr string) module.CheckResult {
|
|||
if s.c.stage != StageRcpt {
|
||||
return module.CheckResult{}
|
||||
}
|
||||
defer trace.StartRegion(ctx, "command/CheckRcpt").End()
|
||||
|
||||
cmdName, cmdArgs := s.expandCommand(addr)
|
||||
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
|
||||
|
@ -331,6 +338,8 @@ func (s *state) CheckBody(ctx context.Context, hdr textproto.Header, body buffer
|
|||
return module.CheckResult{}
|
||||
}
|
||||
|
||||
defer trace.StartRegion(ctx, "command/CheckBody").End()
|
||||
|
||||
cmdName, cmdArgs := s.expandCommand("")
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"io"
|
||||
nettextproto "net/textproto"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
|
@ -96,6 +97,8 @@ func (d *dkimCheckState) CheckRcpt(ctx context.Context, rcptTo string) module.Ch
|
|||
}
|
||||
|
||||
func (d *dkimCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
|
||||
defer trace.StartRegion(ctx, "verify_dkim/CheckBody").End()
|
||||
|
||||
if !header.Has("DKIM-Signature") {
|
||||
if d.c.noSigAction.Reject || d.c.noSigAction.Quarantine {
|
||||
d.log.Printf("no signatures present")
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
|
@ -323,6 +324,8 @@ func (bl *DNSBL) CheckConnection(ctx context.Context, state *smtp.ConnectionStat
|
|||
return nil
|
||||
}
|
||||
|
||||
defer trace.StartRegion(ctx, "dnsbl/CheckConnection (Early)").End()
|
||||
|
||||
ip, ok := state.RemoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
bl.log.Msg("non-TCP/IP source",
|
||||
|
@ -358,6 +361,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
|
|||
return module.CheckResult{}
|
||||
}
|
||||
|
||||
defer trace.StartRegion(ctx, "dnsbl/CheckConnection").End()
|
||||
|
||||
if s.msgMeta.Conn == nil {
|
||||
s.log.Msg("locally generated message, ignoring")
|
||||
return module.CheckResult{}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
"runtime/trace"
|
||||
|
||||
"blitiri.com.ar/go/spf"
|
||||
"github.com/emersion/go-message/textproto"
|
||||
|
@ -267,6 +268,8 @@ func prepareMailFrom(from string) (string, error) {
|
|||
}
|
||||
|
||||
func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
|
||||
defer trace.StartRegion(ctx, "apply_spf/CheckConnection").End()
|
||||
|
||||
if s.msgMeta.Conn == nil {
|
||||
s.log.Println("locally generated message, skipping")
|
||||
return module.CheckResult{}
|
||||
|
@ -306,6 +309,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
|
|||
}
|
||||
}()
|
||||
|
||||
defer trace.StartRegion(ctx, "apply_spf/CheckConnection (Async)").End()
|
||||
|
||||
res, err := spf.CheckHostWithSender(ip.IP, s.msgMeta.Conn.Hostname, mailFrom)
|
||||
s.log.Debugf("result: %s (%v)", res, err)
|
||||
s.spfFetch <- spfRes{res, err}
|
||||
|
@ -328,6 +333,8 @@ func (s *state) CheckBody(ctx context.Context, header textproto.Header, body buf
|
|||
return module.CheckResult{}
|
||||
}
|
||||
|
||||
defer trace.StartRegion(ctx, "apply_spf/CheckBody").End()
|
||||
|
||||
res, ok := <-s.spfFetch
|
||||
if !ok {
|
||||
return module.CheckResult{
|
||||
|
|
|
@ -3,6 +3,7 @@ package check
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/trace"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/foxcpp/maddy/internal/buffer"
|
||||
|
@ -57,12 +58,18 @@ type statelessCheckState struct {
|
|||
msgMeta *module.MsgMetadata
|
||||
}
|
||||
|
||||
func (s *statelessCheckState) String() string {
|
||||
return s.c.modName + ":" + s.c.instName
|
||||
}
|
||||
|
||||
func (s *statelessCheckState) CheckConnection(ctx context.Context) module.CheckResult {
|
||||
if s.c.connCheck == nil {
|
||||
return module.CheckResult{}
|
||||
}
|
||||
defer trace.StartRegion(ctx, s.c.modName+"/CheckConnection").End()
|
||||
|
||||
originalRes := s.c.connCheck(StatelessCheckContext{
|
||||
Context: ctx,
|
||||
Resolver: s.c.resolver,
|
||||
MsgMeta: s.msgMeta,
|
||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||
|
@ -74,8 +81,10 @@ func (s *statelessCheckState) CheckSender(ctx context.Context, mailFrom string)
|
|||
if s.c.senderCheck == nil {
|
||||
return module.CheckResult{}
|
||||
}
|
||||
defer trace.StartRegion(ctx, s.c.modName+"/CheckSender").End()
|
||||
|
||||
originalRes := s.c.senderCheck(StatelessCheckContext{
|
||||
Context: ctx,
|
||||
Resolver: s.c.resolver,
|
||||
MsgMeta: s.msgMeta,
|
||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||
|
@ -87,8 +96,10 @@ func (s *statelessCheckState) CheckRcpt(ctx context.Context, rcptTo string) modu
|
|||
if s.c.rcptCheck == nil {
|
||||
return module.CheckResult{}
|
||||
}
|
||||
defer trace.StartRegion(ctx, s.c.modName+"/CheckRcpt").End()
|
||||
|
||||
originalRes := s.c.rcptCheck(StatelessCheckContext{
|
||||
Context: ctx,
|
||||
Resolver: s.c.resolver,
|
||||
MsgMeta: s.msgMeta,
|
||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||
|
@ -100,8 +111,10 @@ func (s *statelessCheckState) CheckBody(ctx context.Context, header textproto.He
|
|||
if s.c.bodyCheck == nil {
|
||||
return module.CheckResult{}
|
||||
}
|
||||
defer trace.StartRegion(ctx, s.c.modName+"/CheckBody").End()
|
||||
|
||||
originalRes := s.c.bodyCheck(StatelessCheckContext{
|
||||
Context: ctx,
|
||||
Resolver: s.c.resolver,
|
||||
MsgMeta: s.msgMeta,
|
||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"math/rand"
|
||||
"net"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
|
@ -81,6 +82,8 @@ func (v *Verifier) FetchRecord(ctx context.Context, header textproto.Header) {
|
|||
}
|
||||
}()
|
||||
|
||||
defer trace.StartRegion(ctx, "DMARC/FetchRecord").End()
|
||||
|
||||
policyDomain, record, err := FetchRecord(ctx, v.resolver, fromDomain)
|
||||
v.fetchCh <- verifyData{
|
||||
policyDomain: policyDomain,
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -46,6 +47,7 @@ type Session struct {
|
|||
// msgCtx is not used for cancellation or timeouts, only for tracing.
|
||||
// It is the subcontext of sessionCtx.
|
||||
msgCtx context.Context
|
||||
msgTask *trace.Task
|
||||
mailFrom string
|
||||
opts smtp.MailOptions
|
||||
msgMeta *module.MsgMetadata
|
||||
|
@ -75,6 +77,7 @@ func (s *Session) abort(ctx context.Context) {
|
|||
s.delivery = nil
|
||||
s.deliveryErr = nil
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
}
|
||||
|
||||
func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) {
|
||||
|
@ -138,15 +141,20 @@ func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.Mail
|
|||
)
|
||||
}
|
||||
|
||||
delivery, err := s.endp.pipeline.Start(ctx, msgMeta, cleanFrom)
|
||||
s.msgCtx, s.msgTask = trace.NewTask(ctx, "Incoming Message")
|
||||
mailCtx, mailTask := trace.NewTask(s.msgCtx, "MAIL FROM")
|
||||
defer mailTask.End()
|
||||
|
||||
delivery, err := s.endp.pipeline.Start(mailCtx, msgMeta, cleanFrom)
|
||||
if err != nil {
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
return msgMeta.ID, err
|
||||
}
|
||||
|
||||
s.msgMeta = msgMeta
|
||||
s.mailFrom = cleanFrom
|
||||
s.delivery = delivery
|
||||
s.msgCtx = s.sessionCtx // TODO: Trace annotations.
|
||||
|
||||
return msgMeta.ID, nil
|
||||
}
|
||||
|
@ -171,6 +179,8 @@ func (s *Session) Mail(from string, opts smtp.MailOptions) error {
|
|||
}
|
||||
|
||||
func (s *Session) fetchRDNSName(ctx context.Context) {
|
||||
defer trace.StartRegion(ctx, "rDNS fetch").End()
|
||||
|
||||
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
s.connState.RDNSName.Set(nil)
|
||||
|
@ -211,7 +221,8 @@ func (s *Session) Rcpt(to string) error {
|
|||
}
|
||||
}
|
||||
|
||||
rcptCtx := s.msgCtx
|
||||
rcptCtx, rcptTask := trace.NewTask(s.msgCtx, "RCPT TO")
|
||||
defer rcptTask.End()
|
||||
|
||||
if err := s.rcpt(rcptCtx, to); err != nil {
|
||||
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
|
||||
|
@ -293,7 +304,8 @@ func (s *Session) prepareBody(ctx context.Context, r io.Reader) (textproto.Heade
|
|||
}
|
||||
|
||||
func (s *Session) Data(r io.Reader) error {
|
||||
bodyCtx := s.msgCtx
|
||||
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
|
||||
defer bodyTask.End()
|
||||
|
||||
wrapErr := func(err error) error {
|
||||
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
||||
|
@ -317,6 +329,9 @@ func (s *Session) Data(r io.Reader) error {
|
|||
|
||||
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
||||
s.delivery = nil
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
s.msgTask = nil
|
||||
s.endp.semaphore.Release()
|
||||
|
||||
return nil
|
||||
|
@ -332,7 +347,8 @@ func (sw statusWrapper) SetStatus(rcpt string, err error) {
|
|||
}
|
||||
|
||||
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
|
||||
bodyCtx := s.msgCtx
|
||||
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
|
||||
defer bodyTask.End()
|
||||
|
||||
wrapErr := func(err error) error {
|
||||
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
||||
|
@ -356,6 +372,9 @@ func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
|
|||
|
||||
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
||||
s.delivery = nil
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
s.msgTask = nil
|
||||
s.endp.semaphore.Release()
|
||||
|
||||
return nil
|
||||
|
@ -671,7 +690,7 @@ func (endp *Endpoint) newSession(anonymous bool, username, password string, stat
|
|||
}
|
||||
|
||||
if endp.resolver != nil {
|
||||
rdnsCtx, cancelRDNS := context.WithCancel(context.TODO())
|
||||
rdnsCtx, cancelRDNS := context.WithCancel(s.sessionCtx)
|
||||
s.connState.RDNSName = future.New()
|
||||
s.cancelRDNS = cancelRDNS
|
||||
go s.fetchRDNSName(rdnsCtx)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"net/mail"
|
||||
"path/filepath"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -337,6 +338,8 @@ func (s state) RewriteRcpt(ctx context.Context, rcptTo string) (string, error) {
|
|||
}
|
||||
|
||||
func (s state) RewriteBody(ctx context.Context, h *textproto.Header, body buffer.Buffer) error {
|
||||
defer trace.StartRegion(ctx, "sign_dkim/RewriteBody").End()
|
||||
|
||||
var authUser string
|
||||
if s.meta.Conn != nil {
|
||||
authUser = s.meta.Conn.AuthUser
|
||||
|
|
|
@ -11,10 +11,7 @@ import (
|
|||
func objectName(x interface{}) string {
|
||||
mod, ok := x.(module.Module)
|
||||
if ok {
|
||||
if mod.InstanceName() == "" {
|
||||
return mod.Name()
|
||||
}
|
||||
return mod.InstanceName()
|
||||
return mod.Name() + ":" + mod.InstanceName()
|
||||
}
|
||||
|
||||
_, pipeline := x.(*MsgPipeline)
|
||||
|
@ -22,5 +19,10 @@ func objectName(x interface{}) string {
|
|||
return "reroute"
|
||||
}
|
||||
|
||||
str, ok := x.(fmt.Stringer)
|
||||
if ok {
|
||||
return str.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%T", x)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/trace"
|
||||
"time"
|
||||
|
||||
"github.com/foxcpp/maddy/internal/exterrors"
|
||||
|
@ -23,10 +24,14 @@ var httpClient = &http.Client{
|
|||
Timeout: time.Minute,
|
||||
}
|
||||
|
||||
func downloadPolicy(domain string) (*Policy, error) {
|
||||
func downloadPolicy(ctx context.Context, domain string) (*Policy, error) {
|
||||
// TODO: Consult OCSP/CRL to detect revoked certificates?
|
||||
|
||||
resp, err := httpClient.Get("https://mta-sts." + domain + "/.well-known/mta-sts.txt")
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://mta-sts."+domain+"/.well-known/mta-sts.txt", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -80,8 +85,8 @@ var ErrIgnorePolicy = errors.New("mtasts: policy ignored due to errors")
|
|||
// Get reads policy from cache or tries to fetch it from Policy Host.
|
||||
//
|
||||
// The domain is assumed to be normalized, as done by dns.ForLookup.
|
||||
func (c *Cache) Get(domain string) (*Policy, error) {
|
||||
_, p, err := c.fetch(false, time.Now(), domain)
|
||||
func (c *Cache) Get(ctx context.Context, domain string) (*Policy, error) {
|
||||
_, p, err := c.fetch(ctx, false, time.Now(), domain)
|
||||
return p, err
|
||||
}
|
||||
|
||||
|
@ -118,6 +123,9 @@ func (c *Cache) load(domain string) (id string, fetchTime time.Time, p *Policy,
|
|||
}
|
||||
|
||||
func (c *Cache) Refresh() error {
|
||||
refreshCtx, refreshTask := trace.NewTask(context.Background(), "mtasts.Cache/Refresh")
|
||||
defer refreshTask.End()
|
||||
|
||||
dir, err := ioutil.ReadDir(c.Location)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -132,7 +140,7 @@ func (c *Cache) Refresh() error {
|
|||
// Since otherwise we are going to have expired policy for another 6 hours,
|
||||
// which makes it useless.
|
||||
// See https://tools.ietf.org/html/rfc8461#section-10.2.
|
||||
cacheHit, _, err := c.fetch(true, time.Now().Add(6*time.Hour), ent.Name())
|
||||
cacheHit, _, err := c.fetch(refreshCtx, false, time.Now().Add(6*time.Hour), ent.Name())
|
||||
if err != nil && err != ErrIgnorePolicy {
|
||||
c.Logger.Error("policy update error", err, "domain", ent.Name())
|
||||
}
|
||||
|
@ -147,7 +155,9 @@ func (c *Cache) Refresh() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) {
|
||||
func (c *Cache) fetch(ctx context.Context, ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) {
|
||||
defer trace.StartRegion(ctx, "mtasts.Cache/fetch").End()
|
||||
|
||||
validCache := true
|
||||
cachedId, fetchTime, cachedPolicy, err := c.load(domain)
|
||||
if err != nil {
|
||||
|
@ -163,7 +173,7 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo
|
|||
|
||||
var dnsId string
|
||||
if !ignoreDns {
|
||||
records, err := c.Resolver.LookupTXT(context.Background(), "_mta-sts."+domain)
|
||||
records, err := c.Resolver.LookupTXT(ctx, "_mta-sts."+domain)
|
||||
if err != nil {
|
||||
if validCache {
|
||||
if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsNotFound {
|
||||
|
@ -218,11 +228,15 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo
|
|||
}
|
||||
|
||||
if !validCache || dnsId != cachedId {
|
||||
download := downloadPolicy
|
||||
var (
|
||||
policy *Policy
|
||||
err error
|
||||
)
|
||||
if c.downloadPolicy != nil {
|
||||
download = c.downloadPolicy
|
||||
policy, err = c.downloadPolicy(domain)
|
||||
} else {
|
||||
policy, err = downloadPolicy(ctx, domain)
|
||||
}
|
||||
policy, err := download(domain)
|
||||
if err != nil {
|
||||
if validCache {
|
||||
c.Logger.Error("failed to fetch new policy, using cache", err, "domain", domain)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mtasts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -37,7 +38,7 @@ func TestCacheGet(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
policy, err := c.Get("example.org")
|
||||
policy, err := c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -57,7 +58,7 @@ func TestCacheGet_Error_DNS(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
_, err := c.Get("example.org")
|
||||
_, err := c.Get(context.Background(), "example.org")
|
||||
if err != ErrIgnorePolicy {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -78,7 +79,7 @@ func TestCacheGet_Error_HTTPS(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
_, err := c.Get("example.org")
|
||||
_, err := c.Get(context.Background(), "example.org")
|
||||
if err != ErrIgnorePolicy {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -104,7 +105,7 @@ func TestCacheGet_Cached(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
policy, err := c.Get("example.org")
|
||||
policy, err := c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -115,7 +116,7 @@ func TestCacheGet_Cached(t *testing.T) {
|
|||
// calling downloadPolicy.
|
||||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||
|
||||
policy, err = c.Get("example.org")
|
||||
policy, err = c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -146,7 +147,7 @@ func TestCacheGet_Expired(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
policy, err := c.Get("example.org")
|
||||
policy, err := c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -160,7 +161,7 @@ func TestCacheGet_Expired(t *testing.T) {
|
|||
expectedPolicy.MX = []string{"b"}
|
||||
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
|
||||
|
||||
policy, err = c.Get("example.org")
|
||||
policy, err = c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -190,7 +191,7 @@ func TestCacheGet_IDChange(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
policy, err := c.Get("example.org")
|
||||
policy, err := c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -206,7 +207,7 @@ func TestCacheGet_IDChange(t *testing.T) {
|
|||
expectedPolicy.MX = []string{"b"}
|
||||
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
|
||||
|
||||
policy, err = c.Get("example.org")
|
||||
policy, err = c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -236,7 +237,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
policy, err := c.Get("example.org")
|
||||
policy, err := c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -251,7 +252,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) {
|
|||
resolver.Zones = nil
|
||||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||
|
||||
policy, err = c.Get("example.org")
|
||||
policy, err = c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -281,7 +282,7 @@ func TestCacheGet_HTTPGet_ErrNoPolicy(t *testing.T) {
|
|||
// >(for any reason), and there is no valid (non-expired) previously
|
||||
// >cached policy, senders MUST continue with delivery as though the
|
||||
// >domain has not implemented MTA-STS.
|
||||
_, err := c.Get("example.org")
|
||||
_, err := c.Get(context.Background(), "example.org")
|
||||
if err != ErrIgnorePolicy {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -308,7 +309,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
policy, err := c.Get("example.org")
|
||||
policy, err := c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -325,7 +326,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) {
|
|||
}
|
||||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||
|
||||
policy, err = c.Get("example.org")
|
||||
policy, err = c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -357,7 +358,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
policy, err := c.Get("example.org")
|
||||
policy, err := c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -376,7 +377,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) {
|
|||
}
|
||||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||
|
||||
policy, err = c.Get("example.org")
|
||||
policy, err = c.Get(context.Background(), "example.org")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got policy %v", policy)
|
||||
}
|
||||
|
@ -405,7 +406,7 @@ func TestCacheRefresh(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
policy, err := c.Get("example.org")
|
||||
policy, err := c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -427,7 +428,7 @@ func TestCacheRefresh(t *testing.T) {
|
|||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||
|
||||
// It should return the new record from cache.
|
||||
policy, err = c.Get("example.org")
|
||||
policy, err = c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -459,7 +460,7 @@ func TestCacheRefresh_Error(t *testing.T) {
|
|||
}
|
||||
defer os.RemoveAll(c.Location)
|
||||
|
||||
policy, err := c.Get("example.org")
|
||||
policy, err := c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
@ -477,7 +478,7 @@ func TestCacheRefresh_Error(t *testing.T) {
|
|||
}
|
||||
|
||||
// It should return the old record from cache.
|
||||
policy, err = c.Get("example.org")
|
||||
policy, err = c.Get(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("policy get: %v", err)
|
||||
}
|
||||
|
|
|
@ -10,9 +10,11 @@
|
|||
package smtpconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"runtime/trace"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-smtp"
|
||||
|
@ -27,9 +29,9 @@ import (
|
|||
//
|
||||
// Currently, the C object represents one session and cannot be reused.
|
||||
type C struct {
|
||||
// Dialer to use to estabilish new network connections. Set to net.Dial by
|
||||
// New.
|
||||
Dialer func(network, addr string) (net.Conn, error)
|
||||
// Dialer to use to estabilish new network connections. Set to net.Dialer
|
||||
// DialContext by New.
|
||||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Fail if the connection cannot use TLS.
|
||||
RequireTLS bool
|
||||
|
@ -60,7 +62,7 @@ type C struct {
|
|||
// with resonable default values.
|
||||
func New() *C {
|
||||
return &C{
|
||||
Dialer: net.Dial,
|
||||
Dialer: (&net.Dialer{}).DialContext,
|
||||
TLSConfig: &tls.Config{},
|
||||
Hostname: "localhost.localdomain",
|
||||
}
|
||||
|
@ -122,9 +124,11 @@ func (c *C) wrapClientErr(err error, serverName string) error {
|
|||
}
|
||||
|
||||
// Connect actually estabilishes the network connection with the remote host.
|
||||
func (c *C) Connect(endp config.Endpoint) error {
|
||||
func (c *C) Connect(ctx context.Context, endp config.Endpoint) error {
|
||||
defer trace.StartRegion(ctx, "smtpconn/Connect+TLS").End()
|
||||
|
||||
// TODO: Helper function to try multiple endpoints?
|
||||
cl, err := c.attemptConnect(endp, c.AttemptTLS)
|
||||
cl, err := c.attemptConnect(ctx, endp, c.AttemptTLS)
|
||||
if err != nil {
|
||||
return c.wrapClientErr(err, endp.Host)
|
||||
}
|
||||
|
@ -134,9 +138,9 @@ func (c *C) Connect(endp config.Endpoint) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) {
|
||||
func (c *C) attemptConnect(ctx context.Context, endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) {
|
||||
var conn net.Conn
|
||||
conn, err := c.Dialer(endp.Network(), endp.Address())
|
||||
conn, err := c.Dialer(ctx, endp.Network(), endp.Address())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -197,7 +201,7 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client,
|
|||
// Re-attempt without STARTTLS. It is not possible to reuse connection
|
||||
// since it is probably in a bad state.
|
||||
c.Log.Error("TLS error, falling back to plain-text connection", err, "remote_server", endp.Host+endp.Port)
|
||||
return c.attemptConnect(endp, false)
|
||||
return c.attemptConnect(ctx, endp, false)
|
||||
}
|
||||
|
||||
return cl, nil
|
||||
|
@ -209,7 +213,9 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client,
|
|||
// SMTPUTF8 is forwarded if supported by the remote server, if it is not
|
||||
// supported - attempt will be done to convert addresses to the ASCII form, if
|
||||
// this is not possible, the corresponding method (Mail or Rcpt) will fail.
|
||||
func (c *C) Mail(from string, opts smtp.MailOptions) error {
|
||||
func (c *C) Mail(ctx context.Context, from string, opts smtp.MailOptions) error {
|
||||
defer trace.StartRegion(ctx, "smtpconn/MAIL FROM").End()
|
||||
|
||||
outOpts := smtp.MailOptions{
|
||||
// Future extensions may add additional fields that should not be
|
||||
// copied blindly. So we copy only fields we know should be handled
|
||||
|
@ -268,7 +274,9 @@ func (c *C) Client() *smtp.Client {
|
|||
//
|
||||
// If the address is non-ASCII and cannot be converted to ASCII and the remote
|
||||
// server does not support SMTPUTF8, error will be returned.
|
||||
func (c *C) Rcpt(to string) error {
|
||||
func (c *C) Rcpt(ctx context.Context, to string) error {
|
||||
defer trace.StartRegion(ctx, "smtpconn/RCPT TO").End()
|
||||
|
||||
// If necessary, the extension flag is enabled in Start.
|
||||
if ok, _ := c.cl.Extension("SMTPUTF8"); !address.IsASCII(to) && !ok {
|
||||
var err error
|
||||
|
@ -300,7 +308,9 @@ func (c *C) Rcpt(to string) error {
|
|||
//
|
||||
// If the Data command fails, the connection may be in a unclean state (e.g. in
|
||||
// the middle of message data stream). It is not safe to continue using it.
|
||||
func (c *C) Data(hdr textproto.Header, body io.Reader) error {
|
||||
func (c *C) Data(ctx context.Context, hdr textproto.Header, body io.Reader) error {
|
||||
defer trace.StartRegion(ctx, "smtpconn/DATA").End()
|
||||
|
||||
wc, err := c.cl.Data()
|
||||
if err != nil {
|
||||
return c.wrapClientErr(err, c.serverName)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package smtpconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -14,11 +15,11 @@ import (
|
|||
func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.MailOptions) error {
|
||||
t.Helper()
|
||||
|
||||
if err := conn.Mail(from, opts); err != nil {
|
||||
if err := conn.Mail(context.Background(), from, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rcpt := range to {
|
||||
if err := conn.Rcpt(rcpt); err != nil {
|
||||
if err := conn.Rcpt(context.Background(), rcpt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -26,7 +27,7 @@ func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.M
|
|||
hdr := textproto.Header{}
|
||||
hdr.Add("B", "2")
|
||||
hdr.Add("A", "1")
|
||||
if err := conn.Data(hdr, strings.NewReader("foobar\n")); err != nil {
|
||||
if err := conn.Data(context.Background(), hdr, strings.NewReader("foobar\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -41,7 +42,7 @@ func TestSMTPUTF8_Sender_UTF8_Punycode(t *testing.T) {
|
|||
|
||||
c := New()
|
||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||
if err := c.Connect(config.Endpoint{
|
||||
if err := c.Connect(context.Background(), config.Endpoint{
|
||||
Scheme: "tcp",
|
||||
Host: "127.0.0.1",
|
||||
Port: testPort,
|
||||
|
@ -70,7 +71,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Punycode(t *testing.T) {
|
|||
|
||||
c := New()
|
||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||
if err := c.Connect(config.Endpoint{
|
||||
if err := c.Connect(context.Background(), config.Endpoint{
|
||||
Scheme: "tcp",
|
||||
Host: "127.0.0.1",
|
||||
Port: testPort,
|
||||
|
@ -99,7 +100,7 @@ func TestSMTPUTF8_Sender_UTF8_Reject(t *testing.T) {
|
|||
|
||||
c := New()
|
||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||
if err := c.Connect(config.Endpoint{
|
||||
if err := c.Connect(context.Background(), config.Endpoint{
|
||||
Scheme: "tcp",
|
||||
Host: "127.0.0.1",
|
||||
Port: testPort,
|
||||
|
@ -122,7 +123,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Reject(t *testing.T) {
|
|||
|
||||
c := New()
|
||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||
if err := c.Connect(config.Endpoint{
|
||||
if err := c.Connect(context.Background(), config.Endpoint{
|
||||
Scheme: "tcp",
|
||||
Host: "127.0.0.1",
|
||||
Port: testPort,
|
||||
|
@ -144,7 +145,7 @@ func TestSMTPUTF8_Sender_UTF8_Domain(t *testing.T) {
|
|||
|
||||
c := New()
|
||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||
if err := c.Connect(config.Endpoint{
|
||||
if err := c.Connect(context.Background(), config.Endpoint{
|
||||
Scheme: "tcp",
|
||||
Host: "127.0.0.1",
|
||||
Port: testPort,
|
||||
|
@ -172,7 +173,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Domain(t *testing.T) {
|
|||
|
||||
c := New()
|
||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||
if err := c.Connect(config.Endpoint{
|
||||
if err := c.Connect(context.Background(), config.Endpoint{
|
||||
Scheme: "tcp",
|
||||
Host: "127.0.0.1",
|
||||
Port: testPort,
|
||||
|
@ -201,7 +202,7 @@ func TestSMTPUTF8_Sender_UTF8_Username(t *testing.T) {
|
|||
|
||||
c := New()
|
||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||
if err := c.Connect(config.Endpoint{
|
||||
if err := c.Connect(context.Background(), config.Endpoint{
|
||||
Scheme: "tcp",
|
||||
Host: "127.0.0.1",
|
||||
Port: testPort,
|
||||
|
@ -230,7 +231,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Username(t *testing.T) {
|
|||
|
||||
c := New()
|
||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||
if err := c.Connect(config.Endpoint{
|
||||
if err := c.Connect(context.Background(), config.Endpoint{
|
||||
Scheme: "tcp",
|
||||
Host: "127.0.0.1",
|
||||
Port: testPort,
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/trace"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -63,7 +64,13 @@ type delivery struct {
|
|||
addedRcpts map[string]struct{}
|
||||
}
|
||||
|
||||
func (d *delivery) String() string {
|
||||
return d.store.Name() + ":" + d.store.InstanceName()
|
||||
}
|
||||
|
||||
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
||||
defer trace.StartRegion(ctx, "sql/AddRcpt").End()
|
||||
|
||||
accountName, err := prepareUsername(rcptTo)
|
||||
if err != nil {
|
||||
return &exterrors.SMTPError{
|
||||
|
@ -104,6 +111,8 @@ func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
|||
}
|
||||
|
||||
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
|
||||
defer trace.StartRegion(ctx, "sql/Body").End()
|
||||
|
||||
if d.msgMeta.Quarantine {
|
||||
if err := d.d.SpecialMailbox(specialuse.Junk, d.store.junkMbox); err != nil {
|
||||
return err
|
||||
|
@ -116,14 +125,20 @@ func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffe
|
|||
}
|
||||
|
||||
func (d *delivery) Abort(ctx context.Context) error {
|
||||
defer trace.StartRegion(ctx, "sql/Abort").End()
|
||||
|
||||
return d.d.Abort()
|
||||
}
|
||||
|
||||
func (d *delivery) Commit(ctx context.Context) error {
|
||||
defer trace.StartRegion(ctx, "sql/Commit").End()
|
||||
|
||||
return d.d.Commit()
|
||||
}
|
||||
|
||||
func (store *Storage) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
defer trace.StartRegion(ctx, "sql/Start").End()
|
||||
|
||||
return &delivery{
|
||||
store: store,
|
||||
msgMeta: msgMeta,
|
||||
|
@ -360,6 +375,9 @@ func prepareUsername(username string) (string, error) {
|
|||
}
|
||||
|
||||
func (store *Storage) CheckPlain(username, password string) bool {
|
||||
// TODO: Pass session context there.
|
||||
defer trace.StartRegion(context.Background(), "sql/CheckPlain").End()
|
||||
|
||||
accountName, err := prepareUsername(username)
|
||||
if err != nil {
|
||||
return false
|
||||
|
|
|
@ -50,6 +50,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"runtime/trace"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -438,10 +439,12 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
|||
msgMeta.ID = msgMeta.ID + "-" + strconv.Itoa(meta.TriesCount+1)
|
||||
dl.Debugf("using message ID = %s", msgMeta.ID)
|
||||
|
||||
// TODO: Trace annotations?
|
||||
msgCtx := context.Background()
|
||||
msgCtx, msgTask := trace.NewTask(context.Background(), "Queue delivery")
|
||||
defer msgTask.End()
|
||||
|
||||
delivery, err := q.Target.Start(msgCtx, msgMeta, meta.From)
|
||||
mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM")
|
||||
delivery, err := q.Target.Start(mailCtx, msgMeta, meta.From)
|
||||
mailTask.End()
|
||||
if err != nil {
|
||||
dl.Debugf("target.Start failed: %v", err)
|
||||
perr.Failed = append(perr.Failed, meta.To...)
|
||||
|
@ -454,7 +457,8 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
|||
|
||||
var acceptedRcpts []string
|
||||
for _, rcpt := range meta.To {
|
||||
if err := delivery.AddRcpt(msgCtx, rcpt); err != nil {
|
||||
rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO")
|
||||
if err := delivery.AddRcpt(rcptCtx, rcpt); err != nil {
|
||||
if exterrors.IsTemporaryOrUnspec(err) {
|
||||
perr.TemporaryFailed = append(perr.TemporaryFailed, rcpt)
|
||||
} else {
|
||||
|
@ -466,6 +470,7 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
|||
dl.Debugf("delivery.AddRcpt %s OK", rcpt)
|
||||
acceptedRcpts = append(acceptedRcpts, rcpt)
|
||||
}
|
||||
rcptTask.End()
|
||||
}
|
||||
|
||||
if len(acceptedRcpts) == 0 {
|
||||
|
@ -487,12 +492,15 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
|||
}
|
||||
}
|
||||
|
||||
bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA")
|
||||
defer bodyTask.End()
|
||||
|
||||
partDelivery, ok := delivery.(module.PartialDelivery)
|
||||
if ok {
|
||||
dl.Debugf("using delivery.BodyNonAtomic")
|
||||
partDelivery.BodyNonAtomic(msgCtx, &perr, header, body)
|
||||
partDelivery.BodyNonAtomic(bodyCtx, &perr, header, body)
|
||||
} else {
|
||||
if err := delivery.Body(msgCtx, header, body); err != nil {
|
||||
if err := delivery.Body(bodyCtx, header, body); err != nil {
|
||||
dl.Debugf("delivery.Body failed: %v", err)
|
||||
expandToPartialErr(err)
|
||||
}
|
||||
|
@ -508,13 +516,13 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
|||
if allFailed {
|
||||
// No recipients succeeded.
|
||||
dl.Debugf("delivery.Abort (all recipients failed)")
|
||||
if err := delivery.Abort(msgCtx); err != nil {
|
||||
if err := delivery.Abort(bodyCtx); err != nil {
|
||||
dl.Msg("delivery.Abort failed", err)
|
||||
}
|
||||
return perr
|
||||
}
|
||||
|
||||
if err := delivery.Commit(msgCtx); err != nil {
|
||||
if err := delivery.Commit(bodyCtx); err != nil {
|
||||
dl.Debugf("delivery.Commit failed: %v", err)
|
||||
expandToPartialErr(err)
|
||||
}
|
||||
|
@ -537,6 +545,8 @@ func (qd *queueDelivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
|||
}
|
||||
|
||||
func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
|
||||
defer trace.StartRegion(ctx, "queue/Body").End()
|
||||
|
||||
// Body buffer initially passed to us may not be valid after "delivery" to queue completes.
|
||||
// storeNewMessage returns a new buffer object created from message blob stored on disk.
|
||||
storedBody, err := qd.q.storeNewMessage(qd.meta, header, body)
|
||||
|
@ -550,6 +560,8 @@ func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body
|
|||
}
|
||||
|
||||
func (qd *queueDelivery) Abort(ctx context.Context) error {
|
||||
defer trace.StartRegion(ctx, "queue/Abort").End()
|
||||
|
||||
if qd.body != nil {
|
||||
qd.q.removeFromDisk(qd.meta.MsgMeta)
|
||||
}
|
||||
|
@ -557,6 +569,8 @@ func (qd *queueDelivery) Abort(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (qd *queueDelivery) Commit(ctx context.Context) error {
|
||||
defer trace.StartRegion(ctx, "queue/Commit").End()
|
||||
|
||||
if qd.meta == nil {
|
||||
panic("queue: double Commit")
|
||||
}
|
||||
|
@ -899,14 +913,17 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
|
|||
}
|
||||
dl.Msg("generated failed DSN", "dsn_id", dsnID)
|
||||
|
||||
// TODO: Trace annotations?
|
||||
msgCtx := context.Background()
|
||||
msgCtx, msgTask := trace.NewTask(context.Background(), "DSN Delivery")
|
||||
defer msgTask.End()
|
||||
|
||||
dsnDelivery, err := q.dsnPipeline.Start(msgCtx, dsnMeta, "")
|
||||
mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM")
|
||||
dsnDelivery, err := q.dsnPipeline.Start(mailCtx, dsnMeta, "")
|
||||
mailTask.End()
|
||||
if err != nil {
|
||||
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
|
||||
|
@ -916,15 +933,23 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
|
|||
}
|
||||
}()
|
||||
|
||||
if err = dsnDelivery.AddRcpt(msgCtx, meta.From); err != nil {
|
||||
rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO")
|
||||
if err = dsnDelivery.AddRcpt(rcptCtx, meta.From); err != nil {
|
||||
rcptTask.End()
|
||||
return
|
||||
}
|
||||
if err = dsnDelivery.Body(msgCtx, dsnHeader, dsnBody); err != nil {
|
||||
rcptTask.End()
|
||||
|
||||
bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA")
|
||||
if err = dsnDelivery.Body(bodyCtx, dsnHeader, dsnBody); err != nil {
|
||||
bodyTask.End()
|
||||
return
|
||||
}
|
||||
if err = dsnDelivery.Commit(msgCtx); err != nil {
|
||||
if err = dsnDelivery.Commit(bodyCtx); err != nil {
|
||||
bodyTask.End()
|
||||
return
|
||||
}
|
||||
bodyTask.End()
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package remote
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
|
@ -33,7 +34,7 @@ func TestRemoteDelivery_AuthMX_Fail(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireMXAuth: true,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
|
@ -63,12 +64,12 @@ func TestRemoteDelivery_AuthMX_MTASTS(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireMXAuth: true,
|
||||
tlsConfig: clientCfg,
|
||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
||||
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||
if domain != "example.invalid" {
|
||||
return nil, errors.New("Wrong domain in lookup")
|
||||
}
|
||||
|
@ -114,12 +115,12 @@ func TestRemoteDelivery_AuthMX_PreferAuth(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireMXAuth: true,
|
||||
tlsConfig: clientCfg,
|
||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
||||
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||
if domain != "example.invalid" {
|
||||
return nil, errors.New("Wrong domain in lookup")
|
||||
}
|
||||
|
@ -166,12 +167,12 @@ func TestRemoteDelivery_MTASTS_SkipNonMatching(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireMXAuth: true,
|
||||
tlsConfig: clientCfg,
|
||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
||||
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||
if domain != "example.invalid" {
|
||||
return nil, errors.New("Wrong domain in lookup")
|
||||
}
|
||||
|
@ -208,11 +209,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_Fail(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireMXAuth: true,
|
||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
||||
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||
if domain != "example.invalid" {
|
||||
return nil, errors.New("Wrong domain in lookup")
|
||||
}
|
||||
|
@ -251,11 +252,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_NoPolicy(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireMXAuth: true,
|
||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
||||
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||
if domain != "example.invalid" {
|
||||
return nil, errors.New("Wrong domain in lookup")
|
||||
}
|
||||
|
@ -290,7 +291,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireMXAuth: true,
|
||||
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
|
||||
|
@ -321,7 +322,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_Fail(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireMXAuth: true,
|
||||
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
|
||||
|
@ -354,7 +355,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_NotETLDp1(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireMXAuth: true,
|
||||
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
|
||||
|
@ -405,7 +406,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: extResolver,
|
||||
requireMXAuth: true,
|
||||
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
||||
|
@ -452,7 +453,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC_Fail(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: extResolver,
|
||||
requireMXAuth: true,
|
||||
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
||||
|
@ -506,7 +507,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &resolver,
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: extResolver,
|
||||
requireMXAuth: true,
|
||||
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
||||
|
@ -556,7 +557,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral_Fail(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &resolver,
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: extResolver,
|
||||
requireMXAuth: true,
|
||||
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/trace"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -57,12 +58,12 @@ type Target struct {
|
|||
mxAuth map[string]struct{}
|
||||
|
||||
resolver dns.Resolver
|
||||
dialer func(network, addr string) (net.Conn, error)
|
||||
dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
extResolver *dns.ExtResolver
|
||||
|
||||
// This is the callback that is usually mtastsCache.Get,
|
||||
// but replaced by tests to mock mtasts.Cache.
|
||||
mtastsGet func(domain string) (*mtasts.Policy, error)
|
||||
mtastsGet func(ctx context.Context, domain string) (*mtasts.Policy, error)
|
||||
|
||||
mtastsCache mtasts.Cache
|
||||
stsCacheUpdateTick *time.Ticker
|
||||
|
@ -80,7 +81,7 @@ func New(_, instName string, _, inlineArgs []string) (module.Module, error) {
|
|||
return &Target{
|
||||
name: instName,
|
||||
resolver: dns.DefaultResolver(),
|
||||
dialer: net.Dial,
|
||||
dialer: (&net.Dialer{}).DialContext,
|
||||
mtastsCache: mtasts.Cache{Resolver: dns.DefaultResolver()},
|
||||
Log: log.Logger{Name: "remote"},
|
||||
|
||||
|
@ -195,6 +196,8 @@ func (rt *Target) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFr
|
|||
}
|
||||
|
||||
func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error {
|
||||
defer trace.StartRegion(ctx, "remote/AddRcpt").End()
|
||||
|
||||
if rd.msgMeta.Quarantine {
|
||||
return &exterrors.SMTPError{
|
||||
Code: 550,
|
||||
|
@ -232,7 +235,7 @@ func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err := conn.Rcpt(to); err != nil {
|
||||
if err := conn.Rcpt(ctx, to); err != nil {
|
||||
return moduleError(err)
|
||||
}
|
||||
|
||||
|
@ -379,6 +382,8 @@ func (m *multipleErrs) SetStatus(rcptTo string, err error) {
|
|||
}
|
||||
|
||||
func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buffer buffer.Buffer) error {
|
||||
defer trace.StartRegion(ctx, "remote/Body").End()
|
||||
|
||||
merr := multipleErrs{
|
||||
errs: make(map[string]error),
|
||||
}
|
||||
|
@ -396,6 +401,8 @@ func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buf
|
|||
}
|
||||
|
||||
func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusCollector, header textproto.Header, b buffer.Buffer) {
|
||||
defer trace.StartRegion(ctx, "remote/BodyNonAtomic").End()
|
||||
|
||||
if rd.msgMeta.Quarantine {
|
||||
for _, rcpt := range rd.recipients {
|
||||
c.SetStatus(rcpt, &exterrors.SMTPError{
|
||||
|
@ -425,7 +432,7 @@ func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusColl
|
|||
}
|
||||
defer bodyR.Close()
|
||||
|
||||
err = conn.Data(header, bodyR)
|
||||
err = conn.Data(ctx, header, bodyR)
|
||||
for _, rcpt := range conn.Rcpts() {
|
||||
c.SetStatus(rcpt, err)
|
||||
}
|
||||
|
@ -486,7 +493,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
|
|||
addrs = append(addrs, nonAuthMXs...)
|
||||
rd.Log.DebugMsg("considering", "mxs", addrs)
|
||||
for i, addr := range addrs {
|
||||
err = conn.Connect(config.Endpoint{
|
||||
err = conn.Connect(ctx, config.Endpoint{
|
||||
Scheme: "tcp",
|
||||
Host: addr,
|
||||
Port: smtpPort,
|
||||
|
@ -517,7 +524,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
|
|||
|
||||
rd.Log.DebugMsg("connected", "remote_server", conn.ServerName())
|
||||
|
||||
if err := conn.Mail(rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil {
|
||||
if err := conn.Mail(ctx, rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil {
|
||||
conn.Close()
|
||||
return nil, moduleError(err)
|
||||
}
|
||||
|
@ -526,8 +533,8 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
func (rt *Target) getSTSPolicy(domain string) (*mtasts.Policy, error) {
|
||||
stsPolicy, err := rt.mtastsGet(domain)
|
||||
func (rt *Target) getSTSPolicy(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||
stsPolicy, err := rt.mtastsGet(ctx, domain)
|
||||
if err != nil && !mtasts.IsNoPolicy(err) {
|
||||
return nil, &exterrors.SMTPError{
|
||||
Code: exterrors.SMTPCode(err, 450, 554),
|
||||
|
@ -568,9 +575,11 @@ func (rt *Target) stsCacheUpdater() {
|
|||
}
|
||||
|
||||
func (rd *remoteDelivery) lookupAndFilter(ctx context.Context, domain string) (authMXs, nonAuthMXs []string, requireTLS bool, err error) {
|
||||
defer trace.StartRegion(ctx, "remote/LookupAndFilterMX").End()
|
||||
|
||||
var policy *mtasts.Policy
|
||||
if _, use := rd.rt.mxAuth[AuthMTASTS]; use {
|
||||
policy, err = rd.rt.getSTSPolicy(domain)
|
||||
policy, err = rd.rt.getSTSPolicy(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ func TestRemoteDelivery(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ func TestRemoteDelivery_IPLiteral(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &resolver,
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ func TestRemoteDelivery_FallbackMX(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: resolver,
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ func TestRemoteDelivery_BodyNonAtomic(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: resolver,
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ func TestRemoteDelivery_Abort(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -199,7 +199,7 @@ func TestRemoteDelivery_CommitWithoutBody(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -243,7 +243,7 @@ func TestRemoteDelivery_MAILFROMErr(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -276,7 +276,7 @@ func TestRemoteDelivery_NoMX(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: resolver,
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -313,7 +313,7 @@ func TestRemoteDelivery_NullMX(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -349,7 +349,7 @@ func TestRemoteDelivery_Quarantined(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -404,7 +404,7 @@ func TestRemoteDelivery_MAILFROMErr_Repeated(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -451,7 +451,7 @@ func TestRemoteDelivery_RcptErr(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -509,7 +509,7 @@ func TestRemoteDelivery_DownMX(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -539,7 +539,7 @@ func TestRemoteDelivery_AllMXDown(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -577,7 +577,7 @@ func TestRemoteDelivery_Split(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -623,7 +623,7 @@ func TestRemoteDelivery_Split_Fail(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -683,7 +683,7 @@ func TestRemoteDelivery_BodyErr(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -744,7 +744,7 @@ func TestRemoteDelivery_Split_BodyErr(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -807,7 +807,7 @@ func TestRemoteDelivery_Split_BodyErr_NonAtomic(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
}
|
||||
|
@ -867,7 +867,7 @@ func TestRemoteDelivery_TLSErrFallback(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
tlsConfig: &tls.Config{},
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
|
@ -895,7 +895,7 @@ func TestRemoteDelivery_RequireTLS_Missing(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireTLS: true,
|
||||
Log: testutils.Logger(t, "remote"),
|
||||
|
@ -925,7 +925,7 @@ func TestRemoteDelivery_RequireTLS_Present(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
requireTLS: true,
|
||||
tlsConfig: clientCfg,
|
||||
|
@ -954,7 +954,7 @@ func TestRemoteDelivery_RequireTLS_NoErrFallback(t *testing.T) {
|
|||
name: "remote",
|
||||
hostname: "mx.example.com",
|
||||
resolver: &mockdns.Resolver{Zones: zones},
|
||||
dialer: resolver.Dial,
|
||||
dialer: resolver.DialContext,
|
||||
extResolver: nil,
|
||||
tlsConfig: &tls.Config{},
|
||||
requireTLS: true,
|
||||
|
|
|
@ -44,7 +44,7 @@ func TestRemoteDelivery_EHLO_ALabel(t *testing.T) {
|
|||
|
||||
tgt := mod.(*Target)
|
||||
tgt.resolver = &mockdns.Resolver{Zones: zones}
|
||||
tgt.dialer = resolver.Dial
|
||||
tgt.dialer = resolver.DialContext
|
||||
tgt.extResolver = nil
|
||||
tgt.Log = testutils.Logger(t, "remote")
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime/trace"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/foxcpp/maddy/internal/buffer"
|
||||
|
@ -126,24 +127,26 @@ type delivery struct {
|
|||
}
|
||||
|
||||
func (u *Downstream) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||
defer trace.StartRegion(ctx, "smtp_downstream/Start").End()
|
||||
|
||||
d := &delivery{
|
||||
u: u,
|
||||
log: target.DeliveryLogger(u.log, msgMeta),
|
||||
msgMeta: msgMeta,
|
||||
mailFrom: mailFrom,
|
||||
}
|
||||
if err := d.connect(); err != nil {
|
||||
if err := d.connect(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := d.conn.Mail(mailFrom, msgMeta.SMTPOpts); err != nil {
|
||||
if err := d.conn.Mail(ctx, mailFrom, msgMeta.SMTPOpts); err != nil {
|
||||
d.conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *delivery) connect() error {
|
||||
func (d *delivery) connect(ctx context.Context) error {
|
||||
// TODO: Review possibility of connection pooling here.
|
||||
var lastErr error
|
||||
|
||||
|
@ -156,7 +159,7 @@ func (d *delivery) connect() error {
|
|||
conn.AddrInSMTPMsg = false
|
||||
|
||||
for _, endp := range d.u.endpoints {
|
||||
err := conn.Connect(endp)
|
||||
err := conn.Connect(ctx, endp)
|
||||
if err == nil {
|
||||
d.log.DebugMsg("connected", "downstream_server", conn.ServerName())
|
||||
lastErr = nil
|
||||
|
@ -191,7 +194,7 @@ func (d *delivery) connect() error {
|
|||
}
|
||||
|
||||
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
||||
return moduleError(d.conn.Rcpt(rcptTo))
|
||||
return moduleError(d.conn.Rcpt(ctx, rcptTo))
|
||||
}
|
||||
|
||||
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
|
||||
|
@ -217,7 +220,7 @@ func (d *delivery) Commit(ctx context.Context) error {
|
|||
defer d.conn.Close()
|
||||
defer d.body.Close()
|
||||
|
||||
return moduleError(d.conn.Data(d.hdr, d.body))
|
||||
return moduleError(d.conn.Data(ctx, d.hdr, d.body))
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue