Merge pull request #22613 from lewis6991/feat/tsqueryutil

This commit is contained in:
Lewis Russell 2023-03-11 17:13:20 +00:00 committed by GitHub
commit b55b8ddf81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 125 additions and 123 deletions

View File

@ -50,6 +50,11 @@ The following changes may require adaptations in user config or plugins.
• Unsaved changes are now preserved rather than discarded when |channel-stdio| • Unsaved changes are now preserved rather than discarded when |channel-stdio|
is closed. is closed.
• Changes to |vim.treesitter.get_node_text()|:
- It now returns `string`, as opposed to `string|string[]|nil`.
- The `concat` option has been removed as it was not consistently applied.
- Invalid ranges now cause an error instead of returning `nil`.
============================================================================== ==============================================================================
NEW FEATURES *news-features* NEW FEATURES *news-features*

View File

@ -789,14 +789,12 @@ get_node_text({node}, {source}, {opts})
• {source} (integer|string) Buffer or string from which the {node} is • {source} (integer|string) Buffer or string from which the {node} is
extracted extracted
• {opts} (table|nil) Optional parameters. • {opts} (table|nil) Optional parameters.
• concat: (boolean) Concatenate result in a string (default
true)
• metadata (table) Metadata of a specific capture. This • metadata (table) Metadata of a specific capture. This
would be set to `metadata[capture_id]` when using would be set to `metadata[capture_id]` when using
|vim.treesitter.add_directive()|. |vim.treesitter.add_directive()|.
Return: ~ Return: ~
(string[]|string|nil) (string)
get_query({lang}, {query_name}) *vim.treesitter.get_query()* get_query({lang}, {query_name}) *vim.treesitter.get_query()*
Returns the runtime query {query_name} for {lang}. Returns the runtime query {query_name} for {lang}.
@ -822,6 +820,19 @@ get_query_files({lang}, {query_name}, {is_included})
string[] query_files List of files to load for given query and string[] query_files List of files to load for given query and
language language
get_range({node}, {source}, {metadata}) *vim.treesitter.get_range()*
Get the range of a |TSNode|. Can also supply {source} and {metadata} to
get the range with directives applied.
Parameters: ~
• {node} |TSNode|
• {source} integer|string|nil Buffer or string from which the {node}
is extracted
• {metadata} TSMetadata|nil
Return: ~
(table)
list_directives() *vim.treesitter.list_directives()* list_directives() *vim.treesitter.list_directives()*
Lists the currently available directives to use in queries. Lists the currently available directives to use in queries.
@ -887,7 +898,8 @@ Query:iter_captures({self}, {node}, {source}, {start}, {stop})
• {self} • {self}
Return: ~ Return: ~
(fun(): integer, TSNode, TSMetadata ): capture id, capture node, metadata (fun(): integer, TSNode, TSMetadata): capture id, capture node,
metadata
*Query:iter_matches()* *Query:iter_matches()*
Query:iter_matches({self}, {node}, {source}, {start}, {stop}) Query:iter_matches({self}, {node}, {source}, {start}, {stop})

View File

@ -1,4 +1,5 @@
local Range = require('vim.treesitter._range') local Range = require('vim.treesitter._range')
local Query = require('vim.treesitter.query')
local api = vim.api local api = vim.api
@ -74,18 +75,6 @@ function FoldInfo:get_stop(lnum)
return self.stop_counts[lnum] or 0 return self.stop_counts[lnum] or 0
end end
---@private
--- TODO(lewis6991): copied from languagetree.lua. Consolidate
---@param node TSNode
---@param metadata TSMetadata
---@return Range4
local function get_range_from_metadata(node, metadata)
if metadata and metadata.range then
return metadata.range --[[@as Range4]]
end
return { node:range() }
end
local function trim_level(level) local function trim_level(level)
local max_fold_level = vim.wo.foldnestmax local max_fold_level = vim.wo.foldnestmax
if level > max_fold_level then if level > max_fold_level then
@ -118,7 +107,7 @@ local function get_folds_levels(bufnr, info, srow, erow)
for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do
if query.captures[id] == 'fold' then if query.captures[id] == 'fold' then
local range = get_range_from_metadata(node, metadata[id]) local range = Query.get_range(node, bufnr, metadata[id])
local start, _, stop, stop_col = Range.unpack4(range) local start, _, stop, stop_col = Range.unpack4(range)
if stop_col == 0 then if stop_col == 0 then

