fix(treesitter): correctly handle query quantifiers (#24738)

Query patterns can contain quantifiers (e.g. (foo)+ @bar), so a single
capture can map to multiple nodes. The iter_matches API can not handle
this situation because the match table incorrectly maps capture indices
to a single node instead of to an array of nodes.

The match table should be updated to map capture indices to an array of
nodes. However, this is a massively breaking change, so must be done
with a proper deprecation period.

`iter_matches`, `add_predicate` and `add_directive` must opt-in to the
correct behavior for backward compatibility. This is done with a new
"all" option. This option will become the default and removed after the
0.10 release.

Co-authored-by: Christian Clason <c.clason@uni-graz.at>
Co-authored-by: MDeiml <matthias@deiml.net>
Co-authored-by: Gregory Anders <greg@gpanders.com>
This commit is contained in:
Thomas Vigouroux 2024-02-16 18:54:47 +01:00 committed by GitHub
parent 1ba3500abd
commit bd5008de07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 799 additions and 231 deletions

View File

@ -427,6 +427,18 @@ The following changes to existing APIs or features add new behavior.
• |nvim_buf_call()| and |nvim_win_call()| now preserves any return value (NB: not multiple return values) • |nvim_buf_call()| and |nvim_win_call()| now preserves any return value (NB: not multiple return values)
• Treesitter
• |Query:iter_matches()|, |vim.treesitter.query.add_predicate()|, and
|vim.treesitter.query.add_directive()| accept a new `all` option which
ensures that all matching nodes are returned as a table. The default option
`all=false` returns only a single node, breaking captures with quantifiers
like `(comment)+ @comment; it is only provided for backward compatibility
and will be removed after Nvim 0.10.
• |vim.treesitter.query.add_predicate()| and
|vim.treesitter.query.add_directive()| now accept an options table rather
than a boolean "force" argument. To force a predicate or directive to
override an existing predicate or directive, use `{ force = true }`.
============================================================================== ==============================================================================
REMOVED FEATURES *news-removed* REMOVED FEATURES *news-removed*
@ -480,7 +492,7 @@ release.
• `vim.loop` has been renamed to |vim.uv|. • `vim.loop` has been renamed to |vim.uv|.
• vim.treesitter.languagetree functions: • vim.treesitter functions:
- |LanguageTree:for_each_child()| Use |LanguageTree:children()| (non-recursive) instead. - |LanguageTree:for_each_child()| Use |LanguageTree:children()| (non-recursive) instead.
• The "term_background" UI option |ui-ext-options| is deprecated and no longer • The "term_background" UI option |ui-ext-options| is deprecated and no longer

View File

@ -223,6 +223,10 @@ The following predicates are built in:
((identifier) @variable.builtin (#eq? @variable.builtin "self")) ((identifier) @variable.builtin (#eq? @variable.builtin "self"))
((node1) @left (node2) @right (#eq? @left @right)) ((node1) @left (node2) @right (#eq? @left @right))
< <
`any-eq?` *treesitter-predicate-any-eq?*
Like `eq?`, but for quantified patterns only one captured node must
match.
`match?` *treesitter-predicate-match?* `match?` *treesitter-predicate-match?*
`vim-match?` *treesitter-predicate-vim-match?* `vim-match?` *treesitter-predicate-vim-match?*
Match a |regexp| against the text corresponding to a node: >query Match a |regexp| against the text corresponding to a node: >query
@ -231,15 +235,28 @@ The following predicates are built in:
Note: The `^` and `$` anchors will match the start and end of the Note: The `^` and `$` anchors will match the start and end of the
node's text. node's text.
`any-match?` *treesitter-predicate-any-match?*
`any-vim-match?` *treesitter-predicate-any-vim-match?*
Like `match?`, but for quantified patterns only one captured node must
match.
`lua-match?` *treesitter-predicate-lua-match?* `lua-match?` *treesitter-predicate-lua-match?*
Match |lua-patterns| against the text corresponding to a node, Match |lua-patterns| against the text corresponding to a node,
similar to `match?` similar to `match?`
`any-lua-match?` *treesitter-predicate-any-lua-match?*
Like `lua-match?`, but for quantified patterns only one captured node
must match.
`contains?` *treesitter-predicate-contains?* `contains?` *treesitter-predicate-contains?*
Match a string against parts of the text corresponding to a node: >query Match a string against parts of the text corresponding to a node: >query
((identifier) @foo (#contains? @foo "foo")) ((identifier) @foo (#contains? @foo "foo"))
((identifier) @foo-bar (#contains? @foo-bar "foo" "bar")) ((identifier) @foo-bar (#contains? @foo-bar "foo" "bar"))
< <
`any-contains?` *treesitter-predicate-any-contains?*
Like `contains?`, but for quantified patterns only one captured node
must match.
`any-of?` *treesitter-predicate-any-of?* `any-of?` *treesitter-predicate-any-of?*
Match any of the given strings against the text corresponding to Match any of the given strings against the text corresponding to
a node: >query a node: >query
@ -265,6 +282,32 @@ The following predicates are built in:
Each predicate has a `not-` prefixed predicate that is just the negation of Each predicate has a `not-` prefixed predicate that is just the negation of
the predicate. the predicate.
*lua-treesitter-all-predicate*
*lua-treesitter-any-predicate*
Queries can use quantifiers to capture multiple nodes. When a capture contains
multiple nodes, predicates match only if ALL nodes contained by the capture
match the predicate. Some predicates (`eq?`, `match?`, `lua-match?`,
`contains?`) accept an `any-` prefix to instead match if ANY of the nodes
contained by the capture match the predicate.
As an example, consider the following Lua code: >lua
-- TODO: This is a
-- very long
-- comment (just imagine it)
<
using the following predicated query:
>query
(((comment)+ @comment)
(#match? @comment "TODO"))
<
This query will not match because not all of the nodes captured by @comment
match the predicate. Instead, use:
>query
(((comment)+ @comment)
(#any-match? @comment "TODO"))
<
Further predicates can be added via |vim.treesitter.query.add_predicate()|. Further predicates can be added via |vim.treesitter.query.add_predicate()|.
Use |vim.treesitter.query.list_predicates()| to list all available predicates. Use |vim.treesitter.query.list_predicates()| to list all available predicates.
@ -923,28 +966,35 @@ register({lang}, {filetype}) *vim.treesitter.language.register()*
Lua module: vim.treesitter.query *lua-treesitter-query* Lua module: vim.treesitter.query *lua-treesitter-query*
*vim.treesitter.query.add_directive()* *vim.treesitter.query.add_directive()*
add_directive({name}, {handler}, {force}) add_directive({name}, {handler}, {opts})
Adds a new directive to be used in queries Adds a new directive to be used in queries
Handlers can set match level data by setting directly on the metadata Handlers can set match level data by setting directly on the metadata
object `metadata.key = value`, additionally, handlers can set node level object `metadata.key = value`. Additionally, handlers can set node level
data by using the capture id on the metadata table data by using the capture id on the metadata table
`metadata[capture_id].key = value` `metadata[capture_id].key = value`
Parameters: ~ Parameters: ~
• {name} (`string`) Name of the directive, without leading # • {name} (`string`) Name of the directive, without leading #
• {handler} (`function`) • {handler} (`function`)
• match: see |treesitter-query| • match: A table mapping capture IDs to a list of captured
node-level data are accessible via `match[capture_id]` nodes
• pattern: the index of the matching pattern in the query
• pattern: see |treesitter-query| file
• predicate: list of strings containing the full directive • predicate: list of strings containing the full directive
being called, e.g. `(node (#set! conceal "-"))` would get being called, e.g. `(node (#set! conceal "-"))` would get
the predicate `{ "#set!", "conceal", "-" }` the predicate `{ "#set!", "conceal", "-" }`
• {force} (`boolean?`) • {opts} (`table<string, any>`) Optional options:
• force (boolean): Override an existing predicate of the
same name
• all (boolean): Use the correct implementation of the
match table where capture IDs map to a list of nodes
instead of a single node. Defaults to false (for backward
compatibility). This option will eventually become the
default and removed.
*vim.treesitter.query.add_predicate()* *vim.treesitter.query.add_predicate()*
add_predicate({name}, {handler}, {force}) add_predicate({name}, {handler}, {opts})
Adds a new predicate to be used in queries Adds a new predicate to be used in queries
Parameters: ~ Parameters: ~
@ -952,7 +1002,14 @@ add_predicate({name}, {handler}, {force})
• {handler} (`function`) • {handler} (`function`)
• see |vim.treesitter.query.add_directive()| for argument • see |vim.treesitter.query.add_directive()| for argument
meanings meanings
• {force} (`boolean?`) • {opts} (`table<string, any>`) Optional options:
• force (boolean): Override an existing predicate of the
same name
• all (boolean): Use the correct implementation of the
match table where capture IDs map to a list of nodes
instead of a single node. Defaults to false (for backward
compatibility). This option will eventually become the
default and removed.
edit({lang}) *vim.treesitter.query.edit()* edit({lang}) *vim.treesitter.query.edit()*
Opens a live editor to query the buffer you started from. Opens a live editor to query the buffer you started from.
@ -1102,18 +1159,25 @@ Query:iter_matches({node}, {source}, {start}, {stop}, {opts})
Iterate over all matches within a {node}. The arguments are the same as Iterate over all matches within a {node}. The arguments are the same as
for |Query:iter_captures()| but the iterated values are different: an for |Query:iter_captures()| but the iterated values are different: an
(1-based) index of the pattern in the query, a table mapping capture (1-based) index of the pattern in the query, a table mapping capture
indices to nodes, and metadata from any directives processing the match. indices to a list of nodes, and metadata from any directives processing
If the query has more than one pattern, the capture table might be sparse the match.
and e.g. `pairs()` method should be used over `ipairs`. Here is an example
iterating over all captures in every match: >lua WARNING: Set `all=true` to ensure all matching nodes in a match are
for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do returned, otherwise only the last node in a match is returned, breaking
for id, node in pairs(match) do captures involving quantifiers such as `(comment)+ @comment`. The default
option `all=false` is only provided for backward compatibility and will be
removed after Nvim 0.10.
Example: >lua
for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1, { all = true }) do
for id, nodes in pairs(match) do
local name = query.captures[id] local name = query.captures[id]
-- `node` was captured by the `name` capture in the match for _, node in ipairs(nodes) do
-- `node` was captured by the `name` capture in the match
local node_data = metadata[id] -- Node level metadata local node_data = metadata[id] -- Node level metadata
... use the info here ...
-- ... use the info here ... end
end end
end end
< <
@ -1129,9 +1193,14 @@ Query:iter_matches({node}, {source}, {start}, {stop}, {opts})
• max_start_depth (integer) if non-zero, sets the maximum • max_start_depth (integer) if non-zero, sets the maximum
start depth for each match. This is used to prevent start depth for each match. This is used to prevent
traversing too deep into a tree. traversing too deep into a tree.
• all (boolean) When set, the returned match table maps
capture IDs to a list of nodes. Older versions of
iter_matches incorrectly mapped capture IDs to a single
node, which is incorrect behavior. This option will
eventually become the default and removed.
Return: ~ Return: ~
(`fun(): integer, table<integer,TSNode>, table`) pattern id, match, (`fun(): integer, table<integer, TSNode[]>, table`) pattern id, match,
metadata metadata
set({lang}, {query_name}, {text}) *vim.treesitter.query.set()* set({lang}, {query_name}, {text}) *vim.treesitter.query.set()*

View File

@ -39,7 +39,7 @@ local TSNode = {}
---@param start? integer ---@param start? integer
---@param end_? integer ---@param end_? integer
---@param opts? table ---@param opts? table
---@return fun(): integer, TSNode, any ---@return fun(): integer, TSNode, TSMatch
function TSNode:_rawquery(query, captures, start, end_, opts) end function TSNode:_rawquery(query, captures, start, end_, opts) end
---@param query TSQuery ---@param query TSQuery
@ -47,7 +47,7 @@ function TSNode:_rawquery(query, captures, start, end_, opts) end
---@param start? integer ---@param start? integer
---@param end_? integer ---@param end_? integer
---@param opts? table ---@param opts? table
---@return fun(): integer, any ---@return fun(): integer, TSMatch
function TSNode:_rawquery(query, captures, start, end_, opts) end function TSNode:_rawquery(query, captures, start, end_, opts) end
---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string) ---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string)

