diff --git a/framework/config/map.go b/framework/config/map.go index 10b1762..c85c19f 100644 --- a/framework/config/map.go +++ b/framework/config/map.go @@ -20,6 +20,7 @@ package config import ( "errors" + "fmt" "reflect" "strconv" "strings" @@ -305,6 +306,16 @@ func (m *Map) DataSize(name string, inheritGlobal, required bool, defaultVal int }, store) } +func ParseBool(s string) (bool, error) { + switch strings.ToLower(s) { + case "1", "true", "on", "yes": + return true, nil + case "0", "false", "off", "no": + return false, nil + } + return false, fmt.Errorf("bool argument should be 'yes' or 'no'") +} + // Bool maps presence of some configuration directive to a boolean variable. // Additionally, 'name yes' and 'name no' are mapped to true and false // correspondingly. @@ -327,13 +338,11 @@ func (m *Map) Bool(name string, inheritGlobal, defaultVal bool, store *bool) { return nil, NodeErr(node, "expected exactly 1 argument") } - switch strings.ToLower(node.Args[0]) { - case "1", "true", "on", "yes": - return true, nil - case "0", "false", "off", "no": - return false, nil + b, err := ParseBool(node.Args[0]) + if err != nil { + return nil, NodeErr(node, "bool argument should be 'yes' or 'no'") } - return nil, NodeErr(node, "bool argument should be 'yes' or 'no'") + return b, nil }, store) } diff --git a/internal/smtpconn/smtpconn.go b/internal/smtpconn/smtpconn.go index 7f66bd2..ec42974 100644 --- a/internal/smtpconn/smtpconn.go +++ b/internal/smtpconn/smtpconn.go @@ -259,12 +259,15 @@ func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint, return false, nil, nil, err } - if endp.IsTLS() || !starttls { - return endp.IsTLS(), cl, conn, nil + if !starttls { + return false, cl, conn, nil } if ok, _ := cl.Extension("STARTTLS"); !ok { - return false, cl, conn, nil + if err := cl.Quit(); err != nil { + cl.Close() + } + return false, nil, nil, fmt.Errorf("TLS required but unsupported by downstream") } cfg := tlsConfig.Clone() diff --git a/internal/target/smtp/smtp_downstream.go b/internal/target/smtp/smtp_downstream.go index f03fc82..8880f6c 100644 --- a/internal/target/smtp/smtp_downstream.go +++ b/internal/target/smtp/smtp_downstream.go @@ -29,7 +29,6 @@ package smtp_downstream import ( "context" "crypto/tls" - "errors" "fmt" "net" "runtime/trace" @@ -54,12 +53,11 @@ type Downstream struct { lmtp bool targetsArg []string - requireTLS bool - attemptStartTLS bool - hostname string - endpoints []config.Endpoint - saslFactory saslClientFactory - tlsConfig tls.Config + starttls bool + hostname string + endpoints []config.Endpoint + saslFactory saslClientFactory + tlsConfig tls.Config connectTimeout time.Duration commandTimeout time.Duration @@ -89,10 +87,34 @@ func NewDownstream(modName, instName string, _, inlineArgs []string) (module.Mod } func (u *Downstream) Init(cfg *config.Map) error { + var attemptTLS *bool + var targetsArg []string cfg.Bool("debug", true, false, &u.log.Debug) - cfg.Bool("require_tls", false, false, &u.requireTLS) - cfg.Bool("attempt_starttls", false, !u.lmtp, &u.attemptStartTLS) + cfg.Callback("require_tls", func(m *config.Map, node config.Node) error { + u.log.Msg("require_tls directive is deprecated and ignored") + return nil + }) + cfg.Callback("attempt_starttls", func(m *config.Map, node config.Node) error { + u.log.Msg("attempt_starttls directive is deprecated and equivalent to starttls") + + if len(node.Args) == 0 { + trueVal := true + attemptTLS = &trueVal + return nil + } + if len(node.Args) != 1 { + return config.NodeErr(node, "expected exactly 1 argument") + } + + b, err := config.ParseBool(node.Args[0]) + if err != nil { + return err + } + attemptTLS = &b + return nil + }) + cfg.Bool("starttls", false, !u.lmtp, &u.starttls) cfg.String("hostname", true, true, "", &u.hostname) cfg.StringList("targets", false, false, nil, &targetsArg) cfg.Custom("auth", false, false, func() (interface{}, error) { @@ -109,6 +131,10 @@ func (u *Downstream) Init(cfg *config.Map) error { return err } + if attemptTLS != nil { + u.starttls = *attemptTLS + } + // INTERNATIONALIZATION: See RFC 6531 Section 3.7.1. var err error u.hostname, err = idna.ToASCII(u.hostname) @@ -201,14 +227,11 @@ func (d *delivery) connect(ctx context.Context) error { } for _, endp := range d.u.endpoints { - var ( - didTLS bool - err error - ) + var err error if d.u.lmtp { - didTLS, err = conn.ConnectLMTP(ctx, endp, d.u.attemptStartTLS, &d.u.tlsConfig) + _, err = conn.ConnectLMTP(ctx, endp, d.u.starttls, &d.u.tlsConfig) } else { - didTLS, err = conn.Connect(ctx, endp, d.u.attemptStartTLS, &d.u.tlsConfig) + _, err = conn.Connect(ctx, endp, d.u.starttls, &d.u.tlsConfig) } if err != nil { if len(d.u.endpoints) != 1 { @@ -220,12 +243,6 @@ func (d *delivery) connect(ctx context.Context) error { d.log.DebugMsg("connected", "downstream_server", conn.ServerName()) - if !didTLS && d.u.requireTLS { - conn.Close() - lastErr = errors.New("TLS is required, but unsupported by downstream") - continue - } - lastErr = nil break } diff --git a/internal/target/smtp/smtp_downstream_test.go b/internal/target/smtp/smtp_downstream_test.go index 011ca04..31ef295 100644 --- a/internal/target/smtp/smtp_downstream_test.go +++ b/internal/target/smtp/smtp_downstream_test.go @@ -207,7 +207,7 @@ func TestDownstreamDelivery_MAILErr(t *testing.T) { testutils.CheckSMTPErr(t, err, 550, exterrors.EnhancedCode{5, 1, 2}, "Hey") } -func TestDownstreamDelivery_AttemptTLS(t *testing.T) { +func TestDownstreamDelivery_StartTLS(t *testing.T) { clientCfg, be, srv := testutils.SMTPServerSTARTTLS(t, "127.0.0.1:"+testPort) defer srv.Close() defer testutils.CheckSMTPConnLeak(t, srv) @@ -221,9 +221,9 @@ func TestDownstreamDelivery_AttemptTLS(t *testing.T) { Port: testPort, }, }, - tlsConfig: *clientCfg.Clone(), - attemptStartTLS: true, - log: testutils.Logger(t, "target.smtp"), + tlsConfig: *clientCfg.Clone(), + starttls: true, + log: testutils.Logger(t, "target.smtp"), } testutils.DoTestDelivery(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"}) @@ -235,85 +235,7 @@ func TestDownstreamDelivery_AttemptTLS(t *testing.T) { } } -func TestDownstreamDelivery_AttemptTLS_Fallback(t *testing.T) { - be, srv := testutils.SMTPServer(t, "127.0.0.1:"+testPort) - defer srv.Close() - defer testutils.CheckSMTPConnLeak(t, srv) - - mod := &Downstream{ - hostname: "mx.example.invalid", - endpoints: []config.Endpoint{ - { - Scheme: "tcp", - Host: "127.0.0.1", - Port: testPort, - }, - }, - attemptStartTLS: true, - log: testutils.Logger(t, "target.smtp"), - } - - testutils.DoTestDelivery(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"}) - be.CheckMsg(t, 0, "test@example.invalid", []string{"rcpt@example.invalid"}) -} - -func TestDownstreamDelivery_RequireTLS(t *testing.T) { - clientCfg, be, srv := testutils.SMTPServerSTARTTLS(t, "127.0.0.1:"+testPort) - defer srv.Close() - defer testutils.CheckSMTPConnLeak(t, srv) - - mod := &Downstream{ - hostname: "mx.example.invalid", - endpoints: []config.Endpoint{ - { - Scheme: "tcp", - Host: "127.0.0.1", - Port: testPort, - }, - }, - tlsConfig: *clientCfg.Clone(), - attemptStartTLS: true, - requireTLS: true, - log: testutils.Logger(t, "target.smtp"), - } - - testutils.DoTestDelivery(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"}) - be.CheckMsg(t, 0, "test@example.invalid", []string{"rcpt@example.invalid"}) - tlsState, ok := be.Messages[0].Conn.TLSConnectionState() - if !ok || !tlsState.HandshakeComplete { - t.Fatal("Message was not delivered over TLS") - } -} - -func TestDownstreamDelivery_RequireTLS_Implicit(t *testing.T) { - clientCfg, be, srv := testutils.SMTPServerTLS(t, "127.0.0.1:"+testPort) - defer srv.Close() - defer testutils.CheckSMTPConnLeak(t, srv) - - mod := &Downstream{ - hostname: "mx.example.invalid", - endpoints: []config.Endpoint{ - { - Scheme: "tls", - Host: "127.0.0.1", - Port: testPort, - }, - }, - tlsConfig: *clientCfg.Clone(), - attemptStartTLS: true, - requireTLS: true, - log: testutils.Logger(t, "target.smtp"), - } - - testutils.DoTestDelivery(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"}) - be.CheckMsg(t, 0, "test@example.invalid", []string{"rcpt@example.invalid"}) - tlsState, ok := be.Messages[0].Conn.TLSConnectionState() - if !ok || !tlsState.HandshakeComplete { - t.Fatal("Message was not delivered over TLS") - } -} - -func TestDownstreamDelivery_RequireTLS_Fail(t *testing.T) { +func TestDownstreamDelivery_StartTLS_NoFallback(t *testing.T) { _, srv := testutils.SMTPServer(t, "127.0.0.1:"+testPort) defer srv.Close() defer testutils.CheckSMTPConnLeak(t, srv) @@ -327,9 +249,8 @@ func TestDownstreamDelivery_RequireTLS_Fail(t *testing.T) { Port: testPort, }, }, - attemptStartTLS: true, - requireTLS: true, - log: testutils.Logger(t, "target.smtp"), + starttls: true, + log: testutils.Logger(t, "target.smtp"), } _, err := testutils.DoTestDeliveryErr(t, mod, "test@example.invalid", []string{"rcpt@example.invalid"}) diff --git a/internal/target/smtp/smtputf8_test.go b/internal/target/smtp/smtputf8_test.go index d99eaf1..74aae23 100644 --- a/internal/target/smtp/smtputf8_test.go +++ b/internal/target/smtp/smtputf8_test.go @@ -40,6 +40,10 @@ func TestDownstreamDelivery_EHLO_ALabel(t *testing.T) { Name: "hostname", Args: []string{"ั‚ะตัั‚.invalid"}, }, + { + Name: "starttls", + Args: []string{"no"}, + }, }, })); err != nil { t.Fatal(err)