local Range = require('vim.treesitter._range') local api = vim.api ---@class TS.FoldInfo ---@field levels table ---@field levels0 table ---@field private start_counts table ---@field private stop_counts table local FoldInfo = {} FoldInfo.__index = FoldInfo ---@private function FoldInfo.new() return setmetatable({ start_counts = {}, stop_counts = {}, levels0 = {}, levels = {}, }, FoldInfo) end ---@package ---@param srow integer ---@param erow integer function FoldInfo:invalidate_range(srow, erow) for i = srow, erow do self.start_counts[i + 1] = nil self.stop_counts[i + 1] = nil self.levels0[i + 1] = nil self.levels[i + 1] = nil end end ---@package ---@param srow integer ---@param erow integer function FoldInfo:remove_range(srow, erow) for i = erow - 1, srow, -1 do table.remove(self.levels, i + 1) table.remove(self.levels0, i + 1) table.remove(self.start_counts, i + 1) table.remove(self.stop_counts, i + 1) end end ---@package ---@param srow integer ---@param erow integer function FoldInfo:add_range(srow, erow) for i = srow, erow - 1 do table.insert(self.levels, i + 1, '-1') table.insert(self.levels0, i + 1, -1) table.insert(self.start_counts, i + 1, nil) table.insert(self.stop_counts, i + 1, nil) end end ---@package ---@param lnum integer function FoldInfo:add_start(lnum) self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1 end ---@package ---@param lnum integer function FoldInfo:add_stop(lnum) self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1 end ---@package ---@param lnum integer ---@return integer function FoldInfo:get_start(lnum) return self.start_counts[lnum] or 0 end ---@package ---@param lnum integer ---@return integer function FoldInfo:get_stop(lnum) return self.stop_counts[lnum] or 0 end local function trim_level(level) local max_fold_level = vim.wo.foldnestmax if level > max_fold_level then return max_fold_level end return level end ---@param bufnr integer ---@param info TS.FoldInfo ---@param srow integer? ---@param erow integer? local function get_folds_levels(bufnr, info, srow, erow) srow = srow or 0 erow = erow or api.nvim_buf_line_count(bufnr) info:invalidate_range(srow, erow) local prev_start = -1 local prev_stop = -1 vim.treesitter.get_parser(bufnr):for_each_tree(function(tree, ltree) local query = vim.treesitter.query.get(ltree:lang(), 'folds') if not query then return end -- erow in query is end-exclusive local q_erow = erow and erow + 1 or -1 for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do if query.captures[id] == 'fold' then local range = vim.treesitter.get_range(node, bufnr, metadata[id]) local start, _, stop, stop_col = Range.unpack4(range) if stop_col == 0 then stop = stop - 1 end local fold_length = stop - start + 1 -- Fold only multiline nodes that are not exactly the same as previously met folds -- Checking against just the previously found fold is sufficient if nodes -- are returned in preorder or postorder when traversing tree if fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) then info:add_start(start + 1) info:add_stop(stop + 1) prev_start = start prev_stop = stop end end end end) local current_level = info.levels0[srow] or 0 -- We now have the list of fold opening and closing, fill the gaps and mark where fold start for lnum = srow + 1, erow + 1 do local last_trimmed_level = trim_level(current_level) current_level = current_level + info:get_start(lnum) info.levels0[lnum] = current_level local trimmed_level = trim_level(current_level) current_level = current_level - info:get_stop(lnum) -- Determine if it's the start/end of a fold -- NB: vim's fold-expr interface does not have a mechanism to indicate that -- two (or more) folds start at this line, so it cannot distinguish between -- ( \n ( \n )) \n (( \n ) \n ) -- versus -- ( \n ( \n ) \n ( \n ) \n ) -- If it did have such a mechanism, (trimmed_level - last_trimmed_level) -- would be the correct number of starts to pass on. local prefix = '' if trimmed_level - last_trimmed_level > 0 then prefix = '>' end info.levels[lnum] = prefix .. tostring(trimmed_level) end end local M = {} ---@type table local foldinfos = {} local function recompute_folds() if api.nvim_get_mode().mode == 'i' then -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave api.nvim_create_autocmd('InsertLeave', { once = true, callback = vim._foldupdate, }) return end vim._foldupdate() end ---@param bufnr integer ---@param foldinfo TS.FoldInfo ---@param tree_changes Range4[] local function on_changedtree(bufnr, foldinfo, tree_changes) -- For some reason, queries seem to use the old buffer state in on_bytes. -- Get around this by scheduling and manually updating folds. vim.schedule(function() for _, change in ipairs(tree_changes) do local srow, _, erow = Range.unpack4(change) get_folds_levels(bufnr, foldinfo, srow, erow) end recompute_folds() end) end ---@param bufnr integer ---@param foldinfo TS.FoldInfo ---@param start_row integer ---@param old_row integer ---@param new_row integer local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) local end_row_old = start_row + old_row local end_row_new = start_row + new_row if new_row < old_row then foldinfo:remove_range(end_row_new, end_row_old) elseif new_row > old_row then foldinfo:add_range(start_row, end_row_new) vim.schedule(function() get_folds_levels(bufnr, foldinfo, start_row, end_row_new) recompute_folds() end) end end ---@package ---@param lnum integer|nil ---@return string function M.foldexpr(lnum) lnum = lnum or vim.v.lnum local bufnr = api.nvim_get_current_buf() if not vim.treesitter._has_parser(bufnr) or not lnum then return '0' end if not foldinfos[bufnr] then foldinfos[bufnr] = FoldInfo.new() get_folds_levels(bufnr, foldinfos[bufnr]) local parser = vim.treesitter.get_parser(bufnr) parser:register_cbs({ on_changedtree = function(tree_changes) on_changedtree(bufnr, foldinfos[bufnr], tree_changes) end, on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _) on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row) end, on_detach = function() foldinfos[bufnr] = nil end, }) end return foldinfos[bufnr].levels[lnum] or '0' end return M