View File

@ -2,8 +2,21 @@ local api = vim.api
local M = {} local M = {}
---@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer} ---@class Range4
---@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer} ---@field [1] integer start row
---@field [2] integer start column
---@field [3] integer end row
---@field [4] integer end column
---@class Range6
---@field [1] integer start row
---@field [2] integer start column
---@field [3] integer start bytes
---@field [4] integer end row
---@field [5] integer end column
---@field [6] integer end bytes
---@alias Range Range4|Range6
---@private ---@private
---@param a_row integer ---@param a_row integer
@ -74,8 +87,8 @@ function M.validate(r)
end end
---@private ---@private
---@param r1 Range4|Range6 ---@param r1 Range
---@param r2 Range4|Range6 ---@param r2 Range
---@return boolean ---@return boolean
function M.intercepts(r1, r2) function M.intercepts(r1, r2)
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
@ -95,7 +108,7 @@ function M.intercepts(r1, r2)
end end
---@private ---@private
---@param r Range4|Range6 ---@param r Range
---@return integer, integer, integer, integer ---@return integer, integer, integer, integer
function M.unpack4(r) function M.unpack4(r)
local off_1 = #r == 6 and 1 or 0 local off_1 = #r == 6 and 1 or 0
@ -110,8 +123,8 @@ function M.unpack6(r)
end end
---@private ---@private
---@param r1 Range4|Range6 ---@param r1 Range
---@param r2 Range4|Range6 ---@param r2 Range
---@return boolean whether r1 contains r2 ---@return boolean whether r1 contains r2
function M.contains(r1, r2) function M.contains(r1, r2)
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
@ -132,7 +145,7 @@ end
---@private ---@private
---@param source integer|string ---@param source integer|string
---@param range Range4|Range6 ---@param range Range
---@return Range6 ---@return Range6
function M.add_bytes(source, range) function M.add_bytes(source, range)
if type(range) == 'table' and #range == 6 then if type(range) == 'table' and #range == 6 then

View File

