From 81a1b3a462df0eb0becc8f128b04224cffd8434e Mon Sep 17 00:00:00 2001 From: vanaigr Date: Wed, 18 Dec 2024 01:06:41 -0600 Subject: [PATCH] refactor: split predicates and directives --- runtime/lua/vim/treesitter/highlighter.lua | 4 +- runtime/lua/vim/treesitter/query.lua | 133 +++++++++++++-------- 2 files changed, 85 insertions(+), 52 deletions(-) diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index 8ce8652f7d..96503c38ea 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -299,6 +299,8 @@ local function on_line_impl(self, buf, line, is_spell_nav) state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) end + local captures = state.highlighter_query:query().captures + while line >= state.next_row do local capture, node, metadata, match = state.iter(line) @@ -311,7 +313,7 @@ local function on_line_impl(self, buf, line, is_spell_nav) if capture then local hl = state.highlighter_query:get_hl_from_capture(capture) - local capture_name = state.highlighter_query:query().captures[capture] + local capture_name = captures[capture] local spell, spell_pri_offset = get_spell(capture_name) diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index ae22e9e385..3e930f3eae 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -4,6 +4,56 @@ local memoize = vim.func._memoize local M = {} +local function is_directive(name) + return string.sub(name, -1) == '!' +end + +---@alias Pattern (integer|string)[] + +---@nodoc +---@class ProcessedPredicate +---@field [1] string predicate name +---@field [2] boolean should match +---@field [3] Pattern + +---@alias ProcessedDirective Pattern + +---@private +---@param patterns table +---@return table +local function process_patterns(patterns) + ---@type table + local processed_patterns = {} + + for k, pattern_list in pairs(patterns) do + ---@type ProcessedPredicate[] + local predicates = {} + ---@type ProcessedDirective[] + local directives = {} + + for _, pattern in ipairs(pattern_list) do + -- Note: ree-sitter strips the leading # from predicates for us. + local pred_name = pattern[1] + ---@cast pred_name string + + if is_directive(pred_name) then + table.insert(directives, pattern) + else + local should_match = true + if pred_name:match('^not%-') then + pred_name = pred_name:sub(5) + should_match = false + end + table.insert(predicates, { pred_name, should_match, pattern }) + end + end + + processed_patterns[k] = { preds = predicates, directives = directives } + end + + return processed_patterns +end + ---@nodoc ---Parsed query, see |vim.treesitter.query.parse()| --- @@ -12,6 +62,7 @@ local M = {} ---@field captures string[] list of (unique) capture names defined in query ---@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives) ---@field query TSQuery userdata query object +---@field processed_patterns table local Query = {} Query.__index = Query @@ -30,6 +81,7 @@ function Query.new(lang, ts_query) patterns = query_info.patterns, } self.captures = self.info.captures + self.processed_patterns = process_patterns(self.info.patterns) return self end @@ -740,67 +792,52 @@ function M.list_predicates() return vim.tbl_keys(predicate_handlers) end -local function xor(x, y) - return (x or y) and not (x and y) -end - -local function is_directive(name) - return string.sub(name, -1) == '!' -end - ---@private +---@param preds ProcessedPredicate[] ---@param match TSQueryMatch ---@param source integer|string -function Query:match_preds(preds, pattern, captures, source) - for _, pred in pairs(preds) do +function Query:match_preds(preds, pattern_i, captures, source) + for _, pred in ipairs(preds) do -- Here we only want to return if a predicate DOES NOT match, and -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). - -- Also, tree-sitter strips the leading # from predicates for us. - local is_not = false - -- Skip over directives... they will get processed after all the predicates. - if not is_directive(pred[1]) then - local pred_name = pred[1] - if pred_name:match('^not%-') then - pred_name = pred_name:sub(5) - is_not = true - end + local processed_name = pred[1] + local should_match = pred[2] + local pattern = pred[3] - local handler = predicate_handlers[pred_name] + local handler = predicate_handlers[processed_name] - if not handler then - error(string.format('No handler for %s', pred[1])) - return false - end + if not handler then + error(string.format('No handler for %s', pattern[1])) + return false + end - local pred_matches = handler(captures, pattern, source, pred) + local pred_matches = handler(captures, pattern_i, source, pattern) - if not xor(is_not, pred_matches) then - return false - end + if pred_matches ~= should_match then + return false end end return true end ---@private +---@param directives ProcessedDirective[] ---@param match TSQueryMatch ---@return vim.treesitter.query.TSMetadata metadata -function Query:apply_directives(preds, pattern, captures, source) +function Query:apply_directives(directives, pattern_i, captures, source) ---@type vim.treesitter.query.TSMetadata local metadata = {} - for _, pred in pairs(preds) do - if is_directive(pred[1]) then - local handler = directive_handlers[pred[1]] + for _, directive in pairs(directives) do + local handler = directive_handlers[directive[1]] - if not handler then - error(string.format('No handler for %s', pred[1])) - end - - handler(captures, pattern, source, pred, metadata) + if not handler then + error(string.format('No handler for %s', directive[1])) end + + handler(captures, pattern_i, source, directive, metadata) end return metadata @@ -824,12 +861,6 @@ local function value_or_node_range(start, stop, node) return start, stop end ---- @param match TSQueryMatch ---- @return integer -local function match_id_hash(_, match) - return (match:info()) -end - --- Iterate over all captures from all matches inside {node} --- --- {source} is needed if the query contains predicates; then the caller @@ -892,11 +923,11 @@ function Query:iter_captures(node, source, start, stop) end if not metadata then - local preds = self.info.patterns[pattern] - if preds then + local patterns = self.processed_patterns[pattern] + if patterns then local captures = match:captures() - if not self:match_preds(preds, pattern, captures, source) then + if not self:match_preds(patterns.preds, pattern, captures, source) then cursor:remove_match(match_id) if end_line and captured_node:range() > end_line then return nil, captured_node, nil, nil @@ -904,7 +935,7 @@ function Query:iter_captures(node, source, start, stop) return iter(end_line) -- tail call: try next match end - metadata = self:apply_directives(preds, pattern, captures, source) + metadata = self:apply_directives(patterns.directives, pattern, captures, source) else metadata = {} end @@ -975,17 +1006,17 @@ function Query:iter_matches(node, source, start, stop, opts) end local match_id, pattern = match:info() - local preds = self.info.patterns[pattern] + local patterns = self.processed_patterns[pattern] local captures = match:captures() --- @type vim.treesitter.query.TSMetadata local metadata - if preds then - if not self:match_preds(preds, pattern, captures, source) then + if patterns then + if not self:match_preds(patterns.preds, pattern, captures, source) then cursor:remove_match(match_id) return iter() -- tail call: try next match end - metadata = self:apply_directives(preds, pattern, captures, source) + metadata = self:apply_directives(patterns.directives, pattern, captures, source) else metadata = {} end