Rework config directives iteration

Create more high-level wrapper (config.Map) instead of walking syntax
tree directly.
This commit is contained in:
fox.cpp 2019-03-25 23:16:28 +03:00 committed by emersion
parent 9c27d4416f
commit 84d150a00f
12 changed files with 590 additions and 352 deletions

View file

@ -26,7 +26,7 @@ func main() {
}
defer f.Close()
config, err := config.ReadConfig(f, absCfg)
config, err := config.Read(f, absCfg)
if err != nil {
log.Fatalf("Cannot parse %q: %v", configpath, err)
}

153
config.go Normal file
View file

@ -0,0 +1,153 @@
package maddy
import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"github.com/emersion/maddy/config"
"github.com/emersion/maddy/module"
)
/*
Config matchers for module interfaces.
*/
func authProvider(modName string) (module.AuthProvider, error) {
modObj := module.GetInstance(modName)
if modObj == nil {
return nil, fmt.Errorf("unknown auth. provider instance: %s", modName)
}
provider, ok := modObj.(module.AuthProvider)
if !ok {
return nil, fmt.Errorf("module %s doesn't implements auth. provider interface", modObj.Name())
}
return provider, nil
}
func authDirective(m *config.Map, node *config.Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
modObj, err := authProvider(node.Args[0])
if err != nil {
return nil, m.MatchErr("%s", err.Error())
}
return modObj, nil
}
func storageBackend(modName string) (module.Storage, error) {
modObj := module.GetInstance(modName)
if modObj == nil {
return nil, fmt.Errorf("unknown storage backend instance: %s", modName)
}
backend, ok := modObj.(module.Storage)
if !ok {
return nil, fmt.Errorf("module %s doesn't implements storage interface", modObj.Name())
}
return backend, nil
}
func storageDirective(m *config.Map, node *config.Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
modObj, err := storageBackend(node.Args[0])
if err != nil {
return nil, m.MatchErr("%s", err.Error())
}
return modObj, nil
}
func deliveryTarget(modName string) (module.DeliveryTarget, error) {
mod := module.GetInstance(modName)
if mod == nil {
return nil, fmt.Errorf("unknown delivery target instance: %s", modName)
}
target, ok := mod.(module.DeliveryTarget)
if !ok {
return nil, fmt.Errorf("module %s doesn't implements delivery target interface", mod.Name())
}
return target, nil
}
func deliveryDirective(m *config.Map, node *config.Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
modObj, err := deliveryTarget(node.Args[0])
if err != nil {
return nil, m.MatchErr("%s", err.Error())
}
return modObj, nil
}
func defaultAuthProvider() (interface{}, error) {
res, err := authProvider("default_auth")
if err != nil {
res, err = authProvider("default")
if err != nil {
return nil, errors.New("missing default auth. provider, must set custom")
}
}
return res, nil
}
func defaultStorage() (interface{}, error) {
res, err := storageBackend("default_storage")
if err != nil {
res, err = storageBackend("default")
if err != nil {
return nil, errors.New("missing default storage backend, must set custom")
}
}
return res, nil
}
func errorsDirective(m *config.Map, node *config.Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
output := node.Args[0]
var w io.Writer
switch output {
case "off":
w = ioutil.Discard
case "stdout":
w = os.Stdout
case "stderr":
w = os.Stderr
default:
f, err := os.OpenFile(output, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
if err != nil {
return nil, err
}
w = f
}
return log.New(w, "imap ", log.LstdFlags), nil
}

316
config/map.go Normal file
View file

@ -0,0 +1,316 @@
package config
import (
"errors"
"fmt"
"reflect"
"strconv"
)
type matcher struct {
name string
required bool
inheritGlobal bool
defaultVal func() (interface{}, error)
mapper func(*Map, *Node) (interface{}, error)
store reflect.Value
}
func (m *matcher) match(map_ *Map, node *Node) error {
val, err := m.mapper(map_, node)
if err != nil {
return err
}
valRefl := reflect.ValueOf(val)
// Convert untyped nil into typed nil. Otherwise it will panic.
if !valRefl.IsValid() {
valRefl = reflect.Zero(m.store.Type())
}
m.store.Set(valRefl)
return nil
}
type Map struct {
allowUnknown bool
// Set to currently processed node when defaultVal or mapper functions are
// called.
curNode *Node
entries map[string]matcher
}
// MatchErr returns error with formatted message, if called from defaultVal or
// mapper functions - message will be prepended with information about
// processed config node.
func (m *Map) MatchErr(format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
if m.curNode != nil {
return fmt.Errorf("%s:%d %s: %s", m.curNode.File, m.curNode.Line, m.curNode.Name, msg)
} else {
return errors.New(msg)
}
}
func (m *Map) AllowUnknown() {
m.allowUnknown = true
}
//ffs, give me generics already
func (m *Map) Bool(name string, inheritGlobal bool, store *bool) {
m.Custom(name, inheritGlobal, false, func() (interface{}, error) {
return false, nil
}, func(m *Map, node *Node) (interface{}, error) {
if len(node.Args) != 0 {
return nil, m.MatchErr("unexpected arguments")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
return true, nil
}, store)
}
func (m *Map) String(name string, inheritGlobal, required bool, defaultVal string, store *string) {
m.Custom(name, inheritGlobal, required, func() (interface{}, error) {
return defaultVal, nil
}, func(m *Map, node *Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
return node.Args[0], nil
}, store)
}
func (m *Map) Int(name string, inheritGlobal, required bool, defaultVal int, store *int) {
m.Custom(name, inheritGlobal, required, func() (interface{}, error) {
return defaultVal, nil
}, func(m *Map, node *Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
i, err := strconv.Atoi(node.Args[0])
if err != nil {
return nil, m.MatchErr("invalid integer: %s", node.Args[0])
}
return i, nil
}, store)
}
func (m *Map) UInt(name string, inheritGlobal, required bool, defaultVal uint, store *uint) {
m.Custom(name, inheritGlobal, required, func() (interface{}, error) {
return defaultVal, nil
}, func(m *Map, node *Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
i, err := strconv.ParseUint(node.Args[0], 10, 32)
if err != nil {
return nil, m.MatchErr("invalid integer: %s", node.Args[0])
}
return uint(i), nil
}, store)
}
func (m *Map) Int32(name string, inheritGlobal, required bool, defaultVal int32, store *int32) {
m.Custom(name, inheritGlobal, required, func() (interface{}, error) {
return defaultVal, nil
}, func(m *Map, node *Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
i, err := strconv.ParseInt(node.Args[0], 10, 32)
if err != nil {
return nil, m.MatchErr("invalid integer: %s", node.Args[0])
}
return int32(i), nil
}, store)
}
func (m *Map) UInt32(name string, inheritGlobal, required bool, defaultVal uint32, store *uint32) {
m.Custom(name, inheritGlobal, required, func() (interface{}, error) {
return defaultVal, nil
}, func(m *Map, node *Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
i, err := strconv.ParseUint(node.Args[0], 10, 32)
if err != nil {
return nil, m.MatchErr("invalid integer: %s", node.Args[0])
}
return uint32(i), nil
}, store)
}
func (m *Map) Int64(name string, inheritGlobal, required bool, defaultVal int64, store *int64) {
m.Custom(name, inheritGlobal, required, func() (interface{}, error) {
return defaultVal, nil
}, func(m *Map, node *Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
i, err := strconv.ParseInt(node.Args[0], 10, 64)
if err != nil {
return nil, m.MatchErr("invalid integer: %s", node.Args[0])
}
return i, nil
}, store)
}
func (m *Map) UInt64(name string, inheritGlobal, required bool, defaultVal uint64, store *uint64) {
m.Custom(name, inheritGlobal, required, func() (interface{}, error) {
return defaultVal, nil
}, func(m *Map, node *Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
if len(node.Children) != 0 {
return nil, m.MatchErr("can't declare block here")
}
i, err := strconv.ParseUint(node.Args[0], 10, 64)
if err != nil {
return nil, m.MatchErr("invalid integer: %s", node.Args[0])
}
return i, nil
}, store)
}
func (m *Map) Float(name string, inheritGlobal, required bool, defaultVal float64, store *float64) {
m.Custom(name, inheritGlobal, required, func() (interface{}, error) {
return defaultVal, nil
}, func(m *Map, node *Node) (interface{}, error) {
if len(node.Args) != 1 {
return nil, m.MatchErr("expected 1 argument")
}
f, err := strconv.ParseFloat(node.Args[0], 64)
if err != nil {
return nil, m.MatchErr("invalid float: %s", node.Args[0])
}
return f, nil
}, store)
}
// Custom maps configuration directive
func (m *Map) Custom(name string, inheritGlobal, required bool, defaultVal func() (interface{}, error), mapper func(*Map, *Node) (interface{}, error), store interface{}) {
if m.entries == nil {
m.entries = make(map[string]matcher)
}
val := reflect.ValueOf(store).Elem()
if !val.CanSet() {
panic("Map.Custom: store argument must be settable (a pointer)")
}
if _, ok := m.entries[name]; ok {
panic("Map.Custom: duplicate matcher")
}
m.entries[name] = matcher{
name: name,
inheritGlobal: inheritGlobal,
required: required,
defaultVal: defaultVal,
mapper: mapper,
store: val,
}
}
func (m *Map) Process(globalCfg map[string]Node, tree *Node) (unmatched []Node, err error) {
unmatched = make([]Node, 0, len(tree.Children))
matched := make(map[string]bool)
for _, subnode := range tree.Children {
m.curNode = &subnode
if matched[subnode.Name] {
return nil, m.MatchErr("duplicate directive")
}
matcher, ok := m.entries[subnode.Name]
if !ok {
if !m.allowUnknown {
return nil, m.MatchErr("unexpected directive")
}
unmatched = append(unmatched, subnode)
continue
}
if err := matcher.match(m, m.curNode); err != nil {
return nil, err
}
matched[subnode.Name] = true
}
m.curNode = nil
for _, matcher := range m.entries {
if matched[matcher.name] {
continue
}
globalNode, ok := globalCfg[matcher.name]
if matcher.inheritGlobal && ok {
m.curNode = &globalNode
if err := matcher.match(m, m.curNode); err != nil {
m.curNode = nil
return nil, err
}
m.curNode = nil
} else if !matcher.required {
if matcher.defaultVal == nil {
continue
}
val, err := matcher.defaultVal()
if err != nil {
return nil, err
}
if val == nil {
return nil, m.MatchErr("missing required directive: %s", matcher.name)
}
valRefl := reflect.ValueOf(val)
// Convert untyped nil into typed nil. Otherwise it will panic.
if !valRefl.IsValid() {
valRefl = reflect.Zero(matcher.store.Type())
}
matcher.store.Set(valRefl)
continue
} else {
return nil, m.MatchErr("missing required directive: %s", matcher.name)
}
}
return unmatched, nil
}

View file

@ -10,11 +10,25 @@ import (
"github.com/mholt/caddy/caddyfile"
)
// name arg0 arg1 {
// children0
// children1
// }
type Node struct {
Name string
Args []string
// First string at node's line.
Name string
// Any strings placed after node name.
Args []string
// If node is a block - all nodes inside the block. Can be nil.
Children []Node
Snippet bool
// Whether current parsed node is a snippet. Always false for all nodes
// returned from Read because snippets are expanded before it returns.
Snippet bool
File string
Line int
}
type parseContext struct {
@ -27,6 +41,8 @@ type parseContext struct {
func (ctx *parseContext) readCfgNode() (Node, error) {
node := Node{}
node.File = ctx.File()
node.Line = ctx.Line()
if ctx.Val() == "{" {
ctx.parens++
return node, ctx.SyntaxErr("block header")
@ -191,7 +207,7 @@ func readCfgTree(r io.Reader, location string, depth int) (nodes []Node, snips m
return root.Children, ctx.snippets, nil
}
func ReadConfig(r io.Reader, location string) (nodes []Node, err error) {
func Read(r io.Reader, location string) (nodes []Node, err error) {
nodes, _, err = readCfgTree(r, location, 1)
return
}

View file

@ -9,7 +9,7 @@ import (
var defaultDriver, defaultDsn string
func initDefaultStorage(globalCfg map[string][]string) {
func initDefaultStorage(globalCfg map[string]config.Node) {
if defaultDriver == "" {
defaultDriver = "sqlite3"
}
@ -17,7 +17,7 @@ func initDefaultStorage(globalCfg map[string][]string) {
defaultDsn = "maddy.db"
}
mod, err := NewSQLMail("default", globalCfg, config.Node{
mod, err := NewSQLMail("default", globalCfg, config.Node{ //TODO!
Name: "sqlmail",
Args: []string{"default"},
Children: []config.Node{

132
imap.go
View file

@ -4,8 +4,6 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
@ -32,92 +30,31 @@ type IMAPEndpoint struct {
listenersWg sync.WaitGroup
}
func NewIMAPEndpoint(instName string, globalCfg map[string][]string, cfg config.Node) (module.Module, error) {
func NewIMAPEndpoint(instName string, globalCfg map[string]config.Node, rawCfg config.Node) (module.Module, error) {
endp := new(IMAPEndpoint)
endp.name = instName
var (
err error
errorArgs []string
tlsArgs []string
errorLog *log.Logger
insecureAuth bool
ioDebug bool
)
for _, entry := range cfg.Children {
switch entry.Name {
case "auth":
endp.Auth, err = authProvider(entry.Args)
if err != nil {
return nil, err
}
case "storage":
endp.Store, err = storageBackend(entry.Args)
if err != nil {
return nil, err
}
case "tls":
tlsArgs = entry.Args
case "insecure_auth":
log.Printf("imap %s: authentication over unencrypted connections is allowed, this is insecure configuration and should be used only for testing!", instName)
insecureAuth = true
case "io_debug":
log.Printf("imap %s: I/O debugging is on!", instName)
ioDebug = true
case "errors":
errorArgs = entry.Args
default:
return nil, fmt.Errorf("unknown config directive: %s", entry.Name)
}
}
cfg := config.Map{}
cfg.Custom("auth", false, true, defaultAuthProvider, authDirective, &endp.Auth)
cfg.Custom("storage", false, true, defaultStorage, storageDirective, &endp.Store)
cfg.Custom("tls", true, true, nil, tlsDirective, &endp.tlsConfig)
cfg.Bool("insecure_auth", false, &insecureAuth)
cfg.Bool("io_debug", false, &insecureAuth)
cfg.Custom("errors", false, false, func() (interface{}, error) {
return log.New(os.Stderr, "imap ", log.LstdFlags), nil
}, errorsDirective, &errorLog)
if globalTls, ok := globalCfg["tls"]; tlsArgs == nil && ok {
tlsArgs = globalTls
}
if tlsArgs == nil {
return nil, errors.New("TLS is not configured")
}
endp.tlsConfig = new(tls.Config)
if err := setTLS(tlsArgs, &endp.tlsConfig); err != nil {
if _, err := cfg.Process(globalCfg, &rawCfg); err != nil {
return nil, err
}
// Print warning only if TLS is in per-module configuration.
// Otherwise we will print it when reading global config.
if endp.tlsConfig == nil && globalCfg["tls"] == nil {
log.Printf("imap %s: TLS is disabled, this is insecure configuration and should be used only for testing!", endp.name)
}
if endp.tlsConfig == nil {
insecureAuth = true
}
if endp.Auth == nil {
endp.Auth, err = authProvider([]string{"default_auth"})
if err != nil {
endp.Auth, err = authProvider([]string{"default"})
if err != nil {
return nil, errors.New("missing default auth. provider, must set custom")
}
}
log.Printf("imap %s: using %s auth. provider (%s %s)",
instName, endp.Auth.(module.Module).InstanceName(),
endp.Auth.(module.Module).Name(), endp.Auth.(module.Module).Version(),
)
}
if endp.Store == nil {
endp.Store, err = storageBackend([]string{"default_storage"})
if err != nil {
endp.Store, err = storageBackend([]string{"default"})
if err != nil {
return nil, errors.New("missing default storage backend, must set custom")
}
}
log.Printf("imap %s: using %s storage backend (%s %s)",
instName, endp.Store.(module.Module).InstanceName(),
endp.Store.(module.Module).Name(), endp.Store.(module.Module).Version(),
)
}
addresses := make([]Address, 0, len(cfg.Args))
for _, addr := range cfg.Args {
addresses := make([]Address, 0, len(rawCfg.Args))
for _, addr := range rawCfg.Args {
saddr, err := standardizeAddress(addr)
if err != nil {
return nil, fmt.Errorf("invalid address: %s", instName)
@ -134,12 +71,6 @@ func NewIMAPEndpoint(instName string, globalCfg map[string][]string, cfg config.
if ioDebug {
endp.serv.Debug = os.Stderr
}
if errorArgs != nil {
err := setIMAPErrors(instName, errorArgs, endp.serv)
if err != nil {
return nil, err
}
}
if err := endp.enableExtensions(); err != nil {
return nil, err
@ -173,6 +104,14 @@ func NewIMAPEndpoint(instName string, globalCfg map[string][]string, cfg config.
}()
}
if endp.serv.AllowInsecureAuth {
log.Printf("imap %s: authentication over unencrypted connections is allowed, this is insecure configuration and should be used only for testing!", endp.name)
}
if endp.serv.TLSConfig == nil {
log.Printf("imap %s: TLS is disabled, this is insecure configuration and should be used only for testing!", endp.name)
endp.serv.AllowInsecureAuth = true
}
return endp, nil
}
@ -226,33 +165,6 @@ func (endp *IMAPEndpoint) enableExtensions() error {
return nil
}
func setIMAPErrors(instName string, args []string, s *imapserver.Server) error {
if len(args) != 1 {
return fmt.Errorf("missing errors directive values")
}
output := args[0]
var w io.Writer
switch output {
case "off":
w = ioutil.Discard
case "stdout":
w = os.Stdout
case "stderr":
w = os.Stderr
default:
f, err := os.OpenFile(output, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
if err != nil {
return err
}
w = f
}
s.ErrorLog = log.New(w, "imap "+instName+" ", log.LstdFlags)
return nil
}
func init() {
module.Register("imap", NewIMAPEndpoint)
}

View file

@ -1,13 +1,11 @@
package maddy
import (
"errors"
"fmt"
"io"
"log"
"os"
"os/signal"
"strings"
"syscall"
"github.com/emersion/maddy/config"
@ -16,18 +14,13 @@ import (
func Start(cfg []config.Node) error {
instances := make([]module.Module, 0, len(cfg))
globalCfg := make(map[string][]string)
globalCfg := make(map[string]config.Node)
defaultPresent := false
for _, block := range cfg {
switch block.Name {
case "tls":
if len(block.Args) == 1 && block.Args[0] == "off" {
log.Println("TLS is disabled, this is insecure configuration and should be used only for testing!")
}
fallthrough
case "hostname":
globalCfg[block.Name] = block.Args
case "tls", "hostname":
globalCfg[block.Name] = block
continue
default:
if len(block.Args) != 0 && block.Args[0] == "default" {
@ -57,18 +50,11 @@ func Start(cfg []config.Node) error {
factory := module.GetMod(modName)
if factory == nil {
return fmt.Errorf("unknown module: %s", modName)
return fmt.Errorf("%s:%d: unknown module: %s", block.File, block.Line, modName)
}
if mod := module.GetInstance(instName); mod != nil {
if !strings.HasPrefix(instName, "default") {
return fmt.Errorf("module instance named %s already exists", instName)
}
// Clean up default module before replacing it.
if closer, ok := mod.(io.Closer); ok {
closer.Close()
}
return fmt.Errorf("%s:%d: module instance named %s already exists", block.File, block.Line, instName)
}
inst, err := factory(instName, globalCfg, block)
@ -99,57 +85,3 @@ func Start(cfg []config.Node) error {
return nil
}
func authProvider(args []string) (module.AuthProvider, error) {
if len(args) != 1 {
return nil, errors.New("auth: expected 1 argument")
}
authName := args[0]
authMod := module.GetInstance(authName)
if authMod == nil {
return nil, fmt.Errorf("unknown auth. provider instance: %s", authName)
}
provider, ok := authMod.(module.AuthProvider)
if !ok {
return nil, fmt.Errorf("module %s doesn't implements auth. provider interface", authMod.Name())
}
return provider, nil
}
func storageBackend(args []string) (module.Storage, error) {
if len(args) != 1 {
return nil, errors.New("storage: expected 1 argument")
}
authName := args[0]
authMod := module.GetInstance(authName)
if authMod == nil {
return nil, fmt.Errorf("unknown storage backend instance: %s", authName)
}
provider, ok := authMod.(module.Storage)
if !ok {
return nil, fmt.Errorf("module %s doesn't implements storage interface", authMod.Name())
}
return provider, nil
}
func deliveryTarget(args []string) (module.DeliveryTarget, error) {
if len(args) != 1 {
return nil, errors.New("delivery: expected 1 argument")
}
modName := args[0]
mod := module.GetInstance(modName)
if mod == nil {
return nil, fmt.Errorf("unknown delivery target instance: %s", modName)
}
target, ok := mod.(module.DeliveryTarget)
if !ok {
return nil, fmt.Errorf("module %s doesn't implements delivery target interface", mod.Name())
}
return target, nil
}

View file

@ -39,4 +39,4 @@ type Module interface {
}
// FuncNewModule is function that creates new instance of module with specified name.
type FuncNewModule func(name string, globalCfg map[string][]string, cfg config.Node) (Module, error)
type FuncNewModule func(name string, globalCfg map[string]config.Node, cfg config.Node) (Module, error)

150
smtp.go
View file

@ -10,7 +10,6 @@ import (
"net"
"net/mail"
"os"
"strconv"
"strings"
"sync"
@ -35,7 +34,7 @@ func (u SMTPUser) Send(from string, to []string, r io.Reader) error {
SrcTLSState: u.state.TLS,
SrcHostname: u.state.Hostname,
SrcAddr: u.state.RemoteAddr,
OurHostname: u.endp.domain,
OurHostname: u.endp.serv.Domain,
Ctx: make(map[string]interface{}),
}
@ -89,12 +88,10 @@ type SMTPEndpoint struct {
Auth module.AuthProvider
serv *smtp.Server
name string
domain string
listeners []net.Listener
pipeline []SMTPPipelineStep
submission bool
tlsConfig *tls.Config
listenersWg sync.WaitGroup
}
@ -111,7 +108,7 @@ func (endp *SMTPEndpoint) Version() string {
return VersionStr
}
func NewSMTPEndpoint(instName string, globalCfg map[string][]string, cfg config.Node) (module.Module, error) {
func NewSMTPEndpoint(instName string, globalCfg map[string]config.Node, cfg config.Node) (module.Module, error) {
endp := new(SMTPEndpoint)
endp.name = instName
endp.serv = smtp.NewServer(endp)
@ -129,68 +126,54 @@ func NewSMTPEndpoint(instName string, globalCfg map[string][]string, cfg config.
addresses = append(addresses, saddr)
}
if err := endp.setupListeners(addresses, endp.tlsConfig); err != nil {
if err := endp.setupListeners(addresses); err != nil {
for _, l := range endp.listeners {
l.Close()
}
return nil, err
}
if endp.serv.AllowInsecureAuth {
log.Printf("smtp %s: authentication over unencrypted connections is allowed, this is insecure configuration and should be used only for testing!", endp.name)
}
if endp.serv.TLSConfig == nil && !endp.serv.LMTP {
log.Printf("smtp %s: TLS is disabled, this is insecure configuration and should be used only for testing!", endp.name)
endp.serv.AllowInsecureAuth = true
}
return endp, nil
}
func (endp *SMTPEndpoint) setConfig(globalCfg map[string][]string, cfg config.Node) error {
func (endp *SMTPEndpoint) setConfig(globalCfg map[string]config.Node, rawCfg config.Node) error {
var (
err error
tlsArgs []string
ioDebug bool
localDeliveryDefault string
localDeliveryOpts map[string]string
remoteDeliveryDefault string
remoteDeliveryOpts map[string]string
)
maxIdle := -1
maxMsgBytes := -1
insecureAuth := false
ioDebug := false
for _, entry := range cfg.Children {
cfg := config.Map{}
cfg.Custom("auth", false, true, defaultAuthProvider, authDirective, &endp.Auth)
cfg.String("hostname", true, true, "", &endp.serv.Domain)
cfg.Int("max_idle", false, false, 60, &endp.serv.MaxIdleSeconds)
cfg.Int("max_message_size", false, false, 32*1024*1024, &endp.serv.MaxMessageBytes)
cfg.Int("max_recipients", false, false, 255, &endp.serv.MaxRecipients)
cfg.Custom("tls", true, true, nil, tlsDirective, &endp.serv.TLSConfig)
cfg.Bool("insecure_auth", false, &endp.serv.AllowInsecureAuth)
cfg.Bool("io_debug", false, &ioDebug)
cfg.Bool("submission", false, &endp.submission)
cfg.AllowUnknown()
remainingDirs, err := cfg.Process(globalCfg, &rawCfg)
if err != nil {
return err
}
for _, entry := range remainingDirs {
switch entry.Name {
case "auth":
endp.Auth, err = authProvider(entry.Args)
if err != nil {
return err
}
case "hostname":
if len(entry.Args) != 1 {
return errors.New("hostname: expected 1 argument")
}
endp.domain = entry.Args[0]
case "max_idle":
if len(entry.Args) != 1 {
return errors.New("max_idle: expected 1 argument")
}
maxIdle, err = strconv.Atoi(entry.Args[0])
if err != nil {
return errors.New("max_idle: invalid integer value")
}
case "max_message_size":
if len(entry.Args) != 1 {
return errors.New("max_message_size: expected 1 argument")
}
maxMsgBytes, err = strconv.Atoi(entry.Args[0])
if err != nil {
return errors.New("max_message_size: invalid integer value")
}
case "tls":
tlsArgs = entry.Args
case "insecure_auth":
log.Printf("smtp %s: authentication over unencrypted connections is allowed, this is insecure configuration and should be used only for testing!", endp.name)
insecureAuth = true
case "io_debug":
log.Printf("smtp %v: I/O debugging is on!\n", endp.name)
ioDebug = true
case "submission":
endp.submission = true
case "local_delivery":
if len(entry.Args) == 0 {
return errors.New("local_delivery: expected at least 1 argument")
@ -226,37 +209,6 @@ func (endp *SMTPEndpoint) setConfig(globalCfg map[string][]string, cfg config.No
}
}
if globalTls, ok := globalCfg["tls"]; tlsArgs == nil && ok {
tlsArgs = globalTls
}
if tlsArgs == nil {
return errors.New("TLS is not configured")
}
endp.tlsConfig = new(tls.Config)
if err := setTLS(tlsArgs, &endp.tlsConfig); err != nil {
return err
}
// Print warning only if TLS is in per-module configuration.
// Otherwise we will print it when reading global config.
if endp.tlsConfig == nil && globalCfg["tls"] == nil {
log.Printf("smtp %s: TLS is disabled, this is insecure configuration and should be used only for testing!", endp.name)
}
if endp.tlsConfig == nil {
insecureAuth = true
} else {
endp.serv.TLSConfig = endp.tlsConfig
}
if globalDomain, ok := globalCfg["hostname"]; endp.domain == "" && ok {
if len(globalDomain) != 1 {
return errors.New("hostname: expected 1 argument")
}
endp.domain = globalDomain[0]
}
if endp.domain == "" {
return fmt.Errorf("hostname is not set")
}
if len(endp.pipeline) == 0 {
err := endp.setDefaultPipeline(localDeliveryDefault, remoteDeliveryDefault, localDeliveryOpts, remoteDeliveryOpts)
if err != nil {
@ -267,28 +219,6 @@ func (endp *SMTPEndpoint) setConfig(globalCfg map[string][]string, cfg config.No
endp.pipeline = append([]SMTPPipelineStep{submissionPrepareStep{}, requireAuthStep{}}, endp.pipeline...)
}
if endp.Auth == nil {
endp.Auth, err = authProvider([]string{"default_auth"})
if err != nil {
endp.Auth, err = authProvider([]string{"default"})
if err != nil {
return errors.New("missing default auth. provider, must set custom")
}
}
log.Printf("smtp %s: using %s auth. provider (%s %s)",
endp.name, endp.Auth.(module.Module).InstanceName(),
endp.Auth.(module.Module).Name(), endp.Auth.(module.Module).Version(),
)
}
endp.serv.AllowInsecureAuth = insecureAuth
endp.serv.Domain = endp.domain
if maxMsgBytes != -1 {
endp.serv.MaxMessageBytes = maxMsgBytes
}
if maxIdle != -1 {
endp.serv.MaxIdleSeconds = maxIdle
}
if ioDebug {
endp.serv.Debug = os.Stderr
}
@ -296,7 +226,7 @@ func (endp *SMTPEndpoint) setConfig(globalCfg map[string][]string, cfg config.No
return nil
}
func (endp *SMTPEndpoint) setupListeners(addresses []Address, tlsConf *tls.Config) error {
func (endp *SMTPEndpoint) setupListeners(addresses []Address) error {
var smtpUsed, lmtpUsed bool
for _, addr := range addresses {
if addr.Scheme == "smtp" || addr.Scheme == "smtps" {
@ -321,10 +251,10 @@ func (endp *SMTPEndpoint) setupListeners(addresses []Address, tlsConf *tls.Confi
log.Printf("smtp: listening on %v\n", addr)
if addr.IsTLS() {
if endp.tlsConfig == nil {
if endp.serv.TLSConfig == nil {
return errors.New("can't bind on SMTPS endpoint without TLS configuration")
}
l = tls.NewListener(l, tlsConf)
l = tls.NewListener(l, endp.serv.TLSConfig)
}
endp.listeners = append(endp.listeners, l)
@ -347,21 +277,19 @@ func (endp *SMTPEndpoint) setupListeners(addresses []Address, tlsConf *tls.Confi
}
func (endp *SMTPEndpoint) setDefaultPipeline(localDeliveryName, remoteDeliveryName string, localOpts, remoteOpts map[string]string) error {
log.Printf("smtp %s: using default pipeline configuration (submission=%v)", endp.InstanceName(), endp.submission)
var err error
var localDelivery module.DeliveryTarget
if localDeliveryName == "" {
localDelivery, err = deliveryTarget([]string{"default_local_delivery"})
localDelivery, err = deliveryTarget("default_local_delivery")
if err != nil {
localDelivery, err = deliveryTarget([]string{"default"})
localDelivery, err = deliveryTarget("default")
if err != nil {
return err
}
}
localOpts = map[string]string{"local_only": ""}
} else {
localDelivery, err = deliveryTarget([]string{localDeliveryName})
localDelivery, err = deliveryTarget(localDeliveryName)
if err != nil {
return err
}
@ -372,7 +300,7 @@ func (endp *SMTPEndpoint) setDefaultPipeline(localDeliveryName, remoteDeliveryNa
remoteDeliveryName = "default_remote_delivery"
remoteOpts = map[string]string{"remote_only": ""}
}
remoteDelivery, err := deliveryTarget([]string{remoteDeliveryName})
remoteDelivery, err := deliveryTarget(remoteDeliveryName)
if err != nil {
return err
}

View file

@ -3,9 +3,7 @@ package maddy
import (
"bytes"
"errors"
"fmt"
"io"
"strconv"
"strings"
"time"
@ -33,51 +31,28 @@ func (sqlm *SQLMail) Version() string {
return sqlmail.VersionStr
}
func NewSQLMail(instName string, globalCfg map[string][]string, cfg config.Node) (module.Module, error) {
func NewSQLMail(instName string, globalCfg map[string]config.Node, rawCfg config.Node) (module.Module, error) {
var driver string
var dsn string
var appendlimitSet bool
appendlimitVal := int64(-1)
opts := imapsqlmail.Opts{}
for _, entry := range cfg.Children {
switch entry.Name {
case "driver":
if len(entry.Args) != 1 {
return nil, fmt.Errorf("driver: expected 1 argument")
}
driver = entry.Args[0]
case "dsn":
if len(entry.Args) != 1 {
return nil, fmt.Errorf("dsn: expected 1 argument")
}
dsn = entry.Args[0]
case "appendlimit":
if len(entry.Args) != 1 {
return nil, errors.New("appendlimit: expected 1 argument")
}
lim, err := strconv.Atoi(entry.Args[0])
if lim == -1 {
continue
}
cfg := config.Map{}
cfg.String("driver", false, true, "", &driver)
cfg.String("dsn", false, true, "", &dsn)
cfg.Int64("appendlimit", false, false, 32*1024*1024, &appendlimitVal)
lim32 := uint32(lim)
if err != nil {
return nil, errors.New("appendlimit: invalid value")
}
opts.MaxMsgBytes = &lim32
appendlimitSet = true
default:
return nil, fmt.Errorf("unknown directive: %s", entry.Name)
}
if _, err := cfg.Process(globalCfg, &rawCfg); err != nil {
return nil, err
}
if !appendlimitSet {
lim := uint32(32 * 1024 * 1024) // 32 MiB
opts.MaxMsgBytes = &lim
if appendlimitVal == -1 {
opts.MaxMsgBytes = nil
} else {
opts.MaxMsgBytes = new(uint32)
*opts.MaxMsgBytes = uint32(appendlimitVal)
}
sqlm, err := imapsqlmail.NewBackend(driver, dsn, opts)
mod := SQLMail{

View file

@ -15,7 +15,7 @@ import (
"github.com/google/uuid"
)
func NewSubmissionEndpoint(instName string, globalCfg map[string][]string, cfg config.Node) (module.Module, error) {
func NewSubmissionEndpoint(instName string, globalCfg map[string]config.Node, cfg config.Node) (module.Module, error) {
cfg.Children = append(cfg.Children, config.Node{Name: "submission"})
return NewSMTPEndpoint(instName, globalCfg, cfg)
}

28
tls.go
View file

@ -10,29 +10,35 @@ import (
"math/big"
"net"
"time"
"github.com/emersion/maddy/config"
)
func setTLS(args []string, config **tls.Config) error {
switch len(args) {
func tlsDirective(m *config.Map, node *config.Node) (interface{}, error) {
switch len(node.Args) {
case 1:
switch args[0] {
switch node.Args[0] {
case "off":
*config = nil
return nil
return nil, nil
case "self_signed":
if err := makeSelfSignedCert(*config); err != nil {
return err
cfg := &tls.Config{}
if err := makeSelfSignedCert(cfg); err != nil {
return nil, err
}
return &cfg, nil
}
case 2:
if cert, err := tls.LoadX509KeyPair(args[0], args[1]); err != nil {
return err
cfg := tls.Config{}
if cert, err := tls.LoadX509KeyPair(node.Args[0], node.Args[1]); err != nil {
return nil, err
} else {
(*config).Certificates = append((*config).Certificates, cert)
cfg.Certificates = append(cfg.Certificates, cert)
}
default:
return nil, m.MatchErr("expected 1 or 2 arguments")
}
return nil
return nil, nil
}
func makeSelfSignedCert(config *tls.Config) error {