Implement & integrate generic SASL authentication support

This should make it possible to implement OAuth and TLS client
certificates authentication.
This commit is contained in:
fox.cpp 2020-02-27 21:40:04 +03:00
parent 0507fb89f4
commit eaaadfa6df
No known key found for this signature in database
GPG key ID: E76D97CCEDE90B6C
9 changed files with 855 additions and 509 deletions

154
internal/auth/sasl.go Normal file
View file

@ -0,0 +1,154 @@
package auth
import (
"errors"
"fmt"
"net"
"github.com/emersion/go-sasl"
"github.com/foxcpp/maddy/internal/config"
modconfig "github.com/foxcpp/maddy/internal/config/module"
"github.com/foxcpp/maddy/internal/log"
"github.com/foxcpp/maddy/internal/module"
"golang.org/x/text/secure/precis"
)
var (
ErrUnsupportedMech = errors.New("Unsupported SASL mechanism")
)
// SASLAuth is a wrapper that initializes sasl.Server using authenticators that
// call maddy module objects.
//
// It supports reporting of multiple authorization identities so multiple
// accounts can be associated with a single set of credentials.
type SASLAuth struct {
Log log.Logger
OnlyFirstID bool
Plain []module.PlainAuth
}
func (s *SASLAuth) SASLMechanisms() []string {
var mechs []string
if len(s.Plain) != 0 {
mechs = append(mechs, sasl.Plain, sasl.Login)
}
return mechs
}
func (s *SASLAuth) AuthPlain(username, password string) ([]string, error) {
if len(s.Plain) == 0 {
return nil, ErrUnsupportedMech
}
var lastErr error
accounts := make([]string, 0, 1)
for _, p := range s.Plain {
pAccs, err := p.AuthPlain(username, password)
if err != nil {
lastErr = err
continue
}
if s.OnlyFirstID {
return pAccs, nil
}
accounts = append(accounts, pAccs...)
}
if len(accounts) == 0 {
return nil, fmt.Errorf("no auth. provider accepted creds, last err: %w", lastErr)
}
return accounts, nil
}
func filterIdentity(accounts []string, identity string) ([]string, error) {
if identity == "" {
return accounts, nil
}
matchFound := false
for _, acc := range accounts {
if precis.UsernameCaseMapped.Compare(acc, identity) {
accounts = []string{identity}
matchFound = true
break
}
}
if !matchFound {
return nil, errors.New("auth: invalid credentials")
}
return accounts, nil
}
// CreateSASL creates the sasl.Server instance for the corresponding mechanism.
//
// successCb will be called with the slice of authorization identities
// associated with credentials used.
// If it fails - authentication will fail too.
func (s *SASLAuth) CreateSASL(mech string, remoteAddr net.Addr, successCb func([]string) error) sasl.Server {
switch mech {
case sasl.Plain:
return sasl.NewPlainServer(func(identity, username, password string) error {
accounts, err := s.AuthPlain(username, password)
if err != nil {
s.Log.Error("authentication failed", err, "username", username, "identity", identity, "src_ip", remoteAddr)
return errors.New("auth: invalid credentials")
}
if len(accounts) == 0 {
accounts = []string{username}
}
accounts, err = filterIdentity(accounts, identity)
if err != nil {
s.Log.Error("not authorized", err, "username", username, "identity", identity, "src_ip", remoteAddr)
return errors.New("auth: invalid credentials")
}
return successCb(accounts)
})
case sasl.Login:
return sasl.NewLoginServer(func(username, password string) error {
accounts, err := s.AuthPlain(username, password)
if err != nil {
s.Log.Error("authentication failed", err, "username", username, "src_ip", remoteAddr)
return errors.New("auth: invalid credentials")
}
return successCb(accounts)
})
}
return FailingSASLServ{Err: ErrUnsupportedMech}
}
// AddProvider adds the SASL authentication provider to its mapping by parsing
// the 'auth' configuration directive.
func (s *SASLAuth) AddProvider(m *config.Map, node *config.Node) error {
mod, err := modconfig.SASLAuthDirective(m, node)
if err != nil {
return err
}
saslAuth := mod.(module.SASLProvider)
for _, mech := range saslAuth.SASLMechanisms() {
switch mech {
case sasl.Login, sasl.Plain:
plainAuth, ok := saslAuth.(module.PlainAuth)
if !ok {
return m.MatchErr("auth: provider does not implement PlainAuth even though it reports PLAIN/LOGIN mechanism")
}
s.Plain = append(s.Plain, plainAuth)
default:
return m.MatchErr("auth: unknown SASL mechanism")
}
}
return nil
}
type FailingSASLServ struct{ Err error }
func (s FailingSASLServ) Next([]byte) ([]byte, bool, error) {
return nil, true, s.Err
}

103
internal/auth/sasl_test.go Normal file
View file

