refactor(iter): use metatable as packed table tag (#23254)

This is a more robust method for tagging a packed table as it completely
eliminates the possibility of mistaking an actual table key as the
packed table tag.
This commit is contained in:
Gregory Anders 2023-04-21 16:13:39 -06:00 committed by GitHub
parent ef92b5a994
commit f68af3c3bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 15 deletions

View File

@ -18,10 +18,13 @@ ListIter.__call = function(self)
return self:next() return self:next()
end end
--- Packed tables use this as their metatable
local packedmt = {}
---@private ---@private
local function unpack(t) local function unpack(t)
if type(t) == 'table' and t.__n ~= nil then if getmetatable(t) == packedmt then
return _G.unpack(t, 1, t.__n) return _G.unpack(t, 1, t.n)
end end
return t return t
end end
@ -30,11 +33,20 @@ end
local function pack(...) local function pack(...)
local n = select('#', ...) local n = select('#', ...)
if n > 1 then if n > 1 then
return { __n = n, ... } return setmetatable({ n = n, ... }, packedmt)
end end
return ... return ...
end end
---@private
local function sanitize(t)
if getmetatable(t) == packedmt then
-- Remove length tag
t.n = nil
end
return t
end
--- Add a filter step to the iterator pipeline. --- Add a filter step to the iterator pipeline.
--- ---
--- Example: --- Example:
@ -208,12 +220,7 @@ function Iter.totable(self)
break break
end end
if type(args) == 'table' then t[#t + 1] = sanitize(args)
-- Removed packed table tag if it exists
args.__n = nil
end
t[#t + 1] = args
end end
return t return t
end end
@ -221,12 +228,10 @@ end
---@private ---@private
function ListIter.totable(self) function ListIter.totable(self)
if self._head == 1 and self._tail == #self._table + 1 and self.next == ListIter.next then if self._head == 1 and self._tail == #self._table + 1 and self.next == ListIter.next then
-- Remove any packed table tags -- Sanitize packed table values
for i = 1, #self._table do if getmetatable(self._table[1]) == packedmt then
local v = self._table[i] for i = 1, #self._table do
if type(v) == 'table' then self._table[i] = sanitize(self._table[i])
v.__n = nil
self._table[i] = v
end end
end end
return self._table return self._table

View File

@ -3416,6 +3416,41 @@ describe('lua stdlib', function()
{ item_3 = 'test' }, { item_3 = 'test' },
}, output) }, output)
end) end)
it('handles nil values', function()
local t = {1, 2, 3, 4, 5}
do
local it = vim.iter(t):enumerate():map(function(i, v)
if i % 2 == 0 then
return nil, v*v
end
return v, nil
end)
eq({
{ [1] = 1 },
{ [2] = 4 },
{ [1] = 3 },
{ [2] = 16 },
{ [1] = 5 },
}, it:totable())
end
do
local it = vim.iter(ipairs(t)):map(function(i, v)
if i % 2 == 0 then
return nil, v*v
end
return v, nil
end)
eq({
{ [1] = 1 },
{ [2] = 4 },
{ [1] = 3 },
{ [2] = 16 },
{ [1] = 5 },
}, it:totable())
end
end)
end) end)
end) end)