View File

@ -122,7 +122,7 @@ local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
end) end)
--- @param buf integer --- @param buf integer
--- @param match table<integer,TSNode> --- @param match table<integer,TSNode[]>
--- @param query Query --- @param query Query
--- @param lang_context QueryLinterLanguageContext --- @param lang_context QueryLinterLanguageContext
--- @param diagnostics Diagnostic[] --- @param diagnostics Diagnostic[]
@ -130,20 +130,22 @@ local function lint_match(buf, match, query, lang_context, diagnostics)
local lang = lang_context.lang local lang = lang_context.lang
local parser_info = lang_context.parser_info local parser_info = lang_context.parser_info
for id, node in pairs(match) do for id, nodes in pairs(match) do
local cap_id = query.captures[id] for _, node in ipairs(nodes) do
local cap_id = query.captures[id]
-- perform language-independent checks only for first lang -- perform language-independent checks only for first lang
if lang_context.is_first_lang and cap_id == 'error' then if lang_context.is_first_lang and cap_id == 'error' then
local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text) add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text)
end end
-- other checks rely on Neovim parser introspection -- other checks rely on Neovim parser introspection
if lang and parser_info and cap_id == 'toplevel' then if lang and parser_info and cap_id == 'toplevel' then
local err = parse(node, buf, lang) local err = parse(node, buf, lang)
if err then if err then
add_lint_for_node(diagnostics, err.range, err.msg, lang) add_lint_for_node(diagnostics, err.range, err.msg, lang)
end
end end
end end
end end