@ -0,0 +1,103 @@
package auth
import (
"errors"
"net"
"reflect"
"testing"
"github.com/emersion/go-sasl"
"github.com/foxcpp/maddy/internal/module"
"github.com/foxcpp/maddy/internal/testutils"
)
type mockAuth struct {
db map[string][]string
}
func (mockAuth) SASLMechanisms() []string {
return []string{sasl.Plain, sasl.Login}
}
func (m mockAuth) AuthPlain(username, _ string) ([]string, error) {
ids, ok := m.db[username]
if !ok {
return nil, errors.New("invalid creds")
}
return ids, nil
}
func TestCreateSASL(t *testing.T) {
a := SASLAuth{
Log: testutils.Logger(t, "saslauth"),
Plain: []module.PlainAuth{
&mockAuth{
db: map[string][]string{
"user1": []string{"user1a", "user1b"},
},
},
},
}
t.Run("XWHATEVER", func(t *testing.T) {
srv := a.CreateSASL("XWHATEVER", &net.TCPAddr{}, func([]string) error { return nil })
_, _, err := srv.Next([]byte(""))
if err == nil {
t.Error("No error for XWHATEVER use")
}
})
t.Run("PLAIN", func(t *testing.T) {
var ids []string
srv := a.CreateSASL("PLAIN", &net.TCPAddr{}, func(passed []string) error {
ids = passed
return nil
})
_, _, err := srv.Next([]byte("\x00user1\x00aa"))
if err != nil {
t.Error("Unexpected error:", err)
}
if !reflect.DeepEqual(ids, []string{"user1a", "user1b"}) {
t.Error("Wrong auth. identities passed to callback:", ids)
}
})
t.Run("PLAIN with autorization identity", func(t *testing.T) {
var ids []string
srv := a.CreateSASL("PLAIN", &net.TCPAddr{}, func(passed []string) error {
ids = passed
return nil
})
_, _, err := srv.Next([]byte("user1a\x00user1\x00aa"))
if err != nil {
t.Error("Unexpected error:", err)
}
if !reflect.DeepEqual(ids, []string{"user1a"}) {
t.Error("Wrong auth. identities passed to callback:", ids)
}
})
t.Run("PLAIN with wrong authorization identity", func(t *testing.T) {
srv := a.CreateSASL("PLAIN", &net.TCPAddr{}, func(passed []string) error {
return nil
})
_, _, err := srv.Next([]byte("user1c\x00user1\x00aa"))
if err == nil {
t.Error("Next should fail")
}
})
t.Run("PLAIN with wrong authentication identity", func(t *testing.T) {
srv := a.CreateSASL("PLAIN", &net.TCPAddr{}, func(passed []string) error {
return nil
})
_, _, err := srv.Next([]byte("\x00user2\x00aa"))
if err == nil {
t.Error("Next should fail")
}
})
}

View file

