mirror of
https://github.com/foxcpp/maddy.git
synced 2025-04-05 14:07:38 +03:00
Implement & integrate generic SASL authentication support
This should make it possible to implement OAuth and TLS client certificates authentication.
This commit is contained in:
parent
0507fb89f4
commit
eaaadfa6df
9 changed files with 855 additions and 509 deletions
154
internal/auth/sasl.go
Normal file
154
internal/auth/sasl.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/emersion/go-sasl"
|
||||
"github.com/foxcpp/maddy/internal/config"
|
||||
modconfig "github.com/foxcpp/maddy/internal/config/module"
|
||||
"github.com/foxcpp/maddy/internal/log"
|
||||
"github.com/foxcpp/maddy/internal/module"
|
||||
"golang.org/x/text/secure/precis"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedMech = errors.New("Unsupported SASL mechanism")
|
||||
)
|
||||
|
||||
// SASLAuth is a wrapper that initializes sasl.Server using authenticators that
|
||||
// call maddy module objects.
|
||||
//
|
||||
// It supports reporting of multiple authorization identities so multiple
|
||||
// accounts can be associated with a single set of credentials.
|
||||
type SASLAuth struct {
|
||||
Log log.Logger
|
||||
OnlyFirstID bool
|
||||
|
||||
Plain []module.PlainAuth
|
||||
}
|
||||
|
||||
func (s *SASLAuth) SASLMechanisms() []string {
|
||||
var mechs []string
|
||||
|
||||
if len(s.Plain) != 0 {
|
||||
mechs = append(mechs, sasl.Plain, sasl.Login)
|
||||
}
|
||||
|
||||
return mechs
|
||||
}
|
||||
|
||||
func (s *SASLAuth) AuthPlain(username, password string) ([]string, error) {
|
||||
if len(s.Plain) == 0 {
|
||||
return nil, ErrUnsupportedMech
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
accounts := make([]string, 0, 1)
|
||||
for _, p := range s.Plain {
|
||||
pAccs, err := p.AuthPlain(username, password)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
if s.OnlyFirstID {
|
||||
return pAccs, nil
|
||||
}
|
||||
accounts = append(accounts, pAccs...)
|
||||
}
|
||||
|
||||
if len(accounts) == 0 {
|
||||
return nil, fmt.Errorf("no auth. provider accepted creds, last err: %w", lastErr)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
func filterIdentity(accounts []string, identity string) ([]string, error) {
|
||||
if identity == "" {
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
matchFound := false
|
||||
for _, acc := range accounts {
|
||||
if precis.UsernameCaseMapped.Compare(acc, identity) {
|
||||
accounts = []string{identity}
|
||||
matchFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matchFound {
|
||||
return nil, errors.New("auth: invalid credentials")
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// CreateSASL creates the sasl.Server instance for the corresponding mechanism.
|
||||
//
|
||||
// successCb will be called with the slice of authorization identities
|
||||
// associated with credentials used.
|
||||
// If it fails - authentication will fail too.
|
||||
func (s *SASLAuth) CreateSASL(mech string, remoteAddr net.Addr, successCb func([]string) error) sasl.Server {
|
||||
switch mech {
|
||||
case sasl.Plain:
|
||||
return sasl.NewPlainServer(func(identity, username, password string) error {
|
||||
accounts, err := s.AuthPlain(username, password)
|
||||
if err != nil {
|
||||
s.Log.Error("authentication failed", err, "username", username, "identity", identity, "src_ip", remoteAddr)
|
||||
return errors.New("auth: invalid credentials")
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
accounts = []string{username}
|
||||
}
|
||||
accounts, err = filterIdentity(accounts, identity)
|
||||
if err != nil {
|
||||
s.Log.Error("not authorized", err, "username", username, "identity", identity, "src_ip", remoteAddr)
|
||||
return errors.New("auth: invalid credentials")
|
||||
}
|
||||
|
||||
return successCb(accounts)
|
||||
})
|
||||
case sasl.Login:
|
||||
return sasl.NewLoginServer(func(username, password string) error {
|
||||
accounts, err := s.AuthPlain(username, password)
|
||||
if err != nil {
|
||||
s.Log.Error("authentication failed", err, "username", username, "src_ip", remoteAddr)
|
||||
return errors.New("auth: invalid credentials")
|
||||
}
|
||||
|
||||
return successCb(accounts)
|
||||
})
|
||||
}
|
||||
return FailingSASLServ{Err: ErrUnsupportedMech}
|
||||
}
|
||||
|
||||
// AddProvider adds the SASL authentication provider to its mapping by parsing
|
||||
// the 'auth' configuration directive.
|
||||
func (s *SASLAuth) AddProvider(m *config.Map, node *config.Node) error {
|
||||
mod, err := modconfig.SASLAuthDirective(m, node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
saslAuth := mod.(module.SASLProvider)
|
||||
for _, mech := range saslAuth.SASLMechanisms() {
|
||||
switch mech {
|
||||
case sasl.Login, sasl.Plain:
|
||||
plainAuth, ok := saslAuth.(module.PlainAuth)
|
||||
if !ok {
|
||||
return m.MatchErr("auth: provider does not implement PlainAuth even though it reports PLAIN/LOGIN mechanism")
|
||||
}
|
||||
s.Plain = append(s.Plain, plainAuth)
|
||||
default:
|
||||
return m.MatchErr("auth: unknown SASL mechanism")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type FailingSASLServ struct{ Err error }
|
||||
|
||||
func (s FailingSASLServ) Next([]byte) ([]byte, bool, error) {
|
||||
return nil, true, s.Err
|
||||
}
|
103
internal/auth/sasl_test.go
Normal file
103
internal/auth/sasl_test.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/emersion/go-sasl"
|
||||
"github.com/foxcpp/maddy/internal/module"
|
||||
"github.com/foxcpp/maddy/internal/testutils"
|
||||
)
|
||||
|
||||
type mockAuth struct {
|
||||
db map[string][]string
|
||||
}
|
||||
|
||||
func (mockAuth) SASLMechanisms() []string {
|
||||
return []string{sasl.Plain, sasl.Login}
|
||||
}
|
||||
|
||||
func (m mockAuth) AuthPlain(username, _ string) ([]string, error) {
|
||||
ids, ok := m.db[username]
|
||||
if !ok {
|
||||
return nil, errors.New("invalid creds")
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func TestCreateSASL(t *testing.T) {
|
||||
a := SASLAuth{
|
||||
Log: testutils.Logger(t, "saslauth"),
|
||||
Plain: []module.PlainAuth{
|
||||
&mockAuth{
|
||||
db: map[string][]string{
|
||||
"user1": []string{"user1a", "user1b"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("XWHATEVER", func(t *testing.T) {
|
||||
srv := a.CreateSASL("XWHATEVER", &net.TCPAddr{}, func([]string) error { return nil })
|
||||
_, _, err := srv.Next([]byte(""))
|
||||
if err == nil {
|
||||
t.Error("No error for XWHATEVER use")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PLAIN", func(t *testing.T) {
|
||||
var ids []string
|
||||
srv := a.CreateSASL("PLAIN", &net.TCPAddr{}, func(passed []string) error {
|
||||
ids = passed
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, err := srv.Next([]byte("\x00user1\x00aa"))
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(ids, []string{"user1a", "user1b"}) {
|
||||
t.Error("Wrong auth. identities passed to callback:", ids)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PLAIN with autorization identity", func(t *testing.T) {
|
||||
var ids []string
|
||||
srv := a.CreateSASL("PLAIN", &net.TCPAddr{}, func(passed []string) error {
|
||||
ids = passed
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, err := srv.Next([]byte("user1a\x00user1\x00aa"))
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(ids, []string{"user1a"}) {
|
||||
t.Error("Wrong auth. identities passed to callback:", ids)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PLAIN with wrong authorization identity", func(t *testing.T) {
|
||||
srv := a.CreateSASL("PLAIN", &net.TCPAddr{}, func(passed []string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, err := srv.Next([]byte("user1c\x00user1\x00aa"))
|
||||
if err == nil {
|
||||
t.Error("Next should fail")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PLAIN with wrong authentication identity", func(t *testing.T) {
|
||||
srv := a.CreateSASL("PLAIN", &net.TCPAddr{}, func(passed []string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, err := srv.Next([]byte("\x00user2\x00aa"))
|
||||
if err == nil {
|
||||
t.Error("Next should fail")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -5,8 +5,8 @@ import (
|
|||
"github.com/foxcpp/maddy/internal/module"
|
||||
)
|
||||
|
||||
func AuthDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
var provider module.PlainAuth
|
||||
func SASLAuthDirective(m *config.Map, node *config.Node) (interface{}, error) {
|
||||
var provider module.SASLProvider
|
||||
if err := ModuleFromNode(node.Args, node, m.Globals, &provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -19,8 +19,10 @@ import (
|
|||
imapserver "github.com/emersion/go-imap/server"
|
||||
"github.com/emersion/go-message"
|
||||
_ "github.com/emersion/go-message/charset"
|
||||
"github.com/emersion/go-sasl"
|
||||
i18nlevel "github.com/foxcpp/go-imap-i18nlevel"
|
||||
"github.com/foxcpp/go-imap-sql/children"
|
||||
"github.com/foxcpp/maddy/internal/auth"
|
||||
"github.com/foxcpp/maddy/internal/config"
|
||||
modconfig "github.com/foxcpp/maddy/internal/config/module"
|
||||
"github.com/foxcpp/maddy/internal/log"
|
||||
|
@ -32,13 +34,14 @@ type Endpoint struct {
|
|||
addrs []string
|
||||
serv *imapserver.Server
|
||||
listeners []net.Listener
|
||||
Auth module.PlainAuth
|
||||
Store module.Storage
|
||||
|
||||
updater imapbackend.BackendUpdater
|
||||
tlsConfig *tls.Config
|
||||
listenersWg sync.WaitGroup
|
||||
|
||||
saslAuth auth.SASLAuth
|
||||
|
||||
Log log.Logger
|
||||
}
|
||||
|
||||
|
@ -46,6 +49,9 @@ func New(modName string, addrs []string) (module.Module, error) {
|
|||
endp := &Endpoint{
|
||||
addrs: addrs,
|
||||
Log: log.Logger{Name: "imap"},
|
||||
saslAuth: auth.SASLAuth{
|
||||
Log: log.Logger{Name: "imap/saslauth"},
|
||||
},
|
||||
}
|
||||
|
||||
return endp, nil
|
||||
|
@ -58,7 +64,9 @@ func (endp *Endpoint) Init(cfg *config.Map) error {
|
|||
ioErrors bool
|
||||
)
|
||||
|
||||
cfg.Custom("auth", false, true, nil, modconfig.AuthDirective, &endp.Auth)
|
||||
cfg.Callback("auth", func(m *config.Map, node *config.Node) error {
|
||||
return endp.saslAuth.AddProvider(m, node)
|
||||
})
|
||||
cfg.Custom("storage", false, true, nil, modconfig.StorageDirective, &endp.Store)
|
||||
cfg.Custom("tls", true, true, nil, config.TLSDirective, &endp.tlsConfig)
|
||||
cfg.Bool("insecure_auth", false, false, &insecureAuth)
|
||||
|
@ -113,6 +121,14 @@ func (endp *Endpoint) Init(cfg *config.Map) error {
|
|||
return err
|
||||
}
|
||||
|
||||
for _, mech := range endp.saslAuth.SASLMechanisms() {
|
||||
endp.serv.EnableAuth(mech, func(c imapserver.Conn) sasl.Server {
|
||||
return endp.saslAuth.CreateSASL(mech, c.Info().RemoteAddr, func(ids []string) error {
|
||||
return endp.openAccount(c, ids)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if err := endp.setupListeners(addresses); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -183,8 +199,19 @@ func (endp *Endpoint) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (endp *Endpoint) openAccount(c imapserver.Conn, identities []string) error {
|
||||
u, err := endp.Store.GetOrCreateUser(identities[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.Context()
|
||||
ctx.State = imap.AuthenticatedState
|
||||
ctx.User = u
|
||||
return nil
|
||||
}
|
||||
|
||||
func (endp *Endpoint) Login(connInfo *imap.ConnInfo, username, password string) (imapbackend.User, error) {
|
||||
_, err := endp.Auth.AuthPlain(username, password)
|
||||
_, err := endp.saslAuth.AuthPlain(username, password)
|
||||
if err != nil {
|
||||
endp.Log.Error("authentication failed", err, "username", username, "src_ip", connInfo.RemoteAddr)
|
||||
return nil, imapbackend.ErrInvalidCredentials
|
||||
|
|
506
internal/endpoint/smtp/session.go
Normal file
506
internal/endpoint/smtp/session.go
Normal file
|
@ -0,0 +1,506 @@
|
|||
package smtp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/maddy/internal/address"
|
||||
"github.com/foxcpp/maddy/internal/buffer"
|
||||
"github.com/foxcpp/maddy/internal/dns"
|
||||
"github.com/foxcpp/maddy/internal/exterrors"
|
||||
"github.com/foxcpp/maddy/internal/log"
|
||||
"github.com/foxcpp/maddy/internal/module"
|
||||
"github.com/foxcpp/maddy/internal/msgpipeline"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
endp *Endpoint
|
||||
|
||||
// Specific for this session.
|
||||
// sessionCtx is not used for cancellation or timeouts, only for tracing.
|
||||
sessionCtx context.Context
|
||||
cancelRDNS func()
|
||||
connState module.ConnState
|
||||
repeatedMailErrs int
|
||||
loggedRcptErrors int
|
||||
|
||||
// Specific for the currently handled message.
|
||||
// msgCtx is not used for cancellation or timeouts, only for tracing.
|
||||
// It is the subcontext of sessionCtx.
|
||||
// Mutex is used to prevent Close from accessing inconsistent state when it
|
||||
// is called asynchronously to any SMTP command.
|
||||
msgLock sync.Mutex
|
||||
msgCtx context.Context
|
||||
msgTask *trace.Task
|
||||
mailFrom string
|
||||
opts smtp.MailOptions
|
||||
msgMeta *module.MsgMetadata
|
||||
delivery module.Delivery
|
||||
deliveryErr error
|
||||
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
func (s *Session) Reset() {
|
||||
s.msgLock.Lock()
|
||||
defer s.msgLock.Unlock()
|
||||
|
||||
if s.delivery != nil {
|
||||
s.abort(s.msgCtx)
|
||||
}
|
||||
s.endp.Log.DebugMsg("reset")
|
||||
}
|
||||
|
||||
func (s *Session) releaseLimits() {
|
||||
_, domain, err := address.Split(s.mailFrom)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
addr, ok := s.msgMeta.Conn.RemoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
addr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
|
||||
}
|
||||
s.endp.limits.ReleaseMsg(addr.IP, domain)
|
||||
}
|
||||
|
||||
func (s *Session) abort(ctx context.Context) {
|
||||
if err := s.delivery.Abort(ctx); err != nil {
|
||||
s.endp.Log.Error("delivery abort failed", err)
|
||||
}
|
||||
s.log.Msg("aborted", "msg_id", s.msgMeta.ID)
|
||||
|
||||
s.mailFrom = ""
|
||||
s.opts = smtp.MailOptions{}
|
||||
s.msgMeta = nil
|
||||
s.delivery = nil
|
||||
s.deliveryErr = nil
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
}
|
||||
|
||||
func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) {
|
||||
var err error
|
||||
msgMeta := &module.MsgMetadata{
|
||||
Conn: &s.connState,
|
||||
SMTPOpts: opts,
|
||||
}
|
||||
|
||||
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
|
||||
// used.
|
||||
for _, ch := range from {
|
||||
if ch > 128 && !opts.UTF8 {
|
||||
return "", &exterrors.SMTPError{
|
||||
Code: 550,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
|
||||
Message: "SMTPUTF8 is required for non-ASCII senders",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode punycode, normalize to NFC and case-fold address.
|
||||
cleanFrom, err := address.CleanDomain(from)
|
||||
if err != nil {
|
||||
return "", &exterrors.SMTPError{
|
||||
Code: 553,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 1, 7},
|
||||
Message: "Unable to normalize the sender address",
|
||||
}
|
||||
}
|
||||
|
||||
msgMeta.ID, err = msgpipeline.GenerateMsgID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
msgMeta.OriginalFrom = from
|
||||
|
||||
_, domain, err := address.Split(cleanFrom)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
remoteIP, ok := msgMeta.Conn.RemoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
remoteIP = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
|
||||
}
|
||||
if err := s.endp.limits.TakeMsg(context.Background(), remoteIP.IP, domain); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if s.connState.AuthUser != "" {
|
||||
s.log.Msg("incoming message",
|
||||
"src_host", msgMeta.Conn.Hostname,
|
||||
"src_ip", msgMeta.Conn.RemoteAddr.String(),
|
||||
"sender", from,
|
||||
"msg_id", msgMeta.ID,
|
||||
"username", s.connState.AuthUser,
|
||||
)
|
||||
} else {
|
||||
s.log.Msg("incoming message",
|
||||
"src_host", msgMeta.Conn.Hostname,
|
||||
"src_ip", msgMeta.Conn.RemoteAddr.String(),
|
||||
"sender", from,
|
||||
"msg_id", msgMeta.ID,
|
||||
)
|
||||
}
|
||||
|
||||
s.msgCtx, s.msgTask = trace.NewTask(ctx, "Incoming Message")
|
||||
|
||||
mailCtx, mailTask := trace.NewTask(s.msgCtx, "MAIL FROM")
|
||||
defer mailTask.End()
|
||||
|
||||
delivery, err := s.endp.pipeline.Start(mailCtx, msgMeta, cleanFrom)
|
||||
if err != nil {
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
return msgMeta.ID, err
|
||||
}
|
||||
|
||||
s.msgMeta = msgMeta
|
||||
s.mailFrom = cleanFrom
|
||||
s.delivery = delivery
|
||||
|
||||
return msgMeta.ID, nil
|
||||
}
|
||||
|
||||
func (s *Session) Mail(from string, opts smtp.MailOptions) error {
|
||||
s.msgLock.Lock()
|
||||
defer s.msgLock.Unlock()
|
||||
|
||||
if !s.endp.deferServerReject {
|
||||
// Will initialize s.msgCtx.
|
||||
msgID, err := s.startDelivery(s.sessionCtx, from, opts)
|
||||
if err != nil {
|
||||
if err != context.DeadlineExceeded {
|
||||
s.log.Error("MAIL FROM error", err, "msg_id", msgID)
|
||||
}
|
||||
return s.endp.wrapErr(msgID, !opts.UTF8, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep the MAIL FROM argument for deferred startDelivery.
|
||||
s.mailFrom = from
|
||||
s.opts = opts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) fetchRDNSName(ctx context.Context) {
|
||||
defer trace.StartRegion(ctx, "rDNS fetch").End()
|
||||
|
||||
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
s.connState.RDNSName.Set(nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
name, err := dns.LookupAddr(ctx, s.endp.resolver, tcpAddr.IP)
|
||||
if err != nil {
|
||||
dnsErr, ok := err.(*net.DNSError)
|
||||
if ok && dnsErr.IsNotFound {
|
||||
s.connState.RDNSName.Set(nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
reason, misc := exterrors.UnwrapDNSErr(err)
|
||||
misc["reason"] = reason
|
||||
s.log.Error("rDNS error", exterrors.WithFields(err, misc), "src_ip", s.connState.RemoteAddr)
|
||||
s.connState.RDNSName.Set(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.connState.RDNSName.Set(name, nil)
|
||||
}
|
||||
|
||||
func (s *Session) Rcpt(to string) error {
|
||||
s.msgLock.Lock()
|
||||
defer s.msgLock.Unlock()
|
||||
|
||||
// deferServerReject = true and this is the first RCPT TO command.
|
||||
if s.delivery == nil {
|
||||
// If we already attempted to initialize the delivery -
|
||||
// fail again.
|
||||
if s.deliveryErr != nil {
|
||||
s.repeatedMailErrs++
|
||||
// The deliveryErr is already wrapped.
|
||||
return s.deliveryErr
|
||||
}
|
||||
|
||||
// It will initialize s.msgCtx.
|
||||
msgID, err := s.startDelivery(s.sessionCtx, s.mailFrom, s.opts)
|
||||
if err != nil {
|
||||
if err != context.DeadlineExceeded {
|
||||
s.log.Error("MAIL FROM error (deferred)", err, "rcpt", to, "msg_id", msgID)
|
||||
}
|
||||
s.deliveryErr = s.endp.wrapErr(msgID, !s.opts.UTF8, err)
|
||||
return s.deliveryErr
|
||||
}
|
||||
}
|
||||
|
||||
rcptCtx, rcptTask := trace.NewTask(s.msgCtx, "RCPT TO")
|
||||
defer rcptTask.End()
|
||||
|
||||
if err := s.rcpt(rcptCtx, to); err != nil {
|
||||
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
|
||||
s.log.Error("RCPT error", err, "rcpt", to)
|
||||
s.loggedRcptErrors++
|
||||
if s.loggedRcptErrors == s.endp.maxLoggedRcptErrors {
|
||||
s.log.Msg("too many RCPT errors, possible dictonary attack", "src_ip", s.connState.RemoteAddr, "msg_id", s.msgMeta.ID)
|
||||
}
|
||||
}
|
||||
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
|
||||
}
|
||||
s.endp.Log.Msg("RCPT ok", "rcpt", to, "msg_id", s.msgMeta.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) rcpt(ctx context.Context, to string) error {
|
||||
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
|
||||
// used.
|
||||
if !address.IsASCII(to) && !s.opts.UTF8 {
|
||||
return &exterrors.SMTPError{
|
||||
Code: 553,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
|
||||
Message: "SMTPUTF8 is required for non-ASCII recipients",
|
||||
}
|
||||
}
|
||||
cleanTo, err := address.CleanDomain(to)
|
||||
if err != nil {
|
||||
return &exterrors.SMTPError{
|
||||
Code: 501,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 1, 2},
|
||||
Message: "Unable to normalize the recipient address",
|
||||
}
|
||||
}
|
||||
|
||||
return s.delivery.AddRcpt(ctx, cleanTo)
|
||||
}
|
||||
|
||||
func (s *Session) Logout() error {
|
||||
s.msgLock.Lock()
|
||||
defer s.msgLock.Unlock()
|
||||
|
||||
if s.delivery != nil {
|
||||
s.abort(s.msgCtx)
|
||||
|
||||
if s.repeatedMailErrs > s.endp.maxLoggedRcptErrors {
|
||||
s.log.Msg("MAIL FROM repeated error a lot of times, possible dictonary attack", "count", s.repeatedMailErrs, "src_ip", s.connState.RemoteAddr)
|
||||
}
|
||||
}
|
||||
if s.cancelRDNS != nil {
|
||||
s.cancelRDNS()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) prepareBody(ctx context.Context, r io.Reader) (textproto.Header, buffer.Buffer, error) {
|
||||
bufr := bufio.NewReader(r)
|
||||
header, err := textproto.ReadHeader(bufr)
|
||||
if err != nil {
|
||||
return textproto.Header{}, nil, err
|
||||
}
|
||||
|
||||
if s.endp.submission {
|
||||
// The MsgMetadata is passed by pointer all the way down.
|
||||
if err := s.submissionPrepare(s.msgMeta, &header); err != nil {
|
||||
return textproto.Header{}, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
buf, err := s.endp.buffer(bufr)
|
||||
if err != nil {
|
||||
return textproto.Header{}, nil, err
|
||||
}
|
||||
|
||||
return header, buf, nil
|
||||
}
|
||||
|
||||
func (s *Session) Data(r io.Reader) error {
|
||||
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
|
||||
defer bodyTask.End()
|
||||
|
||||
wrapErr := func(err error) error {
|
||||
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
||||
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
|
||||
}
|
||||
|
||||
header, buf, err := s.prepareBody(bodyCtx, r)
|
||||
if err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := buf.Remove(); err != nil {
|
||||
s.log.Error("failed to remove buffered body", err)
|
||||
}
|
||||
|
||||
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
||||
s.delivery = nil
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
s.msgTask = nil
|
||||
s.releaseLimits()
|
||||
}()
|
||||
|
||||
if err := s.checkRoutingLoops(header); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if err := s.delivery.Body(bodyCtx, header, buf); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if err := s.delivery.Commit(bodyCtx); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type statusWrapper struct {
|
||||
sc smtp.StatusCollector
|
||||
s *Session
|
||||
}
|
||||
|
||||
func (sw statusWrapper) SetStatus(rcpt string, err error) {
|
||||
sw.sc.SetStatus(rcpt, sw.s.endp.wrapErr(sw.s.msgMeta.ID, !sw.s.opts.UTF8, err))
|
||||
}
|
||||
|
||||
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
|
||||
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
|
||||
defer bodyTask.End()
|
||||
|
||||
wrapErr := func(err error) error {
|
||||
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
||||
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
|
||||
}
|
||||
|
||||
header, buf, err := s.prepareBody(bodyCtx, r)
|
||||
if err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := buf.Remove(); err != nil {
|
||||
s.log.Error("failed to remove buffered body", err)
|
||||
}
|
||||
|
||||
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
||||
s.delivery = nil
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
s.msgTask = nil
|
||||
s.releaseLimits()
|
||||
}()
|
||||
|
||||
if err := s.checkRoutingLoops(header); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
s.delivery.(module.PartialDelivery).BodyNonAtomic(bodyCtx, statusWrapper{sc, s}, header, buf)
|
||||
|
||||
// We can't really tell whether it is failed completely or succeeded
|
||||
// so always commit. Should be harmless, anyway.
|
||||
if err := s.delivery.Commit(bodyCtx); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) checkRoutingLoops(header textproto.Header) error {
|
||||
// RFC 5321 Section 6.3:
|
||||
// >Simple counting of the number of "Received:" header fields in a
|
||||
// >message has proven to be an effective, although rarely optimal,
|
||||
// >method of detecting loops in mail systems.
|
||||
receivedCount := 0
|
||||
for f := header.FieldsByKey("Received"); f.Next(); {
|
||||
receivedCount++
|
||||
}
|
||||
if receivedCount > s.endp.maxReceived {
|
||||
return &exterrors.SMTPError{
|
||||
Code: 554,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 4, 6},
|
||||
Message: fmt.Sprintf("Too many Received header fields (%d), possible forwarding loop", receivedCount),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (endp *Endpoint) wrapErr(msgId string, mangleUTF8 bool, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return &smtp.SMTPError{
|
||||
Code: 451,
|
||||
EnhancedCode: smtp.EnhancedCode{4, 4, 5},
|
||||
Message: "High load, try again later",
|
||||
}
|
||||
}
|
||||
|
||||
res := &smtp.SMTPError{
|
||||
Code: 554,
|
||||
EnhancedCode: smtp.EnhancedCodeNotSet,
|
||||
// Err on the side of caution if the error lacks SMTP annotations. If
|
||||
// we just pass the error text through, we might accidenetally disclose
|
||||
// details of server configuration.
|
||||
Message: "Internal server error",
|
||||
}
|
||||
|
||||
if exterrors.IsTemporary(err) {
|
||||
res.Code = 451
|
||||
}
|
||||
|
||||
ctxInfo := exterrors.Fields(err)
|
||||
ctxCode, ok := ctxInfo["smtp_code"].(int)
|
||||
if ok {
|
||||
res.Code = ctxCode
|
||||
}
|
||||
ctxEnchCode, ok := ctxInfo["smtp_enchcode"].(exterrors.EnhancedCode)
|
||||
if ok {
|
||||
res.EnhancedCode = smtp.EnhancedCode(ctxEnchCode)
|
||||
}
|
||||
ctxMsg, ok := ctxInfo["smtp_msg"].(string)
|
||||
if ok {
|
||||
res.Message = ctxMsg
|
||||
}
|
||||
|
||||
if smtpErr, ok := err.(*smtp.SMTPError); ok {
|
||||
endp.Log.Printf("plain SMTP error returned, this is deprecated")
|
||||
res.Code = smtpErr.Code
|
||||
res.EnhancedCode = smtpErr.EnhancedCode
|
||||
res.Message = smtpErr.Message
|
||||
}
|
||||
|
||||
if msgId != "" {
|
||||
res.Message += " (msg ID = " + msgId + ")"
|
||||
}
|
||||
|
||||
// INTERNATIONALIZATION: See RFC 6531 Section 3.7.4.1.
|
||||
if mangleUTF8 {
|
||||
b := strings.Builder{}
|
||||
b.Grow(len(res.Message))
|
||||
for _, ch := range res.Message {
|
||||
if ch > 128 {
|
||||
b.WriteRune('?')
|
||||
} else {
|
||||
b.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
res.Message = b.String()
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
|
@ -1,31 +1,27 @@
|
|||
package smtp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"math/rand"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-sasl"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/maddy/internal/address"
|
||||
"github.com/foxcpp/maddy/internal/auth"
|
||||
"github.com/foxcpp/maddy/internal/buffer"
|
||||
"github.com/foxcpp/maddy/internal/config"
|
||||
modconfig "github.com/foxcpp/maddy/internal/config/module"
|
||||
"github.com/foxcpp/maddy/internal/dns"
|
||||
"github.com/foxcpp/maddy/internal/exterrors"
|
||||
"github.com/foxcpp/maddy/internal/future"
|
||||
"github.com/foxcpp/maddy/internal/limits"
|
||||
"github.com/foxcpp/maddy/internal/log"
|
||||
|
@ -34,492 +30,9 @@ import (
|
|||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
endp *Endpoint
|
||||
|
||||
// Specific for this session.
|
||||
// sessionCtx is not used for cancellation or timeouts, only for tracing.
|
||||
sessionCtx context.Context
|
||||
cancelRDNS func()
|
||||
connState module.ConnState
|
||||
repeatedMailErrs int
|
||||
loggedRcptErrors int
|
||||
|
||||
// Specific for the currently handled message.
|
||||
// msgCtx is not used for cancellation or timeouts, only for tracing.
|
||||
// It is the subcontext of sessionCtx.
|
||||
// Mutex is used to prevent Close from accessing inconsistent state when it
|
||||
// is called asynchronously to any SMTP command.
|
||||
msgLock sync.Mutex
|
||||
msgCtx context.Context
|
||||
msgTask *trace.Task
|
||||
mailFrom string
|
||||
opts smtp.MailOptions
|
||||
msgMeta *module.MsgMetadata
|
||||
delivery module.Delivery
|
||||
deliveryErr error
|
||||
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
func (s *Session) Reset() {
|
||||
s.msgLock.Lock()
|
||||
defer s.msgLock.Unlock()
|
||||
|
||||
if s.delivery != nil {
|
||||
s.abort(s.msgCtx)
|
||||
}
|
||||
s.endp.Log.DebugMsg("reset")
|
||||
}
|
||||
|
||||
func (s *Session) releaseLimits() {
|
||||
_, domain, err := address.Split(s.mailFrom)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
addr, ok := s.msgMeta.Conn.RemoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
addr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
|
||||
}
|
||||
s.endp.limits.ReleaseMsg(addr.IP, domain)
|
||||
}
|
||||
|
||||
func (s *Session) abort(ctx context.Context) {
|
||||
if err := s.delivery.Abort(ctx); err != nil {
|
||||
s.endp.Log.Error("delivery abort failed", err)
|
||||
}
|
||||
s.log.Msg("aborted", "msg_id", s.msgMeta.ID)
|
||||
|
||||
s.mailFrom = ""
|
||||
s.opts = smtp.MailOptions{}
|
||||
s.msgMeta = nil
|
||||
s.delivery = nil
|
||||
s.deliveryErr = nil
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
}
|
||||
|
||||
func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) {
|
||||
var err error
|
||||
msgMeta := &module.MsgMetadata{
|
||||
Conn: &s.connState,
|
||||
SMTPOpts: opts,
|
||||
}
|
||||
|
||||
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
|
||||
// used.
|
||||
for _, ch := range from {
|
||||
if ch > 128 && !opts.UTF8 {
|
||||
return "", &exterrors.SMTPError{
|
||||
Code: 550,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
|
||||
Message: "SMTPUTF8 is required for non-ASCII senders",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode punycode, normalize to NFC and case-fold address.
|
||||
cleanFrom, err := address.CleanDomain(from)
|
||||
if err != nil {
|
||||
return "", &exterrors.SMTPError{
|
||||
Code: 553,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 1, 7},
|
||||
Message: "Unable to normalize the sender address",
|
||||
}
|
||||
}
|
||||
|
||||
msgMeta.ID, err = msgpipeline.GenerateMsgID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
msgMeta.OriginalFrom = from
|
||||
|
||||
_, domain, err := address.Split(cleanFrom)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
remoteIP, ok := msgMeta.Conn.RemoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
remoteIP = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
|
||||
}
|
||||
if err := s.endp.limits.TakeMsg(context.Background(), remoteIP.IP, domain); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if s.connState.AuthUser != "" {
|
||||
s.log.Msg("incoming message",
|
||||
"src_host", msgMeta.Conn.Hostname,
|
||||
"src_ip", msgMeta.Conn.RemoteAddr.String(),
|
||||
"sender", from,
|
||||
"msg_id", msgMeta.ID,
|
||||
"username", s.connState.AuthUser,
|
||||
)
|
||||
} else {
|
||||
s.log.Msg("incoming message",
|
||||
"src_host", msgMeta.Conn.Hostname,
|
||||
"src_ip", msgMeta.Conn.RemoteAddr.String(),
|
||||
"sender", from,
|
||||
"msg_id", msgMeta.ID,
|
||||
)
|
||||
}
|
||||
|
||||
s.msgCtx, s.msgTask = trace.NewTask(ctx, "Incoming Message")
|
||||
|
||||
mailCtx, mailTask := trace.NewTask(s.msgCtx, "MAIL FROM")
|
||||
defer mailTask.End()
|
||||
|
||||
delivery, err := s.endp.pipeline.Start(mailCtx, msgMeta, cleanFrom)
|
||||
if err != nil {
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
return msgMeta.ID, err
|
||||
}
|
||||
|
||||
s.msgMeta = msgMeta
|
||||
s.mailFrom = cleanFrom
|
||||
s.delivery = delivery
|
||||
|
||||
return msgMeta.ID, nil
|
||||
}
|
||||
|
||||
func (s *Session) Mail(from string, opts smtp.MailOptions) error {
|
||||
s.msgLock.Lock()
|
||||
defer s.msgLock.Unlock()
|
||||
|
||||
if !s.endp.deferServerReject {
|
||||
// Will initialize s.msgCtx.
|
||||
msgID, err := s.startDelivery(s.sessionCtx, from, opts)
|
||||
if err != nil {
|
||||
if err != context.DeadlineExceeded {
|
||||
s.log.Error("MAIL FROM error", err, "msg_id", msgID)
|
||||
}
|
||||
return s.endp.wrapErr(msgID, !opts.UTF8, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep the MAIL FROM argument for deferred startDelivery.
|
||||
s.mailFrom = from
|
||||
s.opts = opts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) fetchRDNSName(ctx context.Context) {
|
||||
defer trace.StartRegion(ctx, "rDNS fetch").End()
|
||||
|
||||
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
s.connState.RDNSName.Set(nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
name, err := dns.LookupAddr(ctx, s.endp.resolver, tcpAddr.IP)
|
||||
if err != nil {
|
||||
dnsErr, ok := err.(*net.DNSError)
|
||||
if ok && dnsErr.IsNotFound {
|
||||
s.connState.RDNSName.Set(nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
reason, misc := exterrors.UnwrapDNSErr(err)
|
||||
misc["reason"] = reason
|
||||
s.log.Error("rDNS error", exterrors.WithFields(err, misc), "src_ip", s.connState.RemoteAddr)
|
||||
s.connState.RDNSName.Set(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.connState.RDNSName.Set(name, nil)
|
||||
}
|
||||
|
||||
func (s *Session) Rcpt(to string) error {
|
||||
s.msgLock.Lock()
|
||||
defer s.msgLock.Unlock()
|
||||
|
||||
// deferServerReject = true and this is the first RCPT TO command.
|
||||
if s.delivery == nil {
|
||||
// If we already attempted to initialize the delivery -
|
||||
// fail again.
|
||||
if s.deliveryErr != nil {
|
||||
s.repeatedMailErrs++
|
||||
// The deliveryErr is already wrapped.
|
||||
return s.deliveryErr
|
||||
}
|
||||
|
||||
// It will initialize s.msgCtx.
|
||||
msgID, err := s.startDelivery(s.sessionCtx, s.mailFrom, s.opts)
|
||||
if err != nil {
|
||||
if err != context.DeadlineExceeded {
|
||||
s.log.Error("MAIL FROM error (deferred)", err, "rcpt", to, "msg_id", msgID)
|
||||
}
|
||||
s.deliveryErr = s.endp.wrapErr(msgID, !s.opts.UTF8, err)
|
||||
return s.deliveryErr
|
||||
}
|
||||
}
|
||||
|
||||
rcptCtx, rcptTask := trace.NewTask(s.msgCtx, "RCPT TO")
|
||||
defer rcptTask.End()
|
||||
|
||||
if err := s.rcpt(rcptCtx, to); err != nil {
|
||||
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
|
||||
s.log.Error("RCPT error", err, "rcpt", to)
|
||||
s.loggedRcptErrors++
|
||||
if s.loggedRcptErrors == s.endp.maxLoggedRcptErrors {
|
||||
s.log.Msg("too many RCPT errors, possible dictonary attack", "src_ip", s.connState.RemoteAddr, "msg_id", s.msgMeta.ID)
|
||||
}
|
||||
}
|
||||
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
|
||||
}
|
||||
s.endp.Log.Msg("RCPT ok", "rcpt", to, "msg_id", s.msgMeta.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) rcpt(ctx context.Context, to string) error {
|
||||
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
|
||||
// used.
|
||||
if !address.IsASCII(to) && !s.opts.UTF8 {
|
||||
return &exterrors.SMTPError{
|
||||
Code: 553,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
|
||||
Message: "SMTPUTF8 is required for non-ASCII recipients",
|
||||
}
|
||||
}
|
||||
cleanTo, err := address.CleanDomain(to)
|
||||
if err != nil {
|
||||
return &exterrors.SMTPError{
|
||||
Code: 501,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 1, 2},
|
||||
Message: "Unable to normalize the recipient address",
|
||||
}
|
||||
}
|
||||
|
||||
return s.delivery.AddRcpt(ctx, cleanTo)
|
||||
}
|
||||
|
||||
func (s *Session) Logout() error {
|
||||
s.msgLock.Lock()
|
||||
defer s.msgLock.Unlock()
|
||||
|
||||
if s.delivery != nil {
|
||||
s.abort(s.msgCtx)
|
||||
|
||||
if s.repeatedMailErrs > s.endp.maxLoggedRcptErrors {
|
||||
s.log.Msg("MAIL FROM repeated error a lot of times, possible dictonary attack", "count", s.repeatedMailErrs, "src_ip", s.connState.RemoteAddr)
|
||||
}
|
||||
}
|
||||
if s.cancelRDNS != nil {
|
||||
s.cancelRDNS()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) prepareBody(ctx context.Context, r io.Reader) (textproto.Header, buffer.Buffer, error) {
|
||||
bufr := bufio.NewReader(r)
|
||||
header, err := textproto.ReadHeader(bufr)
|
||||
if err != nil {
|
||||
return textproto.Header{}, nil, err
|
||||
}
|
||||
|
||||
if s.endp.submission {
|
||||
// The MsgMetadata is passed by pointer all the way down.
|
||||
if err := s.submissionPrepare(s.msgMeta, &header); err != nil {
|
||||
return textproto.Header{}, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
buf, err := s.endp.buffer(bufr)
|
||||
if err != nil {
|
||||
return textproto.Header{}, nil, err
|
||||
}
|
||||
|
||||
return header, buf, nil
|
||||
}
|
||||
|
||||
func (s *Session) Data(r io.Reader) error {
|
||||
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
|
||||
defer bodyTask.End()
|
||||
|
||||
wrapErr := func(err error) error {
|
||||
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
||||
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
|
||||
}
|
||||
|
||||
header, buf, err := s.prepareBody(bodyCtx, r)
|
||||
if err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := buf.Remove(); err != nil {
|
||||
s.log.Error("failed to remove buffered body", err)
|
||||
}
|
||||
|
||||
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
||||
s.delivery = nil
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
s.msgTask = nil
|
||||
s.releaseLimits()
|
||||
}()
|
||||
|
||||
if err := s.checkRoutingLoops(header); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if err := s.delivery.Body(bodyCtx, header, buf); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if err := s.delivery.Commit(bodyCtx); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type statusWrapper struct {
|
||||
sc smtp.StatusCollector
|
||||
s *Session
|
||||
}
|
||||
|
||||
func (sw statusWrapper) SetStatus(rcpt string, err error) {
|
||||
sw.sc.SetStatus(rcpt, sw.s.endp.wrapErr(sw.s.msgMeta.ID, !sw.s.opts.UTF8, err))
|
||||
}
|
||||
|
||||
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
|
||||
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
|
||||
defer bodyTask.End()
|
||||
|
||||
wrapErr := func(err error) error {
|
||||
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
|
||||
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
|
||||
}
|
||||
|
||||
header, buf, err := s.prepareBody(bodyCtx, r)
|
||||
if err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := buf.Remove(); err != nil {
|
||||
s.log.Error("failed to remove buffered body", err)
|
||||
}
|
||||
|
||||
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
|
||||
s.delivery = nil
|
||||
s.msgCtx = nil
|
||||
s.msgTask.End()
|
||||
s.msgTask = nil
|
||||
s.releaseLimits()
|
||||
}()
|
||||
|
||||
if err := s.checkRoutingLoops(header); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
s.delivery.(module.PartialDelivery).BodyNonAtomic(bodyCtx, statusWrapper{sc, s}, header, buf)
|
||||
|
||||
// We can't really tell whether it is failed completely or succeeded
|
||||
// so always commit. Should be harmless, anyway.
|
||||
if err := s.delivery.Commit(bodyCtx); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) checkRoutingLoops(header textproto.Header) error {
|
||||
// RFC 5321 Section 6.3:
|
||||
// >Simple counting of the number of "Received:" header fields in a
|
||||
// >message has proven to be an effective, although rarely optimal,
|
||||
// >method of detecting loops in mail systems.
|
||||
receivedCount := 0
|
||||
for f := header.FieldsByKey("Received"); f.Next(); {
|
||||
receivedCount++
|
||||
}
|
||||
if receivedCount > s.endp.maxReceived {
|
||||
return &exterrors.SMTPError{
|
||||
Code: 554,
|
||||
EnhancedCode: exterrors.EnhancedCode{5, 4, 6},
|
||||
Message: fmt.Sprintf("Too many Received header fields (%d), possible forwarding loop", receivedCount),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (endp *Endpoint) wrapErr(msgId string, mangleUTF8 bool, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return &smtp.SMTPError{
|
||||
Code: 451,
|
||||
EnhancedCode: smtp.EnhancedCode{4, 4, 5},
|
||||
Message: "High load, try again later",
|
||||
}
|
||||
}
|
||||
|
||||
res := &smtp.SMTPError{
|
||||
Code: 554,
|
||||
EnhancedCode: smtp.EnhancedCodeNotSet,
|
||||
// Err on the side of caution if the error lacks SMTP annotations. If
|
||||
// we just pass the error text through, we might accidenetally disclose
|
||||
// details of server configuration.
|
||||
Message: "Internal server error",
|
||||
}
|
||||
|
||||
if exterrors.IsTemporary(err) {
|
||||
res.Code = 451
|
||||
}
|
||||
|
||||
ctxInfo := exterrors.Fields(err)
|
||||
ctxCode, ok := ctxInfo["smtp_code"].(int)
|
||||
if ok {
|
||||
res.Code = ctxCode
|
||||
}
|
||||
ctxEnchCode, ok := ctxInfo["smtp_enchcode"].(exterrors.EnhancedCode)
|
||||
if ok {
|
||||
res.EnhancedCode = smtp.EnhancedCode(ctxEnchCode)
|
||||
}
|
||||
ctxMsg, ok := ctxInfo["smtp_msg"].(string)
|
||||
if ok {
|
||||
res.Message = ctxMsg
|
||||
}
|
||||
|
||||
if smtpErr, ok := err.(*smtp.SMTPError); ok {
|
||||
endp.Log.Printf("plain SMTP error returned, this is deprecated")
|
||||
res.Code = smtpErr.Code
|
||||
res.EnhancedCode = smtpErr.EnhancedCode
|
||||
res.Message = smtpErr.Message
|
||||
}
|
||||
|
||||
if msgId != "" {
|
||||
res.Message += " (msg ID = " + msgId + ")"
|
||||
}
|
||||
|
||||
// INTERNATIONALIZATION: See RFC 6531 Section 3.7.4.1.
|
||||
if mangleUTF8 {
|
||||
b := strings.Builder{}
|
||||
b.Grow(len(res.Message))
|
||||
for _, ch := range res.Message {
|
||||
if ch > 128 {
|
||||
b.WriteRune('?')
|
||||
} else {
|
||||
b.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
res.Message = b.String()
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type Endpoint struct {
|
||||
hostname string
|
||||
Auth module.PlainAuth
|
||||
saslAuth auth.SASLAuth
|
||||
serv *smtp.Server
|
||||
name string
|
||||
addrs []string
|
||||
|
@ -572,10 +85,6 @@ func (endp *Endpoint) Init(cfg *config.Map) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if endp.Auth != nil {
|
||||
endp.Log.Debugf("authentication provider: %s %s", endp.Auth.(module.Module).Name(), endp.Auth.(module.Module).InstanceName())
|
||||
}
|
||||
|
||||
addresses := make([]config.Endpoint, 0, len(endp.addrs))
|
||||
for _, addr := range endp.addrs {
|
||||
saddr, err := config.ParseEndpoint(addr)
|
||||
|
@ -695,7 +204,9 @@ func (endp *Endpoint) setConfig(cfg *config.Map) error {
|
|||
ioDebug bool
|
||||
)
|
||||
|
||||
cfg.Custom("auth", false, false, nil, modconfig.AuthDirective, &endp.Auth)
|
||||
cfg.Callback("auth", func(m *config.Map, node *config.Node) error {
|
||||
return endp.saslAuth.AddProvider(m, node)
|
||||
})
|
||||
cfg.String("hostname", true, true, "", &endp.hostname)
|
||||
cfg.Duration("write_timeout", false, false, 1*time.Minute, &endp.serv.WriteTimeout)
|
||||
cfg.Duration("read_timeout", false, false, 10*time.Minute, &endp.serv.ReadTimeout)
|
||||
|
@ -738,14 +249,32 @@ func (endp *Endpoint) setConfig(cfg *config.Map) error {
|
|||
endp.pipeline.Log = log.Logger{Name: "smtp/pipeline", Debug: endp.Log.Debug}
|
||||
endp.pipeline.FirstPipeline = true
|
||||
|
||||
endp.serv.AuthDisabled = endp.Auth == nil
|
||||
endp.serv.AuthDisabled = len(endp.saslAuth.SASLMechanisms()) == 0
|
||||
if endp.submission {
|
||||
endp.authAlwaysRequired = true
|
||||
|
||||
if endp.Auth == nil {
|
||||
if len(endp.saslAuth.SASLMechanisms()) == 0 {
|
||||
return fmt.Errorf("%s: auth. provider must be set for submission endpoint", endp.name)
|
||||
}
|
||||
}
|
||||
for _, mech := range endp.saslAuth.SASLMechanisms() {
|
||||
// TODO: The code below lacks handling to set AuthPassword. Don't
|
||||
// override sasl.Plain handler so Login() will be called as usual.
|
||||
if mech == sasl.Plain {
|
||||
continue
|
||||
}
|
||||
|
||||
endp.serv.EnableAuth(mech, func(c *smtp.Conn) sasl.Server {
|
||||
state := c.State()
|
||||
if err := endp.pipeline.RunEarlyChecks(context.TODO(), &state); err != nil {
|
||||
return auth.FailingSASLServ{Err: endp.wrapErr("", true, err)}
|
||||
}
|
||||
|
||||
return endp.saslAuth.CreateSASL(mech, state.RemoteAddr, func(ids []string) error {
|
||||
c.SetSession(endp.newSession(false, ids[0], "", &state))
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// INTERNATIONALIZATION: See RFC 6531 Section 3.3.
|
||||
endp.serv.Domain, err = idna.ToASCII(endp.hostname)
|
||||
|
@ -794,7 +323,7 @@ func (endp *Endpoint) setupListeners(addresses []config.Endpoint) error {
|
|||
}
|
||||
|
||||
func (endp *Endpoint) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
|
||||
if endp.Auth == nil {
|
||||
if endp.serv.AuthDisabled {
|
||||
return nil, smtp.ErrAuthUnsupported
|
||||
}
|
||||
|
||||
|
@ -803,7 +332,7 @@ func (endp *Endpoint) Login(state *smtp.ConnectionState, username, password stri
|
|||
return nil, endp.wrapErr("", true, err)
|
||||
}
|
||||
|
||||
_, err := endp.Auth.AuthPlain(username, password)
|
||||
_, err := endp.saslAuth.AuthPlain(username, password)
|
||||
if err != nil {
|
||||
// TODO: Update fail2ban filters.
|
||||
endp.Log.Error("authentication failed", err, "username", username, "src_ip", state.RemoteAddr)
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/emersion/go-sasl"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/foxcpp/go-mockdns"
|
||||
"github.com/foxcpp/maddy/internal/auth"
|
||||
"github.com/foxcpp/maddy/internal/config"
|
||||
"github.com/foxcpp/maddy/internal/exterrors"
|
||||
"github.com/foxcpp/maddy/internal/module"
|
||||
|
@ -27,7 +28,7 @@ const testMsg = "From: <sender@example.org>\r\n" +
|
|||
"\r\n" +
|
||||
"foobar\r\n"
|
||||
|
||||
func testEndpoint(t *testing.T, modName string, auth module.PlainAuth, tgt module.DeliveryTarget, checks []module.Check, cfg []config.Node) *Endpoint {
|
||||
func testEndpoint(t *testing.T, modName string, authMod module.PlainAuth, tgt module.DeliveryTarget, checks []module.Check, cfg []config.Node) *Endpoint {
|
||||
t.Helper()
|
||||
|
||||
mod, err := New(modName, []string{"tcp://127.0.0.1:" + testPort})
|
||||
|
@ -63,7 +64,7 @@ func testEndpoint(t *testing.T, modName string, auth module.PlainAuth, tgt modul
|
|||
},
|
||||
)
|
||||
|
||||
if auth != nil {
|
||||
if authMod != nil {
|
||||
cfg = append(cfg, config.Node{
|
||||
Name: "auth",
|
||||
Args: []string{"dummy"},
|
||||
|
@ -77,7 +78,10 @@ func testEndpoint(t *testing.T, modName string, auth module.PlainAuth, tgt modul
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
endp.Auth = auth
|
||||
endp.saslAuth = auth.SASLAuth{
|
||||
Log: testutils.Logger(t, "smtp/saslauth"),
|
||||
Plain: []module.PlainAuth{authMod},
|
||||
}
|
||||
|
||||
endp.pipeline = msgpipeline.Mock(tgt, checks)
|
||||
endp.pipeline.Hostname = "mx.example.com"
|
||||
|
|
|
@ -14,3 +14,21 @@ var (
|
|||
type PlainAuth interface {
|
||||
AuthPlain(username, password string) ([]string, error)
|
||||
}
|
||||
|
||||
// SASLProvider is the interface implemented by modules and used by protocol
|
||||
// endpoints that rely on SASL framework for user authentication.
|
||||
//
|
||||
// This actual interface is only used to indicate that the module is a
|
||||
// SASL-compatible auth. provider. For each unique value returned by
|
||||
// SASLMechanisms, the module object should also implement the coresponding
|
||||
// mechanism-specific interface.
|
||||
//
|
||||
// *Rationale*: There is no single generic interface that would handle any SASL
|
||||
// mechanism while permiting the use of a credentials set estabilished once with
|
||||
// multiple auth. providers at once.
|
||||
//
|
||||
// Per-mechanism interfaces:
|
||||
// - PLAIN => PlainAuth
|
||||
type SASLProvider interface {
|
||||
SASLMechanisms() []string
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/emersion/go-message/textproto"
|
||||
"github.com/emersion/go-sasl"
|
||||
"github.com/foxcpp/maddy/internal/buffer"
|
||||
"github.com/foxcpp/maddy/internal/config"
|
||||
)
|
||||
|
@ -15,6 +16,10 @@ import (
|
|||
// and the actual server code (but the latter is kinda pointless).
|
||||
type Dummy struct{ instName string }
|
||||
|
||||
func (d *Dummy) SASLMechanisms() []string {
|
||||
return []string{sasl.Plain, sasl.Login}
|
||||
}
|
||||
|
||||
func (d *Dummy) AuthPlain(username, _ string) ([]string, error) {
|
||||
return []string{username}, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue