target/smtp: Check-in accidentally reverted attempt_starttls changes

This commit is contained in:
fox.cpp 2025-01-25 14:51:35 +03:00
parent cff6cfaca6
commit be0ec6b7cf
No known key found for this signature in database
GPG key ID: 5B991F6215D2FCC0
5 changed files with 70 additions and 116 deletions

View file

@ -20,6 +20,7 @@ package config
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
@ -305,6 +306,16 @@ func (m *Map) DataSize(name string, inheritGlobal, required bool, defaultVal int
}, store)
}
func ParseBool(s string) (bool, error) {
switch strings.ToLower(s) {
case "1", "true", "on", "yes":
return true, nil
case "0", "false", "off", "no":
return false, nil
}
return false, fmt.Errorf("bool argument should be 'yes' or 'no'")
}
// Bool maps presence of some configuration directive to a boolean variable.
// Additionally, 'name yes' and 'name no' are mapped to true and false
// correspondingly.
@ -327,13 +338,11 @@ func (m *Map) Bool(name string, inheritGlobal, defaultVal bool, store *bool) {
return nil, NodeErr(node, "expected exactly 1 argument")
}
switch strings.ToLower(node.Args[0]) {
case "1", "true", "on", "yes":
return true, nil
case "0", "false", "off", "no":
return false, nil
b, err := ParseBool(node.Args[0])
if err != nil {
return nil, NodeErr(node, "bool argument should be 'yes' or 'no'")
}
return nil, NodeErr(node, "bool argument should be 'yes' or 'no'")
return b, nil
}, store)
}

View file

@ -259,12 +259,15 @@ func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint,
return false, nil, nil, err
}
if endp.IsTLS() || !starttls {
return endp.IsTLS(), cl, conn, nil
if !starttls {
return false, cl, conn, nil
}
if ok, _ := cl.Extension("STARTTLS"); !ok {
return false, cl, conn, nil
if err := cl.Quit(); err != nil {
cl.Close()
}
return false, nil, nil, fmt.Errorf("TLS required but unsupported by downstream")
}
cfg := tlsConfig.Clone()

View file

@ -29,7 +29,6 @@ package smtp_downstream
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"runtime/trace"
@ -54,12 +53,11 @@ type Downstream struct {
lmtp bool
targetsArg []string
requireTLS bool
attemptStartTLS bool
hostname string
endpoints []config.Endpoint
saslFactory saslClientFactory
tlsConfig tls.Config
starttls bool
hostname string
endpoints []config.Endpoint
saslFactory saslClientFactory
tlsConfig tls.Config
connectTimeout time.Duration
commandTimeout time.Duration
@ -89,10 +87,34 @@ func NewDownstream(modName, instName string, _, inlineArgs []string) (module.Mod
}
func (u *Downstream) Init(cfg *config.Map) error {
var attemptTLS *bool
var targetsArg []string
cfg.Bool("debug", true, false, &u.log.Debug)
cfg.Bool("require_tls", false, false, &u.requireTLS)
cfg.Bool("attempt_starttls", false, !u.lmtp, &u.attemptStartTLS)
cfg.Callback("require_tls", func(m *config.Map, node config.Node) error {
u.log.Msg("require_tls directive is deprecated and ignored")
return nil
})
cfg.Callback("attempt_starttls", func(m *config.Map, node config.Node) error {
u.log.Msg("attempt_starttls directive is deprecated and equivalent to starttls")
if len(node.Args) == 0 {
trueVal := true
attemptTLS = &trueVal
return nil
}
if len(node.Args) != 1 {
return config.NodeErr(node, "expected exactly 1 argument")
}
b, err := config.ParseBool(node.Args[0])
if err != nil {
return err
}
attemptTLS = &b
return nil
})
cfg.Bool("starttls", false, !u.lmtp, &u.starttls)
cfg.String("hostname", true, true, "", &u.hostname)
cfg.StringList("targets", false, false, nil, &targetsArg)
cfg.Custom("auth", false, false, func() (interface{}, error) {
@ -109,6 +131,10 @@ func (u *Downstream) Init(cfg *config.Map) error {
return err
}
if attemptTLS != nil {
u.starttls = *attemptTLS
}
// INTERNATIONALIZATION: See RFC 6531 Section 3.7.1.
var err error
u.hostname, err = idna.ToASCII(u.hostname)
@ -201,14 +227,11 @@ func (d *delivery) connect(ctx context.Context) error {
}
for _, endp := range d.u.endpoints {
var (
didTLS bool
err error
)
var err error
if d.u.lmtp {
didTLS, err = conn.ConnectLMTP(ctx, endp, d.u.attemptStartTLS, &d.u.tlsConfig)
_, err = conn.ConnectLMTP(ctx, endp, d.u.starttls, &d.u.tlsConfig)
} else {
didTLS, err = conn.Connect(ctx, endp, d.u.attemptStartTLS, &d.u.tlsConfig)
_, err = conn.Connect(ctx, endp, d.u.starttls, &d.u.tlsConfig)
}
if err != nil {
if len(d.u.endpoints) != 1 {
@ -220,12 +243,6 @@ func (d *delivery) connect(ctx context.Context) error {
d.log.DebugMsg("connected", "downstream_server", conn.ServerName())
if !didTLS && d.u.requireTLS {
conn.Close()
lastErr = errors.New("TLS is required, but unsupported by downstream")
continue
}
lastErr = nil
break
}

View file

@ -207,7 +207,7 @@ func TestDownstreamDelivery_MAILErr(t *testing.T) {
testutils.CheckSMTPErr(t, err, 550, exterrors.EnhancedCode{5, 1, 2}, "Hey")
}
func TestDownstreamDelivery_AttemptTLS(t *testing.T) {
func TestDownstreamDelivery_StartTLS(t *testing.T) {
clientCfg, be, srv := testutils.SMTPServerSTARTTLS(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)
@ -221,9 +221,9 @@ func TestDownstreamDelivery_AttemptTLS(t *testing.T) {
Port: testPort,
},
},
tlsConfig: *clientCfg.Clone(),
attemptStartTLS: true,
log: testutils.Logger(t, "target.smtp"),
tlsConfig: *clientCfg.Clone(),
starttls: true,
log: testutils.Logger(t, "target.smtp"),
}
testutils.DoTestDelivery(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"})
@ -235,85 +235,7 @@ func TestDownstreamDelivery_AttemptTLS(t *testing.T) {
}
}
func TestDownstreamDelivery_AttemptTLS_Fallback(t *testing.T) {
be, srv := testutils.SMTPServer(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)
mod := &Downstream{
hostname: "mx.example.invalid",
endpoints: []config.Endpoint{
{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
},
},
attemptStartTLS: true,
log: testutils.Logger(t, "target.smtp"),
}
testutils.DoTestDelivery(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"})
be.CheckMsg(t, 0, "test@example.invalid", []string{"rcpt@example.invalid"})
}
func TestDownstreamDelivery_RequireTLS(t *testing.T) {
clientCfg, be, srv := testutils.SMTPServerSTARTTLS(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)
mod := &Downstream{
hostname: "mx.example.invalid",
endpoints: []config.Endpoint{
{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
},
},
tlsConfig: *clientCfg.Clone(),
attemptStartTLS: true,
requireTLS: true,
log: testutils.Logger(t, "target.smtp"),
}
testutils.DoTestDelivery(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"})
be.CheckMsg(t, 0, "test@example.invalid", []string{"rcpt@example.invalid"})
tlsState, ok := be.Messages[0].Conn.TLSConnectionState()
if !ok || !tlsState.HandshakeComplete {
t.Fatal("Message was not delivered over TLS")
}
}
func TestDownstreamDelivery_RequireTLS_Implicit(t *testing.T) {
clientCfg, be, srv := testutils.SMTPServerTLS(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)
mod := &Downstream{
hostname: "mx.example.invalid",
endpoints: []config.Endpoint{
{
Scheme: "tls",
Host: "127.0.0.1",
Port: testPort,
},
},
tlsConfig: *clientCfg.Clone(),
attemptStartTLS: true,
requireTLS: true,
log: testutils.Logger(t, "target.smtp"),
}
testutils.DoTestDelivery(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"})
be.CheckMsg(t, 0, "test@example.invalid", []string{"rcpt@example.invalid"})
tlsState, ok := be.Messages[0].Conn.TLSConnectionState()
if !ok || !tlsState.HandshakeComplete {
t.Fatal("Message was not delivered over TLS")
}
}
func TestDownstreamDelivery_RequireTLS_Fail(t *testing.T) {
func TestDownstreamDelivery_StartTLS_NoFallback(t *testing.T) {
_, srv := testutils.SMTPServer(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)
@ -327,9 +249,8 @@ func TestDownstreamDelivery_RequireTLS_Fail(t *testing.T) {
Port: testPort,
},
},
attemptStartTLS: true,
requireTLS: true,
log: testutils.Logger(t, "target.smtp"),
starttls: true,
log: testutils.Logger(t, "target.smtp"),
}
_, err := testutils.DoTestDeliveryErr(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"})

View file

@ -40,6 +40,10 @@ func TestDownstreamDelivery_EHLO_ALabel(t *testing.T) {
Name: "hostname",
Args: []string{"тест.invalid"},
},
{
Name: "starttls",
Args: []string{"no"},
},
},
})); err != nil {
t.Fatal(err)