@ -5,8 +5,8 @@ import (
"github.com/foxcpp/maddy/internal/module"
)
func AuthDirective(m *config.Map, node *config.Node) (interface{}, error) {
var provider module.PlainAuth
func SASLAuthDirective(m *config.Map, node *config.Node) (interface{}, error) {
var provider module.SASLProvider
if err := ModuleFromNode(node.Args, node, m.Globals, &provider); err != nil {
return nil, err
}

View file

@ -19,8 +19,10 @@ import (
imapserver "github.com/emersion/go-imap/server"
"github.com/emersion/go-message"
_ "github.com/emersion/go-message/charset"
"github.com/emersion/go-sasl"
i18nlevel "github.com/foxcpp/go-imap-i18nlevel"
"github.com/foxcpp/go-imap-sql/children"
"github.com/foxcpp/maddy/internal/auth"
"github.com/foxcpp/maddy/internal/config"
modconfig "github.com/foxcpp/maddy/internal/config/module"
"github.com/foxcpp/maddy/internal/log"
@ -32,13 +34,14 @@ type Endpoint struct {
addrs []string
serv *imapserver.Server
listeners []net.Listener
Auth module.PlainAuth
Store module.Storage
updater imapbackend.BackendUpdater
tlsConfig *tls.Config
listenersWg sync.WaitGroup
saslAuth auth.SASLAuth
Log log.Logger
}
@ -46,6 +49,9 @@ func New(modName string, addrs []string) (module.Module, error) {
endp := &Endpoint{
addrs: addrs,
Log: log.Logger{Name: "imap"},
saslAuth: auth.SASLAuth{
Log: log.Logger{Name: "imap/saslauth"},
},
}
return endp, nil
@ -58,7 +64,9 @@ func (endp *Endpoint) Init(cfg *config.Map) error {
ioErrors bool
)
cfg.Custom("auth", false, true, nil, modconfig.AuthDirective, &endp.Auth)
cfg.Callback("auth", func(m *config.Map, node *config.Node) error {
return endp.saslAuth.AddProvider(m, node)
})
cfg.Custom("storage", false, true, nil, modconfig.StorageDirective, &endp.Store)
cfg.Custom("tls", true, true, nil, config.TLSDirective, &endp.tlsConfig)
cfg.Bool("insecure_auth", false, false, &insecureAuth)
@ -113,6 +121,14 @@ func (endp *Endpoint) Init(cfg *config.Map) error {
return err
}
for _, mech := range endp.saslAuth.SASLMechanisms() {
endp.serv.EnableAuth(mech, func(c imapserver.Conn) sasl.Server {
return endp.saslAuth.CreateSASL(mech, c.Info().RemoteAddr, func(ids []string) error {
return endp.openAccount(c, ids)
})
})
}
if err := endp.setupListeners(addresses); err != nil {
return err
}
@ -183,8 +199,19 @@ func (endp *Endpoint) Close() error {
return nil
}
func (endp *Endpoint) openAccount(c imapserver.Conn, identities []string) error {
u, err := endp.Store.GetOrCreateUser(identities[0])
if err != nil {
return err
}
ctx := c.Context()
ctx.State = imap.AuthenticatedState
ctx.User = u
return nil
}
func (endp *Endpoint) Login(connInfo *imap.ConnInfo, username, password string) (imapbackend.User, error) {
_, err := endp.Auth.AuthPlain(username, password)
_, err := endp.saslAuth.AuthPlain(username, password)
if err != nil {
endp.Log.Error("authentication failed", err, "username", username, "src_ip", connInfo.RemoteAddr)
return nil, imapbackend.ErrInvalidCredentials

View file

@ -0,0 +1,506 @@
package smtp
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"runtime/trace"
"strings"
"sync"
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/internal/address"
"github.com/foxcpp/maddy/internal/buffer"
"github.com/foxcpp/maddy/internal/dns"
"github.com/foxcpp/maddy/internal/exterrors"
"github.com/foxcpp/maddy/internal/log"
"github.com/foxcpp/maddy/internal/module"
"github.com/foxcpp/maddy/internal/msgpipeline"
)
type Session struct {
endp *Endpoint
// Specific for this session.
// sessionCtx is not used for cancellation or timeouts, only for tracing.
sessionCtx context.Context
cancelRDNS func()
connState module.ConnState
repeatedMailErrs int
loggedRcptErrors int
// Specific for the currently handled message.
// msgCtx is not used for cancellation or timeouts, only for tracing.
// It is the subcontext of sessionCtx.
// Mutex is used to prevent Close from accessing inconsistent state when it
// is called asynchronously to any SMTP command.
msgLock sync.Mutex
msgCtx context.Context
msgTask *trace.Task
mailFrom string
opts smtp.MailOptions
msgMeta *module.MsgMetadata
delivery module.Delivery
deliveryErr error
log log.Logger
}
func (s *Session) Reset() {
s.msgLock.Lock()
defer s.msgLock.Unlock()
if s.delivery != nil {
s.abort(s.msgCtx)
}
s.endp.Log.DebugMsg("reset")
}
func (s *Session) releaseLimits() {
_, domain, err := address.Split(s.mailFrom)
if err != nil {
return
}
addr, ok := s.msgMeta.Conn.RemoteAddr.(*net.TCPAddr)
if !ok {
addr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
}
s.endp.limits.ReleaseMsg(addr.IP, domain)
}
func (s *Session) abort(ctx context.Context) {
if err := s.delivery.Abort(ctx); err != nil {
s.endp.Log.Error("delivery abort failed", err)
}
s.log.Msg("aborted", "msg_id", s.msgMeta.ID)
s.mailFrom = ""
s.opts = smtp.MailOptions{}
s.msgMeta = nil
s.delivery = nil
s.deliveryErr = nil
s.msgCtx = nil
s.msgTask.End()
}
func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) {
var err error
msgMeta := &module.MsgMetadata{
Conn: &s.connState,
SMTPOpts: opts,
}
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
// used.
for _, ch := range from {
if ch > 128 && !opts.UTF8 {
return "", &exterrors.SMTPError{
Code: 550,
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
Message: "SMTPUTF8 is required for non-ASCII senders",
}
}
}
// Decode punycode, normalize to NFC and case-fold address.
cleanFrom, err := address.CleanDomain(from)
if err != nil {
return "", &exterrors.SMTPError{
Code: 553,
EnhancedCode: exterrors.EnhancedCode{5, 1, 7},
Message: "Unable to normalize the sender address",
}
}
msgMeta.ID, err = msgpipeline.GenerateMsgID()
if err != nil {
return "", err
}
msgMeta.OriginalFrom = from
_, domain, err := address.Split(cleanFrom)
if err != nil {
return "", err
}
remoteIP, ok := msgMeta.Conn.RemoteAddr.(*net.TCPAddr)
if !ok {
remoteIP = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
}
if err := s.endp.limits.TakeMsg(context.Background(), remoteIP.IP, domain); err != nil {
return "", err
}
if s.connState.AuthUser != "" {
s.log.Msg("incoming message",
"src_host", msgMeta.Conn.Hostname,
"src_ip", msgMeta.Conn.RemoteAddr.String(),
"sender", from,
"msg_id", msgMeta.ID,
"username", s.connState.AuthUser,
)
} else {
s.log.Msg("incoming message",
"src_host", msgMeta.Conn.Hostname,
"src_ip", msgMeta.Conn.RemoteAddr.String(),
"sender", from,
"msg_id", msgMeta.ID,
)
}
s.msgCtx, s.msgTask = trace.NewTask(ctx, "Incoming Message")
mailCtx, mailTask := trace.NewTask(s.msgCtx, "MAIL FROM")
defer mailTask.End()
delivery, err := s.endp.pipeline.Start(mailCtx, msgMeta, cleanFrom)
if err != nil {
s.msgCtx = nil
s.msgTask.End()
return msgMeta.ID, err
}
s.msgMeta = msgMeta
s.mailFrom = cleanFrom
s.delivery = delivery
return msgMeta.ID, nil
}
func (s *Session) Mail(from string, opts smtp.MailOptions) error {
s.msgLock.Lock()
defer s.msgLock.Unlock()
if !s.endp.deferServerReject {
// Will initialize s.msgCtx.
msgID, err := s.startDelivery(s.sessionCtx, from, opts)
if err != nil {
if err != context.DeadlineExceeded {
s.log.Error("MAIL FROM error", err, "msg_id", msgID)
}
return s.endp.wrapErr(msgID, !opts.UTF8, err)
}
}
// Keep the MAIL FROM argument for deferred startDelivery.
s.mailFrom = from
s.opts = opts
return nil
}
func (s *Session) fetchRDNSName(ctx context.Context) {
defer trace.StartRegion(ctx, "rDNS fetch").End()
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
if !ok {
s.connState.RDNSName.Set(nil, nil)
return
}
name, err := dns.LookupAddr(ctx, s.endp.resolver, tcpAddr.IP)
if err != nil {
dnsErr, ok := err.(*net.DNSError)
if ok && dnsErr.IsNotFound {
s.connState.RDNSName.Set(nil, nil)
return
}
reason, misc := exterrors.UnwrapDNSErr(err)
misc["reason"] = reason
s.log.Error("rDNS error", exterrors.WithFields(err, misc), "src_ip", s.connState.RemoteAddr)
s.connState.RDNSName.Set(nil, err)
return
}
s.connState.RDNSName.Set(name, nil)
}
func (s *Session) Rcpt(to string) error {
s.msgLock.Lock()
defer s.msgLock.Unlock()
// deferServerReject = true and this is the first RCPT TO command.
if s.delivery == nil {
// If we already attempted to initialize the delivery -
// fail again.
if s.deliveryErr != nil {
s.repeatedMailErrs++
// The deliveryErr is already wrapped.
return s.deliveryErr
}
// It will initialize s.msgCtx.
msgID, err := s.startDelivery(s.sessionCtx, s.mailFrom, s.opts)
if err != nil {
if err != context.DeadlineExceeded {
s.log.Error("MAIL FROM error (deferred)", err, "rcpt", to, "msg_id", msgID)
}
s.deliveryErr = s.endp.wrapErr(msgID, !s.opts.UTF8, err)
return s.deliveryErr
}
}
rcptCtx, rcptTask := trace.NewTask(s.msgCtx, "RCPT TO")
defer rcptTask.End()
if err := s.rcpt(rcptCtx, to); err != nil {
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
s.log.Error("RCPT error", err, "rcpt", to)
s.loggedRcptErrors++
if s.loggedRcptErrors == s.endp.maxLoggedRcptErrors {
s.log.Msg("too many RCPT errors, possible dictonary attack", "src_ip", s.connState.RemoteAddr, "msg_id", s.msgMeta.ID)
}
}
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
}
s.endp.Log.Msg("RCPT ok", "rcpt", to, "msg_id", s.msgMeta.ID)
return nil
}
func (s *Session) rcpt(ctx context.Context, to string) error {
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
// used.
if !address.IsASCII(to) && !s.opts.UTF8 {
return &exterrors.SMTPError{
Code: 553,
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
Message: "SMTPUTF8 is required for non-ASCII recipients",
}
}
cleanTo, err := address.CleanDomain(to)
if err != nil {
return &exterrors.SMTPError{
Code: 501,
EnhancedCode: exterrors.EnhancedCode{5, 1, 2},
Message: "Unable to normalize the recipient address",
}
}
return s.delivery.AddRcpt(ctx, cleanTo)
}
func (s *Session) Logout() error {
s.msgLock.Lock()
defer s.msgLock.Unlock()
if s.delivery != nil {
s.abort(s.msgCtx)
if s.repeatedMailErrs > s.endp.maxLoggedRcptErrors {
s.log.Msg("MAIL FROM repeated error a lot of times, possible dictonary attack", "count", s.repeatedMailErrs, "src_ip", s.connState.RemoteAddr)
}
}
if s.cancelRDNS != nil {
s.cancelRDNS()
}
return nil
}
func (s *Session) prepareBody(ctx context.Context, r io.Reader) (textproto.Header, buffer.Buffer, error) {
bufr := bufio.NewReader(r)
header, err := textproto.ReadHeader(bufr)
if err != nil {
return textproto.Header{}, nil, err
}
if s.endp.submission {
// The MsgMetadata is passed by pointer all the way down.
if err := s.submissionPrepare(s.msgMeta, &header); err != nil {
return textproto.Header{}, nil, err
}
}
buf, err := s.endp.buffer(bufr)
if err != nil {
return textproto.Header{}, nil, err
}
return header, buf, nil
}
func (s *Session) Data(r io.Reader) error {
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
defer bodyTask.End()
wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
}
header, buf, err := s.prepareBody(bodyCtx, r)
if err != nil {
return wrapErr(err)
}
defer func() {
if err := buf.Remove(); err != nil {
s.log.Error("failed to remove buffered body", err)
}
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil
s.msgCtx = nil
s.msgTask.End()
s.msgTask = nil
s.releaseLimits()
}()
if err := s.checkRoutingLoops(header); err != nil {
return wrapErr(err)
}
if err := s.delivery.Body(bodyCtx, header, buf); err != nil {
return wrapErr(err)
}
if err := s.delivery.Commit(bodyCtx); err != nil {
return wrapErr(err)
}
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
return nil
}
type statusWrapper struct {
sc smtp.StatusCollector
s *Session
}
func (sw statusWrapper) SetStatus(rcpt string, err error) {
sw.sc.SetStatus(rcpt, sw.s.endp.wrapErr(sw.s.msgMeta.ID, !sw.s.opts.UTF8, err))
}
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
defer bodyTask.End()
wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
}
header, buf, err := s.prepareBody(bodyCtx, r)
if err != nil {
return wrapErr(err)
}
defer func() {
if err := buf.Remove(); err != nil {
s.log.Error("failed to remove buffered body", err)
}
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil
s.msgCtx = nil
s.msgTask.End()
s.msgTask = nil
s.releaseLimits()
}()
if err := s.checkRoutingLoops(header); err != nil {
return wrapErr(err)
}
s.delivery.(module.PartialDelivery).BodyNonAtomic(bodyCtx, statusWrapper{sc, s}, header, buf)
// We can't really tell whether it is failed completely or succeeded
// so always commit. Should be harmless, anyway.
if err := s.delivery.Commit(bodyCtx); err != nil {
return wrapErr(err)
}
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
return nil
}
func (s *Session) checkRoutingLoops(header textproto.Header) error {
// RFC 5321 Section 6.3:
// >Simple counting of the number of "Received:" header fields in a
// >message has proven to be an effective, although rarely optimal,
// >method of detecting loops in mail systems.
receivedCount := 0
for f := header.FieldsByKey("Received"); f.Next(); {
receivedCount++
}
if receivedCount > s.endp.maxReceived {
return &exterrors.SMTPError{
Code: 554,
EnhancedCode: exterrors.EnhancedCode{5, 4, 6},
Message: fmt.Sprintf("Too many Received header fields (%d), possible forwarding loop", receivedCount),
}
}
return nil
}
func (endp *Endpoint) wrapErr(msgId string, mangleUTF8 bool, err error) error {
if err == nil {
return nil
}
if errors.Is(err, context.DeadlineExceeded) {
return &smtp.SMTPError{
Code: 451,
EnhancedCode: smtp.EnhancedCode{4, 4, 5},
Message: "High load, try again later",
}
}
res := &smtp.SMTPError{
Code: 554,
EnhancedCode: smtp.EnhancedCodeNotSet,
// Err on the side of caution if the error lacks SMTP annotations. If
// we just pass the error text through, we might accidenetally disclose
// details of server configuration.
Message: "Internal server error",
}
if exterrors.IsTemporary(err) {
res.Code = 451
}
ctxInfo := exterrors.Fields(err)
ctxCode, ok := ctxInfo["smtp_code"].(int)
if ok {
res.Code = ctxCode
}
ctxEnchCode, ok := ctxInfo["smtp_enchcode"].(exterrors.EnhancedCode)
if ok {
res.EnhancedCode = smtp.EnhancedCode(ctxEnchCode)
}
ctxMsg, ok := ctxInfo["smtp_msg"].(string)
if ok {
res.Message = ctxMsg
}
if smtpErr, ok := err.(*smtp.SMTPError); ok {
endp.Log.Printf("plain SMTP error returned, this is deprecated")
res.Code = smtpErr.Code
res.EnhancedCode = smtpErr.EnhancedCode
res.Message = smtpErr.Message
}
if msgId != "" {
res.Message += " (msg ID = " + msgId + ")"
}
// INTERNATIONALIZATION: See RFC 6531 Section 3.7.4.1.
if mangleUTF8 {
b := strings.Builder{}
b.Grow(len(res.Message))
for _, ch := range res.Message {
if ch > 128 {
b.WriteRune('?')
} else {
b.WriteRune(ch)
}
}
res.Message = b.String()
}
return res
}