@ -448,7 +448,7 @@ end
--- nodes, which is useful for templating languages like ERB and EJS. --- nodes, which is useful for templating languages like ERB and EJS.
--- ---
---@private ---@private
---@param new_regions Range4[][] List of regions this tree should manage and parse. ---@param new_regions Range6[][] List of regions this tree should manage and parse.
function LanguageTree:set_included_regions(new_regions) function LanguageTree:set_included_regions(new_regions)
-- Transform the tables from 4 element long to 6 element long (with byte offset) -- Transform the tables from 4 element long to 6 element long (with byte offset)
for _, region in ipairs(new_regions) do for _, region in ipairs(new_regions) do
@ -478,24 +478,11 @@ end
---@private ---@private
---@param node TSNode ---@param node TSNode
---@param source integer|string
---@param metadata TSMetadata
---@return Range6
local function get_range_from_metadata(node, source, metadata)
if metadata and metadata.range then
return Range.add_bytes(source, metadata.range --[[@as Range4|Range6]])
end
return { node:range(true) }
end
---@private
--- TODO(lewis6991): cleanup of the node_range interface
---@param node TSNode
---@param source string|integer ---@param source string|integer
---@param metadata TSMetadata ---@param metadata TSMetadata
---@return Range6[] ---@return Range6[]
local function get_node_ranges(node, source, metadata, include_children) local function get_node_ranges(node, source, metadata, include_children)
local range = get_range_from_metadata(node, source, metadata) local range = query.get_range(node, source, metadata)
if include_children then if include_children then
return { range } return { range }
@ -535,7 +522,7 @@ end
---@param pattern integer ---@param pattern integer
---@param lang string ---@param lang string
---@param combined boolean ---@param combined boolean
---@param ranges Range4[] ---@param ranges Range6[]
local function add_injection(t, tree_index, pattern, lang, combined, ranges) local function add_injection(t, tree_index, pattern, lang, combined, ranges)
assert(type(lang) == 'string') assert(type(lang) == 'string')
@ -558,31 +545,16 @@ local function add_injection(t, tree_index, pattern, lang, combined, ranges)
table.insert(t[tree_index][lang][pattern].regions, ranges) table.insert(t[tree_index][lang][pattern].regions, ranges)
end end
---@private
---Get node text
---
---Note: `query.get_node_text` returns string|string[]|nil so use this simple alias function
---to annotate it returns string.
---
---TODO(lewis6991): use [at]overload annotations on `query.get_node_text`
---@param node TSNode
---@param source integer|string
---@param metadata table
---@return string
local function get_node_text(node, source, metadata)
return query.get_node_text(node, source, { metadata = metadata }) --[[@as string]]
end
---@private ---@private
--- Extract injections according to: --- Extract injections according to:
--- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection --- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection
---@param match table<integer,TSNode> ---@param match table<integer,TSNode>
---@param metadata table ---@param metadata TSMetadata
---@return string, boolean, Range4[] ---@return string?, boolean, Range6[]
function LanguageTree:_get_injection(match, metadata) function LanguageTree:_get_injection(match, metadata)
local ranges = {} ---@type Range4[] local ranges = {} ---@type Range6[]
local combined = metadata['injection.combined'] ~= nil local combined = metadata['injection.combined'] ~= nil
local lang = metadata['injection.language'] ---@type string local lang = metadata['injection.language'] --[[@as string?]]
local include_children = metadata['injection.include-children'] ~= nil local include_children = metadata['injection.include-children'] ~= nil
for id, node in pairs(match) do for id, node in pairs(match) do
@ -590,7 +562,7 @@ function LanguageTree:_get_injection(match, metadata)
-- Lang should override any other language tag -- Lang should override any other language tag
if name == 'injection.language' then if name == 'injection.language' then
lang = get_node_text(node, self._source, metadata[id]) lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'injection.content' then elseif name == 'injection.content' then
ranges = get_node_ranges(node, self._source, metadata[id], include_children) ranges = get_node_ranges(node, self._source, metadata[id], include_children)
end end
@ -601,11 +573,11 @@ end
---@private ---@private
---@param match table<integer,TSNode> ---@param match table<integer,TSNode>
---@param metadata table ---@param metadata TSMetadata
---@return string, boolean, Range4[] ---@return string, boolean, Range6[]
function LanguageTree:_get_injection_deprecated(match, metadata) function LanguageTree:_get_injection_deprecated(match, metadata)
local lang = nil ---@type string local lang = nil ---@type string
local ranges = {} ---@type Range4[] local ranges = {} ---@type Range6[]
local combined = metadata.combined ~= nil local combined = metadata.combined ~= nil
-- Directives can configure how injections are captured as well as actual node captures. -- Directives can configure how injections are captured as well as actual node captures.
@ -623,8 +595,10 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
end end
end end
if metadata.language then local mlang = metadata.language
lang = metadata.language ---@type string if mlang ~= nil then
assert(type(mlang) == 'string')
lang = mlang
end end
-- You can specify the content and language together -- You can specify the content and language together
@ -635,11 +609,11 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
-- Lang should override any other language tag -- Lang should override any other language tag
if name == 'language' and not lang then if name == 'language' and not lang then
lang = get_node_text(node, self._source, metadata[id]) lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'combined' then elseif name == 'combined' then
combined = true combined = true
elseif name == 'content' and #ranges == 0 then elseif name == 'content' and #ranges == 0 then
ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id]) ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id])
-- Ignore any tags that start with "_" -- Ignore any tags that start with "_"
-- Allows for other tags to be used in matches -- Allows for other tags to be used in matches
elseif string.sub(name, 1, 1) ~= '_' then elseif string.sub(name, 1, 1) ~= '_' then
@ -648,7 +622,7 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
end end
if #ranges == 0 then if #ranges == 0 then
ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id]) ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id])
end end
end end
end end
@ -926,7 +900,7 @@ end
---@private ---@private
---@param tree TSTree ---@param tree TSTree
---@param range Range4 ---@param range Range
---@return boolean ---@return boolean
local function tree_contains(tree, range) local function tree_contains(tree, range)
return Range.contains({ tree:root():range() }, range) return Range.contains({ tree:root():range() }, range)

View File

