diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua index c5d5ef835b..2545853b41 100644 --- a/runtime/lua/vim/iter.lua +++ b/runtime/lua/vim/iter.lua @@ -18,10 +18,13 @@ ListIter.__call = function(self) return self:next() end +--- Packed tables use this as their metatable +local packedmt = {} + ---@private local function unpack(t) - if type(t) == 'table' and t.__n ~= nil then - return _G.unpack(t, 1, t.__n) + if getmetatable(t) == packedmt then + return _G.unpack(t, 1, t.n) end return t end @@ -30,11 +33,20 @@ end local function pack(...) local n = select('#', ...) if n > 1 then - return { __n = n, ... } + return setmetatable({ n = n, ... }, packedmt) end return ... end +---@private +local function sanitize(t) + if getmetatable(t) == packedmt then + -- Remove length tag + t.n = nil + end + return t +end + --- Add a filter step to the iterator pipeline. --- --- Example: @@ -208,12 +220,7 @@ function Iter.totable(self) break end - if type(args) == 'table' then - -- Removed packed table tag if it exists - args.__n = nil - end - - t[#t + 1] = args + t[#t + 1] = sanitize(args) end return t end @@ -221,12 +228,10 @@ end ---@private function ListIter.totable(self) if self._head == 1 and self._tail == #self._table + 1 and self.next == ListIter.next then - -- Remove any packed table tags - for i = 1, #self._table do - local v = self._table[i] - if type(v) == 'table' then - v.__n = nil - self._table[i] = v + -- Sanitize packed table values + if getmetatable(self._table[1]) == packedmt then + for i = 1, #self._table do + self._table[i] = sanitize(self._table[i]) end end return self._table diff --git a/test/functional/lua/vim_spec.lua b/test/functional/lua/vim_spec.lua index b8cc15b2ca..42927f7e1b 100644 --- a/test/functional/lua/vim_spec.lua +++ b/test/functional/lua/vim_spec.lua @@ -3416,6 +3416,41 @@ describe('lua stdlib', function() { item_3 = 'test' }, }, output) end) + + it('handles nil values', function() + local t = {1, 2, 3, 4, 5} + do + local it = vim.iter(t):enumerate():map(function(i, v) + if i % 2 == 0 then + return nil, v*v + end + return v, nil + end) + eq({ + { [1] = 1 }, + { [2] = 4 }, + { [1] = 3 }, + { [2] = 16 }, + { [1] = 5 }, + }, it:totable()) + end + + do + local it = vim.iter(ipairs(t)):map(function(i, v) + if i % 2 == 0 then + return nil, v*v + end + return v, nil + end) + eq({ + { [1] = 1 }, + { [2] = 4 }, + { [1] = 3 }, + { [2] = 16 }, + { [1] = 5 }, + }, it:totable()) + end + end) end) end)