View file

@ -1,31 +1,27 @@
package smtp
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"math/rand"
"net"
"os"
"path/filepath"
"runtime/trace"
"strings"
"sync"
"time"
"math/rand"
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/foxcpp/maddy/internal/address"
"github.com/foxcpp/maddy/internal/auth"
"github.com/foxcpp/maddy/internal/buffer"
"github.com/foxcpp/maddy/internal/config"
modconfig "github.com/foxcpp/maddy/internal/config/module"
"github.com/foxcpp/maddy/internal/dns"
"github.com/foxcpp/maddy/internal/exterrors"
"github.com/foxcpp/maddy/internal/future"
"github.com/foxcpp/maddy/internal/limits"
"github.com/foxcpp/maddy/internal/log"
@ -34,492 +30,9 @@ import (
"golang.org/x/net/idna"
)
type Session struct {
endp *Endpoint
// Specific for this session.
// sessionCtx is not used for cancellation or timeouts, only for tracing.
sessionCtx context.Context
cancelRDNS func()
connState module.ConnState
repeatedMailErrs int
loggedRcptErrors int
// Specific for the currently handled message.
// msgCtx is not used for cancellation or timeouts, only for tracing.
// It is the subcontext of sessionCtx.
// Mutex is used to prevent Close from accessing inconsistent state when it
// is called asynchronously to any SMTP command.
msgLock sync.Mutex
msgCtx context.Context
msgTask *trace.Task
mailFrom string
opts smtp.MailOptions
msgMeta *module.MsgMetadata
delivery module.Delivery
deliveryErr error
log log.Logger
}
func (s *Session) Reset() {
s.msgLock.Lock()
defer s.msgLock.Unlock()
if s.delivery != nil {
s.abort(s.msgCtx)
}
s.endp.Log.DebugMsg("reset")
}
func (s *Session) releaseLimits() {
_, domain, err := address.Split(s.mailFrom)
if err != nil {
return
}
addr, ok := s.msgMeta.Conn.RemoteAddr.(*net.TCPAddr)
if !ok {
addr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
}
s.endp.limits.ReleaseMsg(addr.IP, domain)
}
func (s *Session) abort(ctx context.Context) {
if err := s.delivery.Abort(ctx); err != nil {
s.endp.Log.Error("delivery abort failed", err)
}
s.log.Msg("aborted", "msg_id", s.msgMeta.ID)
s.mailFrom = ""
s.opts = smtp.MailOptions{}
s.msgMeta = nil
s.delivery = nil
s.deliveryErr = nil
s.msgCtx = nil
s.msgTask.End()
}
func (s *Session) startDelivery(ctx context.Context, from string, opts smtp.MailOptions) (string, error) {
var err error
msgMeta := &module.MsgMetadata{
Conn: &s.connState,
SMTPOpts: opts,
}
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
// used.
for _, ch := range from {
if ch > 128 && !opts.UTF8 {
return "", &exterrors.SMTPError{
Code: 550,
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
Message: "SMTPUTF8 is required for non-ASCII senders",
}
}
}
// Decode punycode, normalize to NFC and case-fold address.
cleanFrom, err := address.CleanDomain(from)
if err != nil {
return "", &exterrors.SMTPError{
Code: 553,
EnhancedCode: exterrors.EnhancedCode{5, 1, 7},
Message: "Unable to normalize the sender address",
}
}
msgMeta.ID, err = msgpipeline.GenerateMsgID()
if err != nil {
return "", err
}
msgMeta.OriginalFrom = from
_, domain, err := address.Split(cleanFrom)
if err != nil {
return "", err
}
remoteIP, ok := msgMeta.Conn.RemoteAddr.(*net.TCPAddr)
if !ok {
remoteIP = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}
}
if err := s.endp.limits.TakeMsg(context.Background(), remoteIP.IP, domain); err != nil {
return "", err
}
if s.connState.AuthUser != "" {
s.log.Msg("incoming message",
"src_host", msgMeta.Conn.Hostname,
"src_ip", msgMeta.Conn.RemoteAddr.String(),
"sender", from,
"msg_id", msgMeta.ID,
"username", s.connState.AuthUser,
)
} else {
s.log.Msg("incoming message",
"src_host", msgMeta.Conn.Hostname,
"src_ip", msgMeta.Conn.RemoteAddr.String(),
"sender", from,
"msg_id", msgMeta.ID,
)
}
s.msgCtx, s.msgTask = trace.NewTask(ctx, "Incoming Message")
mailCtx, mailTask := trace.NewTask(s.msgCtx, "MAIL FROM")
defer mailTask.End()
delivery, err := s.endp.pipeline.Start(mailCtx, msgMeta, cleanFrom)
if err != nil {
s.msgCtx = nil
s.msgTask.End()
return msgMeta.ID, err
}
s.msgMeta = msgMeta
s.mailFrom = cleanFrom
s.delivery = delivery
return msgMeta.ID, nil
}
func (s *Session) Mail(from string, opts smtp.MailOptions) error {
s.msgLock.Lock()
defer s.msgLock.Unlock()
if !s.endp.deferServerReject {
// Will initialize s.msgCtx.
msgID, err := s.startDelivery(s.sessionCtx, from, opts)
if err != nil {
if err != context.DeadlineExceeded {
s.log.Error("MAIL FROM error", err, "msg_id", msgID)
}
return s.endp.wrapErr(msgID, !opts.UTF8, err)
}
}
// Keep the MAIL FROM argument for deferred startDelivery.
s.mailFrom = from
s.opts = opts
return nil
}
func (s *Session) fetchRDNSName(ctx context.Context) {
defer trace.StartRegion(ctx, "rDNS fetch").End()
tcpAddr, ok := s.connState.RemoteAddr.(*net.TCPAddr)
if !ok {
s.connState.RDNSName.Set(nil, nil)
return
}
name, err := dns.LookupAddr(ctx, s.endp.resolver, tcpAddr.IP)
if err != nil {
dnsErr, ok := err.(*net.DNSError)
if ok && dnsErr.IsNotFound {
s.connState.RDNSName.Set(nil, nil)
return
}
reason, misc := exterrors.UnwrapDNSErr(err)
misc["reason"] = reason
s.log.Error("rDNS error", exterrors.WithFields(err, misc), "src_ip", s.connState.RemoteAddr)
s.connState.RDNSName.Set(nil, err)
return
}
s.connState.RDNSName.Set(name, nil)
}
func (s *Session) Rcpt(to string) error {
s.msgLock.Lock()
defer s.msgLock.Unlock()
// deferServerReject = true and this is the first RCPT TO command.
if s.delivery == nil {
// If we already attempted to initialize the delivery -
// fail again.
if s.deliveryErr != nil {
s.repeatedMailErrs++
// The deliveryErr is already wrapped.
return s.deliveryErr
}
// It will initialize s.msgCtx.
msgID, err := s.startDelivery(s.sessionCtx, s.mailFrom, s.opts)
if err != nil {
if err != context.DeadlineExceeded {
s.log.Error("MAIL FROM error (deferred)", err, "rcpt", to, "msg_id", msgID)
}
s.deliveryErr = s.endp.wrapErr(msgID, !s.opts.UTF8, err)
return s.deliveryErr
}
}
rcptCtx, rcptTask := trace.NewTask(s.msgCtx, "RCPT TO")
defer rcptTask.End()
if err := s.rcpt(rcptCtx, to); err != nil {
if s.loggedRcptErrors < s.endp.maxLoggedRcptErrors {
s.log.Error("RCPT error", err, "rcpt", to)
s.loggedRcptErrors++
if s.loggedRcptErrors == s.endp.maxLoggedRcptErrors {
s.log.Msg("too many RCPT errors, possible dictonary attack", "src_ip", s.connState.RemoteAddr, "msg_id", s.msgMeta.ID)
}
}
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
}
s.endp.Log.Msg("RCPT ok", "rcpt", to, "msg_id", s.msgMeta.ID)
return nil
}
func (s *Session) rcpt(ctx context.Context, to string) error {
// INTERNATIONALIZATION: Do not permit non-ASCII addresses unless SMTPUTF8 is
// used.
if !address.IsASCII(to) && !s.opts.UTF8 {
return &exterrors.SMTPError{
Code: 553,
EnhancedCode: exterrors.EnhancedCode{5, 6, 7},
Message: "SMTPUTF8 is required for non-ASCII recipients",
}
}
cleanTo, err := address.CleanDomain(to)
if err != nil {
return &exterrors.SMTPError{
Code: 501,
EnhancedCode: exterrors.EnhancedCode{5, 1, 2},
Message: "Unable to normalize the recipient address",
}
}
return s.delivery.AddRcpt(ctx, cleanTo)
}
func (s *Session) Logout() error {
s.msgLock.Lock()
defer s.msgLock.Unlock()
if s.delivery != nil {
s.abort(s.msgCtx)
if s.repeatedMailErrs > s.endp.maxLoggedRcptErrors {
s.log.Msg("MAIL FROM repeated error a lot of times, possible dictonary attack", "count", s.repeatedMailErrs, "src_ip", s.connState.RemoteAddr)
}
}
if s.cancelRDNS != nil {
s.cancelRDNS()
}
return nil
}
func (s *Session) prepareBody(ctx context.Context, r io.Reader) (textproto.Header, buffer.Buffer, error) {
bufr := bufio.NewReader(r)
header, err := textproto.ReadHeader(bufr)
if err != nil {
return textproto.Header{}, nil, err
}
if s.endp.submission {
// The MsgMetadata is passed by pointer all the way down.
if err := s.submissionPrepare(s.msgMeta, &header); err != nil {
return textproto.Header{}, nil, err
}
}
buf, err := s.endp.buffer(bufr)
if err != nil {
return textproto.Header{}, nil, err
}
return header, buf, nil
}
func (s *Session) Data(r io.Reader) error {
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
defer bodyTask.End()
wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
}
header, buf, err := s.prepareBody(bodyCtx, r)
if err != nil {
return wrapErr(err)
}
defer func() {
if err := buf.Remove(); err != nil {
s.log.Error("failed to remove buffered body", err)
}
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil
s.msgCtx = nil
s.msgTask.End()
s.msgTask = nil
s.releaseLimits()
}()
if err := s.checkRoutingLoops(header); err != nil {
return wrapErr(err)
}
if err := s.delivery.Body(bodyCtx, header, buf); err != nil {
return wrapErr(err)
}
if err := s.delivery.Commit(bodyCtx); err != nil {
return wrapErr(err)
}
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
return nil
}
type statusWrapper struct {
sc smtp.StatusCollector
s *Session
}
func (sw statusWrapper) SetStatus(rcpt string, err error) {
sw.sc.SetStatus(rcpt, sw.s.endp.wrapErr(sw.s.msgMeta.ID, !sw.s.opts.UTF8, err))
}
func (s *Session) LMTPData(r io.Reader, sc smtp.StatusCollector) error {
bodyCtx, bodyTask := trace.NewTask(s.msgCtx, "DATA")
defer bodyTask.End()
wrapErr := func(err error) error {
s.log.Error("DATA error", err, "msg_id", s.msgMeta.ID)
return s.endp.wrapErr(s.msgMeta.ID, !s.opts.UTF8, err)
}
header, buf, err := s.prepareBody(bodyCtx, r)
if err != nil {
return wrapErr(err)
}
defer func() {
if err := buf.Remove(); err != nil {
s.log.Error("failed to remove buffered body", err)
}
// go-smtp will call Reset, but it will call Abort if delivery is non-nil.
s.delivery = nil
s.msgCtx = nil
s.msgTask.End()
s.msgTask = nil
s.releaseLimits()
}()
if err := s.checkRoutingLoops(header); err != nil {
return wrapErr(err)
}
s.delivery.(module.PartialDelivery).BodyNonAtomic(bodyCtx, statusWrapper{sc, s}, header, buf)
// We can't really tell whether it is failed completely or succeeded
// so always commit. Should be harmless, anyway.
if err := s.delivery.Commit(bodyCtx); err != nil {
return wrapErr(err)
}
s.log.Msg("accepted", "msg_id", s.msgMeta.ID)
return nil
}
func (s *Session) checkRoutingLoops(header textproto.Header) error {
// RFC 5321 Section 6.3:
// >Simple counting of the number of "Received:" header fields in a
// >message has proven to be an effective, although rarely optimal,
// >method of detecting loops in mail systems.
receivedCount := 0
for f := header.FieldsByKey("Received"); f.Next(); {
receivedCount++
}
if receivedCount > s.endp.maxReceived {
return &exterrors.SMTPError{
Code: 554,
EnhancedCode: exterrors.EnhancedCode{5, 4, 6},
Message: fmt.Sprintf("Too many Received header fields (%d), possible forwarding loop", receivedCount),
}
}
return nil
}
func (endp *Endpoint) wrapErr(msgId string, mangleUTF8 bool, err error) error {
if err == nil {
return nil
}
if errors.Is(err, context.DeadlineExceeded) {
return &smtp.SMTPError{
Code: 451,
EnhancedCode: smtp.EnhancedCode{4, 4, 5},
Message: "High load, try again later",
}
}
res := &smtp.SMTPError{
Code: 554,
EnhancedCode: smtp.EnhancedCodeNotSet,
// Err on the side of caution if the error lacks SMTP annotations. If
// we just pass the error text through, we might accidenetally disclose
// details of server configuration.
Message: "Internal server error",
}
if exterrors.IsTemporary(err) {
res.Code = 451
}
ctxInfo := exterrors.Fields(err)
ctxCode, ok := ctxInfo["smtp_code"].(int)
if ok {
res.Code = ctxCode
}
ctxEnchCode, ok := ctxInfo["smtp_enchcode"].(exterrors.EnhancedCode)
if ok {
res.EnhancedCode = smtp.EnhancedCode(ctxEnchCode)
}
ctxMsg, ok := ctxInfo["smtp_msg"].(string)
if ok {
res.Message = ctxMsg
}
if smtpErr, ok := err.(*smtp.SMTPError); ok {
endp.Log.Printf("plain SMTP error returned, this is deprecated")
res.Code = smtpErr.Code
res.EnhancedCode = smtpErr.EnhancedCode
res.Message = smtpErr.Message
}
if msgId != "" {
res.Message += " (msg ID = " + msgId + ")"
}
// INTERNATIONALIZATION: See RFC 6531 Section 3.7.4.1.
if mangleUTF8 {
b := strings.Builder{}
b.Grow(len(res.Message))
for _, ch := range res.Message {
if ch > 128 {
b.WriteRune('?')
} else {
b.WriteRune(ch)
}
}
res.Message = b.String()
}
return res
}
type Endpoint struct {
hostname string
Auth module.PlainAuth
saslAuth auth.SASLAuth
serv *smtp.Server
name string
addrs []string
@ -572,10 +85,6 @@ func (endp *Endpoint) Init(cfg *config.Map) error {
return err
}
if endp.Auth != nil {
endp.Log.Debugf("authentication provider: %s %s", endp.Auth.(module.Module).Name(), endp.Auth.(module.Module).InstanceName())
}
addresses := make([]config.Endpoint, 0, len(endp.addrs))
for _, addr := range endp.addrs {
saddr, err := config.ParseEndpoint(addr)
@ -695,7 +204,9 @@ func (endp *Endpoint) setConfig(cfg *config.Map) error {
ioDebug bool
)
cfg.Custom("auth", false, false, nil, modconfig.AuthDirective, &endp.Auth)
cfg.Callback("auth", func(m *config.Map, node *config.Node) error {
return endp.saslAuth.AddProvider(m, node)
})
cfg.String("hostname", true, true, "", &endp.hostname)
cfg.Duration("write_timeout", false, false, 1*time.Minute, &endp.serv.WriteTimeout)
cfg.Duration("read_timeout", false, false, 10*time.Minute, &endp.serv.ReadTimeout)
@ -738,14 +249,32 @@ func (endp *Endpoint) setConfig(cfg *config.Map) error {
endp.pipeline.Log = log.Logger{Name: "smtp/pipeline", Debug: endp.Log.Debug}
endp.pipeline.FirstPipeline = true
endp.serv.AuthDisabled = endp.Auth == nil
endp.serv.AuthDisabled = len(endp.saslAuth.SASLMechanisms()) == 0
if endp.submission {
endp.authAlwaysRequired = true
if endp.Auth == nil {
if len(endp.saslAuth.SASLMechanisms()) == 0 {
return fmt.Errorf("%s: auth. provider must be set for submission endpoint", endp.name)
}
}
for _, mech := range endp.saslAuth.SASLMechanisms() {
// TODO: The code below lacks handling to set AuthPassword. Don't
// override sasl.Plain handler so Login() will be called as usual.
if mech == sasl.Plain {
continue
}
endp.serv.EnableAuth(mech, func(c *smtp.Conn) sasl.Server {
state := c.State()
if err := endp.pipeline.RunEarlyChecks(context.TODO(), &state); err != nil {
return auth.FailingSASLServ{Err: endp.wrapErr("", true, err)}
}
return endp.saslAuth.CreateSASL(mech, state.RemoteAddr, func(ids []string) error {
c.SetSession(endp.newSession(false, ids[0], "", &state))
return nil
})
})
}
// INTERNATIONALIZATION: See RFC 6531 Section 3.3.
endp.serv.Domain, err = idna.ToASCII(endp.hostname)
@ -794,7 +323,7 @@ func (endp *Endpoint) setupListeners(addresses []config.Endpoint) error {
}
func (endp *Endpoint) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
if endp.Auth == nil {
if endp.serv.AuthDisabled {
return nil, smtp.ErrAuthUnsupported
}
@ -803,7 +332,7 @@ func (endp *Endpoint) Login(state *smtp.ConnectionState, username, password stri
return nil, endp.wrapErr("", true, err)
}
_, err := endp.Auth.AuthPlain(username, password)
_, err := endp.saslAuth.AuthPlain(username, password)
if err != nil {
// TODO: Update fail2ban filters.
endp.Log.Error("authentication failed", err, "username", username, "src_ip", state.RemoteAddr)

View file

@ -13,6 +13,7 @@ import (
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/foxcpp/go-mockdns"
"github.com/foxcpp/maddy/internal/auth"
"github.com/foxcpp/maddy/internal/config"
"github.com/foxcpp/maddy/internal/exterrors"
"github.com/foxcpp/maddy/internal/module"
@ -27,7 +28,7 @@ const testMsg = "From: <sender@example.org>\r\n" +
"\r\n" +
"foobar\r\n"
func testEndpoint(t *testing.T, modName string, auth module.PlainAuth, tgt module.DeliveryTarget, checks []module.Check, cfg []config.Node) *Endpoint {
func testEndpoint(t *testing.T, modName string, authMod module.PlainAuth, tgt module.DeliveryTarget, checks []module.Check, cfg []config.Node) *Endpoint {
t.Helper()
mod, err := New(modName, []string{"tcp://127.0.0.1:" + testPort})
@ -63,7 +64,7 @@ func testEndpoint(t *testing.T, modName string, auth module.PlainAuth, tgt modul
},
)
if auth != nil {
if authMod != nil {
cfg = append(cfg, config.Node{
Name: "auth",
Args: []string{"dummy"},
@ -77,7 +78,10 @@ func testEndpoint(t *testing.T, modName string, auth module.PlainAuth, tgt modul
t.Fatal(err)
}
endp.Auth = auth
endp.saslAuth = auth.SASLAuth{
Log: testutils.Logger(t, "smtp/saslauth"),
Plain: []module.PlainAuth{authMod},
}
endp.pipeline = msgpipeline.Mock(tgt, checks)
endp.pipeline.Hostname = "mx.example.com"

View file

@ -14,3 +14,21 @@ var (
type PlainAuth interface {
AuthPlain(username, password string) ([]string, error)
}
// SASLProvider is the interface implemented by modules and used by protocol
// endpoints that rely on SASL framework for user authentication.
//
// This actual interface is only used to indicate that the module is a
// SASL-compatible auth. provider. For each unique value returned by
// SASLMechanisms, the module object should also implement the coresponding
// mechanism-specific interface.
//
// *Rationale*: There is no single generic interface that would handle any SASL
// mechanism while permiting the use of a credentials set estabilished once with
// multiple auth. providers at once.
//
// Per-mechanism interfaces:
// - PLAIN => PlainAuth
type SASLProvider interface {
SASLMechanisms() []string
}

View file

@ -4,6 +4,7 @@ import (
"context"
"github.com/emersion/go-message/textproto"
"github.com/emersion/go-sasl"
"github.com/foxcpp/maddy/internal/buffer"
"github.com/foxcpp/maddy/internal/config"
)
@ -15,6 +16,10 @@ import (
// and the actual server code (but the latter is kinda pointless).
type Dummy struct{ instName string }
func (d *Dummy) SASLMechanisms() []string {
return []string{sasl.Plain, sasl.Login}
}
func (d *Dummy) AuthPlain(username, _ string) ([]string, error) {
return []string{username}, nil
}