File: LuaPromise.cpp

package info (click to toggle)
freespace2 24.0.2%2Brepack-1
  • links: PTS, VCS
  • area: non-free
  • in suites: trixie
  • size: 43,188 kB
  • sloc: cpp: 583,107; ansic: 21,729; python: 1,174; sh: 464; makefile: 248; xml: 181
file content (188 lines) | stat: -rw-r--r-- 6,051 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#include "LuaPromise.h"

#include "globalincs/pstypes.h"

#include "scripting/lua/LuaValue.h"

namespace scripting {
namespace api {

enum class State {
	Invalid,
	Pending,
	Resolved,
	Errored,
};

resolve_context::~resolve_context() = default;

class continuation_resolve_context : public resolve_context {
  public:
	continuation_resolve_context(bool wantErrors, LuaPromise::ContinuationFunction continuation)
		: _wantErrors(wantErrors), _continuation(std::move(continuation))
	{
	}
	~continuation_resolve_context() override = default;

	void setResolver(Resolver resolver) override { _resolver = std::move(resolver); }

	void resolve(bool error, const luacpp::LuaValueList& resolveVals)
	{
		Assertion(_resolver, "Promise resolved without a resolver! Probably called twice...");

		// If the error value is what we want then we need to call our continuation. Otherwise the value just passes
		// through this instance without the continuation being called
		if (error == _wantErrors) {
			// The next value is never an error since either we are in the "then" case or we are "catching" the
			// exception and so the return value is no longer an error.
			_resolver(false, _continuation(resolveVals));
		} else {
			_resolver(error, resolveVals);
		}

		// Not needed anymore, might as well clean up references
		_resolver     = nullptr;
		_continuation = nullptr;
	}

  private:
	bool _wantErrors = false;
	LuaPromise::ContinuationFunction _continuation;
	Resolver _resolver;
};

struct LuaPromise::internal_state : std::enable_shared_from_this<LuaPromise::internal_state> {
	State state = State::Invalid;

	luacpp::LuaValueList resolvedValue;

	SCP_vector<std::shared_ptr<continuation_resolve_context>> continuationContexts;

	void registerResolveCallback(const std::shared_ptr<resolve_context>& resolveCtx)
	{
		// Weak pointer here since if our pointer is cleaned up that means no one is interested in our promise anymore.
		auto self = shared_from_this();
		resolveCtx->setResolver([this, self](bool error, const luacpp::LuaValueList& resolveVals) mutable {
			Assertion(self != nullptr, "Resolver called multiple times!");

			resolveCb(error, resolveVals);

			// Remove the reference to ourself in case the resolver does not dispose of this function object
			self = nullptr;
		});
	}

	void resolveCb(bool error, const luacpp::LuaValueList& resolveVals)
	{
		resolvedValue = resolveVals;
		state         = error ? State::Errored : State::Resolved;

		// Call everyone who has registered on our coroutine so that those promises also resolve
		for (const auto& cont : continuationContexts) {
			cont->resolve(error, resolveVals);
		}
		// This will only be needed once so we can clear out the references now
		continuationContexts.clear();
	}
};

LuaPromise::LuaPromise() : m_state(std::make_shared<LuaPromise::internal_state>()) {}

LuaPromise::LuaPromise(const std::shared_ptr<resolve_context>& resolveContext) : LuaPromise()
{
	m_state->state = State::Pending;

	// This executes promises eagerly since registering the callback kicks off the async operation
	m_state->registerResolveCallback(resolveContext);
}

LuaPromise::LuaPromise(const LuaPromise&) = default;
LuaPromise& LuaPromise::operator=(const LuaPromise&) = default;

LuaPromise::LuaPromise(LuaPromise&&) noexcept = default;
LuaPromise& LuaPromise::operator=(LuaPromise&&) noexcept = default;

LuaPromise LuaPromise::then(LuaPromise::ContinuationFunction continuation)
{
	// NOT THREAD SAFE!
	if (m_state->state == State::Invalid) {
		return LuaPromise();
	}

	// The easy way
	if (m_state->state == State::Resolved) {
		return LuaPromise::resolved(continuation(m_state->resolvedValue));
	}

	// If the promise is already in an error state then the continuation doesn't matter
	if (m_state->state == State::Errored) {
		return LuaPromise::errored(m_state->resolvedValue);
	}

	return registerContinuation(false, std::move(continuation));
}

LuaPromise LuaPromise::catchError(LuaPromise::ContinuationFunction continuation)
{
	// NOT THREAD SAFE!
	if (m_state->state == State::Invalid) {
		return LuaPromise();
	}

	// If we want to catch errors then the value from this resolved promise just passes through
	if (m_state->state == State::Resolved) {
		return LuaPromise::resolved(m_state->resolvedValue);
	}

	// We actually want this value
	if (m_state->state == State::Errored) {
		return LuaPromise::errored(continuation(m_state->resolvedValue));
	}

	return registerContinuation(true, std::move(continuation));
}

bool LuaPromise::isValid() const { return m_state->state != State::Invalid; }
bool LuaPromise::isResolved() const { return m_state->state == State::Resolved; }
bool LuaPromise::isErrored() const { return m_state->state == State::Errored; }

const luacpp::LuaValueList& LuaPromise::resolveValue() const
{
	Assertion(isResolved(), "Tried to get value from unresolved promise!");

	return m_state->resolvedValue;
}
const luacpp::LuaValueList& LuaPromise::errorValue() const
{
	Assertion(isErrored(), "Tried to get error value from unresolved promise!");

	return m_state->resolvedValue;
}
LuaPromise LuaPromise::resolved(luacpp::LuaValueList resolveValue)
{
	LuaPromise p;
	p.m_state->state         = State::Resolved;
	p.m_state->resolvedValue = std::move(resolveValue);
	return p;
}

LuaPromise LuaPromise::errored(luacpp::LuaValueList resolveValue)
{
	LuaPromise p;
	p.m_state->state         = State::Errored;
	p.m_state->resolvedValue = std::move(resolveValue);
	return p;
}
LuaPromise LuaPromise::registerContinuation(bool catchErrors, LuaPromise::ContinuationFunction continuation)
{
	// This is the connection between us and the resulting promise. We need the reference to resolve the promise
	// and the returned promise uses it to know when to resolve
	auto continuationContext = std::make_shared<continuation_resolve_context>(catchErrors, std::move(continuation));
	m_state->continuationContexts.push_back(continuationContext);

	return LuaPromise(continuationContext);
}
LuaPromise::~LuaPromise() = default;

} // namespace api
} // namespace scripting