File: smtp.go

package info (click to toggle)
golang-github-nicholas-fedor-shoutrrr 0.8.17-1
  • links: PTS, VCS
  • area: main
  • in suites: forky
  • size: 4,332 kB
  • sloc: sh: 61; makefile: 5
file content (402 lines) | stat: -rw-r--r-- 10,695 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
package smtp

import (
	"context"
	"crypto/rand"
	"crypto/tls"
	"encoding/hex"
	"errors"
	"fmt"
	"io"
	"net"
	"net/smtp"
	"net/url"
	"os"
	"strconv"
	"time"

	"github.com/nicholas-fedor/shoutrrr/internal/failures"
	"github.com/nicholas-fedor/shoutrrr/pkg/format"
	"github.com/nicholas-fedor/shoutrrr/pkg/services/standard"
	"github.com/nicholas-fedor/shoutrrr/pkg/types"
)

const (
	contentHTML      = "text/html; charset=\"UTF-8\""
	contentPlain     = "text/plain; charset=\"UTF-8\""
	contentMultipart = "multipart/alternative; boundary=%s"
	DefaultSMTPPort  = 25 // DefaultSMTPPort is the standard port for SMTP communication.
	boundaryByteLen  = 8  // boundaryByteLen is the number of bytes for the multipart boundary.
)

// ErrNoAuth is a sentinel error indicating no authentication is required.
var ErrNoAuth = errors.New("no authentication required")

// Static errors for SMTP operations.
var (
	ErrServerNoStartTLS = errors.New("server does not support StartTLS")
)

// Service sends notifications to given email addresses via SMTP.
type Service struct {
	standard.Standard
	standard.Templater
	Config            *Config
	multipartBoundary string
	propKeyResolver   format.PropKeyResolver
}

// Initialize loads ServiceConfig from configURL and sets logger for this Service.
func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) error {
	service.SetLogger(logger)
	service.Config = &Config{
		Port:        DefaultSMTPPort,
		ToAddresses: nil,
		Subject:     "",
		Auth:        AuthTypes.Unknown,
		UseStartTLS: true,
		UseHTML:     false,
		Encryption:  EncMethods.Auto,
		ClientHost:  "localhost",
	}

	pkr := format.NewPropKeyResolver(service.Config)

	if err := service.Config.setURL(&pkr, configURL); err != nil {
		return err
	}

	if service.Config.Auth == AuthTypes.Unknown {
		if service.Config.Username != "" {
			service.Config.Auth = AuthTypes.Plain
		} else {
			service.Config.Auth = AuthTypes.None
		}
	}

	service.propKeyResolver = pkr

	return nil
}

// GetID returns the service identifier.
func (service *Service) GetID() string {
	return Scheme
}

// Send sends a notification message to email recipients.
func (service *Service) Send(message string, params *types.Params) error {
	config := service.Config.Clone()
	if err := service.propKeyResolver.UpdateConfigFromParams(&config, params); err != nil {
		return fail(FailApplySendParams, err)
	}

	ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
	defer cancel()

	client, err := getClientConnection(ctx, service.Config)
	if err != nil {
		return fail(FailGetSMTPClient, err)
	}

	return service.doSend(client, message, &config)
}

// getClientConnection establishes a connection to the SMTP server using the provided configuration.
func getClientConnection(ctx context.Context, config *Config) (*smtp.Client, error) {
	var (
		conn net.Conn
		err  error
	)

	addr := net.JoinHostPort(config.Host, strconv.FormatUint(uint64(config.Port), 10))

	if useImplicitTLS(config.Encryption, config.Port) {
		dialer := &tls.Dialer{
			Config: &tls.Config{
				ServerName: config.Host,
				MinVersion: tls.VersionTLS12, // Enforce TLS 1.2 or higher
			},
		}
		conn, err = dialer.DialContext(ctx, "tcp", addr)
	} else {
		dialer := &net.Dialer{}
		conn, err = dialer.DialContext(ctx, "tcp", addr)
	}

	if err != nil {
		return nil, fail(FailConnectToServer, err)
	}

	client, err := smtp.NewClient(conn, config.Host)
	if err != nil {
		return nil, fail(FailCreateSMTPClient, err)
	}

	return client, nil
}

