maddy/internal/endpoint/smtp/smtp.go
2019-12-07 05:42:11 +03:00

679 lines
17 KiB
Go

package smtp
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"math/rand"
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/internal/address"
"github.com/foxcpp/maddy/internal/buffer"
"github.com/foxcpp/maddy/internal/config"
modconfig "github.com/foxcpp/maddy/internal/config/module"
"github.com/foxcpp/maddy/internal/dns"
"github.com/foxcpp/maddy/internal/exterrors"
"github.com/foxcpp/maddy/internal/future"
"github.com/foxcpp/maddy/internal/limiters"
"github.com/foxcpp/maddy/internal/log"
"github.com/foxcpp/maddy/internal/module"
"github.com/foxcpp/maddy/internal/msgpipeline"
"github.com/foxcpp/maddy/internal/target"
"golang.org/x/net/idna"
)
type Session struct {
endp *Endpoint
// Specific for this session.
cancelRDNS func()
connState module.ConnState
repeatedMailErrs int
loggedRcptErrors int
// Specific for the currently handled message.
mailFrom string
opts smtp.MailOptions
msgMeta *module.MsgMetadata
delivery module.Delivery
deliveryErr error
log log.Logger
}
func (s *Session) Reset() {
if s.delivery != nil {
s.abort()
}
s.endp.Log.DebugMsg("reset")
}
func (s *Session) abort() {
s.endp.semaphore.Release()
if err := s.delivery.Abort(); err != nil {
s.endp.Log.Error("delivery abort failed", err)
}
s.log.Msg("aborted", "msg_id", s.msgMeta.ID)
s.mailFrom = ""
s.opts = smtp.MailOptions{}
s.msgMeta = nil
s.delivery = nil
s.deliveryErr = nil
}
func (s *Session) startDelivery(from string, opts smtp.MailOptions) (string, error) {
var err error
msgMeta := &module.MsgMetadata{
Conn: &s.connState,
SMTPOpts: opts,
}
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
// used.
for _, ch := range from {
if ch > 128 && !opts.UTF8 {
return "", &exterrors.SMTPError{
Code: 550,
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
Message: "SMTPUTF8 is required for non-ASCII senders",
}
}
}
// Decode punycode, normalize to NFC and case-fold address.
cleanFrom, err := address.CleanDomain(from)
if err != nil {
return "", &exterrors.SMTPError{
Code: 553,
EnhancedCode: exterrors.EnhancedCode{5, 1, 7},
Message: "Unable to normalize the sender address",
}
}
msgMeta.ID, err = msgpipeline.GenerateMsgID()
if err != nil {
return "", err
}
msgMeta.OriginalFrom = cleanFrom
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.endp.ratelimit.TakeContext(ctx); err != nil {
return "", err
}
if err := s.endp.semaphore.TakeContext(ctx); err != nil {
return "", err
}
if s.connState.AuthUser != "" {
s.log.Msg("incoming message",
"src_host", msgMeta.Conn.Hostname,
"src_ip", msgMeta.Conn.RemoteAddr.String(),
"sender", from,
"msg_id", msgMeta.ID,
"username", s.connState.AuthUser,
)
} else {
s.log.Msg("incoming message",
"src_host", msgMeta.Conn.Hostname,
"src_ip", msgMeta.Conn.RemoteAddr.String(),
"sender", from,
"msg_id", msgMeta.ID,
)
}
delivery, err := s.endp.pipeline.Start(msgMeta, cleanFrom)
if err != nil {
return msgMeta.ID, err
}
s.msgMeta = msgMeta
s.mailFrom = cleanFrom
s.delivery = delivery
return msgMeta.ID, nil
}
func (s *Session) Mail(from string, opts smtp.MailOptions) error {
if !s.endp.deferServerReject {
msgID, err := s.startDelivery(from, opts)
if err != nil {
if err != context.DeadlineExceeded {
s.log.Error("MAIL FROM error", err, "msg_id", msgID)
}
return s.endp.wrapErr(msgID, !opts.UTF8, err)
}
}
// Keep the MAIL FROM argument for deferred startDelivery.
s.mailFrom = from
s.opts = opts
return nil
}
func (s *Session) fetchRDNSName(ctx context.Context) {
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
if !ok {
s.connState.RDNSName.Set(nil)
return
}
name, err := dns.LookupAddr(ctx, s.endp.resolver, tcpAddr.IP)
if err != nil {
reason, misc := exterrors.UnwrapDNSErr(err)
misc["reason"] = reason
s.log.Error("rDNS error", exterrors.WithFields(err, misc), "src_ip", s.connState.RemoteAddr)
s.connState.RDNSName.Set(nil)
return
}
s.connState.RDNSName.Set(name)
}
func (s *Session) Rcpt(to string) error {
// deferServerReject = true and this is the first RCPT TO command.
if s.delivery == nil {
// If we already attempted to initialize the delivery -
// fail again.
if s.deliveryErr != nil {
s.repeatedMailErrs++
// The deliveryErr is already wrapped.
return s.deliveryErr
}
msgID, err := s.startDelivery(s.mailFrom, s.opts)
if err != nil {
if err != context.DeadlineExceeded {
s.log.Error("MAIL FROM error (deferred)", err, "rcpt", to, "msg_id", msgID)
}
s.deliveryErr = s.endp.wrapErr(msgID, !s.opts.UTF8, err)
return s.deliveryErr
}
}
if err := s.rcpt(to); err != nil {
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
s.log.Error("RCPT error", err, "rcpt", to)
s.loggedRcptErrors++
if s.loggedRcptErrors == s.endp.maxLoggedRcptErrors {
s.log.Msg("too many RCPT errors, possible dictonary attack", "src_ip", s.connState.RemoteAddr, "msg_id", s.msgMeta.ID)
}
}
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
}
s.endp.Log.Msg("RCPT ok", "rcpt", to, "msg_id", s.msgMeta.ID)
return nil
}
func (s *Session) rcpt(to string) error {
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
// used.
if !address.IsASCII(to) && !s.opts.UTF8 {
return &exterrors.SMTPError{
Code: 553,
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
Message: "SMTPUTF8 is required for non-ASCII recipients",
}
}
cleanTo, err := address.CleanDomain(to)
if err != nil {
return &exterrors.SMTPError{
Code: 501,
EnhancedCode: exterrors.EnhancedCode{5, 1, 2},
Message: "Unable to normalize the recipient address",
}
}
return s.delivery.AddRcpt(cleanTo)
}
func (s *Session) Logout() error {
if s.delivery != nil {
s.abort()
if s.repeatedMailErrs > s.endp.maxLoggedRcptErrors {
s.log.Msg("MAIL FROM repeated error a lot of times, possible dictonary attack", "count", s.repeatedMailErrs, "src_ip", s.connState.RemoteAddr)
}
}
if s.cancelRDNS != nil {
s.cancelRDNS()
}
return nil
}
func (s *Session) prepareBody(r io.Reader) (textproto.Header, buffer.Buffer, error) {
bufr := bufio.NewReader(r)
header, err := textproto.ReadHeader(bufr)
if err != nil {
return textproto.Header{}, nil, err
}
if s.endp.submission {
// The MsgMetadata is passed by pointer all the way down.
if err := s.submissionPrepare(s.msgMeta, &header); err != nil {
return textproto.Header{}, nil, err
}
}
// TODO: Disk buffering.
buf, err := buffer.BufferInMemory(bufr)
if err != nil {
return textproto.Header{}, nil, err
}
received, err := target.GenerateReceived(context.TODO(), s.msgMeta, s.endp.hostname, s.msgMeta.OriginalFrom)
if err != nil {
return textproto.Header{}, nil, err
}
header.Add("Received", received)
return header, buf, nil
}
func (s *Session) Data(r io.Reader) error {
wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
}
header, buf, err := s.prepareBody(r)
if err != nil {
return wrapErr(err)
}
if err := s.delivery.Body(header, buf); err != nil {
return wrapErr(err)
}
if err := s.delivery.Commit(); err != nil {
return wrapErr(err)
}
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil
s.endp.semaphore.Release()
return nil
}
type statusWrapper struct {
sc smtp.StatusCollector
s *Session
}
func (sw statusWrapper) SetStatus(rcpt string, err error) {
sw.sc.SetStatus(rcpt, sw.s.endp.wrapErr(sw.s.msgMeta.ID, !sw.s.opts.UTF8, err))
}
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
}
header, buf, err := s.prepareBody(r)
if err != nil {
return wrapErr(err)
}
s.delivery.(module.PartialDelivery).BodyNonAtomic(statusWrapper{sc, s}, header, buf)
// We can't really tell whether it is failed completely or succeeded
// so always commit. Should be harmless, anyway.
if err := s.delivery.Commit(); err != nil {
return wrapErr(err)
}
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil
s.endp.semaphore.Release()
return nil
}
func (endp *Endpoint) wrapErr(msgId string, mangleUTF8 bool, err error) error {
if err == nil {
return nil
}
if err == context.DeadlineExceeded {
return &smtp.SMTPError{
Code: 451,
EnhancedCode: smtp.EnhancedCode{4, 4, 5},
Message: "High load, try again later",
}
}
res := &smtp.SMTPError{
Code: 554,
EnhancedCode: smtp.EnhancedCodeNotSet,
// Err on the side of caution if the error lacks SMTP annotations. If
// we just pass the error text through, we might accidenetally disclose
// details of server configuration.
Message: "Internal server error",
}
if exterrors.IsTemporary(err) {
res.Code = 451
}
ctxInfo := exterrors.Fields(err)
ctxCode, ok := ctxInfo["smtp_code"].(int)
if ok {
res.Code = ctxCode
}
ctxEnchCode, ok := ctxInfo["smtp_enchcode"].(exterrors.EnhancedCode)
if ok {
res.EnhancedCode = smtp.EnhancedCode(ctxEnchCode)
}
ctxMsg, ok := ctxInfo["smtp_msg"].(string)
if ok {
res.Message = ctxMsg
}
if smtpErr, ok := err.(*smtp.SMTPError); ok {
endp.Log.Printf("plain SMTP error returned, this is deprecated")
res.Code = smtpErr.Code
res.EnhancedCode = smtpErr.EnhancedCode
res.Message = smtpErr.Message
}
if msgId != "" {
res.Message += " (msg ID = " + msgId + ")"
}
// INTERNATIONALIZATION: See RFC 6531 Section 3.7.4.1.
if mangleUTF8 {
b := strings.Builder{}
b.Grow(len(res.Message))
for _, ch := range res.Message {
if ch > 128 {
b.WriteRune('?')
} else {
b.WriteRune(ch)
}
}
res.Message = b.String()
}
return res
}
type Endpoint struct {
hostname string
Auth module.AuthProvider
serv *smtp.Server
name string
addrs []string
listeners []net.Listener
pipeline *msgpipeline.MsgPipeline
resolver dns.Resolver
ratelimit limiters.Rate
semaphore limiters.Semaphore
authAlwaysRequired bool
submission bool
lmtp bool
deferServerReject bool
maxLoggedRcptErrors int
listenersWg sync.WaitGroup
Log log.Logger
}
func (endp *Endpoint) Name() string {
return endp.name
}
func (endp *Endpoint) InstanceName() string {
return endp.name
}
func New(modName string, addrs []string) (module.Module, error) {
endp := &Endpoint{
name: modName,
addrs: addrs,
submission: modName == "submission",
lmtp: modName == "lmtp",
resolver: net.DefaultResolver,
Log: log.Logger{Name: modName},
}
return endp, nil
}
func (endp *Endpoint) Init(cfg *config.Map) error {
endp.serv = smtp.NewServer(endp)
endp.serv.ErrorLog = endp.Log
endp.serv.LMTP = endp.lmtp
endp.serv.EnableSMTPUTF8 = true
if err := endp.setConfig(cfg); err != nil {
return err
}
if endp.Auth != nil {
endp.Log.Debugf("authentication provider: %s %s", endp.Auth.(module.Module).Name(), endp.Auth.(module.Module).InstanceName())
}
addresses := make([]config.Endpoint, 0, len(endp.addrs))
for _, addr := range endp.addrs {
saddr, err := config.ParseEndpoint(addr)
if err != nil {
return fmt.Errorf("%s: invalid address: %s", addr, endp.name)
}
addresses = append(addresses, saddr)
}
if err := endp.setupListeners(addresses); err != nil {
for _, l := range endp.listeners {
l.Close()
}
return err
}
allLocal := true
for _, addr := range addresses {
if addr.Scheme != "unix" && !strings.HasPrefix(addr.Host, "127.0.0.") {
allLocal = false
}
}
if endp.serv.AllowInsecureAuth && !allLocal {
endp.Log.Println("authentication over unencrypted connections is allowed, this is insecure configuration and should be used only for testing!")
}
if endp.serv.TLSConfig == nil {
if !allLocal {
endp.Log.Println("TLS is disabled, this is insecure configuration and should be used only for testing!")
}
endp.serv.AllowInsecureAuth = true
}
return nil
}
func (endp *Endpoint) setConfig(cfg *config.Map) error {
var (
err error
ioDebug bool
)
cfg.Custom("auth", false, false, nil, modconfig.AuthDirective, &endp.Auth)
cfg.String("hostname", true, true, "", &endp.hostname)
cfg.Duration("write_timeout", false, false, 1*time.Minute, &endp.serv.WriteTimeout)
cfg.Duration("read_timeout", false, false, 10*time.Minute, &endp.serv.ReadTimeout)
cfg.DataSize("max_message_size", false, false, 32*1024*1024, &endp.serv.MaxMessageBytes)
cfg.Int("max_recipients", false, false, 20000, &endp.serv.MaxRecipients)
cfg.Custom("tls", true, true, nil, config.TLSDirective, &endp.serv.TLSConfig)
cfg.Bool("insecure_auth", false, false, &endp.serv.AllowInsecureAuth)
cfg.Bool("io_debug", false, false, &ioDebug)
cfg.Bool("debug", true, false, &endp.Log.Debug)
cfg.Bool("defer_sender_reject", false, true, &endp.deferServerReject)
cfg.Int("max_logged_rcpt_errors", false, false, 5, &endp.maxLoggedRcptErrors)
cfg.Custom("ratelimit", false, false, func() (interface{}, error) {
return limiters.NewRate(10, time.Second), nil
}, config.GlobalRateLimit, &endp.ratelimit)
cfg.Custom("concurrency", false, false, func() (interface{}, error) {
return limiters.NewSemaphore(1000), nil
}, config.ConcurrencyLimit, &endp.semaphore)
cfg.AllowUnknown()
unknown, err := cfg.Process()
if err != nil {
return err
}
endp.pipeline, err = msgpipeline.New(cfg.Globals, unknown)
if err != nil {
return err
}
endp.pipeline.Hostname = endp.serv.Domain
endp.pipeline.Resolver = endp.resolver
endp.pipeline.Log = log.Logger{Name: "smtp/pipeline", Debug: endp.Log.Debug}
endp.serv.AuthDisabled = endp.Auth == nil
if endp.submission {
endp.authAlwaysRequired = true
if endp.Auth == nil {
return fmt.Errorf("%s: auth. provider must be set for submission endpoint", endp.name)
}
}
// INTERNATIONALIZATION: See RFC 6531 Section 3.3.
endp.serv.Domain, err = idna.ToASCII(endp.hostname)
if err != nil {
return fmt.Errorf("%s: can not represent the hostname as an A-label name: %w", endp.name, err)
}
if ioDebug {
endp.serv.Debug = endp.Log.DebugWriter()
endp.Log.Println("I/O debugging is on! It may leak passwords in logs, be careful!")
}
return nil
}
func (endp *Endpoint) setupListeners(addresses []config.Endpoint) error {
for _, addr := range addresses {
var l net.Listener
var err error
l, err = net.Listen(addr.Network(), addr.Address())
if err != nil {
return fmt.Errorf("%s: %w", endp.name, err)
}
endp.Log.Printf("listening on %v", addr)
if addr.IsTLS() {
if endp.serv.TLSConfig == nil {
return fmt.Errorf("%s: can't bind on SMTPS endpoint without TLS configuration", endp.name)
}
l = tls.NewListener(l, endp.serv.TLSConfig)
}
endp.listeners = append(endp.listeners, l)
endp.listenersWg.Add(1)
addr := addr
go func() {
if err := endp.serv.Serve(l); err != nil {
endp.Log.Printf("failed to serve %s: %s", addr, err)
}
endp.listenersWg.Done()
}()
}
return nil
}
func (endp *Endpoint) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
if endp.Auth == nil {
return nil, smtp.ErrAuthUnsupported
}
// Executed before authentication and session initialization.
if err := endp.pipeline.RunEarlyChecks(state); err != nil {
return nil, endp.wrapErr("", true, err)
}
if !endp.Auth.CheckPlain(username, password) {
endp.Log.Msg("authentication failed", "username", username, "src_ip", state.RemoteAddr)
return nil, errors.New("Invalid credentials")
}
return endp.newSession(false, username, password, state), nil
}
func (endp *Endpoint) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
if endp.authAlwaysRequired {
return nil, smtp.ErrAuthRequired
}
// Executed before authentication and session initialization.
if err := endp.pipeline.RunEarlyChecks(state); err != nil {
return nil, endp.wrapErr("", true, err)
}
return endp.newSession(true, "", "", state), nil
}
func (endp *Endpoint) newSession(anonymous bool, username, password string, state *smtp.ConnectionState) smtp.Session {
s := &Session{
endp: endp,
log: endp.Log,
connState: module.ConnState{
ConnectionState: *state,
AuthUser: username,
AuthPassword: password,
},
}
if endp.serv.LMTP {
s.connState.Proto = "LMTP"
} else {
// Check if TLS connection state struct is poplated.
// If it is - we are ssing TLS.
if state.TLS.HandshakeComplete {
s.connState.Proto = "ESMTPS"
} else {
s.connState.Proto = "ESMTP"
}
}
if endp.resolver != nil {
rdnsCtx, cancelRDNS := context.WithCancel(context.TODO())
s.connState.RDNSName = future.New()
s.cancelRDNS = cancelRDNS
go s.fetchRDNSName(rdnsCtx)
}
return s
}
func (endp *Endpoint) Close() error {
endp.serv.Close()
endp.listenersWg.Wait()
return nil
}
func init() {
module.RegisterEndpoint("smtp", New)
module.RegisterEndpoint("submission", New)
module.RegisterEndpoint("lmtp", New)
rand.Seed(time.Now().UnixNano())
}