From 534544cbf7ac92aef44336cc9da1bfc02a441e6e Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Mon, 18 Nov 2024 17:15:05 +0000 Subject: [PATCH] test: move exec_lua logic to separate module By making it a separate module, the embedded Nvim session can require this module directly instead of setup code sending over the module via RPC. Also make exec_lua wrap _G.print so messages can be seen in the test output immediately as the exec_lua returns. --- test/functional/testnvim.lua | 122 +-------------------- test/functional/testnvim/exec_lua.lua | 148 ++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 121 deletions(-) create mode 100644 test/functional/testnvim/exec_lua.lua diff --git a/test/functional/testnvim.lua b/test/functional/testnvim.lua index 60b2f872fc..43c38d18c0 100644 --- a/test/functional/testnvim.lua +++ b/test/functional/testnvim.lua @@ -800,81 +800,6 @@ function M.exec_capture(code) return M.api.nvim_exec2(code, { output = true }).output end ---- @param f function ---- @return table -local function get_upvalues(f) - local i = 1 - local upvalues = {} --- @type table - while true do - local n, v = debug.getupvalue(f, i) - if not n then - break - end - upvalues[n] = v - i = i + 1 - end - return upvalues -end - ---- @param f function ---- @param upvalues table -local function set_upvalues(f, upvalues) - local i = 1 - while true do - local n = debug.getupvalue(f, i) - if not n then - break - end - if upvalues[n] then - debug.setupvalue(f, i, upvalues[n]) - end - i = i + 1 - end -end - ---- @type fun(f: function): table -_G.__get_upvalues = nil - ---- @type fun(f: function, upvalues: table) -_G.__set_upvalues = nil - ---- @param self table ---- @param bytecode string ---- @param upvalues table ---- @param ... any[] ---- @return any[] result ---- @return table upvalues -local function exec_lua_handler(self, bytecode, upvalues, ...) - local f = assert(loadstring(bytecode)) - self.set_upvalues(f, upvalues) - local ret = { f(...) } --- @type any[] - --- @type table - local new_upvalues = self.get_upvalues(f) - - do -- Check return value types for better error messages - local invalid_types = { - ['thread'] = true, - ['function'] = true, - ['userdata'] = true, - } - - for k, v in pairs(ret) do - if invalid_types[type(v)] then - error( - string.format( - "Return index %d with value '%s' of type '%s' cannot be serialized over RPC", - k, - tostring(v), - type(v) - ) - ) - end - end - end - - return ret, new_upvalues -end - --- Execute Lua code in the wrapped Nvim session. --- --- When `code` is passed as a function, it is converted into Lua byte code. @@ -921,52 +846,7 @@ function M.exec_lua(code, ...) end assert(session, 'no Nvim session') - - if not session.exec_lua_setup then - assert( - session:request( - 'nvim_exec_lua', - [[ - _G.__test_exec_lua = { - get_upvalues = loadstring((select(1,...))), - set_upvalues = loadstring((select(2,...))), - handler = loadstring((select(3,...))) - } - setmetatable(_G.__test_exec_lua, { __index = _G.__test_exec_lua }) - ]], - { string.dump(get_upvalues), string.dump(set_upvalues), string.dump(exec_lua_handler) } - ) - ) - session.exec_lua_setup = true - end - - local stat, rv = session:request( - 'nvim_exec_lua', - 'return { _G.__test_exec_lua:handler(...) }', - { string.dump(code), get_upvalues(code), ... } - ) - - if not stat then - error(rv[2]) - end - - --- @type any[], table - local ret, upvalues = unpack(rv) - - -- Update upvalues - if next(upvalues) then - local caller = debug.getinfo(2) - local f = caller.func - -- On PUC-Lua, if the function is a tail call, then func will be nil. - -- In this case we need to use the current function. - if not f then - assert(caller.source == '=(tail call)') - f = debug.getinfo(1).func - end - set_upvalues(f, upvalues) - end - - return unpack(ret, 1, table.maxn(ret)) + return require('test.functional.testnvim.exec_lua')(session, 2, code, ...) end function M.get_pathsep() diff --git a/test/functional/testnvim/exec_lua.lua b/test/functional/testnvim/exec_lua.lua new file mode 100644 index 0000000000..ddd9905ce7 --- /dev/null +++ b/test/functional/testnvim/exec_lua.lua @@ -0,0 +1,148 @@ +--- @param f function +--- @return table +local function get_upvalues(f) + local i = 1 + local upvalues = {} --- @type table + while true do + local n, v = debug.getupvalue(f, i) + if not n then + break + end + upvalues[n] = v + i = i + 1 + end + return upvalues +end + +--- @param f function +--- @param upvalues table +local function set_upvalues(f, upvalues) + local i = 1 + while true do + local n = debug.getupvalue(f, i) + if not n then + break + end + if upvalues[n] then + debug.setupvalue(f, i, upvalues[n]) + end + i = i + 1 + end +end + +--- @param messages string[] +--- @param ... ... +local function add_print(messages, ...) + local msg = {} --- @type string[] + for i = 1, select('#', ...) do + msg[#msg + 1] = tostring(select(i, ...)) + end + table.insert(messages, table.concat(msg, '\t')) +end + +local invalid_types = { + ['thread'] = true, + ['function'] = true, + ['userdata'] = true, +} + +--- @param r any[] +local function check_returns(r) + for k, v in pairs(r) do + if invalid_types[type(v)] then + error( + string.format( + "Return index %d with value '%s' of type '%s' cannot be serialized over RPC", + k, + tostring(v), + type(v) + ), + 2 + ) + end + end +end + +local M = {} + +--- This is run in the context of the remote Nvim instance. +--- @param bytecode string +--- @param upvalues table +--- @param ... any[] +--- @return any[] result +--- @return table upvalues +--- @return string[] messages +function M.handler(bytecode, upvalues, ...) + local messages = {} --- @type string[] + local orig_print = _G.print + + function _G.print(...) + add_print(messages, ...) + return orig_print(...) + end + + local f = assert(loadstring(bytecode)) + + set_upvalues(f, upvalues) + + -- Run in pcall so we can return any print messages + local ret = { pcall(f, ...) } --- @type any[] + + _G.print = orig_print + + local new_upvalues = get_upvalues(f) + + -- Check return value types for better error messages + check_returns(ret) + + return ret, new_upvalues, messages +end + +--- @param session test.Session +--- @param lvl integer +--- @param code function +--- @param ... ... +local function run(session, lvl, code, ...) + local stat, rv = session:request( + 'nvim_exec_lua', + [[return { require('test.functional.testnvim.exec_lua').handler(...) }]], + { string.dump(code), get_upvalues(code), ... } + ) + + if not stat then + error(rv[2], 2) + end + + --- @type any[], table, string[] + local ret, upvalues, messages = unpack(rv) + + for _, m in ipairs(messages) do + print(m) + end + + if not ret[1] then + error(ret[2], 2) + end + + -- Update upvalues + if next(upvalues) then + local caller = debug.getinfo(lvl) + local i = 0 + + -- On PUC-Lua, if the function is a tail call, then func will be nil. + -- In this case we need to use the caller. + while not caller.func do + i = i + 1 + caller = debug.getinfo(lvl + i) + end + set_upvalues(caller.func, upvalues) + end + + return unpack(ret, 2, table.maxn(ret)) +end + +return setmetatable(M, { + __call = function(_, ...) + return run(...) + end, +})