diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index 5c1cc06908..735627d29f 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -5,35 +5,20 @@ local Range = require('vim.treesitter._range') local api = vim.api ---@class TS.FoldInfo ----@field levels table ----@field levels0 table ----@field private start_counts table ----@field private stop_counts table +---@field levels string[] the foldexpr result for each line +---@field levels0 integer[] the raw fold levels +---@field edits? {[1]: integer, [2]: integer} line range edited since the last invocation of the callback scheduled in on_bytes. 0-indexed, end-exclusive. local FoldInfo = {} FoldInfo.__index = FoldInfo ---@private function FoldInfo.new() return setmetatable({ - start_counts = {}, - stop_counts = {}, levels0 = {}, levels = {}, }, FoldInfo) end ----@package ----@param srow integer ----@param erow integer -function FoldInfo:invalidate_range(srow, erow) - for i = srow, erow do - self.start_counts[i + 1] = nil - self.stop_counts[i + 1] = nil - self.levels0[i + 1] = nil - self.levels[i + 1] = nil - end -end - --- Efficiently remove items from middle of a list a list. --- --- Calling table.remove() in a loop will re-index the tail of the table on @@ -55,12 +40,10 @@ end ---@package ---@param srow integer ----@param erow integer +---@param erow integer 0-indexed, exclusive function FoldInfo:remove_range(srow, erow) list_remove(self.levels, srow + 1, erow) list_remove(self.levels0, srow + 1, erow) - list_remove(self.start_counts, srow + 1, erow) - list_remove(self.stop_counts, srow + 1, erow) end --- Efficiently insert items into the middle of a list. @@ -91,46 +74,37 @@ end ---@package ---@param srow integer ----@param erow integer +---@param erow integer 0-indexed, exclusive function FoldInfo:add_range(srow, erow) - list_insert(self.levels, srow + 1, erow, '-1') + list_insert(self.levels, srow + 1, erow, '=') list_insert(self.levels0, srow + 1, erow, -1) - list_insert(self.start_counts, srow + 1, erow, nil) - list_insert(self.stop_counts, srow + 1, erow, nil) end ---@package ----@param lnum integer -function FoldInfo:add_start(lnum) - self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1 -end - ----@package ----@param lnum integer -function FoldInfo:add_stop(lnum) - self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1 -end - ----@package ----@param lnum integer ----@return integer -function FoldInfo:get_start(lnum) - return self.start_counts[lnum] or 0 -end - ----@package ----@param lnum integer ----@return integer -function FoldInfo:get_stop(lnum) - return self.stop_counts[lnum] or 0 -end - -local function trim_level(level) - local max_fold_level = vim.wo.foldnestmax - if level > max_fold_level then - return max_fold_level +---@param srow integer +---@param erow_old integer +---@param erow_new integer 0-indexed, exclusive +function FoldInfo:edit_range(srow, erow_old, erow_new) + if self.edits then + self.edits[1] = math.min(srow, self.edits[1]) + if erow_old <= self.edits[2] then + self.edits[2] = self.edits[2] + (erow_new - erow_old) + end + self.edits[2] = math.max(self.edits[2], erow_new) + else + self.edits = { srow, erow_new } + end +end + +---@package +---@return integer? srow +---@return integer? erow 0-indexed, exclusive +function FoldInfo:flush_edit() + if self.edits then + local srow, erow = self.edits[1], self.edits[2] + self.edits = nil + return srow, erow end - return level end --- If a parser doesn't have any ranges explicitly set, treesitter will @@ -140,10 +114,10 @@ end --- TODO(lewis6991): Handle this generally --- --- @param bufnr integer ---- @param erow integer? +--- @param erow integer? 0-indexed, exclusive --- @return integer local function normalise_erow(bufnr, erow) - local max_erow = api.nvim_buf_line_count(bufnr) - 1 + local max_erow = api.nvim_buf_line_count(bufnr) return math.min(erow or max_erow, max_erow) end @@ -152,31 +126,30 @@ end ---@param bufnr integer ---@param info TS.FoldInfo ---@param srow integer? ----@param erow integer? +---@param erow integer? 0-indexed, exclusive ---@param parse_injections? boolean local function get_folds_levels(bufnr, info, srow, erow, parse_injections) srow = srow or 0 erow = normalise_erow(bufnr, erow) - info:invalidate_range(srow, erow) - - local prev_start = -1 - local prev_stop = -1 - local parser = ts.get_parser(bufnr) parser:parse(parse_injections and { srow, erow } or nil) + local enter_counts = {} ---@type table + local leave_counts = {} ---@type table + local prev_start = -1 + local prev_stop = -1 + parser:for_each_tree(function(tree, ltree) local query = ts.query.get(ltree:lang(), 'folds') if not query then return end - -- erow in query is end-exclusive - local q_erow = erow and erow + 1 or -1 - - for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do + -- Collect folds starting from srow - 1, because we should first subtract the folds that end at + -- srow - 1 from the level of srow - 1 to get accurate level of srow. + for id, node, metadata in query:iter_captures(tree:root(), bufnr, math.max(srow - 1, 0), erow) do if query.captures[id] == 'fold' then local range = ts.get_range(node, bufnr, metadata[id]) local start, _, stop, stop_col = Range.unpack4(range) @@ -193,8 +166,8 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) if fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) then - info:add_start(start + 1) - info:add_stop(stop + 1) + enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 + leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 prev_start = start prev_stop = stop end @@ -202,16 +175,15 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) end end) - local current_level = info.levels0[srow] or 0 + local nestmax = vim.wo.foldnestmax + local level0_prev = info.levels0[srow] or 0 + local leave_prev = leave_counts[srow] or 0 -- We now have the list of fold opening and closing, fill the gaps and mark where fold start - for lnum = srow + 1, erow + 1 do - local last_trimmed_level = trim_level(current_level) - current_level = current_level + info:get_start(lnum) - info.levels0[lnum] = current_level - - local trimmed_level = trim_level(current_level) - current_level = current_level - info:get_stop(lnum) + for lnum = srow + 1, erow do + local enter_line = enter_counts[lnum] or 0 + local leave_line = leave_counts[lnum] or 0 + local level0 = level0_prev - leave_prev + enter_line -- Determine if it's the start/end of a fold -- NB: vim's fold-expr interface does not have a mechanism to indicate that @@ -219,14 +191,36 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) -- ( \n ( \n )) \n (( \n ) \n ) -- versus -- ( \n ( \n ) \n ( \n ) \n ) - -- If it did have such a mechanism, (trimmed_level - last_trimmed_level) + -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and + -- vim interprets as the second case. + -- If it did have such a mechanism, (clamped - clamped_prev) -- would be the correct number of starts to pass on. + local adjusted = level0 ---@type integer local prefix = '' - if trimmed_level - last_trimmed_level > 0 then + if enter_line > 0 then prefix = '>' + if leave_line > 0 then + -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line + -- so that f2 gets the correct level on this line. This may reduce the size of f1 below + -- foldminlines, but we don't handle it for simplicity. + adjusted = level0 - leave_line + leave_line = 0 + end end - info.levels[lnum] = prefix .. tostring(trimmed_level) + -- Clamp at foldnestmax. + local clamped = adjusted + if adjusted > nestmax then + prefix = '' + clamped = nestmax + end + + -- Record the "real" level, so that it can be used as "base" of later get_folds_levels(). + info.levels0[lnum] = adjusted + info.levels[lnum] = prefix .. tostring(clamped) + + leave_prev = leave_line + level0_prev = adjusted end end @@ -296,8 +290,12 @@ end local function on_changedtree(bufnr, foldinfo, tree_changes) schedule_if_loaded(bufnr, function() for _, change in ipairs(tree_changes) do - local srow, _, erow = Range.unpack4(change) - get_folds_levels(bufnr, foldinfo, srow, erow) + local srow, _, erow, ecol = Range.unpack4(change) + if ecol > 0 then + erow = erow + 1 + end + -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. + get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow) end if #tree_changes > 0 then foldupdate(bufnr) @@ -309,19 +307,46 @@ end ---@param foldinfo TS.FoldInfo ---@param start_row integer ---@param old_row integer +---@param old_col integer ---@param new_row integer -local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) - local end_row_old = start_row + old_row - local end_row_new = start_row + new_row +---@param new_col integer +local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, new_row, new_col) + -- extend the end to fully include the range + local end_row_old = start_row + old_row + 1 + local end_row_new = start_row + new_row + 1 if new_row ~= old_row then + -- foldexpr can be evaluated before the scheduled callback is invoked. So it may observe the + -- outdated levels, which may spuriously open the folds that didn't change. So we should shift + -- folds as accurately as possible. For this to be perfectly accurate, we should track the + -- actual TSNodes that account for each fold, and compare the node's range with the edited + -- range. But for simplicity, we just check whether the start row is completely removed (e.g., + -- `dd`) or shifted (e.g., `o`). if new_row < old_row then - foldinfo:remove_range(end_row_new, end_row_old) + if start_col == 0 and new_row == 0 and new_col == 0 then + foldinfo:remove_range(start_row, start_row + (end_row_old - end_row_new)) + else + foldinfo:remove_range(end_row_new, end_row_old) + end else - foldinfo:add_range(start_row, end_row_new) + if start_col == 0 and old_row == 0 and old_col == 0 then + foldinfo:add_range(start_row, start_row + (end_row_new - end_row_old)) + else + foldinfo:add_range(end_row_old, end_row_new) + end end + foldinfo:edit_range(start_row, end_row_old, end_row_new) + + -- This callback must not use on_bytes arguments, because they can be outdated when the callback + -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing + -- the scheduled callback. So we should collect the edits. schedule_if_loaded(bufnr, function() - get_folds_levels(bufnr, foldinfo, start_row, end_row_new) + local srow, erow = foldinfo:flush_edit() + if not srow then + return + end + -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. + get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow) foldupdate(bufnr) end) end @@ -348,8 +373,8 @@ function M.foldexpr(lnum) on_changedtree(bufnr, foldinfos[bufnr], tree_changes) end, - on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _) - on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row) + on_bytes = function(_, _, start_row, start_col, _, old_row, old_col, _, new_row, new_col, _) + on_bytes(bufnr, foldinfos[bufnr], start_row, start_col, old_row, old_col, new_row, new_col) end, on_detach = function() @@ -361,6 +386,18 @@ function M.foldexpr(lnum) return foldinfos[bufnr].levels[lnum] or '0' end +api.nvim_create_autocmd('OptionSet', { + pattern = { 'foldminlines', 'foldnestmax' }, + desc = 'Refresh treesitter folds', + callback = function() + for _, bufnr in ipairs(vim.tbl_keys(foldinfos)) do + foldinfos[bufnr] = FoldInfo.new() + get_folds_levels(bufnr, foldinfos[bufnr]) + foldupdate(bufnr) + end + end, +}) + ---@package ---@return { [1]: string, [2]: string[] }[]|string function M.foldtext() diff --git a/test/functional/treesitter/fold_spec.lua b/test/functional/treesitter/fold_spec.lua index bcace366bd..1482be9637 100644 --- a/test/functional/treesitter/fold_spec.lua +++ b/test/functional/treesitter/fold_spec.lua @@ -5,6 +5,7 @@ local insert = helpers.insert local exec_lua = helpers.exec_lua local command = helpers.command local feed = helpers.feed +local poke_eventloop = helpers.poke_eventloop local Screen = require('test.functional.ui.screen') before_each(clear) @@ -12,6 +13,11 @@ before_each(clear) describe('treesitter foldexpr', function() clear() + before_each(function() + -- open folds to avoid deleting entire folded region + exec_lua([[vim.opt.foldlevel = 9]]) + end) + local test_text = [[ void ui_refresh(void) { @@ -33,6 +39,10 @@ void ui_refresh(void) } }]] + local function parse(lang) + exec_lua(([[vim.treesitter.get_parser(0, %s):parse()]]):format(lang and '"' .. lang .. '"' or 'nil')) + end + local function get_fold_levels() return exec_lua([[ local res = {} @@ -46,7 +56,7 @@ void ui_refresh(void) it("can compute fold levels", function() insert(test_text) - exec_lua([[vim.treesitter.get_parser(0, "c")]]) + parse('c') eq({ [1] = '>1', @@ -67,16 +77,18 @@ void ui_refresh(void) [16] = '3', [17] = '3', [18] = '2', - [19] = '1' }, get_fold_levels()) + [19] = '1', + }, get_fold_levels()) end) it("recomputes fold levels after lines are added/removed", function() insert(test_text) - exec_lua([[vim.treesitter.get_parser(0, "c")]]) + parse('c') command('1,2d') + poke_eventloop() eq({ [1] = '0', @@ -95,9 +107,11 @@ void ui_refresh(void) [14] = '2', [15] = '2', [16] = '1', - [17] = '0' }, get_fold_levels()) + [17] = '0', + }, get_fold_levels()) command('1put!') + poke_eventloop() eq({ [1] = '>1', @@ -118,7 +132,274 @@ void ui_refresh(void) [16] = '3', [17] = '3', [18] = '2', - [19] = '1' }, get_fold_levels()) + [19] = '1', + }, get_fold_levels()) + end) + + it("handles changes close to start/end of folds", function() + insert([[ +# h1 +t1 +# h2 +t2]]) + + exec_lua([[vim.treesitter.query.set('markdown', 'folds', '(section) @fold')]]) + parse('markdown') + + eq({ + [1] = '>1', + [2] = '1', + [3] = '>1', + [4] = '1', + }, get_fold_levels()) + + feed('2ggo') + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '1', + [3] = '1', + [4] = '>1', + [5] = '1', + }, get_fold_levels()) + + feed('dd') + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '1', + [3] = '>1', + [4] = '1', + }, get_fold_levels()) + + feed('2ggdd') + poke_eventloop() + + eq({ + [1] = '0', + [2] = '>1', + [3] = '1', + }, get_fold_levels()) + + feed('u') + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '1', + [3] = '>1', + [4] = '1', + }, get_fold_levels()) + + feed('3ggdd') + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '1', + [3] = '1', + }, get_fold_levels()) + + feed('u') + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '1', + [3] = '>1', + [4] = '1', + }, get_fold_levels()) + + feed('3ggI#') + parse() + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '1', + [3] = '>2', + [4] = '2', + }, get_fold_levels()) + + feed('x') + parse() + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '1', + [3] = '>1', + [4] = '1', + }, get_fold_levels()) + + end) + + it("handles changes that trigger multiple on_bytes", function() + insert([[ +function f() + asdf() + asdf() +end +-- comment]]) + + exec_lua([[vim.treesitter.query.set('lua', 'folds', '[(function_declaration) (parameters) (arguments)] @fold')]]) + parse('lua') + + eq({ + [1] = '>1', + [2] = '1', + [3] = '1', + [4] = '1', + [5] = '0', + }, get_fold_levels()) + + command('1,4join') + poke_eventloop() + + eq({ + [1] = '0', + [2] = '0', + }, get_fold_levels()) + + feed('u') + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '1', + [3] = '1', + [4] = '1', + [5] = '0', + }, get_fold_levels()) + + end) + + it("handles multiple folds that overlap at the end and start", function() + insert([[ +function f() + g( + function() + asdf() + end, function() + end + ) +end]]) + + exec_lua([[vim.treesitter.query.set('lua', 'folds', '[(function_declaration) (function_definition) (parameters) (arguments)] @fold')]]) + parse('lua') + + -- If fold1.stop = fold2.start, then move fold1's stop up so that fold2.start gets proper level. + eq({ + [1] = '>1', + [2] = '>2', + [3] = '>3', + [4] = '3', + [5] = '>3', + [6] = '3', + [7] = '2', + [8] = '1', + }, get_fold_levels()) + + command('1,8join') + feed('u') + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '>2', + [3] = '>3', + [4] = '3', + [5] = '>3', + [6] = '3', + [7] = '2', + [8] = '1', + }, get_fold_levels()) + + end) + + it("handles multiple folds that start at the same line", function() + insert([[ +function f(a) + if #(g({ + k = v, + })) > 0 then + return + end +end]]) + + exec_lua([[vim.treesitter.query.set('lua', 'folds', '[(if_statement) (function_declaration) (parameters) (arguments) (table_constructor)] @fold')]]) + parse('lua') + + eq({ + [1] = '>1', + [2] = '>3', + [3] = '3', + [4] = '3', + [5] = '2', + [6] = '2', + [7] = '1', + }, get_fold_levels()) + + command('2,6join') + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '1', + [3] = '1', + }, get_fold_levels()) + + feed('u') + poke_eventloop() + + eq({ + [1] = '>1', + [2] = '>3', + [3] = '3', + [4] = '3', + [5] = '2', + [6] = '2', + [7] = '1', + }, get_fold_levels()) + + end) + + it("takes account of relevant options", function() + insert([[ +# h1 +t1 +## h2 +t2 +### h3 +t3]]) + + exec_lua([[vim.treesitter.query.set('markdown', 'folds', '(section) @fold')]]) + parse('markdown') + + command([[set foldminlines=2]]) + + eq({ + [1] = '>1', + [2] = '1', + [3] = '>2', + [4] = '2', + [5] = '2', + [6] = '2', + }, get_fold_levels()) + + command([[set foldminlines=1 foldnestmax=1]]) + + eq({ + [1] = '>1', + [2] = '1', + [3] = '1', + [4] = '1', + [5] = '1', + [6] = '1', + }, get_fold_levels()) + end) it("updates folds in all windows", function() @@ -131,8 +412,8 @@ void ui_refresh(void) [4] = {reverse = true}; }) - exec_lua([[vim.treesitter.get_parser(0, "c")]]) - command([[set foldmethod=expr foldexpr=v:lua.vim.treesitter.foldexpr() foldcolumn=1 foldlevel=9]]) + parse("c") + command([[set foldmethod=expr foldexpr=v:lua.vim.treesitter.foldexpr() foldcolumn=1]]) command('split') insert(test_text) @@ -282,7 +563,7 @@ void ui_refresh(void) local screen = Screen.new(60, 36) screen:attach() - exec_lua([[vim.treesitter.get_parser(0, "c")]]) + parse("c") command([[set foldmethod=expr foldexpr=v:lua.vim.treesitter.foldexpr() foldcolumn=1 foldlevel=9]]) insert(test_text) command('16d') @@ -330,6 +611,58 @@ void ui_refresh(void) }} end) + it("doesn't open folds that are not touched", function() + local screen = Screen.new(40, 8) + screen:set_default_attr_ids({ + [1] = {foreground = Screen.colors.DarkBlue, background = Screen.colors.Gray}; + [2] = {foreground = Screen.colors.DarkBlue, background = Screen.colors.LightGray}; + [3] = {foreground = Screen.colors.Blue1, bold = true}; + [4] = {bold = true}; + }) + screen:attach() + + insert([[ +# h1 +t1 +# h2 +t2]]) + exec_lua([[vim.treesitter.query.set('markdown', 'folds', '(section) @fold')]]) + parse('markdown') + command([[set foldmethod=expr foldexpr=v:lua.vim.treesitter.foldexpr() foldcolumn=1 foldlevel=0]]) + + + feed('ggzojo') + poke_eventloop() + + screen:expect{grid=[[ + {1:-}# h1 | + {1:│}t1 | + {1:│}^ | + {1:+}{2:+-- 2 lines: # h2·····················}| + {3:~ }| + {3:~ }| + {3:~ }| + {4:-- INSERT --} | + ]]} + + feed('u') + -- TODO(tomtomjhj): `u` spuriously opens the fold (#26499). + feed('zMggzo') + + feed('dd') + poke_eventloop() + + screen:expect{grid=[[ + {1:-}^t1 | + {1:-}# h2 | + {1:│}t2 | + {3:~ }| + {3:~ }| + {3:~ }| + {3:~ }| + 1 line less; before #2 0 seconds ago | + ]]} + end) end) describe('treesitter foldtext', function()