target/remote: Implement STARTTLS Everywhere list support

This commit is contained in:
fox.cpp 2019-12-26 15:08:41 +03:00
parent 21b589b5da
commit c0a73bc3d0
No known key found for this signature in database
GPG key ID: E76D97CCEDE90B6C
7 changed files with 587 additions and 17 deletions

View file

@ -3,10 +3,16 @@ package remote
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/foxcpp/go-mtasts"
"github.com/foxcpp/go-mtasts/preload"
"github.com/foxcpp/maddy/internal/config"
"github.com/foxcpp/maddy/internal/dns"
"github.com/foxcpp/maddy/internal/exterrors"
@ -77,8 +83,8 @@ type (
// CheckConn is called to check whether the policy permits to use this
// connection.
//
// tlsLevel contains the TLS security level estabilished by checks
// executed before.
// tlsLevel and mxLevel contain the TLS security level estabilished by
// checks executed before.
//
// domain is passed to the CheckConn to allow simpler implementation
// of stateless policy objects.
@ -86,7 +92,7 @@ type (
// If tlsState.HandshakeCompleted is false, TLS is not used. If
// tlsState.VerifiedChains is nil, InsecureSkipVerify was used (no
// ServerName or PKI check was done).
CheckConn(ctx context.Context, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error)
CheckConn(ctx context.Context, mxLevel MXLevel, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error)
// Reset cleans the internal object state for use with another message.
// newMsg may be nil if object is not needed anymore.
@ -213,7 +219,7 @@ func (c *mtastsDelivery) CheckMX(ctx context.Context, mxLevel MXLevel, domain, m
return MX_MTASTS, nil
}
func (c *mtastsDelivery) CheckConn(ctx context.Context, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error) {
func (c *mtastsDelivery) CheckConn(ctx context.Context, mxLevel MXLevel, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error) {
policyI, err := c.policyFut.GetContext(ctx)
if err != nil {
c.c.log.DebugMsg("MTA-STS error", "err", err)
@ -255,6 +261,277 @@ func (c *mtastsDelivery) Reset(msgMeta *module.MsgMetadata) {
}
}
var (
// Delay between list expiry and first update attempt.
preloadUpdateGrace = 5 * time.Minute
// Minimal time between preload list update attempts.
//
// This is adjusted for preloadUpdateGrace so we will the chance to make 10
// attempts to update list before it expires.
preloadUpdateCooldown = 30 * time.Second
)
type (
FuncPreloadList = func(*http.Client, preload.Source) (*preload.List, error)
stsPreloadPolicy struct {
l *preload.List
lLock sync.RWMutex
log log.Logger
updaterStop chan struct{}
sourcePath string
client *http.Client
listDownload FuncPreloadList
enforceTesting bool
}
)
func NewSTSPreloadPolicy(debug bool, client *http.Client, listDownload FuncPreloadList, cfg *config.Map) (*stsPreloadPolicy, error) {
p := &stsPreloadPolicy{
log: log.Logger{Name: "remote/preload", Debug: debug},
updaterStop: make(chan struct{}),
client: client,
listDownload: preload.Download,
}
var sourcePath string
cfg.String("source", false, false, "eff", &sourcePath)
cfg.Bool("enforceTesting", false, true, &p.enforceTesting)
if _, err := cfg.Process(); err != nil {
return nil, err
}
var err error
p.l, err = p.load(client, listDownload, sourcePath)
if err != nil {
return nil, err
}
p.sourcePath = sourcePath
go p.updater()
return p, nil
}
func (p *stsPreloadPolicy) load(client *http.Client, listDownload FuncPreloadList, sourcePath string) (*preload.List, error) {
var (
l *preload.List
err error
)
src := preload.Source{
ListURI: sourcePath,
}
switch {
case strings.HasPrefix(sourcePath, "file://"):
// Load list from FS.
path := strings.TrimPrefix(sourcePath, "file://")
f, err := os.Open(path)
if err != nil {
return nil, err
}
// If the list is provided by an external FS source, it is its
// responsibility to make sure it is valid. Hard-fail if it is not.
l, err = preload.Read(f)
if err != nil {
return nil, fmt.Errorf("remote/preload: %w", err)
}
defer f.Close()
p.log.DebugMsg("loaded list from FS", "entries", len(l.Policies), "path", path)
case sourcePath == "eff":
src = preload.STARTTLSEverywhere
fallthrough
case strings.HasPrefix(sourcePath, "http://"):
// XXX: Only for testing, remove later.
fallthrough
case strings.HasPrefix(sourcePath, "https://"):
// Download list using HTTPS.
// TODO: Cache on disk and update it asynchronously to reduce start-up
// time. This will also reduce persistent attacker ability to prevent
// list (re-)discovery.
p.log.DebugMsg("downloading list", "uri", sourcePath)
l, err = listDownload(client, src)
if err != nil {
return nil, fmt.Errorf("remote/preload: %w", err)
}
p.log.DebugMsg("downloaded list", "entries", len(l.Policies), "uri", sourcePath)
default:
return nil, fmt.Errorf("remote/preload: unknown list source or unsupported schema: %v", sourcePath)
}
return l, nil
}
func (p *stsPreloadPolicy) updater() {
for {
updateDelay := time.Until(time.Time(p.l.Expires)) - preloadUpdateGrace
if updateDelay <= 0 {
updateDelay = preloadUpdateCooldown
}
// TODO: Increase update delay for multiple failures.
t := time.NewTimer(updateDelay)
p.log.DebugMsg("sleeping to update", "delay", updateDelay)
select {
case <-t.C:
// Attempt to update the list.
newList, err := p.load(p.client, p.listDownload, p.sourcePath)
if err != nil {
p.log.Error("failed to update list", err)
continue
}
if err := p.update(newList); err != nil {
p.log.Error("failed to update list", err)
continue
}
p.log.DebugMsg("updated list", "entries", len(newList.Policies))
case <-p.updaterStop:
t.Stop()
p.updaterStop <- struct{}{}
return
}
}
}
func (p *stsPreloadPolicy) update(newList *preload.List) error {
if newList.Expired() {
return exterrors.WithFields(errors.New("the new STARTLS Everywhere list is expired"),
map[string]interface{}{
"timestamp": newList.Timestamp,
"expires": newList.Expires,
})
}
p.lLock.Lock()
defer p.lLock.Unlock()
if time.Time(newList.Timestamp).Before(time.Time(p.l.Timestamp)) {
return exterrors.WithFields(errors.New("the new list is older than the currently used one"),
map[string]interface{}{
"old_timestamp": p.l.Timestamp,
"new_timestamp": newList.Timestamp,
"expires": newList.Expires,
})
}
p.l = newList
return nil
}
type preloadDelivery struct {
*stsPreloadPolicy
mtastsPresent bool
}
func (p *stsPreloadPolicy) Start(*module.MsgMetadata) DeliveryPolicy {
return &preloadDelivery{stsPreloadPolicy: p}
}
func (p *preloadDelivery) Reset(*module.MsgMetadata) {}
func (p *preloadDelivery) PrepareDomain(ctx context.Context, domain string) {}
func (p *preloadDelivery) PrepareConn(ctx context.Context, mx string) {}
func (p *preloadDelivery) CheckMX(ctx context.Context, mxLevel MXLevel, domain, mx string, dnssec bool) (MXLevel, error) {
// MTA-STS policy was discovered and took effect already. Do not use
// preload list.
if mxLevel == MX_MTASTS {
p.mtastsPresent = true
return MXNone, nil
}
p.lLock.RLock()
defer p.lLock.RUnlock()
if p.l.Expired() {
p.log.Msg("STARTTLS Everywhere list is expired, ignoring")
return MXNone, nil
}
ent, ok := p.l.Lookup(domain)
if !ok {
p.log.DebugMsg("no entry", "domain", domain)
return MXNone, nil
}
sts := ent.STS(p.l)
if !sts.Match(mx) {
if sts.Mode == mtasts.ModeEnforce || p.enforceTesting {
return MXNone, &exterrors.SMTPError{
Code: 550,
EnhancedCode: exterrors.EnhancedCode{5, 7, 0},
Message: "Failed to estabilish the MX record authenticity (STARTTLS Everywhere)",
}
}
p.log.Msg("MX does not match published non-enforced STARTLS Everywhere entry", "mx", mx, "domain", domain)
return MXNone, nil
}
p.log.DebugMsg("MX OK", "domain", domain)
return MX_MTASTS, nil
}
func (p *preloadDelivery) CheckConn(ctx context.Context, mxLevel MXLevel, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error) {
// MTA-STS policy was discovered and took effect already. Do not use
// preload list. We cannot check level for MX_MTASTS because we can set
// it too in CheckMX.
if p.mtastsPresent {
return TLSNone, nil
}
p.lLock.RLock()
defer p.lLock.RUnlock()
if p.l.Expired() {
p.log.Msg("STARTTLS Everywhere list is expired, ignoring")
return TLSNone, nil
}
ent, ok := p.l.Lookup(domain)
if !ok {
p.log.DebugMsg("no entry", "domain", domain)
return TLSNone, nil
}
if ent.Mode != mtasts.ModeEnforce && !p.enforceTesting {
return TLSNone, nil
}
if !tlsState.HandshakeComplete {
return TLSNone, &exterrors.SMTPError{
Code: 451,
EnhancedCode: exterrors.EnhancedCode{4, 7, 1},
Message: "TLS is required but unavailable or failed (STARTTLS Everywhere)",
}
}
if tlsState.VerifiedChains == nil {
return TLSNone, &exterrors.SMTPError{
Code: 451,
EnhancedCode: exterrors.EnhancedCode{4, 7, 1},
Message: "Recipient server TLS certificate is not trusted but " +
"authentication is required by STARTTLS Everywhere list",
Misc: map[string]interface{}{
"tls_level": tlsLevel,
},
}
}
p.log.DebugMsg("TLS OK", "domain", domain)
return TLSNone, nil
}
func (p *stsPreloadPolicy) Close() error {
p.updaterStop <- struct{}{}
<-p.updaterStop
return nil
}
type dnssecPolicy struct{}
func (dnssecPolicy) Start(*module.MsgMetadata) DeliveryPolicy {
@ -276,7 +553,7 @@ func (dnssecPolicy) CheckMX(ctx context.Context, mxLevel MXLevel, domain, mx str
return MXNone, nil
}
func (dnssecPolicy) CheckConn(ctx context.Context, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error) {
func (dnssecPolicy) CheckConn(ctx context.Context, mxLevel MXLevel, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error) {
return TLSNone, nil
}
@ -341,7 +618,7 @@ func (c *daneDelivery) CheckMX(ctx context.Context, mxLevel MXLevel, domain, mx
return MXNone, nil
}
func (c *daneDelivery) CheckConn(ctx context.Context, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error) {
func (c *daneDelivery) CheckConn(ctx context.Context, mxLevel MXLevel, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error) {
// No DNSSEC support.
if c.c.extResolver == nil {
return TLSNone, nil
@ -449,7 +726,7 @@ func (l localPolicy) CheckMX(ctx context.Context, mxLevel MXLevel, domain, mx st
return MXNone, nil
}
func (l localPolicy) CheckConn(ctx context.Context, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error) {
func (l localPolicy) CheckConn(ctx context.Context, mxLevel MXLevel, tlsLevel TLSLevel, domain, mx string, tlsState tls.ConnectionState) (TLSLevel, error) {
if tlsLevel < l.minTLSLevel {
return TLSNone, &exterrors.SMTPError{
Code: 451,
@ -465,6 +742,8 @@ func (l localPolicy) CheckConn(ctx context.Context, tlsLevel TLSLevel, domain, m
func (rt *Target) policyMatcher(name string) func(*config.Map, *config.Node) (interface{}, error) {
return func(m *config.Map, node *config.Node) (interface{}, error) {
// TODO: Fix rt.Log.Debug propagation. This function is called in
// arbitrary order before or after the 'debug' directive is handled.
switch len(node.Args) {
case 0:
case 1:
@ -482,17 +761,20 @@ func (rt *Target) policyMatcher(name string) func(*config.Map, *config.Node) (in
)
switch name {
case "mtasts":
policy, err = NewMTASTSPolicy(rt.resolver, rt.Log.Debug, config.NewMap(m.Globals, node))
policy, err = NewMTASTSPolicy(rt.resolver, log.DefaultLogger.Debug, config.NewMap(m.Globals, node))
case "dane":
if node.Children != nil {
return nil, m.MatchErr("policy offers no additional configuration")
}
policy = NewDANEPolicy(rt.extResolver, rt.Log.Debug)
policy = NewDANEPolicy(rt.extResolver, log.DefaultLogger.Debug)
case "dnssec":
if node.Children != nil {
return nil, m.MatchErr("policy offers no additional configuration")
}
policy = &dnssecPolicy{}
case "sts_preload":
policy, err = NewSTSPreloadPolicy(log.DefaultLogger.Debug, http.DefaultClient, preload.Download,
config.NewMap(m.Globals, node))
case "local":
policy, err = NewLocalPolicy(config.NewMap(m.Globals, node))
default:
@ -508,17 +790,22 @@ func (rt *Target) policyMatcher(name string) func(*config.Map, *config.Node) (in
func (rt *Target) defaultPolicy(globals map[string]interface{}, name string) func() (interface{}, error) {
return func() (interface{}, error) {
// TODO: Fix rt.Log.Debug propagation. This function is called in
// arbitrary order before or after the 'debug' directive is handled.
var (
policy Policy
err error
)
switch name {
case "mtasts":
policy, err = NewMTASTSPolicy(rt.resolver, rt.Log.Debug, config.NewMap(globals, &config.Node{}))
policy, err = NewMTASTSPolicy(rt.resolver, log.DefaultLogger.Debug, config.NewMap(globals, &config.Node{}))
case "dane":
policy = NewDANEPolicy(rt.extResolver, rt.Log.Debug)
policy = NewDANEPolicy(rt.extResolver, log.DefaultLogger.Debug)
case "dnssec":
policy = &dnssecPolicy{}
case "sts_preload":
policy, err = NewSTSPreloadPolicy(log.DefaultLogger.Debug, http.DefaultClient, preload.Download,
config.NewMap(globals, &config.Node{}))
case "local":
policy, err = NewLocalPolicy(config.NewMap(globals, &config.Node{}))
default: