diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index 496193c6ed..17543bc787 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -5,14 +5,16 @@ local Range = require('vim.treesitter._range') ---@alias TSHlIter fun(end_line: integer|nil): integer, TSNode, TSMetadata ---@class TSHighlightState +---@field tstree TSTree ---@field next_row integer ---@field iter TSHlIter|nil +---@field highlighter_query TSHighlighterQuery ---@class TSHighlighter ---@field active table ---@field bufnr integer ---@field orig_spelloptions string ----@field _highlight_states table +---@field _highlight_states TSHighlightState[] ---@field _queries table ---@field tree LanguageTree ---@field redraw_count integer @@ -157,18 +159,47 @@ function TSHighlighter:destroy() end end ----@package ----@param tstree TSTree ----@return TSHighlightState -function TSHighlighter:get_highlight_state(tstree) - if not self._highlight_states[tstree] then - self._highlight_states[tstree] = { +---@param srow integer +---@param erow integer exclusive +---@private +function TSHighlighter:prepare_highlight_states(srow, erow) + self.tree:for_each_tree(function(tstree, tree) + if not tstree then + return + end + + local root_node = tstree:root() + local root_start_row, _, root_end_row, _ = root_node:range() + + -- Only worry about trees within the visible range + if root_start_row > erow or root_end_row < srow then + return + end + + local highlighter_query = self:get_query(tree:lang()) + + -- Some injected languages may not have highlight queries. + if not highlighter_query:query() then + return + end + + -- _highlight_states should be a list so that the highlights are added in the same order as + -- for_each_tree traversal. This ensures that parents' highlight don't override children's. + table.insert(self._highlight_states, { + tstree = tstree, next_row = 0, iter = nil, - } - end + highlighter_query = highlighter_query, + }) + end) +end - return self._highlight_states[tstree] +---@param fn fun(state: TSHighlightState) +---@package +function TSHighlighter:for_each_highlight_state(fn) + for _, state in ipairs(self._highlight_states) do + fn(state) + end end ---@private @@ -214,12 +245,8 @@ end ---@param line integer ---@param is_spell_nav boolean local function on_line_impl(self, buf, line, is_spell_nav) - self.tree:for_each_tree(function(tstree, tree) - if not tstree then - return - end - - local root_node = tstree:root() + self:for_each_highlight_state(function(state) + local root_node = state.tstree:root() local root_start_row, _, root_end_row, _ = root_node:range() -- Only worry about trees within the line range @@ -227,17 +254,9 @@ local function on_line_impl(self, buf, line, is_spell_nav) return end - local state = self:get_highlight_state(tstree) - local highlighter_query = self:get_query(tree:lang()) - - -- Some injected languages may not have highlight queries. - if not highlighter_query:query() then - return - end - if state.iter == nil or state.next_row < line then state.iter = - highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) + state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) end while line >= state.next_row do @@ -250,9 +269,9 @@ local function on_line_impl(self, buf, line, is_spell_nav) local start_row, start_col, end_row, end_col = Range.unpack4(range) if capture then - local hl = highlighter_query.hl_cache[capture] + local hl = state.highlighter_query.hl_cache[capture] - local capture_name = highlighter_query:query().captures[capture] + local capture_name = state.highlighter_query:query().captures[capture] local spell = nil ---@type boolean? if capture_name == 'spell' then spell = true @@ -327,6 +346,7 @@ function TSHighlighter._on_win(_, _win, buf, topline, botline) end self.tree:parse({ topline, botline + 1 }) self:reset_highlight_state() + self:prepare_highlight_states(topline, botline + 1) self.redraw_count = self.redraw_count + 1 return true end