2023-05-01 02:32:29 -07:00
local ts = vim.treesitter
2023-03-09 08:28:55 -07:00
local Range = require ( ' vim.treesitter._range ' )
2023-02-23 10:05:20 -07:00
local api = vim.api
2023-03-23 04:23:51 -07:00
---@class TS.FoldInfo
2023-12-10 06:18:48 -07:00
---@field levels string[] the foldexpr result for each line
2023-12-08 23:21:08 -07:00
---@field levels0 integer[] the raw fold levels
2023-12-10 06:18:48 -07:00
---@field edits? {[1]: integer, [2]: integer} line range edited since the last invocation of the callback scheduled in on_bytes. 0-indexed, end-exclusive.
2023-03-09 08:28:55 -07:00
local FoldInfo = { }
FoldInfo.__index = FoldInfo
2023-02-23 10:05:20 -07:00
2023-03-23 04:23:51 -07:00
---@private
2023-03-09 08:28:55 -07:00
function FoldInfo . new ( )
return setmetatable ( {
levels0 = { } ,
levels = { } ,
} , FoldInfo )
end
2023-02-23 10:05:20 -07:00
2023-05-02 14:27:14 -07:00
--- Efficiently remove items from middle of a list a list.
---
--- Calling table.remove() in a loop will re-index the tail of the table on
--- every iteration, instead this function will re-index the table exactly
--- once.
---
--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
---
---@param t any[]
---@param first integer
---@param last integer
local function list_remove ( t , first , last )
local n = # t
for i = 0 , n - first do
t [ first + i ] = t [ last + 1 + i ]
t [ last + 1 + i ] = nil
end
end
2023-03-23 04:23:51 -07:00
---@package
2023-03-09 08:28:55 -07:00
---@param srow integer
2023-12-08 23:21:08 -07:00
---@param erow integer 0-indexed, exclusive
2023-03-09 08:28:55 -07:00
function FoldInfo : remove_range ( srow , erow )
2023-05-02 14:27:14 -07:00
list_remove ( self.levels , srow + 1 , erow )
list_remove ( self.levels0 , srow + 1 , erow )
end
--- Efficiently insert items into the middle of a list.
---
--- Calling table.insert() in a loop will re-index the tail of the table on
--- every iteration, instead this function will re-index the table exactly
--- once.
---
--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
---
---@param t any[]
---@param first integer
---@param last integer
---@param v any
local function list_insert ( t , first , last , v )
local n = # t
-- Shift table forward
for i = n - first , 0 , - 1 do
t [ last + 1 + i ] = t [ first + i ]
end
-- Fill in new values
for i = first , last do
t [ i ] = v
2023-03-09 08:28:55 -07:00
end
end
2023-02-23 10:05:20 -07:00
2023-03-23 04:23:51 -07:00
---@package
2023-03-09 08:28:55 -07:00
---@param srow integer
2023-12-08 23:21:08 -07:00
---@param erow integer 0-indexed, exclusive
2023-03-09 08:28:55 -07:00
function FoldInfo : add_range ( srow , erow )
2023-12-10 06:18:48 -07:00
list_insert ( self.levels , srow + 1 , erow , ' = ' )
2023-05-02 14:27:14 -07:00
list_insert ( self.levels0 , srow + 1 , erow , - 1 )
2023-03-09 08:28:55 -07:00
end
2023-02-23 10:05:20 -07:00
2023-03-23 04:23:51 -07:00
---@package
2023-12-10 06:18:48 -07:00
---@param srow integer
---@param erow_old integer
---@param erow_new integer 0-indexed, exclusive
function FoldInfo : edit_range ( srow , erow_old , erow_new )
if self.edits then
self.edits [ 1 ] = math.min ( srow , self.edits [ 1 ] )
if erow_old <= self.edits [ 2 ] then
self.edits [ 2 ] = self.edits [ 2 ] + ( erow_new - erow_old )
end
self.edits [ 2 ] = math.max ( self.edits [ 2 ] , erow_new )
else
self.edits = { srow , erow_new }
end
2023-03-09 08:28:55 -07:00
end
2023-03-23 04:23:51 -07:00
---@package
2023-12-10 06:18:48 -07:00
---@return integer? srow
---@return integer? erow 0-indexed, exclusive
function FoldInfo : flush_edit ( )
if self.edits then
local srow , erow = self.edits [ 1 ] , self.edits [ 2 ]
self.edits = nil
return srow , erow
2023-02-23 10:05:20 -07:00
end
2023-03-09 08:28:55 -07:00
end
2023-02-23 10:05:20 -07:00
2023-05-01 02:32:29 -07:00
--- If a parser doesn't have any ranges explicitly set, treesitter will
--- return a range with end_row and end_bytes with a value of UINT32_MAX,
--- so clip end_row to the max buffer line.
---
--- TODO(lewis6991): Handle this generally
---
--- @param bufnr integer
2023-12-08 23:21:08 -07:00
--- @param erow integer? 0-indexed, exclusive
2023-05-01 02:32:29 -07:00
--- @return integer
local function normalise_erow ( bufnr , erow )
2023-12-08 23:21:08 -07:00
local max_erow = api.nvim_buf_line_count ( bufnr )
2023-05-01 02:32:29 -07:00
return math.min ( erow or max_erow , max_erow )
end
2023-08-10 06:21:56 -07:00
-- TODO(lewis6991): Setup a decor provider so injections folds can be parsed
-- as the window is redrawn
2023-03-09 08:28:55 -07:00
---@param bufnr integer
2023-03-23 04:23:51 -07:00
---@param info TS.FoldInfo
2023-03-09 08:28:55 -07:00
---@param srow integer?
2023-12-08 23:21:08 -07:00
---@param erow integer? 0-indexed, exclusive
2023-08-10 06:21:56 -07:00
---@param parse_injections? boolean
local function get_folds_levels ( bufnr , info , srow , erow , parse_injections )
2023-03-09 08:28:55 -07:00
srow = srow or 0
2023-05-01 02:32:29 -07:00
erow = normalise_erow ( bufnr , erow )
2023-03-09 08:28:55 -07:00
2023-05-01 02:32:29 -07:00
local parser = ts.get_parser ( bufnr )
2023-08-10 06:21:56 -07:00
parser : parse ( parse_injections and { srow , erow } or nil )
2023-05-01 02:32:29 -07:00
2023-12-10 06:18:48 -07:00
local enter_counts = { } ---@type table<integer, integer>
local leave_counts = { } ---@type table<integer, integer>
local prev_start = - 1
local prev_stop = - 1
2023-05-01 02:32:29 -07:00
parser : for_each_tree ( function ( tree , ltree )
local query = ts.query . get ( ltree : lang ( ) , ' folds ' )
2023-03-09 08:28:55 -07:00
if not query then
return
end
2023-12-10 06:18:48 -07:00
-- Collect folds starting from srow - 1, because we should first subtract the folds that end at
-- srow - 1 from the level of srow - 1 to get accurate level of srow.
2024-04-20 10:36:17 -07:00
for _ , match , metadata in
query : iter_matches ( tree : root ( ) , bufnr , math.max ( srow - 1 , 0 ) , erow , { all = true } )
do
for id , nodes in pairs ( match ) do
if query.captures [ id ] == ' fold ' then
local range = ts.get_range ( nodes [ 1 ] , bufnr , metadata [ id ] )
local start , _ , stop , stop_col = Range.unpack4 ( range )
for i = 2 , # nodes , 1 do
local node_range = ts.get_range ( nodes [ i ] , bufnr , metadata [ id ] )
local node_start , _ , node_stop , node_stop_col = Range.unpack4 ( node_range )
if node_start < start then
start = node_start
end
if node_stop > stop then
stop = node_stop
stop_col = node_stop_col
end
end
if stop_col == 0 then
stop = stop - 1
end
local fold_length = stop - start + 1
-- Fold only multiline nodes that are not exactly the same as previously met folds
-- Checking against just the previously found fold is sufficient if nodes
-- are returned in preorder or postorder when traversing tree
if
fold_length > vim.wo . foldminlines and not ( start == prev_start and stop == prev_stop )
then
enter_counts [ start + 1 ] = ( enter_counts [ start + 1 ] or 0 ) + 1
leave_counts [ stop + 1 ] = ( leave_counts [ stop + 1 ] or 0 ) + 1
prev_start = start
prev_stop = stop
end
2023-03-09 08:28:55 -07:00
end
end
2023-02-23 10:05:20 -07:00
end
end )
2023-12-10 06:18:48 -07:00
local nestmax = vim.wo . foldnestmax
local level0_prev = info.levels0 [ srow ] or 0
local leave_prev = leave_counts [ srow ] or 0
2023-02-23 10:05:20 -07:00
-- We now have the list of fold opening and closing, fill the gaps and mark where fold start
2023-12-08 23:21:08 -07:00
for lnum = srow + 1 , erow do
2023-12-10 06:18:48 -07:00
local enter_line = enter_counts [ lnum ] or 0
local leave_line = leave_counts [ lnum ] or 0
local level0 = level0_prev - leave_prev + enter_line
2023-02-23 10:05:20 -07:00
-- Determine if it's the start/end of a fold
-- NB: vim's fold-expr interface does not have a mechanism to indicate that
-- two (or more) folds start at this line, so it cannot distinguish between
-- ( \n ( \n )) \n (( \n ) \n )
-- versus
-- ( \n ( \n ) \n ( \n ) \n )
2023-12-10 06:18:48 -07:00
-- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and
-- vim interprets as the second case.
-- If it did have such a mechanism, (clamped - clamped_prev)
2023-02-23 10:05:20 -07:00
-- would be the correct number of starts to pass on.
2023-12-10 06:18:48 -07:00
local adjusted = level0 ---@type integer
2023-02-23 10:05:20 -07:00
local prefix = ' '
2023-12-10 06:18:48 -07:00
if enter_line > 0 then
2023-02-23 10:05:20 -07:00
prefix = ' > '
2023-12-10 06:18:48 -07:00
if leave_line > 0 then
-- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line
-- so that f2 gets the correct level on this line. This may reduce the size of f1 below
-- foldminlines, but we don't handle it for simplicity.
adjusted = level0 - leave_line
leave_line = 0
end
end
-- Clamp at foldnestmax.
local clamped = adjusted
if adjusted > nestmax then
prefix = ' '
clamped = nestmax
2023-02-23 10:05:20 -07:00
end
2023-12-10 06:18:48 -07:00
-- Record the "real" level, so that it can be used as "base" of later get_folds_levels().
info.levels0 [ lnum ] = adjusted
info.levels [ lnum ] = prefix .. tostring ( clamped )
leave_prev = leave_line
level0_prev = adjusted
2023-03-09 08:28:55 -07:00
end
end
local M = { }
2023-03-23 04:23:51 -07:00
---@type table<integer,TS.FoldInfo>
2023-03-09 08:28:55 -07:00
local foldinfos = { }
2023-08-24 01:32:43 -07:00
local group = api.nvim_create_augroup ( ' treesitter/fold ' , { } )
2023-07-07 03:12:46 -07:00
--- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that
--- the user doesn't use different foldexpr for the same buffer).
---
--- Nvim usually automatically updates folds when text changes, but it doesn't work here because
--- FoldInfo update is scheduled. So we do it manually.
local function foldupdate ( bufnr )
local function do_update ( )
for _ , win in ipairs ( vim.fn . win_findbuf ( bufnr ) ) do
api.nvim_win_call ( win , function ( )
if vim.wo . foldmethod == ' expr ' then
vim._foldupdate ( )
end
end )
end
end
2023-03-09 08:28:55 -07:00
if api.nvim_get_mode ( ) . mode == ' i ' then
-- foldUpdate() is guarded in insert mode. So update folds on InsertLeave
2023-08-24 01:32:43 -07:00
if # ( api.nvim_get_autocmds ( {
group = group ,
buffer = bufnr ,
} ) ) > 0 then
return
end
2023-03-09 08:28:55 -07:00
api.nvim_create_autocmd ( ' InsertLeave ' , {
2023-08-24 01:32:43 -07:00
group = group ,
buffer = bufnr ,
2023-03-09 08:28:55 -07:00
once = true ,
2023-07-07 03:12:46 -07:00
callback = do_update ,
2023-03-09 08:28:55 -07:00
} )
return
2023-02-23 10:05:20 -07:00
end
2023-07-07 03:12:46 -07:00
do_update ( )
2023-03-09 08:28:55 -07:00
end
2023-07-07 03:12:46 -07:00
--- Schedule a function only if bufnr is loaded.
--- We schedule fold level computation for the following reasons:
--- * queries seem to use the old buffer state in on_bytes for some unknown reason;
--- * to avoid textlock;
--- * to avoid infinite recursion:
--- get_folds_levels → parse → _do_callback → on_changedtree → get_folds_levels.
2023-05-02 02:07:18 -07:00
---@param bufnr integer
---@param fn function
local function schedule_if_loaded ( bufnr , fn )
vim.schedule ( function ( )
if not api.nvim_buf_is_loaded ( bufnr ) then
return
end
fn ( )
end )
end
2023-03-09 08:28:55 -07:00
---@param bufnr integer
2023-03-23 04:23:51 -07:00
---@param foldinfo TS.FoldInfo
2023-03-09 08:28:55 -07:00
---@param tree_changes Range4[]
local function on_changedtree ( bufnr , foldinfo , tree_changes )
2023-05-02 02:07:18 -07:00
schedule_if_loaded ( bufnr , function ( )
2023-03-09 08:28:55 -07:00
for _ , change in ipairs ( tree_changes ) do
2023-12-08 23:21:08 -07:00
local srow , _ , erow , ecol = Range.unpack4 ( change )
if ecol > 0 then
erow = erow + 1
end
2023-12-10 06:18:48 -07:00
-- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
get_folds_levels ( bufnr , foldinfo , math.max ( srow - vim.wo . foldminlines , 0 ) , erow )
2023-03-09 08:28:55 -07:00
end
2023-08-19 22:17:45 -07:00
if # tree_changes > 0 then
foldupdate ( bufnr )
end
2023-03-09 08:28:55 -07:00
end )
end
---@param bufnr integer
2023-03-23 04:23:51 -07:00
---@param foldinfo TS.FoldInfo
2023-03-09 08:28:55 -07:00
---@param start_row integer
---@param old_row integer
2023-12-10 06:18:48 -07:00
---@param old_col integer
2023-03-09 08:28:55 -07:00
---@param new_row integer
2023-12-10 06:18:48 -07:00
---@param new_col integer
local function on_bytes ( bufnr , foldinfo , start_row , start_col , old_row , old_col , new_row , new_col )
2023-12-08 23:21:08 -07:00
-- extend the end to fully include the range
local end_row_old = start_row + old_row + 1
local end_row_new = start_row + new_row + 1
2023-03-13 03:44:43 -07:00
2023-06-27 11:05:09 -07:00
if new_row ~= old_row then
2023-12-10 06:18:48 -07:00
-- foldexpr can be evaluated before the scheduled callback is invoked. So it may observe the
-- outdated levels, which may spuriously open the folds that didn't change. So we should shift
-- folds as accurately as possible. For this to be perfectly accurate, we should track the
-- actual TSNodes that account for each fold, and compare the node's range with the edited
-- range. But for simplicity, we just check whether the start row is completely removed (e.g.,
-- `dd`) or shifted (e.g., `o`).
2023-06-27 11:05:09 -07:00
if new_row < old_row then
2023-12-10 06:18:48 -07:00
if start_col == 0 and new_row == 0 and new_col == 0 then
foldinfo : remove_range ( start_row , start_row + ( end_row_old - end_row_new ) )
else
foldinfo : remove_range ( end_row_new , end_row_old )
end
2023-06-27 11:05:09 -07:00
else
2023-12-10 06:18:48 -07:00
if start_col == 0 and old_row == 0 and old_col == 0 then
foldinfo : add_range ( start_row , start_row + ( end_row_new - end_row_old ) )
else
foldinfo : add_range ( end_row_old , end_row_new )
end
2023-06-27 11:05:09 -07:00
end
2023-12-10 06:18:48 -07:00
foldinfo : edit_range ( start_row , end_row_old , end_row_new )
-- This callback must not use on_bytes arguments, because they can be outdated when the callback
-- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing
-- the scheduled callback. So we should collect the edits.
2023-05-02 02:07:18 -07:00
schedule_if_loaded ( bufnr , function ( )
2023-12-10 06:18:48 -07:00
local srow , erow = foldinfo : flush_edit ( )
if not srow then
return
end
-- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
get_folds_levels ( bufnr , foldinfo , math.max ( srow - vim.wo . foldminlines , 0 ) , erow )
2023-07-07 03:12:46 -07:00
foldupdate ( bufnr )
2023-03-09 08:28:55 -07:00
end )
end
end
2023-02-23 10:05:20 -07:00
2023-03-23 04:23:51 -07:00
---@package
2023-02-23 10:05:20 -07:00
---@param lnum integer|nil
---@return string
function M . foldexpr ( lnum )
lnum = lnum or vim.v . lnum
local bufnr = api.nvim_get_current_buf ( )
2023-05-18 02:52:01 -07:00
local parser = vim.F . npcall ( ts.get_parser , bufnr )
if not parser then
2023-02-23 10:05:20 -07:00
return ' 0 '
end
2023-03-09 08:28:55 -07:00
if not foldinfos [ bufnr ] then
foldinfos [ bufnr ] = FoldInfo.new ( )
get_folds_levels ( bufnr , foldinfos [ bufnr ] )
parser : register_cbs ( {
on_changedtree = function ( tree_changes )
on_changedtree ( bufnr , foldinfos [ bufnr ] , tree_changes )
end ,
2023-12-10 06:18:48 -07:00
on_bytes = function ( _ , _ , start_row , start_col , _ , old_row , old_col , _ , new_row , new_col , _ )
on_bytes ( bufnr , foldinfos [ bufnr ] , start_row , start_col , old_row , old_col , new_row , new_col )
2023-03-09 08:28:55 -07:00
end ,
on_detach = function ( )
foldinfos [ bufnr ] = nil
end ,
} )
end
2023-02-23 10:05:20 -07:00
2023-03-09 08:28:55 -07:00
return foldinfos [ bufnr ] . levels [ lnum ] or ' 0 '
2023-02-23 10:05:20 -07:00
end
2023-12-10 06:18:48 -07:00
api.nvim_create_autocmd ( ' OptionSet ' , {
pattern = { ' foldminlines ' , ' foldnestmax ' } ,
desc = ' Refresh treesitter folds ' ,
callback = function ( )
for _ , bufnr in ipairs ( vim.tbl_keys ( foldinfos ) ) do
foldinfos [ bufnr ] = FoldInfo.new ( )
get_folds_levels ( bufnr , foldinfos [ bufnr ] )
foldupdate ( bufnr )
end
end ,
} )
2023-02-23 10:05:20 -07:00
return M