diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt index 757cfecb46..8c6ba38d3c 100644 --- a/runtime/doc/news.txt +++ b/runtime/doc/news.txt @@ -427,6 +427,18 @@ The following changes to existing APIs or features add new behavior. • |nvim_buf_call()| and |nvim_win_call()| now preserves any return value (NB: not multiple return values) +• Treesitter + • |Query:iter_matches()|, |vim.treesitter.query.add_predicate()|, and + |vim.treesitter.query.add_directive()| accept a new `all` option which + ensures that all matching nodes are returned as a table. The default option + `all=false` returns only a single node, breaking captures with quantifiers + like `(comment)+ @comment; it is only provided for backward compatibility + and will be removed after Nvim 0.10. + • |vim.treesitter.query.add_predicate()| and + |vim.treesitter.query.add_directive()| now accept an options table rather + than a boolean "force" argument. To force a predicate or directive to + override an existing predicate or directive, use `{ force = true }`. + ============================================================================== REMOVED FEATURES *news-removed* @@ -480,7 +492,7 @@ release. • `vim.loop` has been renamed to |vim.uv|. -• vim.treesitter.languagetree functions: +• vim.treesitter functions: - |LanguageTree:for_each_child()| Use |LanguageTree:children()| (non-recursive) instead. • The "term_background" UI option |ui-ext-options| is deprecated and no longer diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt index f6ee2ef425..f92955ee48 100644 --- a/runtime/doc/treesitter.txt +++ b/runtime/doc/treesitter.txt @@ -223,6 +223,10 @@ The following predicates are built in: ((identifier) @variable.builtin (#eq? @variable.builtin "self")) ((node1) @left (node2) @right (#eq? @left @right)) < + `any-eq?` *treesitter-predicate-any-eq?* + Like `eq?`, but for quantified patterns only one captured node must + match. + `match?` *treesitter-predicate-match?* `vim-match?` *treesitter-predicate-vim-match?* Match a |regexp| against the text corresponding to a node: >query @@ -231,15 +235,28 @@ The following predicates are built in: Note: The `^` and `$` anchors will match the start and end of the node's text. + `any-match?` *treesitter-predicate-any-match?* + `any-vim-match?` *treesitter-predicate-any-vim-match?* + Like `match?`, but for quantified patterns only one captured node must + match. + `lua-match?` *treesitter-predicate-lua-match?* Match |lua-patterns| against the text corresponding to a node, similar to `match?` + `any-lua-match?` *treesitter-predicate-any-lua-match?* + Like `lua-match?`, but for quantified patterns only one captured node + must match. + `contains?` *treesitter-predicate-contains?* Match a string against parts of the text corresponding to a node: >query ((identifier) @foo (#contains? @foo "foo")) ((identifier) @foo-bar (#contains? @foo-bar "foo" "bar")) < + `any-contains?` *treesitter-predicate-any-contains?* + Like `contains?`, but for quantified patterns only one captured node + must match. + `any-of?` *treesitter-predicate-any-of?* Match any of the given strings against the text corresponding to a node: >query @@ -265,6 +282,32 @@ The following predicates are built in: Each predicate has a `not-` prefixed predicate that is just the negation of the predicate. + *lua-treesitter-all-predicate* + *lua-treesitter-any-predicate* +Queries can use quantifiers to capture multiple nodes. When a capture contains +multiple nodes, predicates match only if ALL nodes contained by the capture +match the predicate. Some predicates (`eq?`, `match?`, `lua-match?`, +`contains?`) accept an `any-` prefix to instead match if ANY of the nodes +contained by the capture match the predicate. + +As an example, consider the following Lua code: >lua + + -- TODO: This is a + -- very long + -- comment (just imagine it) +< +using the following predicated query: +>query + (((comment)+ @comment) + (#match? @comment "TODO")) +< +This query will not match because not all of the nodes captured by @comment +match the predicate. Instead, use: +>query + (((comment)+ @comment) + (#any-match? @comment "TODO")) +< + Further predicates can be added via |vim.treesitter.query.add_predicate()|. Use |vim.treesitter.query.list_predicates()| to list all available predicates. @@ -923,28 +966,35 @@ register({lang}, {filetype}) *vim.treesitter.language.register()* Lua module: vim.treesitter.query *lua-treesitter-query* *vim.treesitter.query.add_directive()* -add_directive({name}, {handler}, {force}) +add_directive({name}, {handler}, {opts}) Adds a new directive to be used in queries Handlers can set match level data by setting directly on the metadata - object `metadata.key = value`, additionally, handlers can set node level + object `metadata.key = value`. Additionally, handlers can set node level data by using the capture id on the metadata table `metadata[capture_id].key = value` Parameters: ~ • {name} (`string`) Name of the directive, without leading # • {handler} (`function`) - • match: see |treesitter-query| - • node-level data are accessible via `match[capture_id]` - - • pattern: see |treesitter-query| + • match: A table mapping capture IDs to a list of captured + nodes + • pattern: the index of the matching pattern in the query + file • predicate: list of strings containing the full directive being called, e.g. `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` - • {force} (`boolean?`) + • {opts} (`table`) Optional options: + • force (boolean): Override an existing predicate of the + same name + • all (boolean): Use the correct implementation of the + match table where capture IDs map to a list of nodes + instead of a single node. Defaults to false (for backward + compatibility). This option will eventually become the + default and removed. *vim.treesitter.query.add_predicate()* -add_predicate({name}, {handler}, {force}) +add_predicate({name}, {handler}, {opts}) Adds a new predicate to be used in queries Parameters: ~ @@ -952,7 +1002,14 @@ add_predicate({name}, {handler}, {force}) • {handler} (`function`) • see |vim.treesitter.query.add_directive()| for argument meanings - • {force} (`boolean?`) + • {opts} (`table`) Optional options: + • force (boolean): Override an existing predicate of the + same name + • all (boolean): Use the correct implementation of the + match table where capture IDs map to a list of nodes + instead of a single node. Defaults to false (for backward + compatibility). This option will eventually become the + default and removed. edit({lang}) *vim.treesitter.query.edit()* Opens a live editor to query the buffer you started from. @@ -1102,18 +1159,25 @@ Query:iter_matches({node}, {source}, {start}, {stop}, {opts}) Iterate over all matches within a {node}. The arguments are the same as for |Query:iter_captures()| but the iterated values are different: an (1-based) index of the pattern in the query, a table mapping capture - indices to nodes, and metadata from any directives processing the match. - If the query has more than one pattern, the capture table might be sparse - and e.g. `pairs()` method should be used over `ipairs`. Here is an example - iterating over all captures in every match: >lua - for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do - for id, node in pairs(match) do + indices to a list of nodes, and metadata from any directives processing + the match. + + WARNING: Set `all=true` to ensure all matching nodes in a match are + returned, otherwise only the last node in a match is returned, breaking + captures involving quantifiers such as `(comment)+ @comment`. The default + option `all=false` is only provided for backward compatibility and will be + removed after Nvim 0.10. + + Example: >lua + for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1, { all = true }) do + for id, nodes in pairs(match) do local name = query.captures[id] - -- `node` was captured by the `name` capture in the match + for _, node in ipairs(nodes) do + -- `node` was captured by the `name` capture in the match - local node_data = metadata[id] -- Node level metadata - - -- ... use the info here ... + local node_data = metadata[id] -- Node level metadata + ... use the info here ... + end end end < @@ -1129,9 +1193,14 @@ Query:iter_matches({node}, {source}, {start}, {stop}, {opts}) • max_start_depth (integer) if non-zero, sets the maximum start depth for each match. This is used to prevent traversing too deep into a tree. + • all (boolean) When set, the returned match table maps + capture IDs to a list of nodes. Older versions of + iter_matches incorrectly mapped capture IDs to a single + node, which is incorrect behavior. This option will + eventually become the default and removed. Return: ~ - (`fun(): integer, table, table`) pattern id, match, + (`fun(): integer, table, table`) pattern id, match, metadata set({lang}, {query_name}, {text}) *vim.treesitter.query.set()* diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua index 6a714de052..0b285d2d7f 100644 --- a/runtime/lua/vim/treesitter/_meta.lua +++ b/runtime/lua/vim/treesitter/_meta.lua @@ -39,7 +39,7 @@ local TSNode = {} ---@param start? integer ---@param end_? integer ---@param opts? table ----@return fun(): integer, TSNode, any +---@return fun(): integer, TSNode, TSMatch function TSNode:_rawquery(query, captures, start, end_, opts) end ---@param query TSQuery @@ -47,7 +47,7 @@ function TSNode:_rawquery(query, captures, start, end_, opts) end ---@param start? integer ---@param end_? integer ---@param opts? table ----@return fun(): integer, any +---@return fun(): integer, TSMatch function TSNode:_rawquery(query, captures, start, end_, opts) end ---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string) diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua index 8651e187c2..378e9c67aa 100644 --- a/runtime/lua/vim/treesitter/_query_linter.lua +++ b/runtime/lua/vim/treesitter/_query_linter.lua @@ -122,7 +122,7 @@ local parse = vim.func._memoize(hash_parse, function(node, buf, lang) end) --- @param buf integer ---- @param match table +--- @param match table --- @param query Query --- @param lang_context QueryLinterLanguageContext --- @param diagnostics Diagnostic[] @@ -130,20 +130,22 @@ local function lint_match(buf, match, query, lang_context, diagnostics) local lang = lang_context.lang local parser_info = lang_context.parser_info - for id, node in pairs(match) do - local cap_id = query.captures[id] + for id, nodes in pairs(match) do + for _, node in ipairs(nodes) do + local cap_id = query.captures[id] - -- perform language-independent checks only for first lang - if lang_context.is_first_lang and cap_id == 'error' then - local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') - add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text) - end + -- perform language-independent checks only for first lang + if lang_context.is_first_lang and cap_id == 'error' then + local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') + add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text) + end - -- other checks rely on Neovim parser introspection - if lang and parser_info and cap_id == 'toplevel' then - local err = parse(node, buf, lang) - if err then - add_lint_for_node(diagnostics, err.range, err.msg, lang) + -- other checks rely on Neovim parser introspection + if lang and parser_info and cap_id == 'toplevel' then + local err = parse(node, buf, lang) + if err then + add_lint_for_node(diagnostics, err.range, err.msg, lang) + end end end end diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 971c4449e8..79566f5eeb 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -784,7 +784,7 @@ end ---@private --- Extract injections according to: --- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection ----@param match table +---@param match table ---@param metadata TSMetadata ---@return string?, boolean, Range6[] function LanguageTree:_get_injection(match, metadata) @@ -796,14 +796,16 @@ function LanguageTree:_get_injection(match, metadata) or (injection_lang and resolve_lang(injection_lang)) local include_children = metadata['injection.include-children'] ~= nil - for id, node in pairs(match) do - local name = self._injection_query.captures[id] - -- Lang should override any other language tag - if name == 'injection.language' then - local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] }) - lang = resolve_lang(text) - elseif name == 'injection.content' then - ranges = get_node_ranges(node, self._source, metadata[id], include_children) + for id, nodes in pairs(match) do + for _, node in ipairs(nodes) do + local name = self._injection_query.captures[id] + -- Lang should override any other language tag + if name == 'injection.language' then + local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] }) + lang = resolve_lang(text) + elseif name == 'injection.content' then + ranges = get_node_ranges(node, self._source, metadata[id], include_children) + end end end @@ -844,7 +846,13 @@ function LanguageTree:_get_injections() local start_line, _, end_line, _ = root_node:range() for pattern, match, metadata in - self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1) + self._injection_query:iter_matches( + root_node, + self._source, + start_line, + end_line + 1, + { all = true } + ) do local lang, combined, ranges = self:_get_injection(match, metadata) if lang then diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index cd65c0d7f6..5bb9e07a82 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -290,47 +290,71 @@ function M.get_node_text(...) return vim.treesitter.get_node_text(...) end ----@alias TSMatch table - ----@alias TSPredicate fun(match: TSMatch, _, _, predicate: any[]): boolean - --- Predicate handler receive the following arguments --- (match, pattern, bufnr, predicate) ----@type table -local predicate_handlers = { - ['eq?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then +--- Implementations of predicates that can optionally be prefixed with "any-". +--- +--- These functions contain the implementations for each predicate, correctly +--- handling the "any" vs "all" semantics. They are called from the +--- predicate_handlers table with the appropriate arguments for each predicate. +local impl = { + --- @param match TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + ['eq'] = function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local node_text = vim.treesitter.get_node_text(node, source) - local str ---@type string - if type(predicate[3]) == 'string' then - -- (#eq? @aa "foo") - str = predicate[3] - else - -- (#eq? @aa @bb) - str = vim.treesitter.get_node_text(match[predicate[3]], source) + for _, node in ipairs(nodes) do + local node_text = vim.treesitter.get_node_text(node, source) + + local str ---@type string + if type(predicate[3]) == 'string' then + -- (#eq? @aa "foo") + str = predicate[3] + else + -- (#eq? @aa @bb) + local other = assert(match[predicate[3]]) + assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes') + str = vim.treesitter.get_node_text(other[1], source) + end + + local res = str ~= nil and node_text == str + if any and res then + return true + elseif not any and not res then + return false + end end - if node_text ~= str or str == nil then - return false - end - - return true + return not any end, - ['lua-match?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then + --- @param match TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + ['lua-match'] = function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local regex = predicate[3] - return string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil + + for _, node in ipairs(nodes) do + local regex = predicate[3] + local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil + if any and res then + return true + elseif not any and not res then + return false + end + end + + return not any end, - ['match?'] = (function() + ['match'] = (function() local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true } local function check_magic(str) if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then @@ -347,27 +371,120 @@ local predicate_handlers = { end, }) - return function(match, _, source, pred) - ---@cast match TSMatch - local node = match[pred[2]] - if not node then + --- @param match TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + return function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - ---@diagnostic disable-next-line no-unknown - local regex = compiled_vim_regexes[pred[3]] - return regex:match_str(vim.treesitter.get_node_text(node, source)) + + for _, node in ipairs(nodes) do + local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex + local res = regex:match_str(vim.treesitter.get_node_text(node, source)) + if any and res then + return true + elseif not any and not res then + return false + end + end + return not any end end)(), - ['contains?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then + --- @param match TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + ['contains'] = function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local node_text = vim.treesitter.get_node_text(node, source) - for i = 3, #predicate do - if string.find(node_text, predicate[i], 1, true) then + for _, node in ipairs(nodes) do + local node_text = vim.treesitter.get_node_text(node, source) + + for i = 3, #predicate do + local res = string.find(node_text, predicate[i], 1, true) + if any and res then + return true + elseif not any and not res then + return false + end + end + end + + return not any + end, +} + +---@class TSMatch +---@field pattern? integer +---@field active? boolean +---@field [integer] TSNode[] + +---@alias TSPredicate fun(match: TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean + +-- Predicate handler receive the following arguments +-- (match, pattern, bufnr, predicate) +---@type table +local predicate_handlers = { + ['eq?'] = function(match, _, source, predicate) + return impl['eq'](match, source, predicate, false) + end, + + ['any-eq?'] = function(match, _, source, predicate) + return impl['eq'](match, source, predicate, true) + end, + + ['lua-match?'] = function(match, _, source, predicate) + return impl['lua-match'](match, source, predicate, false) + end, + + ['any-lua-match?'] = function(match, _, source, predicate) + return impl['lua-match'](match, source, predicate, true) + end, + + ['match?'] = function(match, _, source, predicate) + return impl['match'](match, source, predicate, false) + end, + + ['any-match?'] = function(match, _, source, predicate) + return impl['match'](match, source, predicate, true) + end, + + ['contains?'] = function(match, _, source, predicate) + return impl['contains'](match, source, predicate, false) + end, + + ['any-contains?'] = function(match, _, source, predicate) + return impl['contains'](match, source, predicate, true) + end, + + ['any-of?'] = function(match, _, source, predicate) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then + return true + end + + for _, node in ipairs(nodes) do + local node_text = vim.treesitter.get_node_text(node, source) + + -- Since 'predicate' will not be used by callers of this function, use it + -- to store a string set built from the list of words to check against. + local string_set = predicate['string_set'] --- @type table + if not string_set then + string_set = {} + for i = 3, #predicate do + string_set[predicate[i]] = true + end + predicate['string_set'] = string_set + end + + if string_set[node_text] then return true end end @@ -375,57 +492,39 @@ local predicate_handlers = { return false end, - ['any-of?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then - return true - end - local node_text = vim.treesitter.get_node_text(node, source) - - -- Since 'predicate' will not be used by callers of this function, use it - -- to store a string set built from the list of words to check against. - local string_set = predicate['string_set'] - if not string_set then - string_set = {} - for i = 3, #predicate do - ---@diagnostic disable-next-line:no-unknown - string_set[predicate[i]] = true - end - predicate['string_set'] = string_set - end - - return string_set[node_text] - end, - ['has-ancestor?'] = function(match, _, _, predicate) - local node = match[predicate[2]] - if not node then + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local ancestor_types = {} - for _, type in ipairs({ unpack(predicate, 3) }) do - ancestor_types[type] = true - end - - node = node:parent() - while node do - if ancestor_types[node:type()] then - return true + for _, node in ipairs(nodes) do + local ancestor_types = {} --- @type table + for _, type in ipairs({ unpack(predicate, 3) }) do + ancestor_types[type] = true + end + + local cur = node:parent() + while cur do + if ancestor_types[cur:type()] then + return true + end + cur = cur:parent() end - node = node:parent() end return false end, ['has-parent?'] = function(match, _, _, predicate) - local node = match[predicate[2]] - if not node then + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then - return true + for _, node in ipairs(nodes) do + if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then + return true + end end return false end, @@ -433,6 +532,7 @@ local predicate_handlers = { -- As we provide lua-match? also expose vim-match? predicate_handlers['vim-match?'] = predicate_handlers['match?'] +predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?'] ---@class TSMetadata ---@field range? Range @@ -468,13 +568,17 @@ local directive_handlers = { -- Shifts the range of a node. -- Example: (#offset! @_node 0 1 0 -1) ['offset!'] = function(match, _, _, pred, metadata) - ---@cast pred integer[] - local capture_id = pred[2] + local capture_id = pred[2] --[[@as integer]] + local nodes = match[capture_id] + assert(#nodes == 1, '#offset! does not support captures on multiple nodes') + + local node = nodes[1] + if not metadata[capture_id] then metadata[capture_id] = {} end - local range = metadata[capture_id].range or { match[capture_id]:range() } + local range = metadata[capture_id].range or { node:range() } local start_row_offset = pred[3] or 0 local start_col_offset = pred[4] or 0 local end_row_offset = pred[5] or 0 @@ -498,7 +602,9 @@ local directive_handlers = { local id = pred[2] assert(type(id) == 'number') - local node = match[id] + local nodes = match[id] + assert(#nodes == 1, '#gsub! does not support captures on multiple nodes') + local node = nodes[1] local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' if not metadata[id] then @@ -518,10 +624,9 @@ local directive_handlers = { local capture_id = pred[2] assert(type(capture_id) == 'number') - local node = match[capture_id] - if not node then - return - end + local nodes = match[capture_id] + assert(#nodes == 1, '#trim! does not support captures on multiple nodes') + local node = nodes[1] local start_row, start_col, end_row, end_col = node:range() @@ -552,38 +657,93 @@ local directive_handlers = { --- Adds a new predicate to be used in queries --- ---@param name string Name of the predicate, without leading # ----@param handler function(match:table, pattern:string, bufnr:integer, predicate:string[]) +---@param handler function(match: table, pattern: integer, source: integer|string, predicate: any[], metadata: table) --- - see |vim.treesitter.query.add_directive()| for argument meanings ----@param force boolean|nil -function M.add_predicate(name, handler, force) - if predicate_handlers[name] and not force then - error(string.format('Overriding %s', name)) +---@param opts table Optional options: +--- - force (boolean): Override an existing +--- predicate of the same name +--- - all (boolean): Use the correct +--- implementation of the match table where +--- capture IDs map to a list of nodes instead +--- of a single node. Defaults to false (for +--- backward compatibility). This option will +--- eventually become the default and removed. +function M.add_predicate(name, handler, opts) + -- Backward compatibility: old signature had "force" as boolean argument + if type(opts) == 'boolean' then + opts = { force = opts } end - predicate_handlers[name] = handler + opts = opts or {} + + if predicate_handlers[name] and not opts.force then + error(string.format('Overriding existing predicate %s', name)) + end + + if opts.all then + predicate_handlers[name] = handler + else + --- @param match table + local function wrapper(match, ...) + local m = {} ---@type table + for k, v in pairs(match) do + if type(k) == 'number' then + m[k] = v[#v] + end + end + return handler(m, ...) + end + predicate_handlers[name] = wrapper + end end --- Adds a new directive to be used in queries --- --- Handlers can set match level data by setting directly on the ---- metadata object `metadata.key = value`, additionally, handlers +--- metadata object `metadata.key = value`. Additionally, handlers --- can set node level data by using the capture id on the --- metadata table `metadata[capture_id].key = value` --- ---@param name string Name of the directive, without leading # ----@param handler function(match:table, pattern:string, bufnr:integer, predicate:string[], metadata:table) ---- - match: see |treesitter-query| ---- - node-level data are accessible via `match[capture_id]` ---- - pattern: see |treesitter-query| +---@param handler function(match: table, pattern: integer, source: integer|string, predicate: any[], metadata: table) +--- - match: A table mapping capture IDs to a list of captured nodes +--- - pattern: the index of the matching pattern in the query file --- - predicate: list of strings containing the full directive being called, e.g. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` ----@param force boolean|nil -function M.add_directive(name, handler, force) - if directive_handlers[name] and not force then - error(string.format('Overriding %s', name)) +---@param opts table Optional options: +--- - force (boolean): Override an existing +--- predicate of the same name +--- - all (boolean): Use the correct +--- implementation of the match table where +--- capture IDs map to a list of nodes instead +--- of a single node. Defaults to false (for +--- backward compatibility). This option will +--- eventually become the default and removed. +function M.add_directive(name, handler, opts) + -- Backward compatibility: old signature had "force" as boolean argument + if type(opts) == 'boolean' then + opts = { force = opts } end - directive_handlers[name] = handler + opts = opts or {} + + if directive_handlers[name] and not opts.force then + error(string.format('Overriding existing directive %s', name)) + end + + if opts.all then + directive_handlers[name] = handler + else + --- @param match table + local function wrapper(match, ...) + local m = {} ---@type table + for k, v in pairs(match) do + m[k] = v[#v] + end + handler(m, ...) + end + directive_handlers[name] = wrapper + end end --- Lists the currently available directives to use in queries. @@ -608,7 +768,7 @@ end ---@private ---@param match TSMatch ----@param pattern string +---@param pattern integer ---@param source integer|string function Query:match_preds(match, pattern, source) local preds = self.info.patterns[pattern] @@ -618,18 +778,14 @@ function Query:match_preds(match, pattern, source) -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). -- Also, tree-sitter strips the leading # from predicates for us. - local pred_name ---@type string - - local is_not ---@type boolean + local is_not = false -- Skip over directives... they will get processed after all the predicates. if not is_directive(pred[1]) then - if string.sub(pred[1], 1, 4) == 'not-' then - pred_name = string.sub(pred[1], 5) + local pred_name = pred[1] + if pred_name:match('^not%-') then + pred_name = pred_name:sub(5) is_not = true - else - pred_name = pred[1] - is_not = false end local handler = predicate_handlers[pred_name] @@ -724,7 +880,7 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, true, start, stop) + local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, TSMatch local function iter(end_line) local capture, captured_node, match = raw_iter() local metadata = {} @@ -748,27 +904,34 @@ end --- Iterates the matches of self on a given range. --- ---- Iterate over all matches within a {node}. The arguments are the same as ---- for |Query:iter_captures()| but the iterated values are different: ---- an (1-based) index of the pattern in the query, a table mapping ---- capture indices to nodes, and metadata from any directives processing the match. ---- If the query has more than one pattern, the capture table might be sparse ---- and e.g. `pairs()` method should be used over `ipairs`. ---- Here is an example iterating over all captures in every match: +--- Iterate over all matches within a {node}. The arguments are the same as for +--- |Query:iter_captures()| but the iterated values are different: an (1-based) +--- index of the pattern in the query, a table mapping capture indices to a list +--- of nodes, and metadata from any directives processing the match. +--- +--- WARNING: Set `all=true` to ensure all matching nodes in a match are +--- returned, otherwise only the last node in a match is returned, breaking captures +--- involving quantifiers such as `(comment)+ @comment`. The default option +--- `all=false` is only provided for backward compatibility and will be removed +--- after Nvim 0.10. +--- +--- Example: --- --- ```lua ---- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do ---- for id, node in pairs(match) do +--- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1, { all = true }) do +--- for id, nodes in pairs(match) do --- local name = query.captures[id] ---- -- `node` was captured by the `name` capture in the match +--- for _, node in ipairs(nodes) do +--- -- `node` was captured by the `name` capture in the match --- ---- local node_data = metadata[id] -- Node level metadata ---- ---- -- ... use the info here ... +--- local node_data = metadata[id] -- Node level metadata +--- ... use the info here ... +--- end --- end --- end --- ``` --- +--- ---@param node TSNode under which the search will occur ---@param source (integer|string) Source buffer or string to search ---@param start? integer Starting line for the search. Defaults to `node:start()`. @@ -776,17 +939,20 @@ end ---@param opts? table Optional keyword arguments: --- - max_start_depth (integer) if non-zero, sets the maximum start depth --- for each match. This is used to prevent traversing too deep into a tree. +--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes. +--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is +--- incorrect behavior. This option will eventually become the default and removed. --- ----@return (fun(): integer, table, table): pattern id, match, metadata +---@return (fun(): integer, table, table): pattern id, match, metadata function Query:iter_matches(node, source, start, stop, opts) + local all = opts and opts.all if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, false, start, stop, opts) - ---@cast raw_iter fun(): string, any + local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, TSMatch local function iter() local pattern, match = raw_iter() local metadata = {} @@ -799,6 +965,18 @@ function Query:iter_matches(node, source, start, stop, opts) self:apply_directives(match, pattern, source, metadata) end + + if not all then + -- Convert the match table into the old buggy version for backward + -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all" + -- option! + local old_match = {} ---@type table + for k, v in pairs(match or {}) do + old_match[k] = v[#v] + end + return pattern, old_match, metadata + end + return pattern, match, metadata end return iter diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index c1816a8860..25a753b179 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -1364,9 +1364,16 @@ static int node_equal(lua_State *L) /// assumes the match table being on top of the stack static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx) { - for (int i = 0; i < match->capture_count; i++) { - push_node(L, match->captures[i].node, nodeidx); - lua_rawseti(L, -2, (int)match->captures[i].index + 1); + // [match] + for (size_t i = 0; i < match->capture_count; i++) { + lua_rawgeti(L, -1, (int)match->captures[i].index + 1); // [match, captures] + if (lua_isnil(L, -1)) { // [match, nil] + lua_pop(L, 1); // [match] + lua_createtable(L, 1, 0); // [match, captures] + } + push_node(L, match->captures[i].node, nodeidx); // [match, captures, node] + lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, captures] + lua_rawseti(L, -2, (int)match->captures[i].index + 1); // [match] } } @@ -1379,7 +1386,7 @@ static int query_next_match(lua_State *L) TSQueryMatch match; if (ts_query_cursor_next_match(cursor, &match)) { lua_pushinteger(L, match.pattern_index + 1); // [index] - lua_createtable(L, (int)ts_query_capture_count(query), 2); // [index, match] + lua_createtable(L, (int)ts_query_capture_count(query), 0); // [index, match] set_match(L, &match, lua_upvalueindex(2)); return 2; } @@ -1421,7 +1428,8 @@ static int query_next_capture(lua_State *L) if (n_pred > 0 && (ud->max_match_id < (int)match.id)) { ud->max_match_id = (int)match.id; - lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match] + // Create a new cleared match table + lua_createtable(L, (int)ts_query_capture_count(query), 2); // [index, node, match] set_match(L, &match, lua_upvalueindex(2)); lua_pushinteger(L, match.pattern_index + 1); lua_setfield(L, -2, "pattern"); @@ -1431,6 +1439,10 @@ static int query_next_capture(lua_State *L) lua_pushboolean(L, false); lua_setfield(L, -2, "active"); } + + // Set current_match to the new match + lua_replace(L, lua_upvalueindex(4)); // [index, node] + lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match] return 3; } return 2; diff --git a/test/functional/treesitter/highlight_spec.lua b/test/functional/treesitter/highlight_spec.lua index 932af0332b..10fe08c549 100644 --- a/test/functional/treesitter/highlight_spec.lua +++ b/test/functional/treesitter/highlight_spec.lua @@ -731,6 +731,31 @@ describe('treesitter highlighting (C)', function() eq(3, get_hl '@foo.missing.exists.bar') eq(nil, get_hl '@total.nonsense.but.a.lot.of.dots') end) + + it('supports multiple nodes assigned to the same capture #17060', function() + insert([[ + int x = 4; + int y = 5; + int z = 6; + ]]) + + exec_lua([[ + local query = '((declaration)+ @string)' + vim.treesitter.query.set('c', 'highlights', query) + vim.treesitter.highlighter.new(vim.treesitter.get_parser(0, 'c')) + ]]) + + screen:expect { + grid = [[ + {5:int x = 4;} | + {5:int y = 5;} | + {5:int z = 6;} | + ^ | + {1:~ }|*13 + | + ]], + } + end) end) describe('treesitter highlighting (help)', function() diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua index e63d424622..6ae8c97f85 100644 --- a/test/functional/treesitter/parser_spec.lua +++ b/test/functional/treesitter/parser_spec.lua @@ -8,6 +8,8 @@ local exec_lua = helpers.exec_lua local pcall_err = helpers.pcall_err local feed = helpers.feed local is_os = helpers.is_os +local api = helpers.api +local fn = helpers.fn describe('treesitter parser API', function() before_each(function() @@ -171,7 +173,7 @@ void ui_refresh(void) assert(res_fail) end) - local query = [[ + local test_query = [[ ((call_expression function: (identifier) @minfunc (argument_list (identifier) @min_id)) (eq? @minfunc "MIN")) "for" @keyword (primitive_type) @type @@ -187,7 +189,7 @@ void ui_refresh(void) end) it('supports caching queries', function() - local long_query = query:rep(100) + local long_query = test_query:rep(100) local function q(n) return exec_lua( [[ @@ -230,7 +232,7 @@ void ui_refresh(void) end return res ]], - query + test_query ) eq({ @@ -256,17 +258,19 @@ void ui_refresh(void) parser = vim.treesitter.get_parser(0, "c") tree = parser:parse()[1] res = {} - for pattern, match in cquery:iter_matches(tree:root(), 0, 7, 14) do + for pattern, match in cquery:iter_matches(tree:root(), 0, 7, 14, { all = true }) do -- can't transmit node over RPC. just check the name and range local mrepr = {} - for cid,node in pairs(match) do - table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) + for cid, nodes in pairs(match) do + for _, node in ipairs(nodes) do + table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) + end end table.insert(res, {pattern, mrepr}) end return res ]], - query + test_query ) eq({ @@ -287,6 +291,67 @@ void ui_refresh(void) }, res) end) + it('support query and iter by capture for quantifiers', function() + insert(test_text) + + local res = exec_lua( + [[ + cquery = vim.treesitter.query.parse("c", ...) + parser = vim.treesitter.get_parser(0, "c") + tree = parser:parse()[1] + res = {} + for cid, node in cquery:iter_captures(tree:root(), 0, 7, 14) do + -- can't transmit node over RPC. just check the name and range + table.insert(res, {cquery.captures[cid], node:type(), node:range()}) + end + return res + ]], + '(expression_statement (assignment_expression (call_expression)))+ @funccall' + ) + + eq({ + { 'funccall', 'expression_statement', 11, 4, 11, 34 }, + { 'funccall', 'expression_statement', 12, 4, 12, 37 }, + { 'funccall', 'expression_statement', 13, 4, 13, 34 }, + }, res) + end) + + it('support query and iter by match for quantifiers', function() + insert(test_text) + + local res = exec_lua( + [[ + cquery = vim.treesitter.query.parse("c", ...) + parser = vim.treesitter.get_parser(0, "c") + tree = parser:parse()[1] + res = {} + for pattern, match in cquery:iter_matches(tree:root(), 0, 7, 14, { all = true }) do + -- can't transmit node over RPC. just check the name and range + local mrepr = {} + for cid, nodes in pairs(match) do + for _, node in ipairs(nodes) do + table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) + end + end + table.insert(res, {pattern, mrepr}) + end + return res + ]], + '(expression_statement (assignment_expression (call_expression)))+ @funccall' + ) + + eq({ + { + 1, + { + { 'funccall', 'expression_statement', 11, 4, 11, 34 }, + { 'funccall', 'expression_statement', 12, 4, 12, 37 }, + { 'funccall', 'expression_statement', 13, 4, 13, 34 }, + }, + }, + }, res) + end) + it('supports getting text of multiline node', function() insert(test_text) local res = exec_lua([[ @@ -365,11 +430,13 @@ end]] parser = vim.treesitter.get_parser(0, "c") tree = parser:parse()[1] res = {} - for pattern, match in cquery:iter_matches(tree:root(), 0) do + for pattern, match in cquery:iter_matches(tree:root(), 0, 0, -1, { all = true }) do -- can't transmit node over RPC. just check the name and range local mrepr = {} - for cid,node in pairs(match) do - table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) + for cid, nodes in pairs(match) do + for _, node in ipairs(nodes) do + table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) + end end table.insert(res, {pattern, mrepr}) end @@ -457,11 +524,13 @@ end]] parser = vim.treesitter.get_parser(0, "c") tree = parser:parse()[1] res = {} - for pattern, match in cquery:iter_matches(tree:root(), 0) do + for pattern, match in cquery:iter_matches(tree:root(), 0, 0, -1, { all = true }) do -- can't transmit node over RPC. just check the name and range local mrepr = {} - for cid,node in pairs(match) do - table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) + for cid, nodes in pairs(match) do + for _, node in ipairs(nodes) do + table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) + end end table.insert(res, {pattern, mrepr}) end @@ -486,55 +555,248 @@ end]] local custom_query = '((identifier) @main (#is-main? @main))' - local res = exec_lua( - [[ - local query = vim.treesitter.query + do + local res = exec_lua( + [[ + local query = vim.treesitter.query - local function is_main(match, pattern, bufnr, predicate) - local node = match[ predicate[2] ] + local function is_main(match, pattern, bufnr, predicate) + local nodes = match[ predicate[2] ] + for _, node in ipairs(nodes) do + if query.get_node_text(node, bufnr) == 'main' then + return true + end + end + return false + end - return query.get_node_text(node, bufnr) + local parser = vim.treesitter.get_parser(0, "c") + + -- Time bomb: update this in 0.12 + if vim.fn.has('nvim-0.12') == 1 then + return 'Update this test to remove this message and { all = true } from add_predicate' + end + query.add_predicate("is-main?", is_main, { all = true }) + + local query = query.parse("c", ...) + + local nodes = {} + for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do + table.insert(nodes, {node:range()}) + end + + return nodes + ]], + custom_query + ) + + eq({ { 0, 4, 0, 8 } }, res) end - local parser = vim.treesitter.get_parser(0, "c") + -- Once with the old API. Remove this whole 'do' block in 0.12 + do + local res = exec_lua( + [[ + local query = vim.treesitter.query - query.add_predicate("is-main?", is_main) + local function is_main(match, pattern, bufnr, predicate) + local node = match[ predicate[2] ] - local query = query.parse("c", ...) + return query.get_node_text(node, bufnr) == 'main' + end - local nodes = {} - for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do - table.insert(nodes, {node:range()}) + local parser = vim.treesitter.get_parser(0, "c") + + query.add_predicate("is-main?", is_main, true) + + local query = query.parse("c", ...) + + local nodes = {} + for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do + table.insert(nodes, {node:range()}) + end + + return nodes + ]], + custom_query + ) + + -- Remove this 'do' block in 0.12 + eq(0, fn.has('nvim-0.12')) + eq({ { 0, 4, 0, 8 } }, res) end - return nodes - ]], - custom_query - ) + do + local res = exec_lua [[ + local query = vim.treesitter.query - eq({ { 0, 4, 0, 8 } }, res) + local t = {} + for _, v in ipairs(query.list_predicates()) do + t[v] = true + end - local res_list = exec_lua [[ - local query = vim.treesitter.query + return t + ]] - local list = query.list_predicates() + eq(true, res['is-main?']) + end + end) - table.sort(list) - - return list + it('supports "all" and "any" semantics for predicates on quantified captures #24738', function() + local query_all = [[ + (((comment (comment_content))+) @bar + (#lua-match? @bar "Yes")) ]] + local query_any = [[ + (((comment (comment_content))+) @bar + (#any-lua-match? @bar "Yes")) + ]] + + local function test(input, query) + api.nvim_buf_set_lines(0, 0, -1, true, vim.split(dedent(input), '\n')) + return exec_lua( + [[ + local parser = vim.treesitter.get_parser(0, "lua") + local query = vim.treesitter.query.parse("lua", ...) + local nodes = {} + for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do + nodes[#nodes+1] = { node:range() } + end + return nodes + ]], + query + ) + end + + eq( + {}, + test( + [[ + -- Yes + -- No + -- Yes + ]], + query_all + ) + ) + + eq( + { + { 0, 2, 0, 8 }, + { 1, 2, 1, 8 }, + { 2, 2, 2, 8 }, + }, + test( + [[ + -- Yes + -- Yes + -- Yes + ]], + query_all + ) + ) + + eq( + {}, + test( + [[ + -- No + -- No + -- No + ]], + query_any + ) + ) + + eq( + { + { 0, 2, 0, 7 }, + { 1, 2, 1, 8 }, + { 2, 2, 2, 7 }, + }, + test( + [[ + -- No + -- Yes + -- No + ]], + query_any + ) + ) + end) + + it('supports any- prefix to match any capture when using quantifiers #24738', function() + insert([[ + -- Comment + -- Comment + -- Comment + ]]) + + local query = [[ + (((comment (comment_content))+) @bar + (#lua-match? @bar "Comment")) + ]] + + local result = exec_lua( + [[ + local parser = vim.treesitter.get_parser(0, "lua") + local query = vim.treesitter.query.parse("lua", ...) + local nodes = {} + for _, node in query:iter_captures(parser:parse()[1]:root(), 0) do + nodes[#nodes+1] = { node:range() } + end + return nodes + ]], + query + ) + eq({ - 'any-of?', - 'contains?', - 'eq?', - 'has-ancestor?', - 'has-parent?', - 'is-main?', - 'lua-match?', - 'match?', - 'vim-match?', - }, res_list) + { 0, 2, 0, 12 }, + { 1, 2, 1, 12 }, + { 2, 2, 2, 12 }, + }, result) + end) + + it('supports the old broken version of iter_matches #24738', function() + -- Delete this test in 0.12 when iter_matches is removed + eq(0, fn.has('nvim-0.12')) + + insert(test_text) + local res = exec_lua( + [[ + cquery = vim.treesitter.query.parse("c", ...) + parser = vim.treesitter.get_parser(0, "c") + tree = parser:parse()[1] + res = {} + for pattern, match in cquery:iter_matches(tree:root(), 0, 7, 14) do + local mrepr = {} + for cid, node in pairs(match) do + table.insert(mrepr, {cquery.captures[cid], node:type(), node:range()}) + end + table.insert(res, {pattern, mrepr}) + end + return res + ]], + test_query + ) + + eq({ + { 3, { { 'type', 'primitive_type', 8, 2, 8, 6 } } }, + { 2, { { 'keyword', 'for', 9, 2, 9, 5 } } }, + { 3, { { 'type', 'primitive_type', 9, 7, 9, 13 } } }, + { 4, { { 'fieldarg', 'identifier', 11, 16, 11, 18 } } }, + { + 1, + { { 'minfunc', 'identifier', 11, 12, 11, 15 }, { 'min_id', 'identifier', 11, 27, 11, 32 } }, + }, + { 4, { { 'fieldarg', 'identifier', 12, 17, 12, 19 } } }, + { + 1, + { { 'minfunc', 'identifier', 12, 13, 12, 16 }, { 'min_id', 'identifier', 12, 29, 12, 35 } }, + }, + { 4, { { 'fieldarg', 'identifier', 13, 14, 13, 16 } } }, + }, res) end) it('allows to set simple ranges', function() @@ -866,7 +1128,7 @@ int x = INT_MAX; query = vim.treesitter.query.parse("c", '((number_literal) @number (#set! "key" "value"))') parser = vim.treesitter.get_parser(0, "c") - for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0) do + for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0, 0, -1, { all = true }) do result = metadata.key end @@ -889,7 +1151,7 @@ int x = INT_MAX; query = vim.treesitter.query.parse("c", '((number_literal) @number (#set! @number "key" "value"))') parser = vim.treesitter.get_parser(0, "c") - for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0) do + for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0, 0, -1, { all = true }) do for _, nested_tbl in pairs(metadata) do return nested_tbl.key end @@ -911,7 +1173,7 @@ int x = INT_MAX; query = vim.treesitter.query.parse("c", '((number_literal) @number (#set! @number "key" "value") (#set! @number "key2" "value2"))') parser = vim.treesitter.get_parser(0, "c") - for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0) do + for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0, 0, -1, { all = true }) do for _, nested_tbl in pairs(metadata) do return nested_tbl end