1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
|
--- @param f function
--- @return table<string,any>
local function get_upvalues(f)
local i = 1
local upvalues = {} --- @type table<string,any>
while true do
local n, v = debug.getupvalue(f, i)
if not n then
break
end
upvalues[n] = v
i = i + 1
end
return upvalues
end
--- @param f function
--- @param upvalues table<string,any>
local function set_upvalues(f, upvalues)
local i = 1
while true do
local n = debug.getupvalue(f, i)
if not n then
break
end
if upvalues[n] then
debug.setupvalue(f, i, upvalues[n])
end
i = i + 1
end
end
--- @param messages string[]
--- @param ... ...
local function add_print(messages, ...)
local msg = {} --- @type string[]
for i = 1, select('#', ...) do
msg[#msg + 1] = tostring(select(i, ...))
end
table.insert(messages, table.concat(msg, '\t'))
end
local invalid_types = {
['thread'] = true,
['function'] = true,
['userdata'] = true,
}
--- @param r any[]
local function check_returns(r)
for k, v in pairs(r) do
if invalid_types[type(v)] then
error(
string.format(
"Return index %d with value '%s' of type '%s' cannot be serialized over RPC",
k,
tostring(v),
type(v)
),
2
)
end
end
end
local M = {}
--- This is run in the context of the remote Nvim instance.
--- @param bytecode string
--- @param upvalues table<string,any>
--- @param ... any[]
--- @return any[] result
--- @return table<string,any> upvalues
--- @return string[] messages
function M.handler(bytecode, upvalues, ...)
local messages = {} --- @type string[]
local orig_print = _G.print
function _G.print(...)
add_print(messages, ...)
return orig_print(...)
end
local f = assert(loadstring(bytecode))
set_upvalues(f, upvalues)
-- Run in pcall so we can return any print messages
local ret = { pcall(f, ...) } --- @type any[]
_G.print = orig_print
local new_upvalues = get_upvalues(f)
-- Check return value types for better error messages
check_returns(ret)
return ret, new_upvalues, messages
end
--- @param session test.Session
--- @param lvl integer
--- @param code function
--- @param ... ...
local function run(session, lvl, code, ...)
local stat, rv = session:request(
'nvim_exec_lua',
[[return { require('test.functional.testnvim.exec_lua').handler(...) }]],
{ string.dump(code), get_upvalues(code), ... }
)
if not stat then
error(rv[2], 2)
end
--- @type any[], table<string,any>, string[]
local ret, upvalues, messages = unpack(rv)
for _, m in ipairs(messages) do
print(m)
end
if not ret[1] then
error(ret[2], 2)
end
-- Update upvalues
if next(upvalues) then
local caller = debug.getinfo(lvl)
local i = 0
-- On PUC-Lua, if the function is a tail call, then func will be nil.
-- In this case we need to use the caller.
while not caller.func do
i = i + 1
caller = debug.getinfo(lvl + i)
end
set_upvalues(caller.func, upvalues)
end
return unpack(ret, 2, table.maxn(ret))
end
return setmetatable(M, {
__call = function(_, ...)
return run(...)
end,
})
|