fix(iter): add tag to packed table

If pack() is called with a single value, it does not create a table; it
simply returns the value it is passed. When unpack is called with a
table argument, it interprets that table as a list of values that were
packed together into a table.

This causes a problem when the single value being packed is _itself_ a
table. pack() will not place it into another table, but unpack() sees
the table argument and tries to unpack it.

To fix this, we add a simple "tag" to packed table values so that
unpack() only attempts to unpack tables that have this tag. Other tables
are left alone. The tag is simply the length of the table.
This commit is contained in:
Gregory Anders 2023-04-19 06:45:56 -06:00
parent 0a3645a723
commit 6b96122453
2 changed files with 47 additions and 5 deletions

View File

@ -28,16 +28,17 @@ end
---@private ---@private
local function unpack(t) local function unpack(t)
if type(t) == 'table' then if type(t) == 'table' and t.__n ~= nil then
return _G.unpack(t) return _G.unpack(t, 1, t.__n)
end end
return t return t
end end
---@private ---@private
local function pack(...) local function pack(...)
if select('#', ...) > 1 then local n = select('#', ...)
return { ... } if n > 1 then
return { __n = n, ... }
end end
return ... return ...
end end
@ -210,6 +211,12 @@ function Iter.totable(self)
if args == nil then if args == nil then
break break
end end
if type(args) == 'table' then
-- Removed packed table tag if it exists
args.__n = nil
end
t[#t + 1] = args t[#t + 1] = args
end end
return t return t
@ -218,6 +225,14 @@ end
---@private ---@private
function ListIter.totable(self) function ListIter.totable(self)
if self._head == 1 and self._tail == #self._table + 1 and self.next == ListIter.next then if self._head == 1 and self._tail == #self._table + 1 and self.next == ListIter.next then
-- Remove any packed table tags
for i = 1, #self._table do
local v = self._table[i]
if type(v) == 'table' then
v.__n = nil
self._table[i] = v
end
end
return self._table return self._table
end end
@ -747,7 +762,7 @@ function ListIter.enumerate(self)
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do for i = self._head, self._tail - inc, inc do
local v = self._table[i] local v = self._table[i]
self._table[i] = { i, v } self._table[i] = pack(i, v)
end end
return self return self
end end

View File

@ -3381,6 +3381,33 @@ describe('lua stdlib', function()
end end
end) end)
eq({ A = 2, C = 6 }, it:totable()) eq({ A = 2, C = 6 }, it:totable())
it('handles table values mid-pipeline', function()
local map = {
item = {
file = 'test',
},
item_2 = {
file = 'test',
},
item_3 = {
file = 'test',
},
}
local output = vim.iter(map):map(function(key, value)
return { [key] = value.file }
end):totable()
table.sort(output, function(a, b)
return next(a) < next(b)
end)
eq({
{ item = 'test' },
{ item_2 = 'test' },
{ item_3 = 'test' },
}, output)
end) end)
end) end)
end) end)