mirror of
https://github.com/foxcpp/maddy.git
synced 2025-04-06 14:37:37 +03:00
Instrument the SMTP code using runtime/trace
runtime/trace together with 'go tool trace' provides extremely powerful tooling for performance (latency) analysis. Since maddy prides itself on being "optimized for concurrency", it is a good idea to actually live up to this promise. Closes #144. No need to reinvent the wheel. The original issue proposed a solution to use in production to detect "performance anomalies", it is possible to use runtime/trace in production too, but the corresponding flag to enable profiler endpoint is hidden behind the 'debugflags' build tag at the moment. For SMTP code, the basic latency information can be obtained from regular logs since they include timestamps with millisecond granularity. After the issue is apparent, it is possible to deploy the server executable compiled with tracing support and obtain more information ... Also add missing context.Context arguments to smtpconn.C.
This commit is contained in:
parent
305fdddf24
commit
c4ea9a730f
20 changed files with 281 additions and 135 deletions
|
@ -11,6 +11,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"runtime/trace"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -300,6 +301,9 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: It is not possible to distinguish different commands.
|
||||||
|
defer trace.StartRegion(ctx, "command/CheckConnection").End()
|
||||||
|
|
||||||
cmdName, cmdArgs := s.expandCommand("")
|
cmdName, cmdArgs := s.expandCommand("")
|
||||||
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
|
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
|
||||||
}
|
}
|
||||||
|
@ -311,6 +315,8 @@ func (s *state) CheckSender(ctx context.Context, addr string) module.CheckResult
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer trace.StartRegion(ctx, "command/CheckSender").End()
|
||||||
|
|
||||||
cmdName, cmdArgs := s.expandCommand(addr)
|
cmdName, cmdArgs := s.expandCommand(addr)
|
||||||
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
|
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
|
||||||
}
|
}
|
||||||
|
@ -321,6 +327,7 @@ func (s *state) CheckRcpt(ctx context.Context, addr string) module.CheckResult {
|
||||||
if s.c.stage != StageRcpt {
|
if s.c.stage != StageRcpt {
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
defer trace.StartRegion(ctx, "command/CheckRcpt").End()
|
||||||
|
|
||||||
cmdName, cmdArgs := s.expandCommand(addr)
|
cmdName, cmdArgs := s.expandCommand(addr)
|
||||||
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
|
return s.run(cmdName, cmdArgs, bytes.NewReader(nil))
|
||||||
|
@ -331,6 +338,8 @@ func (s *state) CheckBody(ctx context.Context, hdr textproto.Header, body buffer
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer trace.StartRegion(ctx, "command/CheckBody").End()
|
||||||
|
|
||||||
cmdName, cmdArgs := s.expandCommand("")
|
cmdName, cmdArgs := s.expandCommand("")
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
nettextproto "net/textproto"
|
nettextproto "net/textproto"
|
||||||
|
"runtime/trace"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/emersion/go-message/textproto"
|
"github.com/emersion/go-message/textproto"
|
||||||
|
@ -96,6 +97,8 @@ func (d *dkimCheckState) CheckRcpt(ctx context.Context, rcptTo string) module.Ch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dkimCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
|
func (d *dkimCheckState) CheckBody(ctx context.Context, header textproto.Header, body buffer.Buffer) module.CheckResult {
|
||||||
|
defer trace.StartRegion(ctx, "verify_dkim/CheckBody").End()
|
||||||
|
|
||||||
if !header.Has("DKIM-Signature") {
|
if !header.Has("DKIM-Signature") {
|
||||||
if d.c.noSigAction.Reject || d.c.noSigAction.Quarantine {
|
if d.c.noSigAction.Reject || d.c.noSigAction.Quarantine {
|
||||||
d.log.Printf("no signatures present")
|
d.log.Printf("no signatures present")
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime/trace"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/emersion/go-message/textproto"
|
"github.com/emersion/go-message/textproto"
|
||||||
|
@ -323,6 +324,8 @@ func (bl *DNSBL) CheckConnection(ctx context.Context, state *smtp.ConnectionStat
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer trace.StartRegion(ctx, "dnsbl/CheckConnection (Early)").End()
|
||||||
|
|
||||||
ip, ok := state.RemoteAddr.(*net.TCPAddr)
|
ip, ok := state.RemoteAddr.(*net.TCPAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
bl.log.Msg("non-TCP/IP source",
|
bl.log.Msg("non-TCP/IP source",
|
||||||
|
@ -358,6 +361,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer trace.StartRegion(ctx, "dnsbl/CheckConnection").End()
|
||||||
|
|
||||||
if s.msgMeta.Conn == nil {
|
if s.msgMeta.Conn == nil {
|
||||||
s.log.Msg("locally generated message, ignoring")
|
s.log.Msg("locally generated message, ignoring")
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"runtime/trace"
|
||||||
|
|
||||||
"blitiri.com.ar/go/spf"
|
"blitiri.com.ar/go/spf"
|
||||||
"github.com/emersion/go-message/textproto"
|
"github.com/emersion/go-message/textproto"
|
||||||
|
@ -267,6 +268,8 @@ func prepareMailFrom(from string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
|
func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
|
||||||
|
defer trace.StartRegion(ctx, "apply_spf/CheckConnection").End()
|
||||||
|
|
||||||
if s.msgMeta.Conn == nil {
|
if s.msgMeta.Conn == nil {
|
||||||
s.log.Println("locally generated message, skipping")
|
s.log.Println("locally generated message, skipping")
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
|
@ -306,6 +309,8 @@ func (s *state) CheckConnection(ctx context.Context) module.CheckResult {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
defer trace.StartRegion(ctx, "apply_spf/CheckConnection (Async)").End()
|
||||||
|
|
||||||
res, err := spf.CheckHostWithSender(ip.IP, s.msgMeta.Conn.Hostname, mailFrom)
|
res, err := spf.CheckHostWithSender(ip.IP, s.msgMeta.Conn.Hostname, mailFrom)
|
||||||
s.log.Debugf("result: %s (%v)", res, err)
|
s.log.Debugf("result: %s (%v)", res, err)
|
||||||
s.spfFetch <- spfRes{res, err}
|
s.spfFetch <- spfRes{res, err}
|
||||||
|
@ -328,6 +333,8 @@ func (s *state) CheckBody(ctx context.Context, header textproto.Header, body buf
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer trace.StartRegion(ctx, "apply_spf/CheckBody").End()
|
||||||
|
|
||||||
res, ok := <-s.spfFetch
|
res, ok := <-s.spfFetch
|
||||||
if !ok {
|
if !ok {
|
||||||
return module.CheckResult{
|
return module.CheckResult{
|
||||||
|
|
|
@ -3,6 +3,7 @@ package check
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime/trace"
|
||||||
|
|
||||||
"github.com/emersion/go-message/textproto"
|
"github.com/emersion/go-message/textproto"
|
||||||
"github.com/foxcpp/maddy/internal/buffer"
|
"github.com/foxcpp/maddy/internal/buffer"
|
||||||
|
@ -57,12 +58,18 @@ type statelessCheckState struct {
|
||||||
msgMeta *module.MsgMetadata
|
msgMeta *module.MsgMetadata
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *statelessCheckState) String() string {
|
||||||
|
return s.c.modName + ":" + s.c.instName
|
||||||
|
}
|
||||||
|
|
||||||
func (s *statelessCheckState) CheckConnection(ctx context.Context) module.CheckResult {
|
func (s *statelessCheckState) CheckConnection(ctx context.Context) module.CheckResult {
|
||||||
if s.c.connCheck == nil {
|
if s.c.connCheck == nil {
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
defer trace.StartRegion(ctx, s.c.modName+"/CheckConnection").End()
|
||||||
|
|
||||||
originalRes := s.c.connCheck(StatelessCheckContext{
|
originalRes := s.c.connCheck(StatelessCheckContext{
|
||||||
|
Context: ctx,
|
||||||
Resolver: s.c.resolver,
|
Resolver: s.c.resolver,
|
||||||
MsgMeta: s.msgMeta,
|
MsgMeta: s.msgMeta,
|
||||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||||
|
@ -74,8 +81,10 @@ func (s *statelessCheckState) CheckSender(ctx context.Context, mailFrom string)
|
||||||
if s.c.senderCheck == nil {
|
if s.c.senderCheck == nil {
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
defer trace.StartRegion(ctx, s.c.modName+"/CheckSender").End()
|
||||||
|
|
||||||
originalRes := s.c.senderCheck(StatelessCheckContext{
|
originalRes := s.c.senderCheck(StatelessCheckContext{
|
||||||
|
Context: ctx,
|
||||||
Resolver: s.c.resolver,
|
Resolver: s.c.resolver,
|
||||||
MsgMeta: s.msgMeta,
|
MsgMeta: s.msgMeta,
|
||||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||||
|
@ -87,8 +96,10 @@ func (s *statelessCheckState) CheckRcpt(ctx context.Context, rcptTo string) modu
|
||||||
if s.c.rcptCheck == nil {
|
if s.c.rcptCheck == nil {
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
defer trace.StartRegion(ctx, s.c.modName+"/CheckRcpt").End()
|
||||||
|
|
||||||
originalRes := s.c.rcptCheck(StatelessCheckContext{
|
originalRes := s.c.rcptCheck(StatelessCheckContext{
|
||||||
|
Context: ctx,
|
||||||
Resolver: s.c.resolver,
|
Resolver: s.c.resolver,
|
||||||
MsgMeta: s.msgMeta,
|
MsgMeta: s.msgMeta,
|
||||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||||
|
@ -100,8 +111,10 @@ func (s *statelessCheckState) CheckBody(ctx context.Context, header textproto.He
|
||||||
if s.c.bodyCheck == nil {
|
if s.c.bodyCheck == nil {
|
||||||
return module.CheckResult{}
|
return module.CheckResult{}
|
||||||
}
|
}
|
||||||
|
defer trace.StartRegion(ctx, s.c.modName+"/CheckBody").End()
|
||||||
|
|
||||||
originalRes := s.c.bodyCheck(StatelessCheckContext{
|
originalRes := s.c.bodyCheck(StatelessCheckContext{
|
||||||
|
Context: ctx,
|
||||||
Resolver: s.c.resolver,
|
Resolver: s.c.resolver,
|
||||||
MsgMeta: s.msgMeta,
|
MsgMeta: s.msgMeta,
|
||||||
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
Logger: target.DeliveryLogger(s.c.logger, s.msgMeta),
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime/trace"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/emersion/go-message/textproto"
|
"github.com/emersion/go-message/textproto"
|
||||||
|
@ -81,6 +82,8 @@ func (v *Verifier) FetchRecord(ctx context.Context, header textproto.Header) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
defer trace.StartRegion(ctx, "DMARC/FetchRecord").End()
|
||||||
|
|
||||||
policyDomain, record, err := FetchRecord(ctx, v.resolver, fromDomain)
|
policyDomain, record, err := FetchRecord(ctx, v.resolver, fromDomain)
|
||||||
v.fetchCh <- verifyData{
|
v.fetchCh <- verifyData{
|
||||||
policyDomain: policyDomain,
|
policyDomain: policyDomain,
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime/trace"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -46,6 +47,7 @@ type Session struct {
|
||||||
// msgCtx is not used for cancellation or timeouts, only for tracing.
|
// msgCtx is not used for cancellation or timeouts, only for tracing.
|
||||||
// It is the subcontext of sessionCtx.
|
// It is the subcontext of sessionCtx.
|
||||||
msgCtx context.Context
|
msgCtx context.Context
|
||||||
|
msgTask *trace.Task
|
||||||
mailFrom string
|
mailFrom string
|
||||||
opts smtp.MailOptions
|
opts smtp.MailOptions
|
||||||
msgMeta *module.MsgMetadata
|
msgMeta *module.MsgMetadata
|
||||||
|
@ -75,6 +77,7 @@ func (s *Session) abort(ctx context.Context) {
|
||||||
s.delivery = nil
|
s.delivery = nil
|
||||||
s.deliveryErr = nil
|
s.deliveryErr = nil
|
||||||
s.msgCtx = nil
|
s.msgCtx = nil
|
||||||
|
s.msgTask.End()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) {
|
func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) {
|
||||||
|
@ -138,15 +141,20 @@ func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.Mail
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
delivery, err := s.endp.pipeline.Start(ctx, msgMeta, cleanFrom)
|
s.msgCtx, s.msgTask = trace.NewTask(ctx, "Incoming Message")
|
||||||
|
mailCtx, mailTask := trace.NewTask(s.msgCtx, "MAIL FROM")
|
||||||
|
defer mailTask.End()
|
||||||
|
|
||||||
|
delivery, err := s.endp.pipeline.Start(mailCtx, msgMeta, cleanFrom)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.msgCtx = nil
|
||||||
|
s.msgTask.End()
|
||||||
return msgMeta.ID, err
|
return msgMeta.ID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.msgMeta = msgMeta
|
s.msgMeta = msgMeta
|
||||||
s.mailFrom = cleanFrom
|
s.mailFrom = cleanFrom
|
||||||
s.delivery = delivery
|
s.delivery = delivery
|
||||||
s.msgCtx = s.sessionCtx // TODO: Trace annotations.
|
|
||||||
|
|
||||||
return msgMeta.ID, nil
|
return msgMeta.ID, nil
|
||||||
}
|
}
|
||||||
|
@ -171,6 +179,8 @@ func (s *Session) Mail(from string, opts smtp.MailOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) fetchRDNSName(ctx context.Context) {
|
func (s *Session) fetchRDNSName(ctx context.Context) {
|
||||||
|
defer trace.StartRegion(ctx, "rDNS fetch").End()
|
||||||
|
|
||||||
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
|
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
s.connState.RDNSName.Set(nil)
|
s.connState.RDNSName.Set(nil)
|
||||||
|
@ -211,7 +221,8 @@ func (s *Session) Rcpt(to string) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rcptCtx := s.msgCtx
|
rcptCtx, rcptTask := trace.NewTask(s.msgCtx, "RCPT TO")
|
||||||
|
defer rcptTask.End()
|
||||||
|
|
||||||
if err := s.rcpt(rcptCtx, to); err != nil {
|
if err := s.rcpt(rcptCtx, to); err != nil {
|
||||||
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
|
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
|
||||||
|
@ -293,7 +304,8 @@ func (s *Session) prepareBody(ctx context.Context, r io.Reader) (textproto.Heade
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) Data(r io.Reader) error {
|
func (s *Session) Data(r io.Reader) error {
|
||||||
bodyCtx := s.msgCtx
|
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
|
||||||
|
defer bodyTask.End()
|
||||||
|
|
||||||
wrapErr := func(err error) error {
|
wrapErr := func(err error) error {
|
||||||
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
||||||
|
@ -317,6 +329,9 @@ func (s *Session) Data(r io.Reader) error {
|
||||||
|
|
||||||
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
||||||
s.delivery = nil
|
s.delivery = nil
|
||||||
|
s.msgCtx = nil
|
||||||
|
s.msgTask.End()
|
||||||
|
s.msgTask = nil
|
||||||
s.endp.semaphore.Release()
|
s.endp.semaphore.Release()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -332,7 +347,8 @@ func (sw statusWrapper) SetStatus(rcpt string, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
|
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
|
||||||
bodyCtx := s.msgCtx
|
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
|
||||||
|
defer bodyTask.End()
|
||||||
|
|
||||||
wrapErr := func(err error) error {
|
wrapErr := func(err error) error {
|
||||||
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
||||||
|
@ -356,6 +372,9 @@ func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
|
||||||
|
|
||||||
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
||||||
s.delivery = nil
|
s.delivery = nil
|
||||||
|
s.msgCtx = nil
|
||||||
|
s.msgTask.End()
|
||||||
|
s.msgTask = nil
|
||||||
s.endp.semaphore.Release()
|
s.endp.semaphore.Release()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -671,7 +690,7 @@ func (endp *Endpoint) newSession(anonymous bool, username, password string, stat
|
||||||
}
|
}
|
||||||
|
|
||||||
if endp.resolver != nil {
|
if endp.resolver != nil {
|
||||||
rdnsCtx, cancelRDNS := context.WithCancel(context.TODO())
|
rdnsCtx, cancelRDNS := context.WithCancel(s.sessionCtx)
|
||||||
s.connState.RDNSName = future.New()
|
s.connState.RDNSName = future.New()
|
||||||
s.cancelRDNS = cancelRDNS
|
s.cancelRDNS = cancelRDNS
|
||||||
go s.fetchRDNSName(rdnsCtx)
|
go s.fetchRDNSName(rdnsCtx)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime/trace"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -337,6 +338,8 @@ func (s state) RewriteRcpt(ctx context.Context, rcptTo string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s state) RewriteBody(ctx context.Context, h *textproto.Header, body buffer.Buffer) error {
|
func (s state) RewriteBody(ctx context.Context, h *textproto.Header, body buffer.Buffer) error {
|
||||||
|
defer trace.StartRegion(ctx, "sign_dkim/RewriteBody").End()
|
||||||
|
|
||||||
var authUser string
|
var authUser string
|
||||||
if s.meta.Conn != nil {
|
if s.meta.Conn != nil {
|
||||||
authUser = s.meta.Conn.AuthUser
|
authUser = s.meta.Conn.AuthUser
|
||||||
|
|
|
@ -11,10 +11,7 @@ import (
|
||||||
func objectName(x interface{}) string {
|
func objectName(x interface{}) string {
|
||||||
mod, ok := x.(module.Module)
|
mod, ok := x.(module.Module)
|
||||||
if ok {
|
if ok {
|
||||||
if mod.InstanceName() == "" {
|
return mod.Name() + ":" + mod.InstanceName()
|
||||||
return mod.Name()
|
|
||||||
}
|
|
||||||
return mod.InstanceName()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, pipeline := x.(*MsgPipeline)
|
_, pipeline := x.(*MsgPipeline)
|
||||||
|
@ -22,5 +19,10 @@ func objectName(x interface{}) string {
|
||||||
return "reroute"
|
return "reroute"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
str, ok := x.(fmt.Stringer)
|
||||||
|
if ok {
|
||||||
|
return str.String()
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%T", x)
|
return fmt.Sprintf("%T", x)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime/trace"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/foxcpp/maddy/internal/exterrors"
|
"github.com/foxcpp/maddy/internal/exterrors"
|
||||||
|
@ -23,10 +24,14 @@ var httpClient = &http.Client{
|
||||||
Timeout: time.Minute,
|
Timeout: time.Minute,
|
||||||
}
|
}
|
||||||
|
|
||||||
func downloadPolicy(domain string) (*Policy, error) {
|
func downloadPolicy(ctx context.Context, domain string) (*Policy, error) {
|
||||||
// TODO: Consult OCSP/CRL to detect revoked certificates?
|
// TODO: Consult OCSP/CRL to detect revoked certificates?
|
||||||
|
|
||||||
resp, err := httpClient.Get("https://mta-sts." + domain + "/.well-known/mta-sts.txt")
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://mta-sts."+domain+"/.well-known/mta-sts.txt", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -80,8 +85,8 @@ var ErrIgnorePolicy = errors.New("mtasts: policy ignored due to errors")
|
||||||
// Get reads policy from cache or tries to fetch it from Policy Host.
|
// Get reads policy from cache or tries to fetch it from Policy Host.
|
||||||
//
|
//
|
||||||
// The domain is assumed to be normalized, as done by dns.ForLookup.
|
// The domain is assumed to be normalized, as done by dns.ForLookup.
|
||||||
func (c *Cache) Get(domain string) (*Policy, error) {
|
func (c *Cache) Get(ctx context.Context, domain string) (*Policy, error) {
|
||||||
_, p, err := c.fetch(false, time.Now(), domain)
|
_, p, err := c.fetch(ctx, false, time.Now(), domain)
|
||||||
return p, err
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +123,9 @@ func (c *Cache) load(domain string) (id string, fetchTime time.Time, p *Policy,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) Refresh() error {
|
func (c *Cache) Refresh() error {
|
||||||
|
refreshCtx, refreshTask := trace.NewTask(context.Background(), "mtasts.Cache/Refresh")
|
||||||
|
defer refreshTask.End()
|
||||||
|
|
||||||
dir, err := ioutil.ReadDir(c.Location)
|
dir, err := ioutil.ReadDir(c.Location)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -132,7 +140,7 @@ func (c *Cache) Refresh() error {
|
||||||
// Since otherwise we are going to have expired policy for another 6 hours,
|
// Since otherwise we are going to have expired policy for another 6 hours,
|
||||||
// which makes it useless.
|
// which makes it useless.
|
||||||
// See https://tools.ietf.org/html/rfc8461#section-10.2.
|
// See https://tools.ietf.org/html/rfc8461#section-10.2.
|
||||||
cacheHit, _, err := c.fetch(true, time.Now().Add(6*time.Hour), ent.Name())
|
cacheHit, _, err := c.fetch(refreshCtx, false, time.Now().Add(6*time.Hour), ent.Name())
|
||||||
if err != nil && err != ErrIgnorePolicy {
|
if err != nil && err != ErrIgnorePolicy {
|
||||||
c.Logger.Error("policy update error", err, "domain", ent.Name())
|
c.Logger.Error("policy update error", err, "domain", ent.Name())
|
||||||
}
|
}
|
||||||
|
@ -147,7 +155,9 @@ func (c *Cache) Refresh() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) {
|
func (c *Cache) fetch(ctx context.Context, ignoreDns bool, now time.Time, domain string) (cacheHit bool, p *Policy, err error) {
|
||||||
|
defer trace.StartRegion(ctx, "mtasts.Cache/fetch").End()
|
||||||
|
|
||||||
validCache := true
|
validCache := true
|
||||||
cachedId, fetchTime, cachedPolicy, err := c.load(domain)
|
cachedId, fetchTime, cachedPolicy, err := c.load(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -163,7 +173,7 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo
|
||||||
|
|
||||||
var dnsId string
|
var dnsId string
|
||||||
if !ignoreDns {
|
if !ignoreDns {
|
||||||
records, err := c.Resolver.LookupTXT(context.Background(), "_mta-sts."+domain)
|
records, err := c.Resolver.LookupTXT(ctx, "_mta-sts."+domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if validCache {
|
if validCache {
|
||||||
if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsNotFound {
|
if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsNotFound {
|
||||||
|
@ -218,11 +228,15 @@ func (c *Cache) fetch(ignoreDns bool, now time.Time, domain string) (cacheHit bo
|
||||||
}
|
}
|
||||||
|
|
||||||
if !validCache || dnsId != cachedId {
|
if !validCache || dnsId != cachedId {
|
||||||
download := downloadPolicy
|
var (
|
||||||
|
policy *Policy
|
||||||
|
err error
|
||||||
|
)
|
||||||
if c.downloadPolicy != nil {
|
if c.downloadPolicy != nil {
|
||||||
download = c.downloadPolicy
|
policy, err = c.downloadPolicy(domain)
|
||||||
|
} else {
|
||||||
|
policy, err = downloadPolicy(ctx, domain)
|
||||||
}
|
}
|
||||||
policy, err := download(domain)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if validCache {
|
if validCache {
|
||||||
c.Logger.Error("failed to fetch new policy, using cache", err, "domain", domain)
|
c.Logger.Error("failed to fetch new policy, using cache", err, "domain", domain)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mtasts
|
package mtasts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -37,7 +38,7 @@ func TestCacheGet(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
policy, err := c.Get("example.org")
|
policy, err := c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -57,7 +58,7 @@ func TestCacheGet_Error_DNS(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
_, err := c.Get("example.org")
|
_, err := c.Get(context.Background(), "example.org")
|
||||||
if err != ErrIgnorePolicy {
|
if err != ErrIgnorePolicy {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -78,7 +79,7 @@ func TestCacheGet_Error_HTTPS(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
_, err := c.Get("example.org")
|
_, err := c.Get(context.Background(), "example.org")
|
||||||
if err != ErrIgnorePolicy {
|
if err != ErrIgnorePolicy {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -104,7 +105,7 @@ func TestCacheGet_Cached(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
policy, err := c.Get("example.org")
|
policy, err := c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -115,7 +116,7 @@ func TestCacheGet_Cached(t *testing.T) {
|
||||||
// calling downloadPolicy.
|
// calling downloadPolicy.
|
||||||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||||
|
|
||||||
policy, err = c.Get("example.org")
|
policy, err = c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -146,7 +147,7 @@ func TestCacheGet_Expired(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
policy, err := c.Get("example.org")
|
policy, err := c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -160,7 +161,7 @@ func TestCacheGet_Expired(t *testing.T) {
|
||||||
expectedPolicy.MX = []string{"b"}
|
expectedPolicy.MX = []string{"b"}
|
||||||
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
|
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
|
||||||
|
|
||||||
policy, err = c.Get("example.org")
|
policy, err = c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -190,7 +191,7 @@ func TestCacheGet_IDChange(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
policy, err := c.Get("example.org")
|
policy, err := c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -206,7 +207,7 @@ func TestCacheGet_IDChange(t *testing.T) {
|
||||||
expectedPolicy.MX = []string{"b"}
|
expectedPolicy.MX = []string{"b"}
|
||||||
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
|
c.downloadPolicy = mockDownloadPolicy(expectedPolicy, nil)
|
||||||
|
|
||||||
policy, err = c.Get("example.org")
|
policy, err = c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -236,7 +237,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
policy, err := c.Get("example.org")
|
policy, err := c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -251,7 +252,7 @@ func TestCacheGet_DNSDisappear(t *testing.T) {
|
||||||
resolver.Zones = nil
|
resolver.Zones = nil
|
||||||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||||
|
|
||||||
policy, err = c.Get("example.org")
|
policy, err = c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -281,7 +282,7 @@ func TestCacheGet_HTTPGet_ErrNoPolicy(t *testing.T) {
|
||||||
// >(for any reason), and there is no valid (non-expired) previously
|
// >(for any reason), and there is no valid (non-expired) previously
|
||||||
// >cached policy, senders MUST continue with delivery as though the
|
// >cached policy, senders MUST continue with delivery as though the
|
||||||
// >domain has not implemented MTA-STS.
|
// >domain has not implemented MTA-STS.
|
||||||
_, err := c.Get("example.org")
|
_, err := c.Get(context.Background(), "example.org")
|
||||||
if err != ErrIgnorePolicy {
|
if err != ErrIgnorePolicy {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -308,7 +309,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
policy, err := c.Get("example.org")
|
policy, err := c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -325,7 +326,7 @@ func TestCacheGet_IDChange_Error(t *testing.T) {
|
||||||
}
|
}
|
||||||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||||
|
|
||||||
policy, err = c.Get("example.org")
|
policy, err = c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -357,7 +358,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
policy, err := c.Get("example.org")
|
policy, err := c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -376,7 +377,7 @@ func TestCacheGet_IDChange_Expired_Error(t *testing.T) {
|
||||||
}
|
}
|
||||||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||||
|
|
||||||
policy, err = c.Get("example.org")
|
policy, err = c.Get(context.Background(), "example.org")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected error, got policy %v", policy)
|
t.Fatalf("expected error, got policy %v", policy)
|
||||||
}
|
}
|
||||||
|
@ -405,7 +406,7 @@ func TestCacheRefresh(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
policy, err := c.Get("example.org")
|
policy, err := c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -427,7 +428,7 @@ func TestCacheRefresh(t *testing.T) {
|
||||||
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
c.downloadPolicy = mockDownloadPolicy(nil, errors.New("broken"))
|
||||||
|
|
||||||
// It should return the new record from cache.
|
// It should return the new record from cache.
|
||||||
policy, err = c.Get("example.org")
|
policy, err = c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -459,7 +460,7 @@ func TestCacheRefresh_Error(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(c.Location)
|
defer os.RemoveAll(c.Location)
|
||||||
|
|
||||||
policy, err := c.Get("example.org")
|
policy, err := c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -477,7 +478,7 @@ func TestCacheRefresh_Error(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// It should return the old record from cache.
|
// It should return the old record from cache.
|
||||||
policy, err = c.Get("example.org")
|
policy, err = c.Get(context.Background(), "example.org")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("policy get: %v", err)
|
t.Fatalf("policy get: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,9 +10,11 @@
|
||||||
package smtpconn
|
package smtpconn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime/trace"
|
||||||
|
|
||||||
"github.com/emersion/go-message/textproto"
|
"github.com/emersion/go-message/textproto"
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
|
@ -27,9 +29,9 @@ import (
|
||||||
//
|
//
|
||||||
// Currently, the C object represents one session and cannot be reused.
|
// Currently, the C object represents one session and cannot be reused.
|
||||||
type C struct {
|
type C struct {
|
||||||
// Dialer to use to estabilish new network connections. Set to net.Dial by
|
// Dialer to use to estabilish new network connections. Set to net.Dialer
|
||||||
// New.
|
// DialContext by New.
|
||||||
Dialer func(network, addr string) (net.Conn, error)
|
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
// Fail if the connection cannot use TLS.
|
// Fail if the connection cannot use TLS.
|
||||||
RequireTLS bool
|
RequireTLS bool
|
||||||
|
@ -60,7 +62,7 @@ type C struct {
|
||||||
// with resonable default values.
|
// with resonable default values.
|
||||||
func New() *C {
|
func New() *C {
|
||||||
return &C{
|
return &C{
|
||||||
Dialer: net.Dial,
|
Dialer: (&net.Dialer{}).DialContext,
|
||||||
TLSConfig: &tls.Config{},
|
TLSConfig: &tls.Config{},
|
||||||
Hostname: "localhost.localdomain",
|
Hostname: "localhost.localdomain",
|
||||||
}
|
}
|
||||||
|
@ -122,9 +124,11 @@ func (c *C) wrapClientErr(err error, serverName string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect actually estabilishes the network connection with the remote host.
|
// Connect actually estabilishes the network connection with the remote host.
|
||||||
func (c *C) Connect(endp config.Endpoint) error {
|
func (c *C) Connect(ctx context.Context, endp config.Endpoint) error {
|
||||||
|
defer trace.StartRegion(ctx, "smtpconn/Connect+TLS").End()
|
||||||
|
|
||||||
// TODO: Helper function to try multiple endpoints?
|
// TODO: Helper function to try multiple endpoints?
|
||||||
cl, err := c.attemptConnect(endp, c.AttemptTLS)
|
cl, err := c.attemptConnect(ctx, endp, c.AttemptTLS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.wrapClientErr(err, endp.Host)
|
return c.wrapClientErr(err, endp.Host)
|
||||||
}
|
}
|
||||||
|
@ -134,9 +138,9 @@ func (c *C) Connect(endp config.Endpoint) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) {
|
func (c *C) attemptConnect(ctx context.Context, endp config.Endpoint, attemptTLS bool) (*smtp.Client, error) {
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
conn, err := c.Dialer(endp.Network(), endp.Address())
|
conn, err := c.Dialer(ctx, endp.Network(), endp.Address())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -197,7 +201,7 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client,
|
||||||
// Re-attempt without STARTTLS. It is not possible to reuse connection
|
// Re-attempt without STARTTLS. It is not possible to reuse connection
|
||||||
// since it is probably in a bad state.
|
// since it is probably in a bad state.
|
||||||
c.Log.Error("TLS error, falling back to plain-text connection", err, "remote_server", endp.Host+endp.Port)
|
c.Log.Error("TLS error, falling back to plain-text connection", err, "remote_server", endp.Host+endp.Port)
|
||||||
return c.attemptConnect(endp, false)
|
return c.attemptConnect(ctx, endp, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cl, nil
|
return cl, nil
|
||||||
|
@ -209,7 +213,9 @@ func (c *C) attemptConnect(endp config.Endpoint, attemptTLS bool) (*smtp.Client,
|
||||||
// SMTPUTF8 is forwarded if supported by the remote server, if it is not
|
// SMTPUTF8 is forwarded if supported by the remote server, if it is not
|
||||||
// supported - attempt will be done to convert addresses to the ASCII form, if
|
// supported - attempt will be done to convert addresses to the ASCII form, if
|
||||||
// this is not possible, the corresponding method (Mail or Rcpt) will fail.
|
// this is not possible, the corresponding method (Mail or Rcpt) will fail.
|
||||||
func (c *C) Mail(from string, opts smtp.MailOptions) error {
|
func (c *C) Mail(ctx context.Context, from string, opts smtp.MailOptions) error {
|
||||||
|
defer trace.StartRegion(ctx, "smtpconn/MAIL FROM").End()
|
||||||
|
|
||||||
outOpts := smtp.MailOptions{
|
outOpts := smtp.MailOptions{
|
||||||
// Future extensions may add additional fields that should not be
|
// Future extensions may add additional fields that should not be
|
||||||
// copied blindly. So we copy only fields we know should be handled
|
// copied blindly. So we copy only fields we know should be handled
|
||||||
|
@ -268,7 +274,9 @@ func (c *C) Client() *smtp.Client {
|
||||||
//
|
//
|
||||||
// If the address is non-ASCII and cannot be converted to ASCII and the remote
|
// If the address is non-ASCII and cannot be converted to ASCII and the remote
|
||||||
// server does not support SMTPUTF8, error will be returned.
|
// server does not support SMTPUTF8, error will be returned.
|
||||||
func (c *C) Rcpt(to string) error {
|
func (c *C) Rcpt(ctx context.Context, to string) error {
|
||||||
|
defer trace.StartRegion(ctx, "smtpconn/RCPT TO").End()
|
||||||
|
|
||||||
// If necessary, the extension flag is enabled in Start.
|
// If necessary, the extension flag is enabled in Start.
|
||||||
if ok, _ := c.cl.Extension("SMTPUTF8"); !address.IsASCII(to) && !ok {
|
if ok, _ := c.cl.Extension("SMTPUTF8"); !address.IsASCII(to) && !ok {
|
||||||
var err error
|
var err error
|
||||||
|
@ -300,7 +308,9 @@ func (c *C) Rcpt(to string) error {
|
||||||
//
|
//
|
||||||
// If the Data command fails, the connection may be in a unclean state (e.g. in
|
// If the Data command fails, the connection may be in a unclean state (e.g. in
|
||||||
// the middle of message data stream). It is not safe to continue using it.
|
// the middle of message data stream). It is not safe to continue using it.
|
||||||
func (c *C) Data(hdr textproto.Header, body io.Reader) error {
|
func (c *C) Data(ctx context.Context, hdr textproto.Header, body io.Reader) error {
|
||||||
|
defer trace.StartRegion(ctx, "smtpconn/DATA").End()
|
||||||
|
|
||||||
wc, err := c.cl.Data()
|
wc, err := c.cl.Data()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.wrapClientErr(err, c.serverName)
|
return c.wrapClientErr(err, c.serverName)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package smtpconn
|
package smtpconn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -14,11 +15,11 @@ import (
|
||||||
func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.MailOptions) error {
|
func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.MailOptions) error {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
if err := conn.Mail(from, opts); err != nil {
|
if err := conn.Mail(context.Background(), from, opts); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, rcpt := range to {
|
for _, rcpt := range to {
|
||||||
if err := conn.Rcpt(rcpt); err != nil {
|
if err := conn.Rcpt(context.Background(), rcpt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,7 +27,7 @@ func doTestDelivery(t *testing.T, conn *C, from string, to []string, opts smtp.M
|
||||||
hdr := textproto.Header{}
|
hdr := textproto.Header{}
|
||||||
hdr.Add("B", "2")
|
hdr.Add("B", "2")
|
||||||
hdr.Add("A", "1")
|
hdr.Add("A", "1")
|
||||||
if err := conn.Data(hdr, strings.NewReader("foobar\n")); err != nil {
|
if err := conn.Data(context.Background(), hdr, strings.NewReader("foobar\n")); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +42,7 @@ func TestSMTPUTF8_Sender_UTF8_Punycode(t *testing.T) {
|
||||||
|
|
||||||
c := New()
|
c := New()
|
||||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||||
if err := c.Connect(config.Endpoint{
|
if err := c.Connect(context.Background(), config.Endpoint{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: testPort,
|
Port: testPort,
|
||||||
|
@ -70,7 +71,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Punycode(t *testing.T) {
|
||||||
|
|
||||||
c := New()
|
c := New()
|
||||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||||
if err := c.Connect(config.Endpoint{
|
if err := c.Connect(context.Background(), config.Endpoint{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: testPort,
|
Port: testPort,
|
||||||
|
@ -99,7 +100,7 @@ func TestSMTPUTF8_Sender_UTF8_Reject(t *testing.T) {
|
||||||
|
|
||||||
c := New()
|
c := New()
|
||||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||||
if err := c.Connect(config.Endpoint{
|
if err := c.Connect(context.Background(), config.Endpoint{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: testPort,
|
Port: testPort,
|
||||||
|
@ -122,7 +123,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Reject(t *testing.T) {
|
||||||
|
|
||||||
c := New()
|
c := New()
|
||||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||||
if err := c.Connect(config.Endpoint{
|
if err := c.Connect(context.Background(), config.Endpoint{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: testPort,
|
Port: testPort,
|
||||||
|
@ -144,7 +145,7 @@ func TestSMTPUTF8_Sender_UTF8_Domain(t *testing.T) {
|
||||||
|
|
||||||
c := New()
|
c := New()
|
||||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||||
if err := c.Connect(config.Endpoint{
|
if err := c.Connect(context.Background(), config.Endpoint{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: testPort,
|
Port: testPort,
|
||||||
|
@ -172,7 +173,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Domain(t *testing.T) {
|
||||||
|
|
||||||
c := New()
|
c := New()
|
||||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||||
if err := c.Connect(config.Endpoint{
|
if err := c.Connect(context.Background(), config.Endpoint{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: testPort,
|
Port: testPort,
|
||||||
|
@ -201,7 +202,7 @@ func TestSMTPUTF8_Sender_UTF8_Username(t *testing.T) {
|
||||||
|
|
||||||
c := New()
|
c := New()
|
||||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||||
if err := c.Connect(config.Endpoint{
|
if err := c.Connect(context.Background(), config.Endpoint{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: testPort,
|
Port: testPort,
|
||||||
|
@ -230,7 +231,7 @@ func TestSMTPUTF8_Rcpt_UTF8_Username(t *testing.T) {
|
||||||
|
|
||||||
c := New()
|
c := New()
|
||||||
c.Log = testutils.Logger(t, "smtp_downstream")
|
c.Log = testutils.Logger(t, "smtp_downstream")
|
||||||
if err := c.Connect(config.Endpoint{
|
if err := c.Connect(context.Background(), config.Endpoint{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: testPort,
|
Port: testPort,
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime/trace"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -63,7 +64,13 @@ type delivery struct {
|
||||||
addedRcpts map[string]struct{}
|
addedRcpts map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *delivery) String() string {
|
||||||
|
return d.store.Name() + ":" + d.store.InstanceName()
|
||||||
|
}
|
||||||
|
|
||||||
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
||||||
|
defer trace.StartRegion(ctx, "sql/AddRcpt").End()
|
||||||
|
|
||||||
accountName, err := prepareUsername(rcptTo)
|
accountName, err := prepareUsername(rcptTo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &exterrors.SMTPError{
|
return &exterrors.SMTPError{
|
||||||
|
@ -104,6 +111,8 @@ func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
|
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
|
||||||
|
defer trace.StartRegion(ctx, "sql/Body").End()
|
||||||
|
|
||||||
if d.msgMeta.Quarantine {
|
if d.msgMeta.Quarantine {
|
||||||
if err := d.d.SpecialMailbox(specialuse.Junk, d.store.junkMbox); err != nil {
|
if err := d.d.SpecialMailbox(specialuse.Junk, d.store.junkMbox); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -116,14 +125,20 @@ func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffe
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *delivery) Abort(ctx context.Context) error {
|
func (d *delivery) Abort(ctx context.Context) error {
|
||||||
|
defer trace.StartRegion(ctx, "sql/Abort").End()
|
||||||
|
|
||||||
return d.d.Abort()
|
return d.d.Abort()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *delivery) Commit(ctx context.Context) error {
|
func (d *delivery) Commit(ctx context.Context) error {
|
||||||
|
defer trace.StartRegion(ctx, "sql/Commit").End()
|
||||||
|
|
||||||
return d.d.Commit()
|
return d.d.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *Storage) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
func (store *Storage) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||||
|
defer trace.StartRegion(ctx, "sql/Start").End()
|
||||||
|
|
||||||
return &delivery{
|
return &delivery{
|
||||||
store: store,
|
store: store,
|
||||||
msgMeta: msgMeta,
|
msgMeta: msgMeta,
|
||||||
|
@ -360,6 +375,9 @@ func prepareUsername(username string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *Storage) CheckPlain(username, password string) bool {
|
func (store *Storage) CheckPlain(username, password string) bool {
|
||||||
|
// TODO: Pass session context there.
|
||||||
|
defer trace.StartRegion(context.Background(), "sql/CheckPlain").End()
|
||||||
|
|
||||||
accountName, err := prepareUsername(username)
|
accountName, err := prepareUsername(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -50,6 +50,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"runtime/trace"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -438,10 +439,12 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
||||||
msgMeta.ID = msgMeta.ID + "-" + strconv.Itoa(meta.TriesCount+1)
|
msgMeta.ID = msgMeta.ID + "-" + strconv.Itoa(meta.TriesCount+1)
|
||||||
dl.Debugf("using message ID = %s", msgMeta.ID)
|
dl.Debugf("using message ID = %s", msgMeta.ID)
|
||||||
|
|
||||||
// TODO: Trace annotations?
|
msgCtx, msgTask := trace.NewTask(context.Background(), "Queue delivery")
|
||||||
msgCtx := context.Background()
|
defer msgTask.End()
|
||||||
|
|
||||||
delivery, err := q.Target.Start(msgCtx, msgMeta, meta.From)
|
mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM")
|
||||||
|
delivery, err := q.Target.Start(mailCtx, msgMeta, meta.From)
|
||||||
|
mailTask.End()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dl.Debugf("target.Start failed: %v", err)
|
dl.Debugf("target.Start failed: %v", err)
|
||||||
perr.Failed = append(perr.Failed, meta.To...)
|
perr.Failed = append(perr.Failed, meta.To...)
|
||||||
|
@ -454,7 +457,8 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
||||||
|
|
||||||
var acceptedRcpts []string
|
var acceptedRcpts []string
|
||||||
for _, rcpt := range meta.To {
|
for _, rcpt := range meta.To {
|
||||||
if err := delivery.AddRcpt(msgCtx, rcpt); err != nil {
|
rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO")
|
||||||
|
if err := delivery.AddRcpt(rcptCtx, rcpt); err != nil {
|
||||||
if exterrors.IsTemporaryOrUnspec(err) {
|
if exterrors.IsTemporaryOrUnspec(err) {
|
||||||
perr.TemporaryFailed = append(perr.TemporaryFailed, rcpt)
|
perr.TemporaryFailed = append(perr.TemporaryFailed, rcpt)
|
||||||
} else {
|
} else {
|
||||||
|
@ -466,6 +470,7 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
||||||
dl.Debugf("delivery.AddRcpt %s OK", rcpt)
|
dl.Debugf("delivery.AddRcpt %s OK", rcpt)
|
||||||
acceptedRcpts = append(acceptedRcpts, rcpt)
|
acceptedRcpts = append(acceptedRcpts, rcpt)
|
||||||
}
|
}
|
||||||
|
rcptTask.End()
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(acceptedRcpts) == 0 {
|
if len(acceptedRcpts) == 0 {
|
||||||
|
@ -487,12 +492,15 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA")
|
||||||
|
defer bodyTask.End()
|
||||||
|
|
||||||
partDelivery, ok := delivery.(module.PartialDelivery)
|
partDelivery, ok := delivery.(module.PartialDelivery)
|
||||||
if ok {
|
if ok {
|
||||||
dl.Debugf("using delivery.BodyNonAtomic")
|
dl.Debugf("using delivery.BodyNonAtomic")
|
||||||
partDelivery.BodyNonAtomic(msgCtx, &perr, header, body)
|
partDelivery.BodyNonAtomic(bodyCtx, &perr, header, body)
|
||||||
} else {
|
} else {
|
||||||
if err := delivery.Body(msgCtx, header, body); err != nil {
|
if err := delivery.Body(bodyCtx, header, body); err != nil {
|
||||||
dl.Debugf("delivery.Body failed: %v", err)
|
dl.Debugf("delivery.Body failed: %v", err)
|
||||||
expandToPartialErr(err)
|
expandToPartialErr(err)
|
||||||
}
|
}
|
||||||
|
@ -508,13 +516,13 @@ func (q *Queue) deliver(meta *QueueMetadata, header textproto.Header, body buffe
|
||||||
if allFailed {
|
if allFailed {
|
||||||
// No recipients succeeded.
|
// No recipients succeeded.
|
||||||
dl.Debugf("delivery.Abort (all recipients failed)")
|
dl.Debugf("delivery.Abort (all recipients failed)")
|
||||||
if err := delivery.Abort(msgCtx); err != nil {
|
if err := delivery.Abort(bodyCtx); err != nil {
|
||||||
dl.Msg("delivery.Abort failed", err)
|
dl.Msg("delivery.Abort failed", err)
|
||||||
}
|
}
|
||||||
return perr
|
return perr
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := delivery.Commit(msgCtx); err != nil {
|
if err := delivery.Commit(bodyCtx); err != nil {
|
||||||
dl.Debugf("delivery.Commit failed: %v", err)
|
dl.Debugf("delivery.Commit failed: %v", err)
|
||||||
expandToPartialErr(err)
|
expandToPartialErr(err)
|
||||||
}
|
}
|
||||||
|
@ -537,6 +545,8 @@ func (qd *queueDelivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
|
func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
|
||||||
|
defer trace.StartRegion(ctx, "queue/Body").End()
|
||||||
|
|
||||||
// Body buffer initially passed to us may not be valid after "delivery" to queue completes.
|
// Body buffer initially passed to us may not be valid after "delivery" to queue completes.
|
||||||
// storeNewMessage returns a new buffer object created from message blob stored on disk.
|
// storeNewMessage returns a new buffer object created from message blob stored on disk.
|
||||||
storedBody, err := qd.q.storeNewMessage(qd.meta, header, body)
|
storedBody, err := qd.q.storeNewMessage(qd.meta, header, body)
|
||||||
|
@ -550,6 +560,8 @@ func (qd *queueDelivery) Body(ctx context.Context, header textproto.Header, body
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qd *queueDelivery) Abort(ctx context.Context) error {
|
func (qd *queueDelivery) Abort(ctx context.Context) error {
|
||||||
|
defer trace.StartRegion(ctx, "queue/Abort").End()
|
||||||
|
|
||||||
if qd.body != nil {
|
if qd.body != nil {
|
||||||
qd.q.removeFromDisk(qd.meta.MsgMeta)
|
qd.q.removeFromDisk(qd.meta.MsgMeta)
|
||||||
}
|
}
|
||||||
|
@ -557,6 +569,8 @@ func (qd *queueDelivery) Abort(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qd *queueDelivery) Commit(ctx context.Context) error {
|
func (qd *queueDelivery) Commit(ctx context.Context) error {
|
||||||
|
defer trace.StartRegion(ctx, "queue/Commit").End()
|
||||||
|
|
||||||
if qd.meta == nil {
|
if qd.meta == nil {
|
||||||
panic("queue: double Commit")
|
panic("queue: double Commit")
|
||||||
}
|
}
|
||||||
|
@ -899,14 +913,17 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
|
||||||
}
|
}
|
||||||
dl.Msg("generated failed DSN", "dsn_id", dsnID)
|
dl.Msg("generated failed DSN", "dsn_id", dsnID)
|
||||||
|
|
||||||
// TODO: Trace annotations?
|
msgCtx, msgTask := trace.NewTask(context.Background(), "DSN Delivery")
|
||||||
msgCtx := context.Background()
|
defer msgTask.End()
|
||||||
|
|
||||||
dsnDelivery, err := q.dsnPipeline.Start(msgCtx, dsnMeta, "")
|
mailCtx, mailTask := trace.NewTask(msgCtx, "MAIL FROM")
|
||||||
|
dsnDelivery, err := q.dsnPipeline.Start(mailCtx, dsnMeta, "")
|
||||||
|
mailTask.End()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
|
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
|
dl.Error("failed to enqueue DSN", err, "dsn_id", dsnID)
|
||||||
|
@ -916,15 +933,23 @@ func (q *Queue) emitDSN(meta *QueueMetadata, header textproto.Header) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = dsnDelivery.AddRcpt(msgCtx, meta.From); err != nil {
|
rcptCtx, rcptTask := trace.NewTask(msgCtx, "RCPT TO")
|
||||||
|
if err = dsnDelivery.AddRcpt(rcptCtx, meta.From); err != nil {
|
||||||
|
rcptTask.End()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = dsnDelivery.Body(msgCtx, dsnHeader, dsnBody); err != nil {
|
rcptTask.End()
|
||||||
|
|
||||||
|
bodyCtx, bodyTask := trace.NewTask(msgCtx, "DATA")
|
||||||
|
if err = dsnDelivery.Body(bodyCtx, dsnHeader, dsnBody); err != nil {
|
||||||
|
bodyTask.End()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = dsnDelivery.Commit(msgCtx); err != nil {
|
if err = dsnDelivery.Commit(bodyCtx); err != nil {
|
||||||
|
bodyTask.End()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
bodyTask.End()
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package remote
|
package remote
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -33,7 +34,7 @@ func TestRemoteDelivery_AuthMX_Fail(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
|
@ -63,12 +64,12 @@ func TestRemoteDelivery_AuthMX_MTASTS(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
tlsConfig: clientCfg,
|
tlsConfig: clientCfg,
|
||||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||||
if domain != "example.invalid" {
|
if domain != "example.invalid" {
|
||||||
return nil, errors.New("Wrong domain in lookup")
|
return nil, errors.New("Wrong domain in lookup")
|
||||||
}
|
}
|
||||||
|
@ -114,12 +115,12 @@ func TestRemoteDelivery_AuthMX_PreferAuth(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
tlsConfig: clientCfg,
|
tlsConfig: clientCfg,
|
||||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||||
if domain != "example.invalid" {
|
if domain != "example.invalid" {
|
||||||
return nil, errors.New("Wrong domain in lookup")
|
return nil, errors.New("Wrong domain in lookup")
|
||||||
}
|
}
|
||||||
|
@ -166,12 +167,12 @@ func TestRemoteDelivery_MTASTS_SkipNonMatching(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
tlsConfig: clientCfg,
|
tlsConfig: clientCfg,
|
||||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||||
if domain != "example.invalid" {
|
if domain != "example.invalid" {
|
||||||
return nil, errors.New("Wrong domain in lookup")
|
return nil, errors.New("Wrong domain in lookup")
|
||||||
}
|
}
|
||||||
|
@ -208,11 +209,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_Fail(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||||
if domain != "example.invalid" {
|
if domain != "example.invalid" {
|
||||||
return nil, errors.New("Wrong domain in lookup")
|
return nil, errors.New("Wrong domain in lookup")
|
||||||
}
|
}
|
||||||
|
@ -251,11 +252,11 @@ func TestRemoteDelivery_AuthMX_MTASTS_NoPolicy(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
mxAuth: map[string]struct{}{AuthMTASTS: {}},
|
||||||
mtastsGet: func(domain string) (*mtasts.Policy, error) {
|
mtastsGet: func(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||||
if domain != "example.invalid" {
|
if domain != "example.invalid" {
|
||||||
return nil, errors.New("Wrong domain in lookup")
|
return nil, errors.New("Wrong domain in lookup")
|
||||||
}
|
}
|
||||||
|
@ -290,7 +291,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
|
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
|
||||||
|
@ -321,7 +322,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_Fail(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
|
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
|
||||||
|
@ -354,7 +355,7 @@ func TestRemoteDelivery_AuthMX_CommonDomain_NotETLDp1(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
|
mxAuth: map[string]struct{}{AuthCommonDomain: {}},
|
||||||
|
@ -405,7 +406,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: extResolver,
|
extResolver: extResolver,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
||||||
|
@ -452,7 +453,7 @@ func TestRemoteDelivery_AuthMX_DNSSEC_Fail(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: extResolver,
|
extResolver: extResolver,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
||||||
|
@ -506,7 +507,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &resolver,
|
resolver: &resolver,
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: extResolver,
|
extResolver: extResolver,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
||||||
|
@ -556,7 +557,7 @@ func TestRemoteDelivery_MXAuth_IPLiteral_Fail(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &resolver,
|
resolver: &resolver,
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: extResolver,
|
extResolver: extResolver,
|
||||||
requireMXAuth: true,
|
requireMXAuth: true,
|
||||||
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
mxAuth: map[string]struct{}{AuthDNSSEC: {}},
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime/trace"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -57,12 +58,12 @@ type Target struct {
|
||||||
mxAuth map[string]struct{}
|
mxAuth map[string]struct{}
|
||||||
|
|
||||||
resolver dns.Resolver
|
resolver dns.Resolver
|
||||||
dialer func(network, addr string) (net.Conn, error)
|
dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
extResolver *dns.ExtResolver
|
extResolver *dns.ExtResolver
|
||||||
|
|
||||||
// This is the callback that is usually mtastsCache.Get,
|
// This is the callback that is usually mtastsCache.Get,
|
||||||
// but replaced by tests to mock mtasts.Cache.
|
// but replaced by tests to mock mtasts.Cache.
|
||||||
mtastsGet func(domain string) (*mtasts.Policy, error)
|
mtastsGet func(ctx context.Context, domain string) (*mtasts.Policy, error)
|
||||||
|
|
||||||
mtastsCache mtasts.Cache
|
mtastsCache mtasts.Cache
|
||||||
stsCacheUpdateTick *time.Ticker
|
stsCacheUpdateTick *time.Ticker
|
||||||
|
@ -80,7 +81,7 @@ func New(_, instName string, _, inlineArgs []string) (module.Module, error) {
|
||||||
return &Target{
|
return &Target{
|
||||||
name: instName,
|
name: instName,
|
||||||
resolver: dns.DefaultResolver(),
|
resolver: dns.DefaultResolver(),
|
||||||
dialer: net.Dial,
|
dialer: (&net.Dialer{}).DialContext,
|
||||||
mtastsCache: mtasts.Cache{Resolver: dns.DefaultResolver()},
|
mtastsCache: mtasts.Cache{Resolver: dns.DefaultResolver()},
|
||||||
Log: log.Logger{Name: "remote"},
|
Log: log.Logger{Name: "remote"},
|
||||||
|
|
||||||
|
@ -195,6 +196,8 @@ func (rt *Target) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error {
|
func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error {
|
||||||
|
defer trace.StartRegion(ctx, "remote/AddRcpt").End()
|
||||||
|
|
||||||
if rd.msgMeta.Quarantine {
|
if rd.msgMeta.Quarantine {
|
||||||
return &exterrors.SMTPError{
|
return &exterrors.SMTPError{
|
||||||
Code: 550,
|
Code: 550,
|
||||||
|
@ -232,7 +235,7 @@ func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.Rcpt(to); err != nil {
|
if err := conn.Rcpt(ctx, to); err != nil {
|
||||||
return moduleError(err)
|
return moduleError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,6 +382,8 @@ func (m *multipleErrs) SetStatus(rcptTo string, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buffer buffer.Buffer) error {
|
func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buffer buffer.Buffer) error {
|
||||||
|
defer trace.StartRegion(ctx, "remote/Body").End()
|
||||||
|
|
||||||
merr := multipleErrs{
|
merr := multipleErrs{
|
||||||
errs: make(map[string]error),
|
errs: make(map[string]error),
|
||||||
}
|
}
|
||||||
|
@ -396,6 +401,8 @@ func (rd *remoteDelivery) Body(ctx context.Context, header textproto.Header, buf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusCollector, header textproto.Header, b buffer.Buffer) {
|
func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusCollector, header textproto.Header, b buffer.Buffer) {
|
||||||
|
defer trace.StartRegion(ctx, "remote/BodyNonAtomic").End()
|
||||||
|
|
||||||
if rd.msgMeta.Quarantine {
|
if rd.msgMeta.Quarantine {
|
||||||
for _, rcpt := range rd.recipients {
|
for _, rcpt := range rd.recipients {
|
||||||
c.SetStatus(rcpt, &exterrors.SMTPError{
|
c.SetStatus(rcpt, &exterrors.SMTPError{
|
||||||
|
@ -425,7 +432,7 @@ func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusColl
|
||||||
}
|
}
|
||||||
defer bodyR.Close()
|
defer bodyR.Close()
|
||||||
|
|
||||||
err = conn.Data(header, bodyR)
|
err = conn.Data(ctx, header, bodyR)
|
||||||
for _, rcpt := range conn.Rcpts() {
|
for _, rcpt := range conn.Rcpts() {
|
||||||
c.SetStatus(rcpt, err)
|
c.SetStatus(rcpt, err)
|
||||||
}
|
}
|
||||||
|
@ -486,7 +493,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
|
||||||
addrs = append(addrs, nonAuthMXs...)
|
addrs = append(addrs, nonAuthMXs...)
|
||||||
rd.Log.DebugMsg("considering", "mxs", addrs)
|
rd.Log.DebugMsg("considering", "mxs", addrs)
|
||||||
for i, addr := range addrs {
|
for i, addr := range addrs {
|
||||||
err = conn.Connect(config.Endpoint{
|
err = conn.Connect(ctx, config.Endpoint{
|
||||||
Scheme: "tcp",
|
Scheme: "tcp",
|
||||||
Host: addr,
|
Host: addr,
|
||||||
Port: smtpPort,
|
Port: smtpPort,
|
||||||
|
@ -517,7 +524,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
|
||||||
|
|
||||||
rd.Log.DebugMsg("connected", "remote_server", conn.ServerName())
|
rd.Log.DebugMsg("connected", "remote_server", conn.ServerName())
|
||||||
|
|
||||||
if err := conn.Mail(rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil {
|
if err := conn.Mail(ctx, rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, moduleError(err)
|
return nil, moduleError(err)
|
||||||
}
|
}
|
||||||
|
@ -526,8 +533,8 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rt *Target) getSTSPolicy(domain string) (*mtasts.Policy, error) {
|
func (rt *Target) getSTSPolicy(ctx context.Context, domain string) (*mtasts.Policy, error) {
|
||||||
stsPolicy, err := rt.mtastsGet(domain)
|
stsPolicy, err := rt.mtastsGet(ctx, domain)
|
||||||
if err != nil && !mtasts.IsNoPolicy(err) {
|
if err != nil && !mtasts.IsNoPolicy(err) {
|
||||||
return nil, &exterrors.SMTPError{
|
return nil, &exterrors.SMTPError{
|
||||||
Code: exterrors.SMTPCode(err, 450, 554),
|
Code: exterrors.SMTPCode(err, 450, 554),
|
||||||
|
@ -568,9 +575,11 @@ func (rt *Target) stsCacheUpdater() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rd *remoteDelivery) lookupAndFilter(ctx context.Context, domain string) (authMXs, nonAuthMXs []string, requireTLS bool, err error) {
|
func (rd *remoteDelivery) lookupAndFilter(ctx context.Context, domain string) (authMXs, nonAuthMXs []string, requireTLS bool, err error) {
|
||||||
|
defer trace.StartRegion(ctx, "remote/LookupAndFilterMX").End()
|
||||||
|
|
||||||
var policy *mtasts.Policy
|
var policy *mtasts.Policy
|
||||||
if _, use := rd.rt.mxAuth[AuthMTASTS]; use {
|
if _, use := rd.rt.mxAuth[AuthMTASTS]; use {
|
||||||
policy, err = rd.rt.getSTSPolicy(domain)
|
policy, err = rd.rt.getSTSPolicy(ctx, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, false, err
|
return nil, nil, false, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,7 @@ func TestRemoteDelivery(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -74,7 +74,7 @@ func TestRemoteDelivery_IPLiteral(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &resolver,
|
resolver: &resolver,
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ func TestRemoteDelivery_FallbackMX(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ func TestRemoteDelivery_BodyNonAtomic(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -162,7 +162,7 @@ func TestRemoteDelivery_Abort(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -199,7 +199,7 @@ func TestRemoteDelivery_CommitWithoutBody(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -243,7 +243,7 @@ func TestRemoteDelivery_MAILFROMErr(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -276,7 +276,7 @@ func TestRemoteDelivery_NoMX(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -313,7 +313,7 @@ func TestRemoteDelivery_NullMX(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -349,7 +349,7 @@ func TestRemoteDelivery_Quarantined(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -404,7 +404,7 @@ func TestRemoteDelivery_MAILFROMErr_Repeated(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -451,7 +451,7 @@ func TestRemoteDelivery_RcptErr(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -509,7 +509,7 @@ func TestRemoteDelivery_DownMX(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -539,7 +539,7 @@ func TestRemoteDelivery_AllMXDown(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -577,7 +577,7 @@ func TestRemoteDelivery_Split(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -623,7 +623,7 @@ func TestRemoteDelivery_Split_Fail(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -683,7 +683,7 @@ func TestRemoteDelivery_BodyErr(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -744,7 +744,7 @@ func TestRemoteDelivery_Split_BodyErr(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -807,7 +807,7 @@ func TestRemoteDelivery_Split_BodyErr_NonAtomic(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
}
|
}
|
||||||
|
@ -867,7 +867,7 @@ func TestRemoteDelivery_TLSErrFallback(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
tlsConfig: &tls.Config{},
|
tlsConfig: &tls.Config{},
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
|
@ -895,7 +895,7 @@ func TestRemoteDelivery_RequireTLS_Missing(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireTLS: true,
|
requireTLS: true,
|
||||||
Log: testutils.Logger(t, "remote"),
|
Log: testutils.Logger(t, "remote"),
|
||||||
|
@ -925,7 +925,7 @@ func TestRemoteDelivery_RequireTLS_Present(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
requireTLS: true,
|
requireTLS: true,
|
||||||
tlsConfig: clientCfg,
|
tlsConfig: clientCfg,
|
||||||
|
@ -954,7 +954,7 @@ func TestRemoteDelivery_RequireTLS_NoErrFallback(t *testing.T) {
|
||||||
name: "remote",
|
name: "remote",
|
||||||
hostname: "mx.example.com",
|
hostname: "mx.example.com",
|
||||||
resolver: &mockdns.Resolver{Zones: zones},
|
resolver: &mockdns.Resolver{Zones: zones},
|
||||||
dialer: resolver.Dial,
|
dialer: resolver.DialContext,
|
||||||
extResolver: nil,
|
extResolver: nil,
|
||||||
tlsConfig: &tls.Config{},
|
tlsConfig: &tls.Config{},
|
||||||
requireTLS: true,
|
requireTLS: true,
|
||||||
|
|
|
@ -44,7 +44,7 @@ func TestRemoteDelivery_EHLO_ALabel(t *testing.T) {
|
||||||
|
|
||||||
tgt := mod.(*Target)
|
tgt := mod.(*Target)
|
||||||
tgt.resolver = &mockdns.Resolver{Zones: zones}
|
tgt.resolver = &mockdns.Resolver{Zones: zones}
|
||||||
tgt.dialer = resolver.Dial
|
tgt.dialer = resolver.DialContext
|
||||||
tgt.extResolver = nil
|
tgt.extResolver = nil
|
||||||
tgt.Log = testutils.Logger(t, "remote")
|
tgt.Log = testutils.Logger(t, "remote")
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime/trace"
|
||||||
|
|
||||||
"github.com/emersion/go-message/textproto"
|
"github.com/emersion/go-message/textproto"
|
||||||
"github.com/foxcpp/maddy/internal/buffer"
|
"github.com/foxcpp/maddy/internal/buffer"
|
||||||
|
@ -126,24 +127,26 @@ type delivery struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Downstream) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
func (u *Downstream) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFrom string) (module.Delivery, error) {
|
||||||
|
defer trace.StartRegion(ctx, "smtp_downstream/Start").End()
|
||||||
|
|
||||||
d := &delivery{
|
d := &delivery{
|
||||||
u: u,
|
u: u,
|
||||||
log: target.DeliveryLogger(u.log, msgMeta),
|
log: target.DeliveryLogger(u.log, msgMeta),
|
||||||
msgMeta: msgMeta,
|
msgMeta: msgMeta,
|
||||||
mailFrom: mailFrom,
|
mailFrom: mailFrom,
|
||||||
}
|
}
|
||||||
if err := d.connect(); err != nil {
|
if err := d.connect(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.conn.Mail(mailFrom, msgMeta.SMTPOpts); err != nil {
|
if err := d.conn.Mail(ctx, mailFrom, msgMeta.SMTPOpts); err != nil {
|
||||||
d.conn.Close()
|
d.conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *delivery) connect() error {
|
func (d *delivery) connect(ctx context.Context) error {
|
||||||
// TODO: Review possibility of connection pooling here.
|
// TODO: Review possibility of connection pooling here.
|
||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
|
@ -156,7 +159,7 @@ func (d *delivery) connect() error {
|
||||||
conn.AddrInSMTPMsg = false
|
conn.AddrInSMTPMsg = false
|
||||||
|
|
||||||
for _, endp := range d.u.endpoints {
|
for _, endp := range d.u.endpoints {
|
||||||
err := conn.Connect(endp)
|
err := conn.Connect(ctx, endp)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
d.log.DebugMsg("connected", "downstream_server", conn.ServerName())
|
d.log.DebugMsg("connected", "downstream_server", conn.ServerName())
|
||||||
lastErr = nil
|
lastErr = nil
|
||||||
|
@ -191,7 +194,7 @@ func (d *delivery) connect() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
func (d *delivery) AddRcpt(ctx context.Context, rcptTo string) error {
|
||||||
return moduleError(d.conn.Rcpt(rcptTo))
|
return moduleError(d.conn.Rcpt(ctx, rcptTo))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
|
func (d *delivery) Body(ctx context.Context, header textproto.Header, body buffer.Buffer) error {
|
||||||
|
@ -217,7 +220,7 @@ func (d *delivery) Commit(ctx context.Context) error {
|
||||||
defer d.conn.Close()
|
defer d.conn.Close()
|
||||||
defer d.body.Close()
|
defer d.body.Close()
|
||||||
|
|
||||||
return moduleError(d.conn.Data(d.hdr, d.body))
|
return moduleError(d.conn.Data(ctx, d.hdr, d.body))
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue