Generalize message flow restrictions

Set of flow restrictions is represented as a "limits" module instance
that can be either created inline via "limits" directive in some modules
(including "remote" target and "smtp" endpoint) or defined globally and
referenced in configuration of modules mentioned above.

This permits a variety of use cases, including shared and separate
counters for various endpoints and also "modules group" style sharing
described in #195.
This commit is contained in:
fox.cpp 2020-02-15 14:03:35 +03:00
parent 100ed13784
commit c3ebbb05a0
No known key found for this signature in database
GPG key ID: E76D97CCEDE90B6C
13 changed files with 480 additions and 149 deletions

213
internal/limits/limits.go Normal file
View file

@ -0,0 +1,213 @@
// Package limit provides a module object that can be used to restrict the
// concurrency and rate of the messages flow globally or on per-source,
// per-destination basis.
//
// Note, all domain inputs are interpreted with the assumption they are already
// normalized.
//
// Low-level components are available in the limiters/ subpackage.
package limits
import (
"context"
"net"
"strconv"
"time"
"github.com/foxcpp/maddy/internal/config"
"github.com/foxcpp/maddy/internal/limits/limiters"
"github.com/foxcpp/maddy/internal/module"
)
type Group struct {
instName string
global limiters.MultiLimit
ip *limiters.BucketSet // BucketSet of MultiLimit
source *limiters.BucketSet // BucketSet of MultiLimit
dest *limiters.BucketSet // BucketSet of MultiLimit
}
func New(_, instName string, _, _ []string) (module.Module, error) {
return &Group{
instName: instName,
}, nil
}
func (g *Group) Init(cfg *config.Map) error {
var (
globalL []limiters.L
ipL []func() limiters.L
sourceL []func() limiters.L
destL []func() limiters.L
)
for _, child := range cfg.Block.Children {
if len(child.Args) < 1 {
return config.NodeErr(&child, "at least two arguments are required")
}
var (
ctor func() limiters.L
err error
)
switch kind := child.Args[0]; kind {
case "rate":
ctor, err = rateCtor(cfg, child.Args[1:])
case "concurrency":
ctor, err = concurrencyCtor(cfg, child.Args[1:])
default:
return config.NodeErr(&child, "unknown limit kind: %v", kind)
}
if err != nil {
return err
}
switch scope := child.Name; scope {
case "all":
globalL = append(globalL, ctor())
case "ip":
ipL = append(ipL, ctor)
case "source":
sourceL = append(sourceL, ctor)
case "destination":
destL = append(destL, ctor)
default:
return config.NodeErr(&child, "unknown limit scope: %v", scope)
}
}
// 20010 is slightly higher than the default max. recipients count in
// endpoint/smtp.
g.global = limiters.MultiLimit{Wrapped: globalL}
if len(ipL) != 0 {
g.ip = limiters.NewBucketSet(func() limiters.L {
l := make([]limiters.L, 0, len(ipL))
for _, ctor := range ipL {
l = append(l, ctor())
}
return &limiters.MultiLimit{Wrapped: l}
}, 1*time.Minute, 20010)
}
if len(sourceL) != 0 {
g.source = limiters.NewBucketSet(func() limiters.L {
l := make([]limiters.L, 0, len(sourceL))
for _, ctor := range sourceL {
l = append(l, ctor())
}
return &limiters.MultiLimit{Wrapped: l}
}, 1*time.Minute, 20010)
}
if len(destL) != 0 {
g.dest = limiters.NewBucketSet(func() limiters.L {
l := make([]limiters.L, 0, len(sourceL))
for _, ctor := range sourceL {
l = append(l, ctor())
}
return &limiters.MultiLimit{Wrapped: l}
}, 1*time.Minute, 20010)
}
return nil
}
func rateCtor(cfg *config.Map, args []string) (func() limiters.L, error) {
period := 1 * time.Second
burst := 0
switch len(args) {
case 2:
var err error
period, err = time.ParseDuration(args[1])
if err != nil {
return nil, cfg.MatchErr("%v", err)
}
case 1:
var err error
burst, err = strconv.Atoi(args[0])
if err != nil {
return nil, cfg.MatchErr("%v", err)
}
case 0:
return nil, cfg.MatchErr("at least burst size is needed")
}
return func() limiters.L {
return limiters.NewRate(burst, period)
}, nil
}
func concurrencyCtor(cfg *config.Map, args []string) (func() limiters.L, error) {
if len(args) != 1 {
return nil, cfg.MatchErr("max concurrency value is needed")
}
max, err := strconv.Atoi(args[0])
if err != nil {
return nil, cfg.MatchErr("%v", err)
}
return func() limiters.L {
return limiters.NewSemaphore(max)
}, nil
}
func (g *Group) TakeMsg(ctx context.Context, addr net.IP, sourceDomain string) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := g.global.TakeContext(ctx); err != nil {
return err
}
if g.ip != nil {
if err := g.ip.TakeContext(ctx, addr.String()); err != nil {
g.global.Release()
return err
}
}
if g.source != nil {
if err := g.source.TakeContext(ctx, sourceDomain); err != nil {
g.global.Release()
g.ip.Release(addr.String())
return err
}
}
return nil
}
func (g *Group) TakeDest(ctx context.Context, domain string) error {
if g.dest == nil {
return nil
}
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return g.dest.TakeContext(ctx, domain)
}
func (g *Group) ReleaseMsg(addr net.IP, sourceDomain string) {
g.global.Release()
if g.ip != nil {
g.ip.Release(addr.String())
}
if g.source != nil {
g.source.Release(sourceDomain)
}
}
func (g *Group) ReleaseDest(domain string) {
if g.dest == nil {
return
}
g.dest.Release(domain)
}
func (g *Group) Name() string {
return "limits"
}
func (g *Group) InstanceName() string {
return g.instName
}
func init() {
module.Register("limits", New)
}