diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt
index 4dee958108..23bb6d4343 100644
--- a/runtime/doc/news.txt
+++ b/runtime/doc/news.txt
@@ -168,6 +168,9 @@ The following new APIs or features were added.
`vim.treesitter.language.require_language`.
• `require'bit'` is now always available |lua-bit|
+
+• |vim.treesitter.foldexpr()| can be used for 'foldexpr' to use treesitter for folding.
+
==============================================================================
CHANGED FEATURES *news-changes*
diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt
index 3f505e5d19..ccb3c445df 100644
--- a/runtime/doc/treesitter.txt
+++ b/runtime/doc/treesitter.txt
@@ -481,6 +481,19 @@ library.
==============================================================================
Lua module: vim.treesitter *lua-treesitter-core*
+foldexpr({lnum}) *vim.treesitter.foldexpr()*
+ Returns the fold level for {lnum} in the current buffer. Can be set
+ directly to 'foldexpr': >lua
+
+ vim.wo.foldexpr = 'v:lua.vim.treesitter.foldexpr()'
+<
+
+ Parameters: ~
+ • {lnum} (integer|nil) Line number to calculate fold level for
+
+ Return: ~
+ (string)
+
*vim.treesitter.get_captures_at_cursor()*
get_captures_at_cursor({winnr})
Returns a list of highlight capture names under the cursor
diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua
index 44922bbc4d..fead7b7b1b 100644
--- a/runtime/lua/vim/treesitter.lua
+++ b/runtime/lua/vim/treesitter.lua
@@ -115,6 +115,16 @@ function M.get_parser(bufnr, lang, opts)
return parsers[bufnr]
end
+---@private
+---@param bufnr (integer|nil) Buffer number
+---@return boolean
+function M._has_parser(bufnr)
+ if bufnr == nil or bufnr == 0 then
+ bufnr = a.nvim_get_current_buf()
+ end
+ return parsers[bufnr] ~= nil
+end
+
--- Returns a string parser
---
---@param str string Text to parse
@@ -612,4 +622,14 @@ function M.show_tree(opts)
})
end
+--- Returns the fold level for {lnum} in the current buffer. Can be set directly to 'foldexpr':
+---
lua
+--- vim.wo.foldexpr = 'v:lua.vim.treesitter.foldexpr()'
+---
+---@param lnum integer|nil Line number to calculate fold level for
+---@return string
+function M.foldexpr(lnum)
+ return require('vim.treesitter._fold').foldexpr(lnum)
+end
+
return M
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
new file mode 100644
index 0000000000..a66cc6d543
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -0,0 +1,173 @@
+local api = vim.api
+
+local M = {}
+
+--- Memoizes a function based on the buffer tick of the provided bufnr.
+--- The cache entry is cleared when the buffer is detached to avoid memory leaks.
+---@generic F: function
+---@param fn F fn to memoize, taking the bufnr as first argument
+---@return F
+local function memoize_by_changedtick(fn)
+ ---@type table
+ local cache = {}
+
+ ---@param bufnr integer
+ return function(bufnr, ...)
+ local tick = api.nvim_buf_get_changedtick(bufnr)
+
+ if cache[bufnr] then
+ if cache[bufnr].last_tick == tick then
+ return cache[bufnr].result
+ end
+ else
+ local function detach_handler()
+ cache[bufnr] = nil
+ end
+
+ -- Clean up logic only!
+ api.nvim_buf_attach(bufnr, false, {
+ on_detach = detach_handler,
+ on_reload = detach_handler,
+ })
+ end
+
+ cache[bufnr] = {
+ result = fn(bufnr, ...),
+ last_tick = tick,
+ }
+
+ return cache[bufnr].result
+ end
+end
+
+---@param bufnr integer
+---@param capture string
+---@param query_name string
+---@param callback fun(id: integer, node:TSNode, metadata: TSMetadata)
+local function iter_matches_with_capture(bufnr, capture, query_name, callback)
+ local parser = vim.treesitter.get_parser(bufnr)
+
+ if not parser then
+ return
+ end
+
+ parser:for_each_tree(function(tree, lang_tree)
+ local lang = lang_tree:lang()
+ local query = vim.treesitter.query.get_query(lang, query_name)
+ if query then
+ local root = tree:root()
+ local start, _, stop = root:range()
+ for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do
+ for id, node in pairs(match) do
+ if query.captures[id] == capture then
+ callback(id, node, metadata)
+ end
+ end
+ end
+ end
+ end)
+end
+
+---@private
+--- TODO(lewis6991): copied from languagetree.lua. Consolidate
+---@param node TSNode
+---@param id integer
+---@param metadata TSMetadata
+---@return Range
+local function get_range_from_metadata(node, id, metadata)
+ if metadata[id] and metadata[id].range then
+ return metadata[id].range --[[@as Range]]
+ end
+ return { node:range() }
+end
+
+-- This is cached on buf tick to avoid computing that multiple times
+-- Especially not for every line in the file when `zx` is hit
+---@param bufnr integer
+---@return table
+local folds_levels = memoize_by_changedtick(function(bufnr)
+ local max_fold_level = vim.wo.foldnestmax
+ local function trim_level(level)
+ if level > max_fold_level then
+ return max_fold_level
+ end
+ return level
+ end
+
+ -- start..stop is an inclusive range
+ local start_counts = {} ---@type table
+ local stop_counts = {} ---@type table
+
+ local prev_start = -1
+ local prev_stop = -1
+
+ local min_fold_lines = vim.wo.foldminlines
+
+ iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata)
+ local range = get_range_from_metadata(node, id, metadata)
+ local start, stop, stop_col = range[1], range[3], range[4]
+
+ if stop_col == 0 then
+ stop = stop - 1
+ end
+
+ local fold_length = stop - start + 1
+
+ -- Fold only multiline nodes that are not exactly the same as previously met folds
+ -- Checking against just the previously found fold is sufficient if nodes
+ -- are returned in preorder or postorder when traversing tree
+ if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then
+ start_counts[start] = (start_counts[start] or 0) + 1
+ stop_counts[stop] = (stop_counts[stop] or 0) + 1
+ prev_start = start
+ prev_stop = stop
+ end
+ end)
+
+ ---@type table
+ local levels = {}
+ local current_level = 0
+
+ -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
+ for lnum = 0, api.nvim_buf_line_count(bufnr) do
+ local last_trimmed_level = trim_level(current_level)
+ current_level = current_level + (start_counts[lnum] or 0)
+ local trimmed_level = trim_level(current_level)
+ current_level = current_level - (stop_counts[lnum] or 0)
+
+ -- Determine if it's the start/end of a fold
+ -- NB: vim's fold-expr interface does not have a mechanism to indicate that
+ -- two (or more) folds start at this line, so it cannot distinguish between
+ -- ( \n ( \n )) \n (( \n ) \n )
+ -- versus
+ -- ( \n ( \n ) \n ( \n ) \n )
+ -- If it did have such a mechanism, (trimmed_level - last_trimmed_level)
+ -- would be the correct number of starts to pass on.
+ local prefix = ''
+ if trimmed_level - last_trimmed_level > 0 then
+ prefix = '>'
+ end
+
+ levels[lnum + 1] = prefix .. tostring(trimmed_level)
+ end
+
+ return levels
+end)
+
+---@param lnum integer|nil
+---@return string
+function M.foldexpr(lnum)
+ lnum = lnum or vim.v.lnum
+ local bufnr = api.nvim_get_current_buf()
+
+ ---@diagnostic disable-next-line:invisible
+ if not vim.treesitter._has_parser(bufnr) or not lnum then
+ return '0'
+ end
+
+ local levels = folds_levels(bufnr) or {}
+
+ return levels[lnum] or '0'
+end
+
+return M
diff --git a/runtime/queries/c/folds.scm b/runtime/queries/c/folds.scm
new file mode 100644
index 0000000000..80c3039b6b
--- /dev/null
+++ b/runtime/queries/c/folds.scm
@@ -0,0 +1,19 @@
+[
+ (for_statement)
+ (if_statement)
+ (while_statement)
+ (switch_statement)
+ (case_statement)
+ (function_definition)
+ (struct_specifier)
+ (enum_specifier)
+ (comment)
+ (preproc_if)
+ (preproc_elif)
+ (preproc_else)
+ (preproc_ifdef)
+ (initializer_list)
+] @fold
+
+ (compound_statement
+ (compound_statement) @fold)
diff --git a/runtime/queries/lua/folds.scm b/runtime/queries/lua/folds.scm
new file mode 100644
index 0000000000..d8f0b42df3
--- /dev/null
+++ b/runtime/queries/lua/folds.scm
@@ -0,0 +1,10 @@
+[
+ (do_statement)
+ (while_statement)
+ (repeat_statement)
+ (if_statement)
+ (for_statement)
+ (function_declaration)
+ (function_definition)
+ (table_constructor)
+] @fold
diff --git a/runtime/queries/vim/folds.scm b/runtime/queries/vim/folds.scm
new file mode 100644
index 0000000000..4c99735836
--- /dev/null
+++ b/runtime/queries/vim/folds.scm
@@ -0,0 +1,4 @@
+[
+ (if_statement)
+ (function_definition)
+] @fold
diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua
index fdd6403859..67aad4d1b7 100644
--- a/test/functional/treesitter/parser_spec.lua
+++ b/test/functional/treesitter/parser_spec.lua
@@ -871,4 +871,38 @@ int x = INT_MAX;
end)
end)
end)
+
+ it("can fold via foldexpr", function()
+ insert(test_text)
+ exec_lua([[vim.treesitter.get_parser(0, "c")]])
+
+ local levels = exec_lua([[
+ local res = {}
+ for i = 1, vim.api.nvim_buf_line_count(0) do
+ res[i] = vim.treesitter.foldexpr(i)
+ end
+ return res
+ ]])
+
+ eq({
+ [1] = '>1',
+ [2] = '1',
+ [3] = '1',
+ [4] = '1',
+ [5] = '>2',
+ [6] = '2',
+ [7] = '2',
+ [8] = '1',
+ [9] = '1',
+ [10] = '>2',
+ [11] = '2',
+ [12] = '2',
+ [13] = '2',
+ [14] = '2',
+ [15] = '>3',
+ [16] = '3',
+ [17] = '3',
+ [18] = '2',
+ [19] = '1' }, levels)
+ end)
end)