1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
#include "LuaPromise.h"
#include "globalincs/pstypes.h"
#include "scripting/lua/LuaValue.h"
namespace scripting {
namespace api {
enum class State {
Invalid,
Pending,
Resolved,
Errored,
};
resolve_context::~resolve_context() = default;
class continuation_resolve_context : public resolve_context {
public:
continuation_resolve_context(bool wantErrors, LuaPromise::ContinuationFunction continuation)
: _wantErrors(wantErrors), _continuation(std::move(continuation))
{
}
~continuation_resolve_context() override = default;
void setResolver(Resolver resolver) override { _resolver = std::move(resolver); }
void resolve(bool error, const luacpp::LuaValueList& resolveVals)
{
Assertion(_resolver, "Promise resolved without a resolver! Probably called twice...");
// If the error value is what we want then we need to call our continuation. Otherwise the value just passes
// through this instance without the continuation being called
if (error == _wantErrors) {
// The next value is never an error since either we are in the "then" case or we are "catching" the
// exception and so the return value is no longer an error.
_resolver(false, _continuation(resolveVals));
} else {
_resolver(error, resolveVals);
}
// Not needed anymore, might as well clean up references
_resolver = nullptr;
_continuation = nullptr;
}
private:
bool _wantErrors = false;
LuaPromise::ContinuationFunction _continuation;
Resolver _resolver;
};
struct LuaPromise::internal_state : std::enable_shared_from_this<LuaPromise::internal_state> {
State state = State::Invalid;
luacpp::LuaValueList resolvedValue;
SCP_vector<std::shared_ptr<continuation_resolve_context>> continuationContexts;
void registerResolveCallback(const std::shared_ptr<resolve_context>& resolveCtx)
{
// Weak pointer here since if our pointer is cleaned up that means no one is interested in our promise anymore.
auto self = shared_from_this();
resolveCtx->setResolver([this, self](bool error, const luacpp::LuaValueList& resolveVals) mutable {
Assertion(self != nullptr, "Resolver called multiple times!");
resolveCb(error, resolveVals);
// Remove the reference to ourself in case the resolver does not dispose of this function object
self = nullptr;
});
}
void resolveCb(bool error, const luacpp::LuaValueList& resolveVals)
{
resolvedValue = resolveVals;
state = error ? State::Errored : State::Resolved;
// Call everyone who has registered on our coroutine so that those promises also resolve
for (const auto& cont : continuationContexts) {
cont->resolve(error, resolveVals);
}
// This will only be needed once so we can clear out the references now
continuationContexts.clear();
}
};
LuaPromise::LuaPromise() : m_state(std::make_shared<LuaPromise::internal_state>()) {}
LuaPromise::LuaPromise(const std::shared_ptr<resolve_context>& resolveContext) : LuaPromise()
{
m_state->state = State::Pending;
// This executes promises eagerly since registering the callback kicks off the async operation
m_state->registerResolveCallback(resolveContext);
}
LuaPromise::LuaPromise(const LuaPromise&) = default;
LuaPromise& LuaPromise::operator=(const LuaPromise&) = default;
LuaPromise::LuaPromise(LuaPromise&&) noexcept = default;
LuaPromise& LuaPromise::operator=(LuaPromise&&) noexcept = default;
LuaPromise LuaPromise::then(LuaPromise::ContinuationFunction continuation)
{
// NOT THREAD SAFE!
if (m_state->state == State::Invalid) {
return LuaPromise();
}
// The easy way
if (m_state->state == State::Resolved) {
return LuaPromise::resolved(continuation(m_state->resolvedValue));
}
// If the promise is already in an error state then the continuation doesn't matter
if (m_state->state == State::Errored) {
return LuaPromise::errored(m_state->resolvedValue);
}
return registerContinuation(false, std::move(continuation));
}
LuaPromise LuaPromise::catchError(LuaPromise::ContinuationFunction continuation)
{
// NOT THREAD SAFE!
if (m_state->state == State::Invalid) {
return LuaPromise();
}
// If we want to catch errors then the value from this resolved promise just passes through
if (m_state->state == State::Resolved) {
return LuaPromise::resolved(m_state->resolvedValue);
}
// We actually want this value
if (m_state->state == State::Errored) {
return LuaPromise::errored(continuation(m_state->resolvedValue));
}
return registerContinuation(true, std::move(continuation));
}
bool LuaPromise::isValid() const { return m_state->state != State::Invalid; }
bool LuaPromise::isResolved() const { return m_state->state == State::Resolved; }
bool LuaPromise::isErrored() const { return m_state->state == State::Errored; }
const luacpp::LuaValueList& LuaPromise::resolveValue() const
{
Assertion(isResolved(), "Tried to get value from unresolved promise!");
return m_state->resolvedValue;
}
const luacpp::LuaValueList& LuaPromise::errorValue() const
{
Assertion(isErrored(), "Tried to get error value from unresolved promise!");
return m_state->resolvedValue;
}
LuaPromise LuaPromise::resolved(luacpp::LuaValueList resolveValue)
{
LuaPromise p;
p.m_state->state = State::Resolved;
p.m_state->resolvedValue = std::move(resolveValue);
return p;
}
LuaPromise LuaPromise::errored(luacpp::LuaValueList resolveValue)
{
LuaPromise p;
p.m_state->state = State::Errored;
p.m_state->resolvedValue = std::move(resolveValue);
return p;
}
LuaPromise LuaPromise::registerContinuation(bool catchErrors, LuaPromise::ContinuationFunction continuation)
{
// This is the connection between us and the resulting promise. We need the reference to resolve the promise
// and the returned promise uses it to know when to resolve
auto continuationContext = std::make_shared<continuation_resolve_context>(catchErrors, std::move(continuation));
m_state->continuationContexts.push_back(continuationContext);
return LuaPromise(continuationContext);
}
LuaPromise::~LuaPromise() = default;
} // namespace api
} // namespace scripting
|