File: server.go

package info (click to toggle)
golang-github-henrybear327-go-proton-api 1.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,088 kB
  • sloc: sh: 55; makefile: 26
file content (242 lines) | stat: -rw-r--r-- 6,607 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
package server

import (
	"net/http"
	"net/http/httptest"
	"sync"
	"time"

	"github.com/Masterminds/semver/v3"
	"github.com/bradenaw/juniper/xslices"
	"github.com/gin-gonic/gin"
	"github.com/henrybear327/go-proton-api"
	"github.com/henrybear327/go-proton-api/server/backend"
)

type AuthCacher interface {
	GetAuthInfo(username string) (proton.AuthInfo, bool)
	SetAuthInfo(username string, info proton.AuthInfo)
	GetAuth(username string) (proton.Auth, bool)
	SetAuth(username string, auth proton.Auth)
}

// StatusHook is a function that can be used to modify the response code of a call.
type StatusHook func(*http.Request) (int, bool)

type Server struct {
	// r is the gin router.
	r *gin.Engine

	// s is the underlying server.
	s *httptest.Server

	// b is the server backend, which manages accounts, messages, attachments, etc.
	b *backend.Backend

	// callWatchers records callWatchers received by the server.
	callWatchers     []callWatcher
	callWatchersLock sync.RWMutex

	// statusHooks are hooks that can be used to modify the response code of a call.
	statusHooks     []StatusHook
	statusHooksLock sync.RWMutex

	// domain is the test server domain.
	domain string

	// minAppVersion is the minimum app version that the server will accept.
	minAppVersion *semver.Version

	// proxyOrigin is the URL of the origin server when the server is a proxy.
	proxyOrigin string

	// proxyTransport is the transport to use when the server is a proxy.
	proxyTransport *http.Transport

	// authCacher can optionally be set to cache proxied auth calls.
	authCacher AuthCacher

	// offline is whether to pretend the server is offline and return 5xx errors.
	offline bool

	// rateLimit is the rate limiter for the server.
	rateLimit *rateLimiter
}

func New(opts ...Option) *Server {
	builder := newServerBuilder()

	for _, opt := range opts {
		opt.config(builder)
	}

	return builder.build()
}

// GetHostURL returns the API root to make calls to.
func (s *Server) GetHostURL() string {
	return s.s.URL
}

// GetProxyURL returns the API root to make calls to which should be proxied.
func (s *Server) GetProxyURL() string {
	return s.s.URL + "/proxy"
}

// GetDomain returns the domain of the server (e.g. "proton.local").
func (s *Server) GetDomain() string {
	return s.domain
}

// AddCallWatcher adds a call watcher to the server.
func (s *Server) AddCallWatcher(fn func(Call), paths ...string) {
	s.callWatchersLock.Lock()
	defer s.callWatchersLock.Unlock()

	s.callWatchers = append(s.callWatchers, newCallWatcher(fn, paths...))
}

// AddStatusHook adds a status hook to the server.
func (s *Server) AddStatusHook(fn StatusHook) {
	s.statusHooksLock.Lock()
	defer s.statusHooksLock.Unlock()

	s.statusHooks = append(s.statusHooks, fn)
}

// CreateUser creates a new server user with the given username and password.
// A single address will be created for the user, derived from the username and the server's domain.
func (s *Server) CreateUser(username string, password []byte) (string, string, error) {
	userID, err := s.b.CreateUser(username, password)
	if err != nil {
		return "", "", err
	}

	addrID, err := s.b.CreateAddress(userID, username+"@"+s.domain, password, true, proton.AddressStatusEnabled, proton.AddressTypeOriginal)
	if err != nil {
		return "", "", err
	}

	return userID, addrID, nil
}

func (s *Server) RemoveUser(userID string) error {
	return s.b.RemoveUser(userID)
}

func (s *Server) RefreshUser(userID string, refresh proton.RefreshFlag) error {
	return s.b.RefreshUser(userID, refresh)
}

func (s *Server) GetUserKeyIDs(userID string) ([]string, error) {
	user, err := s.b.GetUser(userID)
	if err != nil {
		return nil, err
	}

	return xslices.Map(user.Keys, func(key proton.Key) string {
		return key.ID
	}), nil
}

func (s *Server) CreateUserKey(userID string, password []byte) error {
	return s.b.CreateUserKey(userID, password)
}

func (s *Server) RemoveUserKey(userID, keyID string) error {
	return s.b.RemoveUserKey(userID, keyID)
}

func (s *Server) CreateAddress(userID, email string, password []byte) (string, error) {
	return s.b.CreateAddress(userID, email, password, true, proton.AddressStatusEnabled, proton.AddressTypeOriginal)
}

func (s *Server) CreateAddressAsUpdate(userID, email string, password []byte) (string, error) {
	return s.b.CreateAddressAsUpdate(userID, email, password, true, proton.AddressStatusEnabled, proton.AddressTypeOriginal)
}

func (s *Server) ChangeAddressType(userID, addrId string, addrType proton.AddressType) error {
	return s.b.ChangeAddressType(userID, addrId, addrType)
}

func (s *Server) RemoveAddress(userID, addrID string) error {
	return s.b.RemoveAddress(userID, addrID)
}

func (s *Server) CreateAddressKey(userID, addrID string, password []byte) error {
	return s.b.CreateAddressKey(userID, addrID, password)
}

func (s *Server) RemoveAddressKey(userID, addrID, keyID string) error {
	return s.b.RemoveAddressKey(userID, addrID, keyID)
}

func (s *Server) CreateLabel(userID, name, parentID string, labelType proton.LabelType) (string, error) {
	label, err := s.b.CreateLabel(userID, name, parentID, labelType)
	if err != nil {
		return "", err
	}

	return label.ID, nil
}

func (s *Server) GetLabels(userID string) ([]proton.Label, error) {
	return s.b.GetLabels(userID)
}

func (s *Server) LabelMessage(userID, msgID, labelID string) error {
	return s.b.LabelMessages(userID, labelID, msgID)
}

func (s *Server) UnlabelMessage(userID, msgID, labelID string) error {
	return s.b.UnlabelMessages(userID, labelID, msgID)
}

func (s *Server) AddAddressCreatedEvent(userID, addrID string) error {
	return s.b.AddAddressCreatedUpdate(userID, addrID)
}

func (s *Server) AddLabelCreatedEvent(userID, labelID string) error {
	return s.b.AddLabelCreatedUpdate(userID, labelID)
}

func (s *Server) AddMessageCreatedEvent(userID, messageID string) error {
	return s.b.AddMessageCreatedUpdate(userID, messageID)
}

// SetMaxUpdatesPerEvent
func (s *Server) SetMaxUpdatesPerEvent(max int) {
	s.b.SetMaxUpdatesPerEvent(max)
}

func (s *Server) SetAuthLife(authLife time.Duration) {
	s.b.SetAuthLife(authLife)
}

func (s *Server) SetMinAppVersion(minAppVersion *semver.Version) {
	s.minAppVersion = minAppVersion
}

func (s *Server) SetOffline(offline bool) {
	s.offline = offline
}

func (s *Server) RevokeUser(userID string) error {
	sessions, err := s.b.GetSessions(userID)
	if err != nil {
		return err
	}

	for _, session := range sessions {
		if err := s.b.DeleteSession(userID, session.UID); err != nil {
			return err
		}
	}

	return nil
}

func (s *Server) Close() {
	s.proxyTransport.CloseIdleConnections()
	s.s.Close()
}