From 6b96122453fda22dc44a581af1d536988c1adf41 Mon Sep 17 00:00:00 2001 From: Gregory Anders Date: Wed, 19 Apr 2023 06:45:56 -0600 Subject: [PATCH] fix(iter): add tag to packed table If pack() is called with a single value, it does not create a table; it simply returns the value it is passed. When unpack is called with a table argument, it interprets that table as a list of values that were packed together into a table. This causes a problem when the single value being packed is _itself_ a table. pack() will not place it into another table, but unpack() sees the table argument and tries to unpack it. To fix this, we add a simple "tag" to packed table values so that unpack() only attempts to unpack tables that have this tag. Other tables are left alone. The tag is simply the length of the table. --- runtime/lua/vim/iter.lua | 25 ++++++++++++++++++++----- test/functional/lua/vim_spec.lua | 27 +++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua index b73c03ba9a..fff7644b6a 100644 --- a/runtime/lua/vim/iter.lua +++ b/runtime/lua/vim/iter.lua @@ -28,16 +28,17 @@ end ---@private local function unpack(t) - if type(t) == 'table' then - return _G.unpack(t) + if type(t) == 'table' and t.__n ~= nil then + return _G.unpack(t, 1, t.__n) end return t end ---@private local function pack(...) - if select('#', ...) > 1 then - return { ... } + local n = select('#', ...) + if n > 1 then + return { __n = n, ... } end return ... end @@ -210,6 +211,12 @@ function Iter.totable(self) if args == nil then break end + + if type(args) == 'table' then + -- Removed packed table tag if it exists + args.__n = nil + end + t[#t + 1] = args end return t @@ -218,6 +225,14 @@ end ---@private function ListIter.totable(self) if self._head == 1 and self._tail == #self._table + 1 and self.next == ListIter.next then + -- Remove any packed table tags + for i = 1, #self._table do + local v = self._table[i] + if type(v) == 'table' then + v.__n = nil + self._table[i] = v + end + end return self._table end @@ -747,7 +762,7 @@ function ListIter.enumerate(self) local inc = self._head < self._tail and 1 or -1 for i = self._head, self._tail - inc, inc do local v = self._table[i] - self._table[i] = { i, v } + self._table[i] = pack(i, v) end return self end diff --git a/test/functional/lua/vim_spec.lua b/test/functional/lua/vim_spec.lua index e5caf6f6f7..07b0f0340a 100644 --- a/test/functional/lua/vim_spec.lua +++ b/test/functional/lua/vim_spec.lua @@ -3381,6 +3381,33 @@ describe('lua stdlib', function() end end) eq({ A = 2, C = 6 }, it:totable()) + + it('handles table values mid-pipeline', function() + local map = { + item = { + file = 'test', + }, + item_2 = { + file = 'test', + }, + item_3 = { + file = 'test', + }, + } + + local output = vim.iter(map):map(function(key, value) + return { [key] = value.file } + end):totable() + + table.sort(output, function(a, b) + return next(a) < next(b) + end) + + eq({ + { item = 'test' }, + { item_2 = 'test' }, + { item_3 = 'test' }, + }, output) end) end) end)