// doSend sends an email message using the provided SMTP client and configuration.
func (service *Service) doSend(client *smtp.Client, message string, config *Config) failure {
	config.FixEmailTags()

	clientHost := service.resolveClientHost(config)

	if err := client.Hello(clientHost); err != nil {
		return fail(FailHandshake, err)
	}

	if config.UseHTML {
		b := make([]byte, boundaryByteLen)
		if _, err := rand.Read(b); err != nil {
			return fail(FailUnknown, err) // Fallback error for rare case
		}

		service.multipartBoundary = hex.EncodeToString(b)
	}

	if config.UseStartTLS && !useImplicitTLS(config.Encryption, config.Port) {
		if supported, _ := client.Extension("StartTLS"); !supported {
			if config.RequireStartTLS {
				return fail(FailEnableStartTLS, ErrServerNoStartTLS)
			}

			service.Logf(
				"Warning: StartTLS enabled, but server does not support it. Connection is unencrypted",
			)
		} else {
			if err := client.StartTLS(&tls.Config{
				ServerName: config.Host,
				MinVersion: tls.VersionTLS12,
				MaxVersion: tls.VersionTLS13,
			}); err != nil {
				return fail(FailEnableStartTLS, err)
			}
		}
	}

	if auth, err := service.getAuth(config); err != nil {
		return err
	} else if auth != nil {
		if err := client.Auth(auth); err != nil {
			return fail(FailAuthenticating, err)
		}
	}

	var errs []error

	for _, toAddress := range config.ToAddresses {
		if err := service.sendToRecipient(client, toAddress, config, message); err != nil {
			errs = append(errs, fail(FailSendRecipient, err, toAddress))
			service.Logf("Failed to send to %q: %v", toAddress, err)

			continue
		}

		service.Logf("Mail successfully sent to %q!", toAddress)
	}

	// Send the QUIT command and close the connection.
	if err := client.Quit(); err != nil {
		errs = append(errs, fail(FailClosingSession, err))
	}

	if len(errs) > 0 {
		return failures.Wrap(
			"failed to send to some recipients",
			FailSendRecipient,
			errors.Join(errs...),
		)
	}

	return nil
}

// resolveClientHost determines the client hostname to use in the SMTP handshake.
func (service *Service) resolveClientHost(config *Config) string {
	if config.ClientHost != "auto" {
		return config.ClientHost
	}

	hostname, err := os.Hostname()
	if err != nil {
		service.Logf("Failed to get hostname, falling back to localhost: %v", err)

		return "localhost"
	}

	return hostname
}

// getAuth returns the appropriate SMTP authentication mechanism based on the configuration.
//
//nolint:exhaustive,nilnil
func (service *Service) getAuth(config *Config) (smtp.Auth, failure) {
	switch config.Auth {
	case AuthTypes.None:
		return nil, nil // No auth required, proceed without error
	case AuthTypes.Plain:
		return smtp.PlainAuth("", config.Username, config.Password, config.Host), nil
	case AuthTypes.CRAMMD5:
		return smtp.CRAMMD5Auth(config.Username, config.Password), nil
	case AuthTypes.OAuth2:
		return OAuth2Auth(config.Username, config.Password), nil
	case AuthTypes.Unknown:
		return nil, fail(FailAuthType, nil, config.Auth.String())
	default:
		return nil, fail(FailAuthType, nil, config.Auth.String())
	}
}