View File

@ -784,7 +784,7 @@ end
---@private ---@private
--- Extract injections according to: --- Extract injections according to:
--- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection --- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection
---@param match table<integer,TSNode> ---@param match table<integer,TSNode[]>
---@param metadata TSMetadata ---@param metadata TSMetadata
---@return string?, boolean, Range6[] ---@return string?, boolean, Range6[]
function LanguageTree:_get_injection(match, metadata) function LanguageTree:_get_injection(match, metadata)
@ -796,14 +796,16 @@ function LanguageTree:_get_injection(match, metadata)
or (injection_lang and resolve_lang(injection_lang)) or (injection_lang and resolve_lang(injection_lang))
local include_children = metadata['injection.include-children'] ~= nil local include_children = metadata['injection.include-children'] ~= nil
for id, node in pairs(match) do for id, nodes in pairs(match) do
local name = self._injection_query.captures[id] for _, node in ipairs(nodes) do
-- Lang should override any other language tag local name = self._injection_query.captures[id]
if name == 'injection.language' then -- Lang should override any other language tag
local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] }) if name == 'injection.language' then
lang = resolve_lang(text) local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'injection.content' then lang = resolve_lang(text)
ranges = get_node_ranges(node, self._source, metadata[id], include_children) elseif name == 'injection.content' then
ranges = get_node_ranges(node, self._source, metadata[id], include_children)
end
end end
end end
@ -844,7 +846,13 @@ function LanguageTree:_get_injections()
local start_line, _, end_line, _ = root_node:range() local start_line, _, end_line, _ = root_node:range()
for pattern, match, metadata in for pattern, match, metadata in
self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1) self._injection_query:iter_matches(
root_node,
self._source,
start_line,
end_line + 1,
{ all = true }
)
do do
local lang, combined, ranges = self:_get_injection(match, metadata) local lang, combined, ranges = self:_get_injection(match, metadata)
if lang then if lang then

View File

