File: LuaCoroutineRunner.cpp

package info (click to toggle)
freespace2 24.0.2%2Brepack-1
  • links: PTS, VCS
  • area: non-free
  • in suites: trixie
  • size: 43,188 kB
  • sloc: cpp: 583,107; ansic: 21,729; python: 1,174; sh: 464; makefile: 248; xml: 181
file content (154 lines) | stat: -rw-r--r-- 5,041 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include "LuaCoroutineRunner.h"

#include "scripting/ade_args.h"
#include "scripting/api/objs/promise.h"

namespace scripting {
namespace api {

namespace {

/**
 * @brief A run context which resumes a coroutine until it is finished
 *
 * This will manage the passed coroutine and resume it until it is completed. For every yielded promise it registers
 * itself as the continuation which then resumes the coroutine.
 */
class run_resolve_context : public resolve_context, public std::enable_shared_from_this<run_resolve_context> {
  public:
	run_resolve_context(luacpp::LuaThread coroutine,
		std::shared_ptr<executor::Executor> executor,
		std::shared_ptr<executor::IExecutionContext> executionContext)
		: _coroutine(std::move(coroutine)), _executor(std::move(executor)),
		  _executionContext(std::move(executionContext))
	{
	}

	void setResolver(Resolver resolver) override
	{
		m_resolver = std::move(resolver);

		// Kick off the coroutine once we know that someone cares about its result
		scheduleResume(luacpp::LuaValueList());
	}

  private:
	void postToExecutor(executor::Executor::Callback cb)
	{
		if (_executor.get() == executor::currentExecutor()) {
			// We are already in the right executor so we can invoke the callback directly
			const auto ret = cb();
			// It is possible that we get a reschedule here if the execution state is currently suspended
			if (ret == executor::Executor::CallbackResult::Reschedule) {
				_executor->post(std::move(cb));
			}
		} else {
			_executor->post(std::move(cb));
		}
	}

	void scheduleResume(const luacpp::LuaValueList& resumeParams)
	{
		if (!_executor) {
			// If we have no executor we just execute the resume directly
			resumeCoroutine(resumeParams);
			return;
		}

		auto self = shared_from_this();
		if (_executionContext) {
			// If we have a context, wrap our resumer in that so that we only execute in our context
			postToExecutor(executor::runInContext(_executionContext,
				[this, self, resumeParams](executor::IExecutionContext::State state) {
					if (state == executor::IExecutionContext::State::Invalid) {
						// State became invalid while waiting
						auto errorMessage = luacpp::LuaValue::createValue(_coroutine.getLuaState(),
							"Coroutine context became invalid.");
						m_resolver(true, {errorMessage});
						m_resolver = nullptr;
						return executor::Executor::CallbackResult::Done;
					}

					resumeCoroutine(resumeParams);
					// We only need to run this once
					return executor::Executor::CallbackResult::Done;
				}));
			return;
		}

		postToExecutor([this, self, resumeParams]() {
			resumeCoroutine(resumeParams);
			// We only need to run this once
			return executor::Executor::CallbackResult::Done;
		});
	}

	void resumeCoroutine(const luacpp::LuaValueList& resumeParams)
	{
		const auto result = _coroutine.resume(resumeParams);

		if (result.completed) {
			// Thread is finished! We can call our resolver
			m_resolver(false, result.returnVals);

			// Clean up reference since we do not need this anymore
			m_resolver = nullptr;
			return;
		}

		// The coroutine suspended so the return value must be a promise
		Assertion(result.returnVals.size() == 1,
			"Wrong number of yielded values. Should be 1 but is " SIZE_T_ARG,
			result.returnVals.size());

		auto promiseStackStart = lua_gettop(_coroutine.getLuaState());
		result.returnVals.front().pushValue(_coroutine.getLuaState());

		internal::Ade_get_args_skip      = promiseStackStart;
		internal::Ade_get_args_lfunction = true;
		LuaPromise* promise              = nullptr;
		if (!ade_get_args(_coroutine.getLuaState(), "o", l_Promise.GetPtr(&promise))) {
			LuaError(_coroutine.getLuaState(),
				"Failed to get promise after coroutine yielded. Make sure you only use async.await in async "
				"coroutines.");
			return;
		}

		lua_settop(_coroutine.getLuaState(), promiseStackStart);

		// Register ourself to be called when the promise resolves so that we can resume our coroutine
		auto self = shared_from_this();
		promise->then([this, self](const luacpp::LuaValueList& resolveVals) {
			// Since "self" keeps "this" alive it is safe to access that here
			scheduleResume(resolveVals);

			return luacpp::LuaValueList();
		});
		promise->catchError([this, self](const luacpp::LuaValueList& resolveVals) {
			// If the awaited coroutine causes an error we stop the coroutine and propagate that error to the
			// promise
			m_resolver(true, resolveVals);
			m_resolver = nullptr;

			return luacpp::LuaValueList();
		});
	}

	Resolver m_resolver;
	luacpp::LuaThread _coroutine;
	std::shared_ptr<executor::Executor> _executor;
	std::shared_ptr<executor::IExecutionContext> _executionContext;
};

} // namespace

LuaPromise runAsyncCoroutine(luacpp::LuaThread luaThread,
	std::shared_ptr<executor::Executor> executor,
	std::shared_ptr<executor::IExecutionContext> executionContext)
{
	return LuaPromise(
		std::make_shared<run_resolve_context>(std::move(luaThread), std::move(executor), std::move(executionContext)));
}

} // namespace api
} // namespace scripting