// sendToRecipient sends an email to a single recipient using the provided SMTP client.
func (service *Service) sendToRecipient(
	client *smtp.Client,
	toAddress string,
	config *Config,
	message string,
) failure {
	// Set the sender and recipient first
	if err := client.Mail(config.FromAddress); err != nil {
		return fail(FailSetSender, err)
	}

	if err := client.Rcpt(toAddress); err != nil {
		return fail(FailSetRecipient, err)
	}

	// Send the email body.
	writeCloser, err := client.Data()
	if err != nil {
		return fail(FailOpenDataStream, err)
	}

	if err := writeHeaders(writeCloser, service.getHeaders(toAddress, config.Subject)); err != nil {
		return err
	}

	var ferr failure
	if config.UseHTML {
		ferr = service.writeMultipartMessage(writeCloser, message)
	} else {
		ferr = service.writeMessagePart(writeCloser, message, "plain")
	}

	if ferr != nil {
		return ferr
	}

	if err = writeCloser.Close(); err != nil {
		return fail(FailCloseDataStream, err)
	}

	return nil
}

// getHeaders constructs email headers for the SMTP message.
func (service *Service) getHeaders(toAddress string, subject string) map[string]string {
	conf := service.Config

	var contentType string
	if conf.UseHTML {
		contentType = fmt.Sprintf(contentMultipart, service.multipartBoundary)
	} else {
		contentType = contentPlain
	}

	return map[string]string{
		"Subject":      subject,
		"Date":         time.Now().Format(time.RFC1123Z),
		"To":           toAddress,
		"From":         fmt.Sprintf("%s <%s>", conf.FromName, conf.FromAddress),
		"MIME-version": "1.0",
		"Content-Type": contentType,
	}
}

// writeMultipartMessage writes a multipart email message to the provided writer.
func (service *Service) writeMultipartMessage(writeCloser io.WriteCloser, message string) failure {
	if err := writeMultipartHeader(writeCloser, service.multipartBoundary, contentPlain); err != nil {
		return fail(FailPlainHeader, err)
	}

	if err := service.writeMessagePart(writeCloser, message, "plain"); err != nil {
		return err
	}

	if err := writeMultipartHeader(writeCloser, service.multipartBoundary, contentHTML); err != nil {
		return fail(FailHTMLHeader, err)
	}

	if err := service.writeMessagePart(writeCloser, message, "HTML"); err != nil {
		return err
	}

	if err := writeMultipartHeader(writeCloser, service.multipartBoundary, ""); err != nil {
		return fail(FailMultiEndHeader, err)
	}

	return nil
}

// writeMessagePart writes a single part of an email message using the specified template.
func (service *Service) writeMessagePart(
	writeCloser io.WriteCloser,
	message string,
	template string,
) failure {
	if tpl, found := service.GetTemplate(template); found {
		data := make(map[string]string)

		data["message"] = message
		if err := tpl.Execute(writeCloser, data); err != nil {
			return fail(FailMessageTemplate, err)
		}
	} else {
		content := message
		if template == "HTML" {
			content = fmt.Sprintf("<pre>%s</pre>", message)
		}

		if _, err := fmt.Fprint(writeCloser, content); err != nil {
			return fail(FailMessageRaw, err)
		}
	}

	return nil
}

// writeMultipartHeader writes a multipart boundary header to the provided writer.
func writeMultipartHeader(writeCloser io.WriteCloser, boundary string, contentType string) error {
	suffix := "\n"
	if len(contentType) < 1 {
		suffix = "--"
	}

	if _, err := fmt.Fprintf(writeCloser, "\n\n--%s%s", boundary, suffix); err != nil {
		return fmt.Errorf("writing multipart boundary: %w", err)
	}

	if len(contentType) > 0 {
		if _, err := fmt.Fprintf(writeCloser, "Content-Type: %s\n\n", contentType); err != nil {
			return fmt.Errorf("writing content type header: %w", err)
		}
	}

	return nil
}

// writeHeaders writes email headers to the provided writer.
func writeHeaders(writeCloser io.WriteCloser, headers map[string]string) failure {
	for key, val := range headers {
		if _, err := fmt.Fprintf(writeCloser, "%s: %s\n", key, val); err != nil {
			return fail(FailWriteHeaders, err)
		}
	}

	_, err := fmt.Fprintln(writeCloser)
	if err != nil {
		return fail(FailWriteHeaders, err)
	}

	return nil
}