refactor: split predicates and directives

This commit is contained in:
vanaigr 2024-12-18 01:06:41 -06:00
parent 67a1640767
commit 81a1b3a462
2 changed files with 85 additions and 52 deletions

View File

@ -299,6 +299,8 @@ local function on_line_impl(self, buf, line, is_spell_nav)
state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1)
end end
local captures = state.highlighter_query:query().captures
while line >= state.next_row do while line >= state.next_row do
local capture, node, metadata, match = state.iter(line) local capture, node, metadata, match = state.iter(line)
@ -311,7 +313,7 @@ local function on_line_impl(self, buf, line, is_spell_nav)
if capture then if capture then
local hl = state.highlighter_query:get_hl_from_capture(capture) local hl = state.highlighter_query:get_hl_from_capture(capture)
local capture_name = state.highlighter_query:query().captures[capture] local capture_name = captures[capture]
local spell, spell_pri_offset = get_spell(capture_name) local spell, spell_pri_offset = get_spell(capture_name)

View File

@ -4,6 +4,56 @@ local memoize = vim.func._memoize
local M = {} local M = {}
local function is_directive(name)
return string.sub(name, -1) == '!'
end
---@alias Pattern (integer|string)[]
---@nodoc
---@class ProcessedPredicate
---@field [1] string predicate name
---@field [2] boolean should match
---@field [3] Pattern
---@alias ProcessedDirective Pattern
---@private
---@param patterns table<integer, Pattern[]>
---@return table<integer, { preds: ProcessedPredicate[], directives: ProcessedDirective[] }>
local function process_patterns(patterns)
---@type table<integer, { preds: ProcessedPredicate[], directives: ProcessedDirective[] }>
local processed_patterns = {}
for k, pattern_list in pairs(patterns) do
---@type ProcessedPredicate[]
local predicates = {}
---@type ProcessedDirective[]
local directives = {}
for _, pattern in ipairs(pattern_list) do
-- Note: ree-sitter strips the leading # from predicates for us.
local pred_name = pattern[1]
---@cast pred_name string
if is_directive(pred_name) then
table.insert(directives, pattern)
else
local should_match = true
if pred_name:match('^not%-') then
pred_name = pred_name:sub(5)
should_match = false
end
table.insert(predicates, { pred_name, should_match, pattern })
end
end
processed_patterns[k] = { preds = predicates, directives = directives }
end
return processed_patterns
end
---@nodoc ---@nodoc
---Parsed query, see |vim.treesitter.query.parse()| ---Parsed query, see |vim.treesitter.query.parse()|
--- ---
@ -12,6 +62,7 @@ local M = {}
---@field captures string[] list of (unique) capture names defined in query ---@field captures string[] list of (unique) capture names defined in query
---@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives) ---@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives)
---@field query TSQuery userdata query object ---@field query TSQuery userdata query object
---@field processed_patterns table<integer, { preds: ProcessedPredicate[], directives: ProcessedDirective[] }>
local Query = {} local Query = {}
Query.__index = Query Query.__index = Query
@ -30,6 +81,7 @@ function Query.new(lang, ts_query)
patterns = query_info.patterns, patterns = query_info.patterns,
} }
self.captures = self.info.captures self.captures = self.info.captures
self.processed_patterns = process_patterns(self.info.patterns)
return self return self
end end
@ -740,67 +792,52 @@ function M.list_predicates()
return vim.tbl_keys(predicate_handlers) return vim.tbl_keys(predicate_handlers)
end end
local function xor(x, y)
return (x or y) and not (x and y)
end
local function is_directive(name)
return string.sub(name, -1) == '!'
end
---@private ---@private
---@param preds ProcessedPredicate[]
---@param match TSQueryMatch ---@param match TSQueryMatch
---@param source integer|string ---@param source integer|string
function Query:match_preds(preds, pattern, captures, source) function Query:match_preds(preds, pattern_i, captures, source)
for _, pred in pairs(preds) do for _, pred in ipairs(preds) do
-- Here we only want to return if a predicate DOES NOT match, and -- Here we only want to return if a predicate DOES NOT match, and
-- continue on the other case. This way unknown predicates will not be considered, -- continue on the other case. This way unknown predicates will not be considered,
-- which allows some testing and easier user extensibility (#12173). -- which allows some testing and easier user extensibility (#12173).
-- Also, tree-sitter strips the leading # from predicates for us.
local is_not = false
-- Skip over directives... they will get processed after all the predicates. local processed_name = pred[1]
if not is_directive(pred[1]) then local should_match = pred[2]
local pred_name = pred[1] local pattern = pred[3]
if pred_name:match('^not%-') then
pred_name = pred_name:sub(5)
is_not = true
end
local handler = predicate_handlers[pred_name] local handler = predicate_handlers[processed_name]
if not handler then if not handler then
error(string.format('No handler for %s', pred[1])) error(string.format('No handler for %s', pattern[1]))
return false return false
end end
local pred_matches = handler(captures, pattern, source, pred) local pred_matches = handler(captures, pattern_i, source, pattern)
if not xor(is_not, pred_matches) then if pred_matches ~= should_match then
return false return false
end end
end end
end
return true return true
end end
---@private ---@private
---@param directives ProcessedDirective[]
---@param match TSQueryMatch ---@param match TSQueryMatch
---@return vim.treesitter.query.TSMetadata metadata ---@return vim.treesitter.query.TSMetadata metadata
function Query:apply_directives(preds, pattern, captures, source) function Query:apply_directives(directives, pattern_i, captures, source)
---@type vim.treesitter.query.TSMetadata ---@type vim.treesitter.query.TSMetadata
local metadata = {} local metadata = {}
for _, pred in pairs(preds) do for _, directive in pairs(directives) do
if is_directive(pred[1]) then local handler = directive_handlers[directive[1]]
local handler = directive_handlers[pred[1]]
if not handler then if not handler then
error(string.format('No handler for %s', pred[1])) error(string.format('No handler for %s', directive[1]))
end end
handler(captures, pattern, source, pred, metadata) handler(captures, pattern_i, source, directive, metadata)
end
end end
return metadata return metadata
@ -824,12 +861,6 @@ local function value_or_node_range(start, stop, node)
return start, stop return start, stop
end end
--- @param match TSQueryMatch
--- @return integer
local function match_id_hash(_, match)
return (match:info())
end
--- Iterate over all captures from all matches inside {node} --- Iterate over all captures from all matches inside {node}
--- ---
--- {source} is needed if the query contains predicates; then the caller --- {source} is needed if the query contains predicates; then the caller
@ -892,11 +923,11 @@ function Query:iter_captures(node, source, start, stop)
end end
if not metadata then if not metadata then
local preds = self.info.patterns[pattern] local patterns = self.processed_patterns[pattern]
if preds then if patterns then
local captures = match:captures() local captures = match:captures()
if not self:match_preds(preds, pattern, captures, source) then if not self:match_preds(patterns.preds, pattern, captures, source) then
cursor:remove_match(match_id) cursor:remove_match(match_id)
if end_line and captured_node:range() > end_line then if end_line and captured_node:range() > end_line then
return nil, captured_node, nil, nil return nil, captured_node, nil, nil
@ -904,7 +935,7 @@ function Query:iter_captures(node, source, start, stop)
return iter(end_line) -- tail call: try next match return iter(end_line) -- tail call: try next match
end end
metadata = self:apply_directives(preds, pattern, captures, source) metadata = self:apply_directives(patterns.directives, pattern, captures, source)
else else
metadata = {} metadata = {}
end end
@ -975,17 +1006,17 @@ function Query:iter_matches(node, source, start, stop, opts)
end end
local match_id, pattern = match:info() local match_id, pattern = match:info()
local preds = self.info.patterns[pattern] local patterns = self.processed_patterns[pattern]
local captures = match:captures() local captures = match:captures()
--- @type vim.treesitter.query.TSMetadata --- @type vim.treesitter.query.TSMetadata
local metadata local metadata
if preds then if patterns then
if not self:match_preds(preds, pattern, captures, source) then if not self:match_preds(patterns.preds, pattern, captures, source) then
cursor:remove_match(match_id) cursor:remove_match(match_id)
return iter() -- tail call: try next match return iter() -- tail call: try next match
end end
metadata = self:apply_directives(preds, pattern, captures, source) metadata = self:apply_directives(patterns.directives, pattern, captures, source)
else else
metadata = {} metadata = {}
end end