@ -1,6 +1,8 @@
local a = vim.api local a = vim.api
local language = require('vim.treesitter.language') local language = require('vim.treesitter.language')
local Range = require('vim.treesitter._range')
---@class Query ---@class Query
---@field captures string[] List of captures used in query ---@field captures string[] List of captures used in query
---@field info TSQueryInfo Contains used queries, predicates, directives ---@field info TSQueryInfo Contains used queries, predicates, directives
@ -56,35 +58,21 @@ local function add_included_lang(base_langs, lang, ilang)
end end
---@private ---@private
---@param buf (integer) ---@param buf integer
---@param range (table) ---@param range Range
---@param concat (boolean) ---@returns string
---@returns (string[]|string|nil) local function buf_range_get_text(buf, range)
local function buf_range_get_text(buf, range, concat) local start_row, start_col, end_row, end_col = Range.unpack4(range)
local lines
local start_row, start_col, end_row, end_col = unpack(range)
local eof_row = a.nvim_buf_line_count(buf)
if start_row >= eof_row then
return nil
end
if end_col == 0 then if end_col == 0 then
lines = a.nvim_buf_get_lines(buf, start_row, end_row, true) if start_row == end_row then
end_col = -1 start_col = -1
else start_row = start_row - 1
lines = a.nvim_buf_get_lines(buf, start_row, end_row + 1, true)
end
if #lines > 0 then
if #lines == 1 then
lines[1] = string.sub(lines[1], start_col + 1, end_col)
else
lines[1] = string.sub(lines[1], start_col + 1)
lines[#lines] = string.sub(lines[#lines], 1, end_col)
end end
end_col = -1
end_row = end_row - 1
end end
local lines = a.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {})
return concat and table.concat(lines, '\n') or lines return table.concat(lines, '\n')
end end
--- Gets the list of files used to make up a query --- Gets the list of files used to make up a query
@ -256,14 +244,28 @@ function M.parse_query(lang, query)
local cached = query_cache[lang][query] local cached = query_cache[lang][query]
if cached then if cached then
return cached return cached
else
local self = setmetatable({}, Query)
self.query = vim._ts_parse_query(lang, query)
self.info = self.query:inspect()
self.captures = self.info.captures
query_cache[lang][query] = self
return self
end end
local self = setmetatable({}, Query)
self.query = vim._ts_parse_query(lang, query)
self.info = self.query:inspect()
self.captures = self.info.captures
query_cache[lang][query] = self
return self
end
---Get the range of a |TSNode|. Can also supply {source} and {metadata}
---to get the range with directives applied.
---@param node TSNode
---@param source integer|string|nil Buffer or string from which the {node} is extracted
---@param metadata TSMetadata|nil
---@return Range6
function M.get_range(node, source, metadata)
if metadata and metadata.range then
assert(source)
return Range.add_bytes(source, metadata.range)
end
return { node:range(true) }
end end
--- Gets the text corresponding to a given node --- Gets the text corresponding to a given node
@ -271,24 +273,22 @@ end
---@param node TSNode ---@param node TSNode
---@param source (integer|string) Buffer or string from which the {node} is extracted ---@param source (integer|string) Buffer or string from which the {node} is extracted
---@param opts (table|nil) Optional parameters. ---@param opts (table|nil) Optional parameters.
--- - concat: (boolean) Concatenate result in a string (default true)
--- - metadata (table) Metadata of a specific capture. This would be --- - metadata (table) Metadata of a specific capture. This would be
--- set to `metadata[capture_id]` when using |vim.treesitter.add_directive()|. --- set to `metadata[capture_id]` when using |vim.treesitter.add_directive()|.
---@return (string[]|string|nil) ---@return string
function M.get_node_text(node, source, opts) function M.get_node_text(node, source, opts)
opts = opts or {} opts = opts or {}
-- TODO(lewis6991): concat only works when source is number.
local concat = vim.F.if_nil(opts.concat, true)
local metadata = opts.metadata or {} local metadata = opts.metadata or {}
if metadata.text then if metadata.text then
return metadata.text return metadata.text
elseif type(source) == 'number' then elseif type(source) == 'number' then
return metadata.range and buf_range_get_text(source, metadata.range, concat) local range = M.get_range(node, source, metadata)
or buf_range_get_text(source, { node:range() }, concat) return buf_range_get_text(source, range)
elseif type(source) == 'string' then
return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
end end
---@cast source string
return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
end end
---@alias TSMatch table<integer,TSNode> ---@alias TSMatch table<integer,TSNode>
@ -312,7 +312,7 @@ local predicate_handlers = {
str = predicate[3] str = predicate[3]
else else
-- (#eq? @aa @bb) -- (#eq? @aa @bb)
str = M.get_node_text(match[predicate[3]], source) --[[@as string]] str = M.get_node_text(match[predicate[3]], source)
end end
if node_text ~= str or str == nil then if node_text ~= str or str == nil then
@ -328,7 +328,7 @@ local predicate_handlers = {
return true return true
end end
local regex = predicate[3] local regex = predicate[3]
return string.find(M.get_node_text(node, source) --[[@as string]], regex) ~= nil return string.find(M.get_node_text(node, source), regex) ~= nil
end, end,
['match?'] = (function() ['match?'] = (function()
@ -366,7 +366,7 @@ local predicate_handlers = {
if not node then if not node then
return true return true
end end
local node_text = M.get_node_text(node, source) --[[@as string]] local node_text = M.get_node_text(node, source)
for i = 3, #predicate do for i = 3, #predicate do
if string.find(node_text, predicate[i], 1, true) then if string.find(node_text, predicate[i], 1, true) then
@ -404,9 +404,9 @@ local predicate_handlers = {
predicate_handlers['vim-match?'] = predicate_handlers['match?'] predicate_handlers['vim-match?'] = predicate_handlers['match?']
---@class TSMetadata ---@class TSMetadata
---@field range Range
---@field [integer] TSMetadata ---@field [integer] TSMetadata
---@field [string] integer|string ---@field [string] integer|string
---@field range Range4
---@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata) ---@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata)
@ -465,13 +465,20 @@ local directive_handlers = {
assert(#pred == 4) assert(#pred == 4)
local id = pred[2] local id = pred[2]
assert(type(id) == 'number')
local node = match[id] local node = match[id]
local text = M.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' local text = M.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
if not metadata[id] then if not metadata[id] then
metadata[id] = {} metadata[id] = {}
end end
metadata[id].text = text:gsub(pred[3], pred[4])
local pattern, replacement = pred[3], pred[3]
assert(type(pattern) == 'string')
assert(type(replacement) == 'string')
metadata[id].text = text:gsub(pattern, replacement)
end, end,
} }

View File

@ -302,7 +302,7 @@ local types = { 'integer', 'number', 'string', 'table', 'list', 'boolean', 'func
local tagged_types = { 'TSNode', 'LanguageTree' } local tagged_types = { 'TSNode', 'LanguageTree' }
-- Document these as 'table' -- Document these as 'table'
local alias_types = { 'Range4', 'Range6' } local alias_types = { 'Range', 'Range4', 'Range6', 'TSMetadata' }
-- Processes the file and writes filtered output to stdout. -- Processes the file and writes filtered output to stdout.
function TLua2DoX_filter.filter(this, AppStamp, Filename) function TLua2DoX_filter.filter(this, AppStamp, Filename)

View File

@ -196,7 +196,7 @@ void ui_refresh(void)
local manyruns = q(100) local manyruns = q(100)
-- First run should be at least 400x slower than an 100 subsequent runs. -- First run should be at least 400x slower than an 100 subsequent runs.
local factor = is_os('win') and 300 or 400 local factor = is_os('win') and 200 or 400
assert(factor * manyruns < firstrun, ('firstrun: %f ms, manyruns: %f ms'):format(firstrun / 1e6, manyruns / 1e6)) assert(factor * manyruns < firstrun, ('firstrun: %f ms, manyruns: %f ms'):format(firstrun / 1e6, manyruns / 1e6))
end) end)
@ -277,13 +277,13 @@ void ui_refresh(void)
eq('void', res2) eq('void', res2)
end) end)
it('support getting text where start of node is past EOF', function() it('support getting text where start of node is one past EOF', function()
local text = [[ local text = [[
def run def run
a = <<~E a = <<~E
end]] end]]
insert(text) insert(text)
local result = exec_lua([[ eq('', exec_lua[[
local fake_node = {} local fake_node = {}
function fake_node:start() function fake_node:start()
return 3, 0, 23 return 3, 0, 23
@ -291,12 +291,14 @@ end]]
function fake_node:end_() function fake_node:end_()
return 3, 0, 23 return 3, 0, 23
end end
function fake_node:range() function fake_node:range(bytes)
if bytes then
return 3, 0, 23, 3, 0, 23
end
return 3, 0, 3, 0 return 3, 0, 3, 0
end end
return vim.treesitter.get_node_text(fake_node, 0) == nil return vim.treesitter.get_node_text(fake_node, 0)
]]) ]])
eq(true, result)
end) end)
it('support getting empty text if node range is zero width', function() it('support getting empty text if node range is zero width', function()