diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt index 4dee958108..23bb6d4343 100644 --- a/runtime/doc/news.txt +++ b/runtime/doc/news.txt @@ -168,6 +168,9 @@ The following new APIs or features were added. `vim.treesitter.language.require_language`. • `require'bit'` is now always available |lua-bit| + +• |vim.treesitter.foldexpr()| can be used for 'foldexpr' to use treesitter for folding. + ============================================================================== CHANGED FEATURES *news-changes* diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt index 3f505e5d19..ccb3c445df 100644 --- a/runtime/doc/treesitter.txt +++ b/runtime/doc/treesitter.txt @@ -481,6 +481,19 @@ library. ============================================================================== Lua module: vim.treesitter *lua-treesitter-core* +foldexpr({lnum}) *vim.treesitter.foldexpr()* + Returns the fold level for {lnum} in the current buffer. Can be set + directly to 'foldexpr': >lua + + vim.wo.foldexpr = 'v:lua.vim.treesitter.foldexpr()' +< + + Parameters: ~ + • {lnum} (integer|nil) Line number to calculate fold level for + + Return: ~ + (string) + *vim.treesitter.get_captures_at_cursor()* get_captures_at_cursor({winnr}) Returns a list of highlight capture names under the cursor diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua index 44922bbc4d..fead7b7b1b 100644 --- a/runtime/lua/vim/treesitter.lua +++ b/runtime/lua/vim/treesitter.lua @@ -115,6 +115,16 @@ function M.get_parser(bufnr, lang, opts) return parsers[bufnr] end +---@private +---@param bufnr (integer|nil) Buffer number +---@return boolean +function M._has_parser(bufnr) + if bufnr == nil or bufnr == 0 then + bufnr = a.nvim_get_current_buf() + end + return parsers[bufnr] ~= nil +end + --- Returns a string parser --- ---@param str string Text to parse @@ -612,4 +622,14 @@ function M.show_tree(opts) }) end +--- Returns the fold level for {lnum} in the current buffer. Can be set directly to 'foldexpr': +---
lua
+--- vim.wo.foldexpr = 'v:lua.vim.treesitter.foldexpr()'
+--- 
+---@param lnum integer|nil Line number to calculate fold level for +---@return string +function M.foldexpr(lnum) + return require('vim.treesitter._fold').foldexpr(lnum) +end + return M diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua new file mode 100644 index 0000000000..a66cc6d543 --- /dev/null +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -0,0 +1,173 @@ +local api = vim.api + +local M = {} + +--- Memoizes a function based on the buffer tick of the provided bufnr. +--- The cache entry is cleared when the buffer is detached to avoid memory leaks. +---@generic F: function +---@param fn F fn to memoize, taking the bufnr as first argument +---@return F +local function memoize_by_changedtick(fn) + ---@type table + local cache = {} + + ---@param bufnr integer + return function(bufnr, ...) + local tick = api.nvim_buf_get_changedtick(bufnr) + + if cache[bufnr] then + if cache[bufnr].last_tick == tick then + return cache[bufnr].result + end + else + local function detach_handler() + cache[bufnr] = nil + end + + -- Clean up logic only! + api.nvim_buf_attach(bufnr, false, { + on_detach = detach_handler, + on_reload = detach_handler, + }) + end + + cache[bufnr] = { + result = fn(bufnr, ...), + last_tick = tick, + } + + return cache[bufnr].result + end +end + +---@param bufnr integer +---@param capture string +---@param query_name string +---@param callback fun(id: integer, node:TSNode, metadata: TSMetadata) +local function iter_matches_with_capture(bufnr, capture, query_name, callback) + local parser = vim.treesitter.get_parser(bufnr) + + if not parser then + return + end + + parser:for_each_tree(function(tree, lang_tree) + local lang = lang_tree:lang() + local query = vim.treesitter.query.get_query(lang, query_name) + if query then + local root = tree:root() + local start, _, stop = root:range() + for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do + for id, node in pairs(match) do + if query.captures[id] == capture then + callback(id, node, metadata) + end + end + end + end + end) +end + +---@private +--- TODO(lewis6991): copied from languagetree.lua. Consolidate +---@param node TSNode +---@param id integer +---@param metadata TSMetadata +---@return Range +local function get_range_from_metadata(node, id, metadata) + if metadata[id] and metadata[id].range then + return metadata[id].range --[[@as Range]] + end + return { node:range() } +end + +-- This is cached on buf tick to avoid computing that multiple times +-- Especially not for every line in the file when `zx` is hit +---@param bufnr integer +---@return table +local folds_levels = memoize_by_changedtick(function(bufnr) + local max_fold_level = vim.wo.foldnestmax + local function trim_level(level) + if level > max_fold_level then + return max_fold_level + end + return level + end + + -- start..stop is an inclusive range + local start_counts = {} ---@type table + local stop_counts = {} ---@type table + + local prev_start = -1 + local prev_stop = -1 + + local min_fold_lines = vim.wo.foldminlines + + iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata) + local range = get_range_from_metadata(node, id, metadata) + local start, stop, stop_col = range[1], range[3], range[4] + + if stop_col == 0 then + stop = stop - 1 + end + + local fold_length = stop - start + 1 + + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then + start_counts[start] = (start_counts[start] or 0) + 1 + stop_counts[stop] = (stop_counts[stop] or 0) + 1 + prev_start = start + prev_stop = stop + end + end) + + ---@type table + local levels = {} + local current_level = 0 + + -- We now have the list of fold opening and closing, fill the gaps and mark where fold start + for lnum = 0, api.nvim_buf_line_count(bufnr) do + local last_trimmed_level = trim_level(current_level) + current_level = current_level + (start_counts[lnum] or 0) + local trimmed_level = trim_level(current_level) + current_level = current_level - (stop_counts[lnum] or 0) + + -- Determine if it's the start/end of a fold + -- NB: vim's fold-expr interface does not have a mechanism to indicate that + -- two (or more) folds start at this line, so it cannot distinguish between + -- ( \n ( \n )) \n (( \n ) \n ) + -- versus + -- ( \n ( \n ) \n ( \n ) \n ) + -- If it did have such a mechanism, (trimmed_level - last_trimmed_level) + -- would be the correct number of starts to pass on. + local prefix = '' + if trimmed_level - last_trimmed_level > 0 then + prefix = '>' + end + + levels[lnum + 1] = prefix .. tostring(trimmed_level) + end + + return levels +end) + +---@param lnum integer|nil +---@return string +function M.foldexpr(lnum) + lnum = lnum or vim.v.lnum + local bufnr = api.nvim_get_current_buf() + + ---@diagnostic disable-next-line:invisible + if not vim.treesitter._has_parser(bufnr) or not lnum then + return '0' + end + + local levels = folds_levels(bufnr) or {} + + return levels[lnum] or '0' +end + +return M diff --git a/runtime/queries/c/folds.scm b/runtime/queries/c/folds.scm new file mode 100644 index 0000000000..80c3039b6b --- /dev/null +++ b/runtime/queries/c/folds.scm @@ -0,0 +1,19 @@ +[ + (for_statement) + (if_statement) + (while_statement) + (switch_statement) + (case_statement) + (function_definition) + (struct_specifier) + (enum_specifier) + (comment) + (preproc_if) + (preproc_elif) + (preproc_else) + (preproc_ifdef) + (initializer_list) +] @fold + + (compound_statement + (compound_statement) @fold) diff --git a/runtime/queries/lua/folds.scm b/runtime/queries/lua/folds.scm new file mode 100644 index 0000000000..d8f0b42df3 --- /dev/null +++ b/runtime/queries/lua/folds.scm @@ -0,0 +1,10 @@ +[ + (do_statement) + (while_statement) + (repeat_statement) + (if_statement) + (for_statement) + (function_declaration) + (function_definition) + (table_constructor) +] @fold diff --git a/runtime/queries/vim/folds.scm b/runtime/queries/vim/folds.scm new file mode 100644 index 0000000000..4c99735836 --- /dev/null +++ b/runtime/queries/vim/folds.scm @@ -0,0 +1,4 @@ +[ + (if_statement) + (function_definition) +] @fold diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua index fdd6403859..67aad4d1b7 100644 --- a/test/functional/treesitter/parser_spec.lua +++ b/test/functional/treesitter/parser_spec.lua @@ -871,4 +871,38 @@ int x = INT_MAX; end) end) end) + + it("can fold via foldexpr", function() + insert(test_text) + exec_lua([[vim.treesitter.get_parser(0, "c")]]) + + local levels = exec_lua([[ + local res = {} + for i = 1, vim.api.nvim_buf_line_count(0) do + res[i] = vim.treesitter.foldexpr(i) + end + return res + ]]) + + eq({ + [1] = '>1', + [2] = '1', + [3] = '1', + [4] = '1', + [5] = '>2', + [6] = '2', + [7] = '2', + [8] = '1', + [9] = '1', + [10] = '>2', + [11] = '2', + [12] = '2', + [13] = '2', + [14] = '2', + [15] = '>3', + [16] = '3', + [17] = '3', + [18] = '2', + [19] = '1' }, levels) + end) end)