diff --git a/docs/man/maddy-targets.5.scd b/docs/man/maddy-targets.5.scd index 50f835a..dfc1caf 100644 --- a/docs/man/maddy-targets.5.scd +++ b/docs/man/maddy-targets.5.scd @@ -151,6 +151,22 @@ need to have support from all servers. It is based on the assumption that server referenced by MX record is likely the final destination and therefore there is only need to secure communication towards it and not beyond. +*Syntax*: conn_reuse_limit _integer_ ++ +*Default*: 10 + +Amount of times the same SMTP connection can be used. +Connections are never reused if the previous DATA command failed. + +*Syntax*: conn_max_idle_count _integer_ ++ +*Default*: 10 + +Max. amount of idle connections per recipient domains to keep in cache. + +*Syntax*: conn_max_idle_time _integer_ ++ +*Default*: 150 (2.5 min) + +Amount of time the idle connection is still considered potentially usable. + ## Security policies *Syntax*: mx_auth _config block_ ++ diff --git a/internal/smtpconn/pool/pool.go b/internal/smtpconn/pool/pool.go new file mode 100644 index 0000000..812a478 --- /dev/null +++ b/internal/smtpconn/pool/pool.go @@ -0,0 +1,140 @@ +package pool + +import ( + "context" + "sync" + "time" +) + +type Conn interface { + Usable() bool + Close() error +} + +type Config struct { + New func(ctx context.Context, key string) (Conn, error) + MaxKeys int + MaxConnsPerKey int + MaxConnLifetimeSec int64 + StaleKeyLifetimeSec int64 +} + +type slot struct { + c chan Conn + // To keep slot size smaller it is just a unix timestamp. + lastUse int64 +} + +type P struct { + cfg Config + keys map[string]slot + keysLock sync.Mutex +} + +func New(cfg Config) *P { + if cfg.New == nil { + cfg.New = func(context.Context, string) (Conn, error) { + return nil, nil + } + } + + return &P{ + cfg: cfg, + keys: make(map[string]slot, cfg.MaxKeys), + } +} + +func (p *P) Get(ctx context.Context, key string) (Conn, error) { + // TODO: See if it is possible to get rid of this lock. + p.keysLock.Lock() + defer p.keysLock.Unlock() + + bucket, ok := p.keys[key] + if !ok { + return p.cfg.New(ctx, key) + } + + if time.Now().Unix()-bucket.lastUse > p.cfg.MaxConnLifetimeSec { + // Drop bucket. + close(bucket.c) + for conn := range bucket.c { + conn.Close() + } + delete(p.keys, key) + + return p.cfg.New(ctx, key) + } + + for { + var conn Conn + select { + case conn, ok = <-bucket.c: + if !ok { + return p.cfg.New(ctx, key) + } + default: + return p.cfg.New(ctx, key) + } + + if !conn.Usable() { + continue + } + + return conn, nil + } +} + +func (p *P) Return(key string, c Conn) { + p.keysLock.Lock() + defer p.keysLock.Unlock() + + if p.keys == nil { + return + } + + bucket, ok := p.keys[key] + if !ok { + // Garbage-collect stale buckets. + if len(p.keys) == p.cfg.MaxKeys { + for k, v := range p.keys { + if v.lastUse+p.cfg.StaleKeyLifetimeSec > time.Now().Unix() { + continue + } + + close(v.c) + for conn := range v.c { + conn.Close() + } + delete(p.keys, k) + } + } + + bucket = slot{ + c: make(chan Conn, p.cfg.MaxConnsPerKey), + lastUse: time.Now().Unix(), + } + p.keys[key] = bucket + } + + select { + case bucket.c <- c: + bucket.lastUse = time.Now().Unix() + default: + // Let it go, let it go... + c.Close() + } +} + +func (p *P) Close() { + p.keysLock.Lock() + defer p.keysLock.Unlock() + + for k, v := range p.keys { + close(v.c) + for conn := range v.c { + conn.Close() + } + delete(p.keys, k) + } + p.keys = nil +} diff --git a/internal/smtpconn/smtpconn.go b/internal/smtpconn/smtpconn.go index e8f4733..448328b 100644 --- a/internal/smtpconn/smtpconn.go +++ b/internal/smtpconn/smtpconn.go @@ -12,6 +12,7 @@ package smtpconn import ( "context" "crypto/tls" + "errors" "io" "net" "runtime/trace" @@ -368,6 +369,14 @@ func (c *C) LMTPData(ctx context.Context, hdr textproto.Header, body io.Reader, return nil } +func (c *C) Noop() error { + if c.cl == nil { + return errors.New("smtpconn: nto connected") + } + + return c.cl.Noop() +} + // Close sends the QUIT command, if it fail - it directly closes the // connection. func (c *C) Close() error { diff --git a/internal/target/remote/connect.go b/internal/target/remote/connect.go index dfaab54..885387c 100644 --- a/internal/target/remote/connect.go +++ b/internal/target/remote/connect.go @@ -19,6 +19,29 @@ type mxConn struct { // Domain this MX belongs to. domain string dnssecOk bool + + // Errors occured previously on this connection. + errored bool + + reuseLimit int + + // Amount of times connection was used for an SMTP transaction. + transactions int + + // MX/TLS security level established for this connection. + mxLevel MXLevel + tlsLevel TLSLevel +} + +func (c *mxConn) Usable() bool { + if c.C == nil || c.transactions > c.reuseLimit || c.C.Client() == nil { + return false + } + return c.C.Client().Reset() == nil +} + +func (c *mxConn) Close() error { + return c.C.Close() } func isVerifyError(err error) bool { @@ -103,7 +126,7 @@ retry: return tlsLevel, tlsErr, nil } -func (rd *remoteDelivery) attemptMX(ctx context.Context, conn mxConn, record *net.MX) error { +func (rd *remoteDelivery) attemptMX(ctx context.Context, conn *mxConn, record *net.MX) error { mxLevel := MXNone connCtx, cancel := context.WithCancel(ctx) @@ -122,7 +145,7 @@ func (rd *remoteDelivery) attemptMX(ctx context.Context, conn mxConn, record *ne p.PrepareConn(ctx, record.Host) } - tlsLevel, tlsErr, err := rd.connect(connCtx, conn, record.Host, rd.rt.tlsConfig) + tlsLevel, tlsErr, err := rd.connect(connCtx, *conn, record.Host, rd.rt.tlsConfig) if err != nil { return err } @@ -144,30 +167,8 @@ func (rd *remoteDelivery) attemptMX(ctx context.Context, conn mxConn, record *ne } } - if rd.msgMeta.SMTPOpts.RequireTLS { - if tlsLevel < TLSAuthenticated { - conn.Close() - return &exterrors.SMTPError{ - Code: 550, - EnhancedCode: exterrors.EnhancedCode{5, 7, 30}, - Message: "TLS it not available or unauthenticated but required (REQUIRETLS)", - Misc: map[string]interface{}{ - "tls_level": tlsLevel, - }, - } - } - if mxLevel < MX_MTASTS { - conn.Close() - return &exterrors.SMTPError{ - Code: 550, - EnhancedCode: exterrors.EnhancedCode{5, 7, 30}, - Message: "Failed to estabilish the MX record authenticity (REQUIRETLS)", - Misc: map[string]interface{}{ - "mx_level": mxLevel, - }, - } - } - } + conn.mxLevel = mxLevel + conn.tlsLevel = tlsLevel mxLevelCnt.WithLabelValues(rd.rt.Name(), mxLevel.String()).Inc() tlsLevelCnt.WithLabelValues(rd.rt.Name(), tlsLevel.String()).Inc() @@ -180,9 +181,84 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string return c.C, nil } + pooledConn, err := rd.rt.pool.Get(ctx, domain) + if err != nil { + return nil, err + } + + var conn *mxConn + // Ignore pool for connections with REQUIRETLS to avoid "pool poisoning" + // where attacker can make messages indeliverable by forcing reuse of old + // connection with weaker security. + if pooledConn != nil && !rd.msgMeta.SMTPOpts.RequireTLS { + conn = pooledConn.(*mxConn) + rd.Log.Msg("reusing cached connection", "domain", domain, "transactions_counter", conn.transactions) + } else { + rd.Log.DebugMsg("opening new connection", "domain", domain, "cache_ignored", pooledConn != nil) + conn, err = rd.newConn(ctx, domain) + if err != nil { + return nil, err + } + } + + if rd.msgMeta.SMTPOpts.RequireTLS { + if conn.tlsLevel < TLSAuthenticated { + conn.Close() + return nil, &exterrors.SMTPError{ + Code: 550, + EnhancedCode: exterrors.EnhancedCode{5, 7, 30}, + Message: "TLS it not available or unauthenticated but required (REQUIRETLS)", + Misc: map[string]interface{}{ + "tls_level": conn.tlsLevel, + }, + } + } + if conn.mxLevel < MX_MTASTS { + conn.Close() + return nil, &exterrors.SMTPError{ + Code: 550, + EnhancedCode: exterrors.EnhancedCode{5, 7, 30}, + Message: "Failed to estabilish the MX record authenticity (REQUIRETLS)", + Misc: map[string]interface{}{ + "mx_level": conn.mxLevel, + }, + } + } + } + + region := trace.StartRegion(ctx, "remote/limits.TakeDest") + if err := rd.rt.limits.TakeDest(ctx, domain); err != nil { + region.End() + return nil, err + } + region.End() + + // Relaxed REQUIRETLS mode is not conforming to the specification strictly + // but allows to start deploying client support for REQUIRETLS without the + // requirement for servers in the whole world to support it. The assumption + // behind it is that MX for the recipient domain is the final destination + // and all other forwarders behind it already have secure connection to + // each other. Therefore it is enough to enforce strict security only on + // the path to the MX even if it does not support the REQUIRETLS to propagate + // this requirement further. + if ok, _ := conn.Client().Extension("REQUIRETLS"); rd.rt.relaxedREQUIRETLS && !ok { + rd.msgMeta.SMTPOpts.RequireTLS = false + } + + if err := conn.Mail(ctx, rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil { + conn.Close() + return nil, err + } + + rd.connections[domain] = conn + return conn.C, nil +} + +func (rd *remoteDelivery) newConn(ctx context.Context, domain string) (*mxConn, error) { conn := mxConn{ - C: smtpconn.New(), - domain: domain, + reuseLimit: rd.rt.connReuseLimit, + C: smtpconn.New(), + domain: domain, } conn.Dialer = rd.rt.dialer @@ -202,13 +278,6 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string } conn.dnssecOk = dnssecOk - region = trace.StartRegion(ctx, "remote/limits.TakeDest") - if err := rd.rt.limits.TakeDest(ctx, domain); err != nil { - region.End() - return nil, err - } - region.End() - var lastErr error region = trace.StartRegion(ctx, "remote/Connect+TLS") for _, record := range records { @@ -220,7 +289,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string } } - if err := rd.attemptMX(ctx, conn, record); err != nil { + if err := rd.attemptMX(ctx, &conn, record); err != nil { if len(records) != 0 { rd.Log.Error("cannot use MX", err, "remote_server", record.Host, "domain", domain) } @@ -245,25 +314,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string } } - // Relaxed REQUIRETLS mode is not conforming to the specification strictly - // but allows to start deploying client support for REQUIRETLS without the - // requirement for servers in the whole world to support it. The assumption - // behind it is that MX for the recipient domain is the final destination - // and all other forwarders behind it already have secure connection to - // each other. Therefore it is enough to enforce strict security only on - // the path to the MX even if it does not support the REQUIRETLS to propagate - // this requirement further. - if ok, _ := conn.Client().Extension("REQUIRETLS"); rd.rt.relaxedREQUIRETLS && !ok { - rd.msgMeta.SMTPOpts.RequireTLS = false - } - - if err := conn.Mail(ctx, rd.mailFrom, rd.msgMeta.SMTPOpts); err != nil { - conn.Close() - return nil, err - } - - rd.connections[domain] = conn - return conn.C, nil + return &conn, nil } func (rd *remoteDelivery) lookupMX(ctx context.Context, domain string) (dnssecOk bool, records []*net.MX, err error) { diff --git a/internal/target/remote/remote.go b/internal/target/remote/remote.go index 6f6f996..ceb2434 100644 --- a/internal/target/remote/remote.go +++ b/internal/target/remote/remote.go @@ -25,6 +25,7 @@ import ( "github.com/foxcpp/maddy/internal/limits" "github.com/foxcpp/maddy/internal/log" "github.com/foxcpp/maddy/internal/module" + "github.com/foxcpp/maddy/internal/smtpconn/pool" "github.com/foxcpp/maddy/internal/target" "golang.org/x/net/idna" ) @@ -98,6 +99,9 @@ type Target struct { allowSecOverride bool relaxedREQUIRETLS bool + pool *pool.P + connReuseLimit int + Log log.Logger } @@ -107,6 +111,7 @@ func New(_, instName string, _, inlineArgs []string) (module.Module, error) { if len(inlineArgs) != 0 { return nil, errors.New("remote: inline arguments are not used") } + // Keep this synchronized with testTarget. return &Target{ name: instName, resolver: dns.DefaultResolver(), @@ -151,10 +156,21 @@ func (rt *Target) Init(cfg *config.Map) error { }, &rt.limits) cfg.Bool("requiretls_override", false, true, &rt.allowSecOverride) cfg.Bool("relaxed_requiretls", false, true, &rt.relaxedREQUIRETLS) + cfg.Int("conn_reuse_limit", false, false, 10, &rt.connReuseLimit) + + poolCfg := pool.Config{ + MaxKeys: 20000, + MaxConnsPerKey: 10, // basically, max. amount of idle connections in cache + MaxConnLifetimeSec: 150, // 2.5 mins, half of recommended idle time from RFC 5321 + StaleKeyLifetimeSec: 60 * 5, // should be bigger than MaxConnLifetimeSec + } + cfg.Int("conn_max_idle_count", false, false, 10, &poolCfg.MaxConnsPerKey) + cfg.Int64("conn_max_idle_time", false, false, 150, &poolCfg.MaxConnLifetimeSec) if _, err := cfg.Process(); err != nil { return err } + rt.pool = pool.New(poolCfg) // INTERNATIONALIZATION: See RFC 6531 Section 3.7.1. rt.hostname, err = idna.ToASCII(rt.hostname) @@ -180,6 +196,8 @@ func (rt *Target) Close() error { p.Close() } + rt.pool.Close() + return nil } @@ -198,7 +216,7 @@ type remoteDelivery struct { Log log.Logger recipients []string - connections map[string]mxConn + connections map[string]*mxConn policies []DeliveryPolicy } @@ -258,7 +276,7 @@ func (rt *Target) Start(ctx context.Context, msgMeta *module.MsgMetadata, mailFr mailFrom: mailFrom, msgMeta: msgMeta, Log: target.DeliveryLogger(rt.Log, msgMeta), - connections: map[string]mxConn{}, + connections: map[string]*mxConn{}, policies: policies, }, nil } @@ -397,7 +415,8 @@ func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusColl var wg sync.WaitGroup - for _, conn := range rd.connections { + for i, conn := range rd.connections { + i := i conn := conn wg.Add(1) go func() { @@ -416,6 +435,7 @@ func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusColl for _, rcpt := range conn.Rcpts() { c.SetStatus(rcpt, err) } + rd.connections[i].errored = err != nil }() } @@ -434,11 +454,17 @@ func (rd *remoteDelivery) Commit(ctx context.Context) error { func (rd *remoteDelivery) Close() error { for _, conn := range rd.connections { - rd.Log.Debugf("disconnected from %s", conn.ServerName()) - rd.rt.limits.ReleaseDest(conn.domain) + conn.transactions++ - conn.Close() + if conn.C == nil || conn.transactions > rd.rt.connReuseLimit || conn.C.Client() == nil || conn.errored { + rd.Log.Debugf("disconnected from %s (errored=%v,transactions=%v,disconnected before=%v)", + conn.ServerName(), conn.errored, conn.transactions, conn.C.Client() == nil) + conn.Close() + } else { + rd.Log.Debugf("returning connection for %s to pool", conn.ServerName()) + rd.rt.pool.Return(conn.domain, conn) + } } var ( diff --git a/internal/target/remote/remote_test.go b/internal/target/remote/remote_test.go index 200775d..d7b5562 100644 --- a/internal/target/remote/remote_test.go +++ b/internal/target/remote/remote_test.go @@ -22,6 +22,7 @@ import ( "github.com/foxcpp/maddy/internal/exterrors" "github.com/foxcpp/maddy/internal/limits" "github.com/foxcpp/maddy/internal/module" + "github.com/foxcpp/maddy/internal/smtpconn/pool" "github.com/foxcpp/maddy/internal/testutils" ) @@ -43,6 +44,12 @@ func testTarget(t *testing.T, zones map[string]mockdns.Zone, extResolver *dns.Ex Log: testutils.Logger(t, "remote"), policies: extraPolicies, limits: &limits.Group{}, + pool: pool.New(pool.Config{ + MaxKeys: 20000, + MaxConnsPerKey: 10, // basically, max. amount of idle connections in cache + MaxConnLifetimeSec: 150, // 2.5 mins, half of recommended idle time from RFC 5321 + StaleKeyLifetimeSec: 60 * 5, // should be bigger than MaxConnLifetimeSec + }), } return &tgt @@ -975,3 +982,30 @@ func TestMain(m *testing.M) { smtpPort = *remoteSmtpPort os.Exit(m.Run()) } + +func TestRemoteDelivery_ConnReuse(t *testing.T) { + be, srv := testutils.SMTPServer(t, "127.0.0.1:"+smtpPort) + defer srv.Close() + defer testutils.CheckSMTPConnLeak(t, srv) + zones := map[string]mockdns.Zone{ + "example.invalid.": { + MX: []net.MX{{Host: "mx.example.invalid.", Pref: 10}}, + }, + "mx.example.invalid.": { + A: []string{"127.0.0.1"}, + }, + } + + tgt := testTarget(t, zones, nil, nil) + tgt.connReuseLimit = 5 + defer tgt.Close() + testutils.DoTestDelivery(t, tgt, "test@example.com", []string{"test@example.invalid"}) + testutils.DoTestDelivery(t, tgt, "test@example.com", []string{"test@example.invalid"}) + + be.CheckMsg(t, 0, "test@example.com", []string{"test@example.invalid"}) + be.CheckMsg(t, 1, "test@example.com", []string{"test@example.invalid"}) + + if len(be.SourceEndpoints) != 1 { + t.Fatal("Only one session should be used, found", be.SourceEndpoints) + } +} diff --git a/internal/testutils/smtp_server.go b/internal/testutils/smtp_server.go index 504a87a..ddec5e9 100644 --- a/internal/testutils/smtp_server.go +++ b/internal/testutils/smtp_server.go @@ -28,6 +28,8 @@ type SMTPMessage struct { type SMTPBackend struct { Messages []*SMTPMessage MailFromCounter int + SessionCounter int + SourceEndpoints map[string]struct{} AuthErr error MailErr error @@ -40,6 +42,11 @@ func (be *SMTPBackend) Login(state *smtp.ConnectionState, username, password str if be.AuthErr != nil { return nil, be.AuthErr } + be.SessionCounter++ + if be.SourceEndpoints == nil { + be.SourceEndpoints = make(map[string]struct{}) + } + be.SourceEndpoints[state.RemoteAddr.String()] = struct{}{} return &session{ backend: be, user: username, @@ -49,6 +56,11 @@ func (be *SMTPBackend) Login(state *smtp.ConnectionState, username, password str } func (be *SMTPBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { + be.SessionCounter++ + if be.SourceEndpoints == nil { + be.SourceEndpoints = make(map[string]struct{}) + } + be.SourceEndpoints[state.RemoteAddr.String()] = struct{}{} return &session{backend: be, state: state}, nil }