@ -290,47 +290,71 @@ function M.get_node_text(...)
return vim.treesitter.get_node_text(...) return vim.treesitter.get_node_text(...)
end end
---@alias TSMatch table<integer,TSNode> --- Implementations of predicates that can optionally be prefixed with "any-".
---
---@alias TSPredicate fun(match: TSMatch, _, _, predicate: any[]): boolean --- These functions contain the implementations for each predicate, correctly
--- handling the "any" vs "all" semantics. They are called from the
-- Predicate handler receive the following arguments --- predicate_handlers table with the appropriate arguments for each predicate.
-- (match, pattern, bufnr, predicate) local impl = {
---@type table<string,TSPredicate> --- @param match TSMatch
local predicate_handlers = { --- @param source integer|string
['eq?'] = function(match, _, source, predicate) --- @param predicate any[]
local node = match[predicate[2]] --- @param any boolean
if not node then ['eq'] = function(match, source, predicate, any)
local nodes = match[predicate[2]]
if not nodes or #nodes == 0 then
return true return true
end end
local node_text = vim.treesitter.get_node_text(node, source)
local str ---@type string for _, node in ipairs(nodes) do
if type(predicate[3]) == 'string' then local node_text = vim.treesitter.get_node_text(node, source)
-- (#eq? @aa "foo")
str = predicate[3] local str ---@type string
else if type(predicate[3]) == 'string' then
-- (#eq? @aa @bb) -- (#eq? @aa "foo")
str = vim.treesitter.get_node_text(match[predicate[3]], source) str = predicate[3]
else
-- (#eq? @aa @bb)
local other = assert(match[predicate[3]])
assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes')
str = vim.treesitter.get_node_text(other[1], source)
end
local res = str ~= nil and node_text == str
if any and res then
return true
elseif not any and not res then
return false
end
end end
if node_text ~= str or str == nil then return not any
return false
end
return true
end, end,
['lua-match?'] = function(match, _, source, predicate) --- @param match TSMatch
local node = match[predicate[2]] --- @param source integer|string
if not node then --- @param predicate any[]
--- @param any boolean
['lua-match'] = function(match, source, predicate, any)
local nodes = match[predicate[2]]
if not nodes or #nodes == 0 then
return true return true
end end
local regex = predicate[3]
return string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil for _, node in ipairs(nodes) do
local regex = predicate[3]
local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil
if any and res then
return true
elseif not any and not res then
return false
end
end
return not any
end, end,
['match?'] = (function() ['match'] = (function()
local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true } local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true }
local function check_magic(str) local function check_magic(str)
if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then
@ -347,27 +371,120 @@ local predicate_handlers = {
end, end,
}) })
return function(match, _, source, pred) --- @param match TSMatch
---@cast match TSMatch --- @param source integer|string
local node = match[pred[2]] --- @param predicate any[]
if not node then --- @param any boolean
return function(match, source, predicate, any)
local nodes = match[predicate[2]]
if not nodes or #nodes == 0 then
return true return true
end end
---@diagnostic disable-next-line no-unknown
local regex = compiled_vim_regexes[pred[3]] for _, node in ipairs(nodes) do
return regex:match_str(vim.treesitter.get_node_text(node, source)) local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex
local res = regex:match_str(vim.treesitter.get_node_text(node, source))
if any and res then
return true
elseif not any and not res then
return false
end
end
return not any
end end
end)(), end)(),
['contains?'] = function(match, _, source, predicate) --- @param match TSMatch
local node = match[predicate[2]] --- @param source integer|string
if not node then --- @param predicate any[]
--- @param any boolean
['contains'] = function(match, source, predicate, any)
local nodes = match[predicate[2]]
if not nodes or #nodes == 0 then
return true return true
end end
local node_text = vim.treesitter.get_node_text(node, source)
for i = 3, #predicate do for _, node in ipairs(nodes) do
if string.find(node_text, predicate[i], 1, true) then local node_text = vim.treesitter.get_node_text(node, source)
for i = 3, #predicate do
local res = string.find(node_text, predicate[i], 1, true)
if any and res then
return true
elseif not any and not res then
return false
end
end
end
return not any
end,
}
---@class TSMatch
---@field pattern? integer
---@field active? boolean
---@field [integer] TSNode[]
---@alias TSPredicate fun(match: TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean
-- Predicate handler receive the following arguments
-- (match, pattern, bufnr, predicate)
---@type table<string,TSPredicate>
local predicate_handlers = {
['eq?'] = function(match, _, source, predicate)
return impl['eq'](match, source, predicate, false)
end,
['any-eq?'] = function(match, _, source, predicate)
return impl['eq'](match, source, predicate, true)
end,
['lua-match?'] = function(match, _, source, predicate)
return impl['lua-match'](match, source, predicate, false)
end,
['any-lua-match?'] = function(match, _, source, predicate)
return impl['lua-match'](match, source, predicate, true)
end,
['match?'] = function(match, _, source, predicate)
return impl['match'](match, source, predicate, false)
end,
['any-match?'] = function(match, _, source, predicate)
return impl['match'](match, source, predicate, true)
end,
['contains?'] = function(match, _, source, predicate)
return impl['contains'](match, source, predicate, false)
end,
['any-contains?'] = function(match, _, source, predicate)
return impl['contains'](match, source, predicate, true)
end,
['any-of?'] = function(match, _, source, predicate)
local nodes = match[predicate[2]]
if not nodes or #nodes == 0 then
return true
end
for _, node in ipairs(nodes) do
local node_text = vim.treesitter.get_node_text(node, source)
-- Since 'predicate' will not be used by callers of this function, use it
-- to store a string set built from the list of words to check against.
local string_set = predicate['string_set'] --- @type table<string, boolean>
if not string_set then
string_set = {}
for i = 3, #predicate do
string_set[predicate[i]] = true
end
predicate['string_set'] = string_set
end
if string_set[node_text] then
return true return true
end end
end end
@ -375,57 +492,39 @@ local predicate_handlers = {
return false return false
end, end,
['any-of?'] = function(match, _, source, predicate)
local node = match[predicate[2]]
if not node then
return true
end
local node_text = vim.treesitter.get_node_text(node, source)
-- Since 'predicate' will not be used by callers of this function, use it
-- to store a string set built from the list of words to check against.
local string_set = predicate['string_set']
if not string_set then
string_set = {}
for i = 3, #predicate do
---@diagnostic disable-next-line:no-unknown
string_set[predicate[i]] = true
end
predicate['string_set'] = string_set
end
return string_set[node_text]
end,
['has-ancestor?'] = function(match, _, _, predicate) ['has-ancestor?'] = function(match, _, _, predicate)
local node = match[predicate[2]] local nodes = match[predicate[2]]
if not node then if not nodes or #nodes == 0 then
return true return true
end end
local ancestor_types = {} for _, node in ipairs(nodes) do
for _, type in ipairs({ unpack(predicate, 3) }) do local ancestor_types = {} --- @type table<string, boolean>
ancestor_types[type] = true for _, type in ipairs({ unpack(predicate, 3) }) do
end ancestor_types[type] = true
end
node = node:parent()
while node do local cur = node:parent()
if ancestor_types[node:type()] then while cur do
return true if ancestor_types[cur:type()] then
return true
end
cur = cur:parent()
end end
node = node:parent()
end end
return false return false
end, end,
['has-parent?'] = function(match, _, _, predicate) ['has-parent?'] = function(match, _, _, predicate)
local node = match[predicate[2]] local nodes = match[predicate[2]]
if not node then if not nodes or #nodes == 0 then
return true return true
end end
if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then for _, node in ipairs(nodes) do
return true if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then
return true
end
end end
return false return false
end, end,
@ -433,6 +532,7 @@ local predicate_handlers = {
-- As we provide lua-match? also expose vim-match? -- As we provide lua-match? also expose vim-match?
predicate_handlers['vim-match?'] = predicate_handlers['match?'] predicate_handlers['vim-match?'] = predicate_handlers['match?']
predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?']
---@class TSMetadata ---@class TSMetadata
---@field range? Range ---@field range? Range
@ -468,13 +568,17 @@ local directive_handlers = {
-- Shifts the range of a node. -- Shifts the range of a node.
-- Example: (#offset! @_node 0 1 0 -1) -- Example: (#offset! @_node 0 1 0 -1)
['offset!'] = function(match, _, _, pred, metadata) ['offset!'] = function(match, _, _, pred, metadata)
---@cast pred integer[] local capture_id = pred[2] --[[@as integer]]
local capture_id = pred[2] local nodes = match[capture_id]
assert(#nodes == 1, '#offset! does not support captures on multiple nodes')
local node = nodes[1]
if not metadata[capture_id] then if not metadata[capture_id] then
metadata[capture_id] = {} metadata[capture_id] = {}
end end
local range = metadata[capture_id].range or { match[capture_id]:range() } local range = metadata[capture_id].range or { node:range() }
local start_row_offset = pred[3] or 0 local start_row_offset = pred[3] or 0
local start_col_offset = pred[4] or 0 local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0 local end_row_offset = pred[5] or 0
@ -498,7 +602,9 @@ local directive_handlers = {
local id = pred[2] local id = pred[2]
assert(type(id) == 'number') assert(type(id) == 'number')
local node = match[id] local nodes = match[id]
assert(#nodes == 1, '#gsub! does not support captures on multiple nodes')
local node = nodes[1]
local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
if not metadata[id] then if not metadata[id] then
@ -518,10 +624,9 @@ local directive_handlers = {
local capture_id = pred[2] local capture_id = pred[2]
assert(type(capture_id) == 'number') assert(type(capture_id) == 'number')
local node = match[capture_id] local nodes = match[capture_id]
if not node then assert(#nodes == 1, '#trim! does not support captures on multiple nodes')
return local node = nodes[1]
end
local start_row, start_col, end_row, end_col = node:range() local start_row, start_col, end_row, end_col = node:range()
@ -552,38 +657,93 @@ local directive_handlers = {
--- Adds a new predicate to be used in queries --- Adds a new predicate to be used in queries
--- ---
---@param name string Name of the predicate, without leading # ---@param name string Name of the predicate, without leading #
---@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[]) ---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table)
--- - see |vim.treesitter.query.add_directive()| for argument meanings --- - see |vim.treesitter.query.add_directive()| for argument meanings
---@param force boolean|nil ---@param opts table<string, any> Optional options:
function M.add_predicate(name, handler, force) --- - force (boolean): Override an existing
if predicate_handlers[name] and not force then --- predicate of the same name
error(string.format('Overriding %s', name)) --- - all (boolean): Use the correct
--- implementation of the match table where
--- capture IDs map to a list of nodes instead
--- of a single node. Defaults to false (for
--- backward compatibility). This option will
--- eventually become the default and removed.
function M.add_predicate(name, handler, opts)
-- Backward compatibility: old signature had "force" as boolean argument
if type(opts) == 'boolean' then
opts = { force = opts }
end end
predicate_handlers[name] = handler opts = opts or {}
if predicate_handlers[name] and not opts.force then
error(string.format('Overriding existing predicate %s', name))
end
if opts.all then
predicate_handlers[name] = handler
else
--- @param match table<integer, TSNode[]>
local function wrapper(match, ...)
local m = {} ---@type table<integer, TSNode>
for k, v in pairs(match) do
if type(k) == 'number' then
m[k] = v[#v]
end
end
return handler(m, ...)
end
predicate_handlers[name] = wrapper
end
end end
--- Adds a new directive to be used in queries --- Adds a new directive to be used in queries
--- ---
--- Handlers can set match level data by setting directly on the --- Handlers can set match level data by setting directly on the
--- metadata object `metadata.key = value`, additionally, handlers --- metadata object `metadata.key = value`. Additionally, handlers
--- can set node level data by using the capture id on the --- can set node level data by using the capture id on the
--- metadata table `metadata[capture_id].key = value` --- metadata table `metadata[capture_id].key = value`
--- ---
---@param name string Name of the directive, without leading # ---@param name string Name of the directive, without leading #
---@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[], metadata:table) ---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table)
--- - match: see |treesitter-query| --- - match: A table mapping capture IDs to a list of captured nodes
--- - node-level data are accessible via `match[capture_id]` --- - pattern: the index of the matching pattern in the query file
--- - pattern: see |treesitter-query|
--- - predicate: list of strings containing the full directive being called, e.g. --- - predicate: list of strings containing the full directive being called, e.g.
--- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }`
---@param force boolean|nil ---@param opts table<string, any> Optional options:
function M.add_directive(name, handler, force) --- - force (boolean): Override an existing
if directive_handlers[name] and not force then --- predicate of the same name
error(string.format('Overriding %s', name)) --- - all (boolean): Use the correct
--- implementation of the match table where
--- capture IDs map to a list of nodes instead
--- of a single node. Defaults to false (for
--- backward compatibility). This option will
--- eventually become the default and removed.
function M.add_directive(name, handler, opts)
-- Backward compatibility: old signature had "force" as boolean argument
if type(opts) == 'boolean' then
opts = { force = opts }
end end
directive_handlers[name] = handler opts = opts or {}
if directive_handlers[name] and not opts.force then
error(string.format('Overriding existing directive %s', name))
end
if opts.all then
directive_handlers[name] = handler
else
--- @param match table<integer, TSNode[]>
local function wrapper(match, ...)
local m = {} ---@type table<integer, TSNode>
for k, v in pairs(match) do
m[k] = v[#v]
end
handler(m, ...)
end
directive_handlers[name] = wrapper
end
end end
--- Lists the currently available directives to use in queries. --- Lists the currently available directives to use in queries.
@ -608,7 +768,7 @@ end
---@private ---@private
---@param match TSMatch ---@param match TSMatch
---@param pattern string ---@param pattern integer
---@param source integer|string ---@param source integer|string
function Query:match_preds(match, pattern, source) function Query:match_preds(match, pattern, source)
local preds = self.info.patterns[pattern] local preds = self.info.patterns[pattern]
@ -618,18 +778,14 @@ function Query:match_preds(match, pattern, source)
-- continue on the other case. This way unknown predicates will not be considered, -- continue on the other case. This way unknown predicates will not be considered,
-- which allows some testing and easier user extensibility (#12173). -- which allows some testing and easier user extensibility (#12173).
-- Also, tree-sitter strips the leading # from predicates for us. -- Also, tree-sitter strips the leading # from predicates for us.
local pred_name ---@type string local is_not = false
local is_not ---@type boolean
-- Skip over directives... they will get processed after all the predicates. -- Skip over directives... they will get processed after all the predicates.
if not is_directive(pred[1]) then if not is_directive(pred[1]) then
if string.sub(pred[1], 1, 4) == 'not-' then local pred_name = pred[1]
pred_name = string.sub(pred[1], 5) if pred_name:match('^not%-') then
pred_name = pred_name:sub(5)
is_not = true is_not = true
else
pred_name = pred[1]
is_not = false
end end
local handler = predicate_handlers[pred_name] local handler = predicate_handlers[pred_name]
@ -724,7 +880,7 @@ function Query:iter_captures(node, source, start, stop)
start, stop = value_or_node_range(start, stop, node) start, stop = value_or_node_range(start, stop, node)
local raw_iter = node:_rawquery(self.query, true, start, stop) local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, TSMatch
local function iter(end_line) local function iter(end_line)
local capture, captured_node, match = raw_iter() local capture, captured_node, match = raw_iter()
local metadata = {} local metadata = {}
@ -748,27 +904,34 @@ end
--- Iterates the matches of self on a given range. --- Iterates the matches of self on a given range.
--- ---
--- Iterate over all matches within a {node}. The arguments are the same as --- Iterate over all matches within a {node}. The arguments are the same as for
--- for |Query:iter_captures()| but the iterated values are different: --- |Query:iter_captures()| but the iterated values are different: an (1-based)
--- an (1-based) index of the pattern in the query, a table mapping --- index of the pattern in the query, a table mapping capture indices to a list
--- capture indices to nodes, and metadata from any directives processing the match. --- of nodes, and metadata from any directives processing the match.
--- If the query has more than one pattern, the capture table might be sparse ---
--- and e.g. `pairs()` method should be used over `ipairs`. --- WARNING: Set `all=true` to ensure all matching nodes in a match are
--- Here is an example iterating over all captures in every match: --- returned, otherwise only the last node in a match is returned, breaking captures
--- involving quantifiers such as `(comment)+ @comment`. The default option
--- `all=false` is only provided for backward compatibility and will be removed
--- after Nvim 0.10.
---
--- Example:
--- ---
--- ```lua --- ```lua
--- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do --- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1, { all = true }) do
--- for id, node in pairs(match) do --- for id, nodes in pairs(match) do
--- local name = query.captures[id] --- local name = query.captures[id]
--- -- `node` was captured by the `name` capture in the match --- for _, node in ipairs(nodes) do
--- -- `node` was captured by the `name` capture in the match
--- ---
--- local node_data = metadata[id] -- Node level metadata --- local node_data = metadata[id] -- Node level metadata
--- --- ... use the info here ...
--- -- ... use the info here ... --- end
--- end --- end
--- end --- end
--- ``` --- ```
--- ---
---
---@param node TSNode under which the search will occur ---@param node TSNode under which the search will occur
---@param source (integer|string) Source buffer or string to search ---@param source (integer|string) Source buffer or string to search
---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param start? integer Starting line for the search. Defaults to `node:start()`.
@ -776,17 +939,20 @@ end
---@param opts? table Optional keyword arguments: ---@param opts? table Optional keyword arguments:
--- - max_start_depth (integer) if non-zero, sets the maximum start depth --- - max_start_depth (integer) if non-zero, sets the maximum start depth
--- for each match. This is used to prevent traversing too deep into a tree. --- for each match. This is used to prevent traversing too deep into a tree.
--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes.
--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is
--- incorrect behavior. This option will eventually become the default and removed.
--- ---
---@return (fun(): integer, table<integer,TSNode>, table): pattern id, match, metadata ---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata
function Query:iter_matches(node, source, start, stop, opts) function Query:iter_matches(node, source, start, stop, opts)
local all = opts and opts.all
if type(source) == 'number' and source == 0 then if type(source) == 'number' and source == 0 then
source = api.nvim_get_current_buf() source = api.nvim_get_current_buf()
end end
start, stop = value_or_node_range(start, stop, node) start, stop = value_or_node_range(start, stop, node)
local raw_iter = node:_rawquery(self.query, false, start, stop, opts) local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, TSMatch
---@cast raw_iter fun(): string, any
local function iter() local function iter()
local pattern, match = raw_iter() local pattern, match = raw_iter()
local metadata = {} local metadata = {}
@ -799,6 +965,18 @@ function Query:iter_matches(node, source, start, stop, opts)
self:apply_directives(match, pattern, source, metadata) self:apply_directives(match, pattern, source, metadata)
end end
if not all then
-- Convert the match table into the old buggy version for backward
-- compatibility. This is slow. Plugin authors, if you're reading this, set the "all"
-- option!
local old_match = {} ---@type table<integer, TSNode>
for k, v in pairs(match or {}) do
old_match[k] = v[#v]
end
return pattern, old_match, metadata
end
return pattern, match, metadata return pattern, match, metadata
end end
return iter return iter

View File

@ -1364,9 +1364,16 @@ static int node_equal(lua_State *L)
/// assumes the match table being on top of the stack /// assumes the match table being on top of the stack
static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx) static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx)
{ {
for (int i = 0; i < match->capture_count; i++) { // [match]
push_node(L, match->captures[i].node, nodeidx); for (size_t i = 0; i < match->capture_count; i++) {
lua_rawseti(L, -2, (int)match->captures[i].index + 1); lua_rawgeti(L, -1, (int)match->captures[i].index + 1); // [match, captures]
if (lua_isnil(L, -1)) { // [match, nil]
lua_pop(L, 1); // [match]
lua_createtable(L, 1, 0); // [match, captures]
}
push_node(L, match->captures[i].node, nodeidx); // [match, captures, node]
lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, captures]
lua_rawseti(L, -2, (int)match->captures[i].index + 1); // [match]
} }
} }
@ -1379,7 +1386,7 @@ static int query_next_match(lua_State *L)
TSQueryMatch match; TSQueryMatch match;
if (ts_query_cursor_next_match(cursor, &match)) { if (ts_query_cursor_next_match(cursor, &match)) {
lua_pushinteger(L, match.pattern_index + 1); // [index] lua_pushinteger(L, match.pattern_index + 1); // [index]
lua_createtable(L, (int)ts_query_capture_count(query), 2); // [index, match] lua_createtable(L, (int)ts_query_capture_count(query), 0); // [index, match]
set_match(L, &match, lua_upvalueindex(2)); set_match(L, &match, lua_upvalueindex(2));
return 2; return 2;
} }
@ -1421,7 +1428,8 @@ static int query_next_capture(lua_State *L)
if (n_pred > 0 && (ud->max_match_id < (int)match.id)) { if (n_pred > 0 && (ud->max_match_id < (int)match.id)) {
ud->max_match_id = (int)match.id; ud->max_match_id = (int)match.id;
lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match] // Create a new cleared match table
lua_createtable(L, (int)ts_query_capture_count(query), 2); // [index, node, match]
set_match(L, &match, lua_upvalueindex(2)); set_match(L, &match, lua_upvalueindex(2));
lua_pushinteger(L, match.pattern_index + 1); lua_pushinteger(L, match.pattern_index + 1);
lua_setfield(L, -2, "pattern"); lua_setfield(L, -2, "pattern");
@ -1431,6 +1439,10 @@ static int query_next_capture(lua_State *L)
lua_pushboolean(L, false); lua_pushboolean(L, false);
lua_setfield(L, -2, "active"); lua_setfield(L, -2, "active");
} }
// Set current_match to the new match
lua_replace(L, lua_upvalueindex(4)); // [index, node]
lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match]
return 3; return 3;
} }
return 2; return 2;

View File

@ -731,6 +731,31 @@ describe('treesitter highlighting (C)', function()
eq(3, get_hl '@foo.missing.exists.bar') eq(3, get_hl '@foo.missing.exists.bar')
eq(nil, get_hl '@total.nonsense.but.a.lot.of.dots') eq(nil, get_hl '@total.nonsense.but.a.lot.of.dots')
end) end)
it('supports multiple nodes assigned to the same capture #17060', function()
insert([[
int x = 4;
int y = 5;
int z = 6;
]])
exec_lua([[
local query = '((declaration)+ @string)'
vim.treesitter.query.set('c', 'highlights', query)
vim.treesitter.highlighter.new(vim.treesitter.get_parser(0, 'c'))
]])
screen:expect {
grid = [[
{5:int x = 4;} |
{5:int y = 5;} |
{5:int z = 6;} |
^ |
{1:~ }|*13
|
]],
}
end)
end) end)
describe('treesitter highlighting (help)', function() describe('treesitter highlighting (help)', function()

View File

@ -8,6 +8,8 @@ local exec_lua = helpers.exec_lua
local pcall_err = helpers.pcall_err local pcall_err = helpers.pcall_err
local feed = helpers.feed local feed = helpers.feed
local is_os = helpers.is_os local is_os = helpers.is_os
local api = helpers.api
local fn = helpers.fn
describe('treesitter parser API', function() describe('treesitter parser API', function()
before_each(function() before_each(function()
@ -171,7 +173,7 @@ void ui_refresh(void)
assert(res_fail) assert(res_fail)
end) end)
local query = [[ local test_query = [[
((call_expression function: (identifier) @minfunc (argument_list (identifier) @min_id)) (eq? @minfunc "MIN")) ((call_expression function: (identifier) @minfunc (argument_list (identifier) @min_id)) (eq? @minfunc "MIN"))
"for" @keyword "for" @keyword
(primitive_type) @type (primitive_type) @type
@ -187,7 +189,7 @@ void ui_refresh(void)
end) end)
it('supports caching queries', function() it('supports caching queries', function()
local long_query = query:rep(100) local long_query = test_query:rep(100)
local function q(n) local function q(n)
return exec_lua( return exec_lua(
[[ [[
@ -230,7 +232,7 @@ void ui_refresh(void)
end end
return res return res
]], ]],
query test_query
) )
eq({ eq({
@ -256,17 +258,19 @@ void ui_refresh(void)
parser = vim.treesitter.get_parser(0, "c") parser = vim.treesitter.get_parser(0, "c")
tree = parser:parse()[1] tree = parser:parse()[1]
res = {} res = {}
for pattern, match in cquery:iter_matches(tree:root(), 0, 7, 14) do for pattern, match in cquery:iter_matches(tree:root(), 0, 7, 14, { all = true }) do
-- can't transmit node over RPC. just check the name and range -- can't transmit node over RPC. just check the name and range
local mrepr = {} local mrepr = {}
for cid,node in pairs(match) do for cid, nodes in pairs(match) do
table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) for _, node in ipairs(nodes) do
table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()})
end
end end
table.insert(res, {pattern, mrepr}) table.insert(res, {pattern, mrepr})
end end
return res return res
]], ]],
query test_query
) )
eq({ eq({
@ -287,6 +291,67 @@ void ui_refresh(void)
}, res) }, res)
end) end)
it('support query and iter by capture for quantifiers', function()
insert(test_text)
local res = exec_lua(
[[
cquery = vim.treesitter.query.parse("c", ...)
parser = vim.treesitter.get_parser(0, "c")
tree = parser:parse()[1]
res = {}
for cid, node in cquery:iter_captures(tree:root(), 0, 7, 14) do
-- can't transmit node over RPC. just check the name and range
table.insert(res, {cquery.captures[cid], node:type(), node:range()})
end
return res
]],
'(expression_statement (assignment_expression (call_expression)))+ @funccall'
)
eq({
{ 'funccall', 'expression_statement', 11, 4, 11, 34 },
{ 'funccall', 'expression_statement', 12, 4, 12, 37 },
{ 'funccall', 'expression_statement', 13, 4, 13, 34 },
}, res)
end)
it('support query and iter by match for quantifiers', function()
insert(test_text)
local res = exec_lua(
[[
cquery = vim.treesitter.query.parse("c", ...)
parser = vim.treesitter.get_parser(0, "c")
tree = parser:parse()[1]
res = {}
for pattern, match in cquery:iter_matches(tree:root(), 0, 7, 14, { all = true }) do
-- can't transmit node over RPC. just check the name and range
local mrepr = {}
for cid, nodes in pairs(match) do
for _, node in ipairs(nodes) do
table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()})
end
end
table.insert(res, {pattern, mrepr})
end
return res
]],
'(expression_statement (assignment_expression (call_expression)))+ @funccall'
)
eq({
{
1,
{
{ 'funccall', 'expression_statement', 11, 4, 11, 34 },
{ 'funccall', 'expression_statement', 12, 4, 12, 37 },
{ 'funccall', 'expression_statement', 13, 4, 13, 34 },
},
},
}, res)
end)
it('supports getting text of multiline node', function() it('supports getting text of multiline node', function()
insert(test_text) insert(test_text)
local res = exec_lua([[ local res = exec_lua([[
@ -365,11 +430,13 @@ end]]
parser = vim.treesitter.get_parser(0, "c") parser = vim.treesitter.get_parser(0, "c")
tree = parser:parse()[1] tree = parser:parse()[1]
res = {} res = {}
for pattern, match in cquery:iter_matches(tree:root(), 0) do for pattern, match in cquery:iter_matches(tree:root(), 0, 0, -1, { all = true }) do
-- can't transmit node over RPC. just check the name and range -- can't transmit node over RPC. just check the name and range
local mrepr = {} local mrepr = {}
for cid,node in pairs(match) do for cid, nodes in pairs(match) do
table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) for _, node in ipairs(nodes) do
table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()})
end
end end
table.insert(res, {pattern, mrepr}) table.insert(res, {pattern, mrepr})
end end
@ -457,11 +524,13 @@ end]]
parser = vim.treesitter.get_parser(0, "c") parser = vim.treesitter.get_parser(0, "c")
tree = parser:parse()[1] tree = parser:parse()[1]
res = {} res = {}
for pattern, match in cquery:iter_matches(tree:root(), 0) do for pattern, match in cquery:iter_matches(tree:root(), 0, 0, -1, { all = true }) do
-- can't transmit node over RPC. just check the name and range -- can't transmit node over RPC. just check the name and range
local mrepr = {} local mrepr = {}
for cid,node in pairs(match) do for cid, nodes in pairs(match) do
table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) for _, node in ipairs(nodes) do
table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()})
end
end end
table.insert(res, {pattern, mrepr}) table.insert(res, {pattern, mrepr})
end end
@ -486,55 +555,248 @@ end]]
local custom_query = '((identifier) @main (#is-main? @main))' local custom_query = '((identifier) @main (#is-main? @main))'
local res = exec_lua( do
[[ local res = exec_lua(
local query = vim.treesitter.query [[
local query = vim.treesitter.query
local function is_main(match, pattern, bufnr, predicate) local function is_main(match, pattern, bufnr, predicate)
local node = match[ predicate[2] ] local nodes = match[ predicate[2] ]
for _, node in ipairs(nodes) do
if query.get_node_text(node, bufnr) == 'main' then
return true
end
end
return false
end
return query.get_node_text(node, bufnr) local parser = vim.treesitter.get_parser(0, "c")
-- Time bomb: update this in 0.12
if vim.fn.has('nvim-0.12') == 1 then
return 'Update this test to remove this message and { all = true } from add_predicate'
end
query.add_predicate("is-main?", is_main, { all = true })
local query = query.parse("c", ...)
local nodes = {}
for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do
table.insert(nodes, {node:range()})
end
return nodes
]],
custom_query
)
eq({ { 0, 4, 0, 8 } }, res)
end end
local parser = vim.treesitter.get_parser(0, "c") -- Once with the old API. Remove this whole 'do' block in 0.12
do
local res = exec_lua(
[[
local query = vim.treesitter.query
query.add_predicate("is-main?", is_main) local function is_main(match, pattern, bufnr, predicate)
local node = match[ predicate[2] ]
local query = query.parse("c", ...) return query.get_node_text(node, bufnr) == 'main'
end
local nodes = {} local parser = vim.treesitter.get_parser(0, "c")
for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do
table.insert(nodes, {node:range()}) query.add_predicate("is-main?", is_main, true)
local query = query.parse("c", ...)
local nodes = {}
for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do
table.insert(nodes, {node:range()})
end
return nodes
]],
custom_query
)
-- Remove this 'do' block in 0.12
eq(0, fn.has('nvim-0.12'))
eq({ { 0, 4, 0, 8 } }, res)
end end
return nodes do
]], local res = exec_lua [[
custom_query local query = vim.treesitter.query
)
eq({ { 0, 4, 0, 8 } }, res) local t = {}
for _, v in ipairs(query.list_predicates()) do
t[v] = true
end
local res_list = exec_lua [[ return t
local query = vim.treesitter.query ]]
local list = query.list_predicates() eq(true, res['is-main?'])
end
end)
table.sort(list) it('supports "all" and "any" semantics for predicates on quantified captures #24738', function()
local query_all = [[
return list (((comment (comment_content))+) @bar
(#lua-match? @bar "Yes"))
]] ]]
local query_any = [[
(((comment (comment_content))+) @bar
(#any-lua-match? @bar "Yes"))
]]
local function test(input, query)
api.nvim_buf_set_lines(0, 0, -1, true, vim.split(dedent(input), '\n'))
return exec_lua(
[[
local parser = vim.treesitter.get_parser(0, "lua")
local query = vim.treesitter.query.parse("lua", ...)
local nodes = {}
for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do
nodes[#nodes+1] = { node:range() }
end
return nodes
]],
query
)
end
eq(
{},
test(
[[
-- Yes
-- No
-- Yes
]],
query_all
)
)
eq(
{
{ 0, 2, 0, 8 },
{ 1, 2, 1, 8 },
{ 2, 2, 2, 8 },
},
test(
[[
-- Yes
-- Yes
-- Yes
]],
query_all
)
)
eq(
{},
test(
[[
-- No
-- No
-- No
]],
query_any
)
)
eq(
{
{ 0, 2, 0, 7 },
{ 1, 2, 1, 8 },
{ 2, 2, 2, 7 },
},
test(
[[
-- No
-- Yes
-- No
]],
query_any
)
)
end)
it('supports any- prefix to match any capture when using quantifiers #24738', function()
insert([[
-- Comment
-- Comment
-- Comment
]])
local query = [[
(((comment (comment_content))+) @bar
(#lua-match? @bar "Comment"))
]]
local result = exec_lua(
[[
local parser = vim.treesitter.get_parser(0, "lua")
local query = vim.treesitter.query.parse("lua", ...)
local nodes = {}
for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do
nodes[#nodes+1] = { node:range() }
end
return nodes
]],
query
)
eq({ eq({
'any-of?', { 0, 2, 0, 12 },
'contains?', { 1, 2, 1, 12 },
'eq?', { 2, 2, 2, 12 },
'has-ancestor?', }, result)
'has-parent?', end)
'is-main?',
'lua-match?', it('supports the old broken version of iter_matches #24738', function()
'match?', -- Delete this test in 0.12 when iter_matches is removed
'vim-match?', eq(0, fn.has('nvim-0.12'))
}, res_list)
insert(test_text)
local res = exec_lua(
[[
cquery = vim.treesitter.query.parse("c", ...)
parser = vim.treesitter.get_parser(0, "c")
tree = parser:parse()[1]
res = {}
for pattern, match in cquery:iter_matches(tree:root(), 0, 7, 14) do
local mrepr = {}
for cid, node in pairs(match) do
table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()})
end
table.insert(res, {pattern, mrepr})
end
return res
]],
test_query
)
eq({
{ 3, { { 'type', 'primitive_type', 8, 2, 8, 6 } } },
{ 2, { { 'keyword', 'for', 9, 2, 9, 5 } } },
{ 3, { { 'type', 'primitive_type', 9, 7, 9, 13 } } },
{ 4, { { 'fieldarg', 'identifier', 11, 16, 11, 18 } } },
{
1,
{ { 'minfunc', 'identifier', 11, 12, 11, 15 }, { 'min_id', 'identifier', 11, 27, 11, 32 } },
},
{ 4, { { 'fieldarg', 'identifier', 12, 17, 12, 19 } } },
{
1,
{ { 'minfunc', 'identifier', 12, 13, 12, 16 }, { 'min_id', 'identifier', 12, 29, 12, 35 } },
},
{ 4, { { 'fieldarg', 'identifier', 13, 14, 13, 16 } } },
}, res)
end) end)
it('allows to set simple ranges', function() it('allows to set simple ranges', function()
@ -866,7 +1128,7 @@ int x = INT_MAX;
query = vim.treesitter.query.parse("c", '((number_literal) @number (#set! "key" "value"))') query = vim.treesitter.query.parse("c", '((number_literal) @number (#set! "key" "value"))')
parser = vim.treesitter.get_parser(0, "c") parser = vim.treesitter.get_parser(0, "c")
for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0) do for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0, 0, -1, { all = true }) do
result = metadata.key result = metadata.key
end end
@ -889,7 +1151,7 @@ int x = INT_MAX;
query = vim.treesitter.query.parse("c", '((number_literal) @number (#set! @number "key" "value"))') query = vim.treesitter.query.parse("c", '((number_literal) @number (#set! @number "key" "value"))')
parser = vim.treesitter.get_parser(0, "c") parser = vim.treesitter.get_parser(0, "c")
for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0) do for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0, 0, -1, { all = true }) do
for _, nested_tbl in pairs(metadata) do for _, nested_tbl in pairs(metadata) do
return nested_tbl.key return nested_tbl.key
end end
@ -911,7 +1173,7 @@ int x = INT_MAX;
query = vim.treesitter.query.parse("c", '((number_literal) @number (#set! @number "key" "value") (#set! @number "key2" "value2"))') query = vim.treesitter.query.parse("c", '((number_literal) @number (#set! @number "key" "value") (#set! @number "key2" "value2"))')
parser = vim.treesitter.get_parser(0, "c") parser = vim.treesitter.get_parser(0, "c")
for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0) do for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0, 0, -1, { all = true }) do
for _, nested_tbl in pairs(metadata) do for _, nested_tbl in pairs(metadata) do
return nested_tbl return nested_tbl
end end