File: server_middleware.go

package info (click to toggle)
ntfy 2.11.0-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 19,364 kB
  • sloc: javascript: 16,782; makefile: 282; sh: 105; php: 21; python: 19
file content (132 lines) | stat: -rw-r--r-- 3,488 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package server

import (
	"net/http"

	"heckel.io/ntfy/v2/util"
)

type contextKey int

const (
	contextRateVisitor contextKey = iota + 2586
	contextTopic
	contextMatrixPushKey
)

func (s *Server) limitRequests(next handleFunc) handleFunc {
	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
			return next(w, r, v)
		} else if !v.RequestAllowed() {
			return errHTTPTooManyRequestsLimitRequests
		}
		return next(w, r, v)
	}
}

// limitRequestsWithTopic limits requests with a topic and stores the rate-limiting-subscriber and topic into request.Context
func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		t, err := s.topicFromPath(r.URL.Path)
		if err != nil {
			return err
		}
		vrate := v
		if rateVisitor := t.RateVisitor(); rateVisitor != nil {
			vrate = rateVisitor
		}
		r = withContext(r, map[contextKey]any{
			contextRateVisitor: vrate,
			contextTopic:       t,
		})
		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
			return next(w, r, v)
		} else if !vrate.RequestAllowed() {
			return errHTTPTooManyRequestsLimitRequests
		}
		return next(w, r, v)
	}
}

func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		if s.config.WebRoot == "" {
			return errHTTPNotFound
		}
		return next(w, r, v)
	}
}

func (s *Server) ensureWebPushEnabled(next handleFunc) handleFunc {
	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		if s.config.WebRoot == "" || s.config.WebPushPublicKey == "" {
			return errHTTPNotFound
		}
		return next(w, r, v)
	}
}

func (s *Server) ensureUserManager(next handleFunc) handleFunc {
	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		if s.userManager == nil {
			return errHTTPNotFound
		}
		return next(w, r, v)
	}
}

func (s *Server) ensureUser(next handleFunc) handleFunc {
	return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		if v.User() == nil {
			return errHTTPUnauthorized
		}
		return next(w, r, v)
	})
}

func (s *Server) ensureAdmin(next handleFunc) handleFunc {
	return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		if !v.User().IsAdmin() {
			return errHTTPUnauthorized
		}
		return next(w, r, v)
	})
}

func (s *Server) ensureCallsEnabled(next handleFunc) handleFunc {
	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		if s.config.TwilioAccount == "" || s.userManager == nil {
			return errHTTPNotFound
		}
		return next(w, r, v)
	}
}

func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		if s.config.StripeSecretKey == "" || s.stripe == nil {
			return errHTTPNotFound
		}
		return next(w, r, v)
	}
}

func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
	return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		if v.User().Billing.StripeCustomerID == "" {
			return errHTTPBadRequestNotAPaidUser
		}
		return next(w, r, v)
	})
}

func (s *Server) withAccountSync(next handleFunc) handleFunc {
	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
		err := next(w, r, v)
		if err == nil {
			s.publishSyncEventAsync(v)
		}
		return err
	}
}