diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt index 801b74df45..78cd05b7d5 100644 --- a/runtime/doc/news.txt +++ b/runtime/doc/news.txt @@ -50,6 +50,11 @@ The following changes may require adaptations in user config or plugins. • Unsaved changes are now preserved rather than discarded when |channel-stdio| is closed. +• Changes to |vim.treesitter.get_node_text()|: + - It now returns `string`, as opposed to `string|string[]|nil`. + - The `concat` option has been removed as it was not consistently applied. + - Invalid ranges now cause an error instead of returning `nil`. + ============================================================================== NEW FEATURES *news-features* diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt index a8f25d2ff9..ddca307e74 100644 --- a/runtime/doc/treesitter.txt +++ b/runtime/doc/treesitter.txt @@ -789,14 +789,12 @@ get_node_text({node}, {source}, {opts}) • {source} (integer|string) Buffer or string from which the {node} is extracted • {opts} (table|nil) Optional parameters. - • concat: (boolean) Concatenate result in a string (default - true) • metadata (table) Metadata of a specific capture. This would be set to `metadata[capture_id]` when using |vim.treesitter.add_directive()|. Return: ~ - (string[]|string|nil) + (string) get_query({lang}, {query_name}) *vim.treesitter.get_query()* Returns the runtime query {query_name} for {lang}. @@ -822,6 +820,19 @@ get_query_files({lang}, {query_name}, {is_included}) string[] query_files List of files to load for given query and language +get_range({node}, {source}, {metadata}) *vim.treesitter.get_range()* + Get the range of a |TSNode|. Can also supply {source} and {metadata} to + get the range with directives applied. + + Parameters: ~ + • {node} |TSNode| + • {source} integer|string|nil Buffer or string from which the {node} + is extracted + • {metadata} TSMetadata|nil + + Return: ~ + (table) + list_directives() *vim.treesitter.list_directives()* Lists the currently available directives to use in queries. @@ -887,7 +898,8 @@ Query:iter_captures({self}, {node}, {source}, {start}, {stop}) • {self} Return: ~ - (fun(): integer, TSNode, TSMetadata ): capture id, capture node, metadata + (fun(): integer, TSNode, TSMetadata): capture id, capture node, + metadata *Query:iter_matches()* Query:iter_matches({self}, {node}, {source}, {start}, {stop}) diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index 435cb9fdb6..fd2c707d17 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -1,4 +1,5 @@ local Range = require('vim.treesitter._range') +local Query = require('vim.treesitter.query') local api = vim.api @@ -74,18 +75,6 @@ function FoldInfo:get_stop(lnum) return self.stop_counts[lnum] or 0 end ----@private ---- TODO(lewis6991): copied from languagetree.lua. Consolidate ----@param node TSNode ----@param metadata TSMetadata ----@return Range4 -local function get_range_from_metadata(node, metadata) - if metadata and metadata.range then - return metadata.range --[[@as Range4]] - end - return { node:range() } -end - local function trim_level(level) local max_fold_level = vim.wo.foldnestmax if level > max_fold_level then @@ -118,7 +107,7 @@ local function get_folds_levels(bufnr, info, srow, erow) for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do if query.captures[id] == 'fold' then - local range = get_range_from_metadata(node, metadata[id]) + local range = Query.get_range(node, bufnr, metadata[id]) local start, _, stop, stop_col = Range.unpack4(range) if stop_col == 0 then diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua index 02918da23f..f4db5016ac 100644 --- a/runtime/lua/vim/treesitter/_range.lua +++ b/runtime/lua/vim/treesitter/_range.lua @@ -2,8 +2,21 @@ local api = vim.api local M = {} ----@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer} ----@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer} +---@class Range4 +---@field [1] integer start row +---@field [2] integer start column +---@field [3] integer end row +---@field [4] integer end column + +---@class Range6 +---@field [1] integer start row +---@field [2] integer start column +---@field [3] integer start bytes +---@field [4] integer end row +---@field [5] integer end column +---@field [6] integer end bytes + +---@alias Range Range4|Range6 ---@private ---@param a_row integer @@ -74,8 +87,8 @@ function M.validate(r) end ---@private ----@param r1 Range4|Range6 ----@param r2 Range4|Range6 +---@param r1 Range +---@param r2 Range ---@return boolean function M.intercepts(r1, r2) local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) @@ -95,7 +108,7 @@ function M.intercepts(r1, r2) end ---@private ----@param r Range4|Range6 +---@param r Range ---@return integer, integer, integer, integer function M.unpack4(r) local off_1 = #r == 6 and 1 or 0 @@ -110,8 +123,8 @@ function M.unpack6(r) end ---@private ----@param r1 Range4|Range6 ----@param r2 Range4|Range6 +---@param r1 Range +---@param r2 Range ---@return boolean whether r1 contains r2 function M.contains(r1, r2) local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) @@ -132,7 +145,7 @@ end ---@private ---@param source integer|string ----@param range Range4|Range6 +---@param range Range ---@return Range6 function M.add_bytes(source, range) if type(range) == 'table' and #range == 6 then diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 26321cd1f4..bdfe281a5b 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -448,7 +448,7 @@ end --- nodes, which is useful for templating languages like ERB and EJS. --- ---@private ----@param new_regions Range4[][] List of regions this tree should manage and parse. +---@param new_regions Range6[][] List of regions this tree should manage and parse. function LanguageTree:set_included_regions(new_regions) -- Transform the tables from 4 element long to 6 element long (with byte offset) for _, region in ipairs(new_regions) do @@ -478,24 +478,11 @@ end ---@private ---@param node TSNode ----@param source integer|string ----@param metadata TSMetadata ----@return Range6 -local function get_range_from_metadata(node, source, metadata) - if metadata and metadata.range then - return Range.add_bytes(source, metadata.range --[[@as Range4|Range6]]) - end - return { node:range(true) } -end - ----@private ---- TODO(lewis6991): cleanup of the node_range interface ----@param node TSNode ---@param source string|integer ---@param metadata TSMetadata ---@return Range6[] local function get_node_ranges(node, source, metadata, include_children) - local range = get_range_from_metadata(node, source, metadata) + local range = query.get_range(node, source, metadata) if include_children then return { range } @@ -535,7 +522,7 @@ end ---@param pattern integer ---@param lang string ---@param combined boolean ----@param ranges Range4[] +---@param ranges Range6[] local function add_injection(t, tree_index, pattern, lang, combined, ranges) assert(type(lang) == 'string') @@ -558,31 +545,16 @@ local function add_injection(t, tree_index, pattern, lang, combined, ranges) table.insert(t[tree_index][lang][pattern].regions, ranges) end ----@private ----Get node text ---- ----Note: `query.get_node_text` returns string|string[]|nil so use this simple alias function ----to annotate it returns string. ---- ----TODO(lewis6991): use [at]overload annotations on `query.get_node_text` ----@param node TSNode ----@param source integer|string ----@param metadata table ----@return string -local function get_node_text(node, source, metadata) - return query.get_node_text(node, source, { metadata = metadata }) --[[@as string]] -end - ---@private --- Extract injections according to: --- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection ---@param match table ----@param metadata table ----@return string, boolean, Range4[] +---@param metadata TSMetadata +---@return string?, boolean, Range6[] function LanguageTree:_get_injection(match, metadata) - local ranges = {} ---@type Range4[] + local ranges = {} ---@type Range6[] local combined = metadata['injection.combined'] ~= nil - local lang = metadata['injection.language'] ---@type string + local lang = metadata['injection.language'] --[[@as string?]] local include_children = metadata['injection.include-children'] ~= nil for id, node in pairs(match) do @@ -590,7 +562,7 @@ function LanguageTree:_get_injection(match, metadata) -- Lang should override any other language tag if name == 'injection.language' then - lang = get_node_text(node, self._source, metadata[id]) + lang = query.get_node_text(node, self._source, { metadata = metadata[id] }) elseif name == 'injection.content' then ranges = get_node_ranges(node, self._source, metadata[id], include_children) end @@ -601,11 +573,11 @@ end ---@private ---@param match table ----@param metadata table ----@return string, boolean, Range4[] +---@param metadata TSMetadata +---@return string, boolean, Range6[] function LanguageTree:_get_injection_deprecated(match, metadata) local lang = nil ---@type string - local ranges = {} ---@type Range4[] + local ranges = {} ---@type Range6[] local combined = metadata.combined ~= nil -- Directives can configure how injections are captured as well as actual node captures. @@ -623,8 +595,10 @@ function LanguageTree:_get_injection_deprecated(match, metadata) end end - if metadata.language then - lang = metadata.language ---@type string + local mlang = metadata.language + if mlang ~= nil then + assert(type(mlang) == 'string') + lang = mlang end -- You can specify the content and language together @@ -635,11 +609,11 @@ function LanguageTree:_get_injection_deprecated(match, metadata) -- Lang should override any other language tag if name == 'language' and not lang then - lang = get_node_text(node, self._source, metadata[id]) + lang = query.get_node_text(node, self._source, { metadata = metadata[id] }) elseif name == 'combined' then combined = true elseif name == 'content' and #ranges == 0 then - ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id]) + ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id]) -- Ignore any tags that start with "_" -- Allows for other tags to be used in matches elseif string.sub(name, 1, 1) ~= '_' then @@ -648,7 +622,7 @@ function LanguageTree:_get_injection_deprecated(match, metadata) end if #ranges == 0 then - ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id]) + ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id]) end end end @@ -926,7 +900,7 @@ end ---@private ---@param tree TSTree ----@param range Range4 +---@param range Range ---@return boolean local function tree_contains(tree, range) return Range.contains({ tree:root():range() }, range) diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index e7cf42283d..f4e038b2d8 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,6 +1,8 @@ local a = vim.api local language = require('vim.treesitter.language') +local Range = require('vim.treesitter._range') + ---@class Query ---@field captures string[] List of captures used in query ---@field info TSQueryInfo Contains used queries, predicates, directives @@ -56,35 +58,21 @@ local function add_included_lang(base_langs, lang, ilang) end ---@private ----@param buf (integer) ----@param range (table) ----@param concat (boolean) ----@returns (string[]|string|nil) -local function buf_range_get_text(buf, range, concat) - local lines - local start_row, start_col, end_row, end_col = unpack(range) - local eof_row = a.nvim_buf_line_count(buf) - if start_row >= eof_row then - return nil - end - +---@param buf integer +---@param range Range +---@returns string +local function buf_range_get_text(buf, range) + local start_row, start_col, end_row, end_col = Range.unpack4(range) if end_col == 0 then - lines = a.nvim_buf_get_lines(buf, start_row, end_row, true) - end_col = -1 - else - lines = a.nvim_buf_get_lines(buf, start_row, end_row + 1, true) - end - - if #lines > 0 then - if #lines == 1 then - lines[1] = string.sub(lines[1], start_col + 1, end_col) - else - lines[1] = string.sub(lines[1], start_col + 1) - lines[#lines] = string.sub(lines[#lines], 1, end_col) + if start_row == end_row then + start_col = -1 + start_row = start_row - 1 end + end_col = -1 + end_row = end_row - 1 end - - return concat and table.concat(lines, '\n') or lines + local lines = a.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {}) + return table.concat(lines, '\n') end --- Gets the list of files used to make up a query @@ -256,14 +244,28 @@ function M.parse_query(lang, query) local cached = query_cache[lang][query] if cached then return cached - else - local self = setmetatable({}, Query) - self.query = vim._ts_parse_query(lang, query) - self.info = self.query:inspect() - self.captures = self.info.captures - query_cache[lang][query] = self - return self end + + local self = setmetatable({}, Query) + self.query = vim._ts_parse_query(lang, query) + self.info = self.query:inspect() + self.captures = self.info.captures + query_cache[lang][query] = self + return self +end + +---Get the range of a |TSNode|. Can also supply {source} and {metadata} +---to get the range with directives applied. +---@param node TSNode +---@param source integer|string|nil Buffer or string from which the {node} is extracted +---@param metadata TSMetadata|nil +---@return Range6 +function M.get_range(node, source, metadata) + if metadata and metadata.range then + assert(source) + return Range.add_bytes(source, metadata.range) + end + return { node:range(true) } end --- Gets the text corresponding to a given node @@ -271,24 +273,22 @@ end ---@param node TSNode ---@param source (integer|string) Buffer or string from which the {node} is extracted ---@param opts (table|nil) Optional parameters. ---- - concat: (boolean) Concatenate result in a string (default true) --- - metadata (table) Metadata of a specific capture. This would be --- set to `metadata[capture_id]` when using |vim.treesitter.add_directive()|. ----@return (string[]|string|nil) +---@return string function M.get_node_text(node, source, opts) opts = opts or {} - -- TODO(lewis6991): concat only works when source is number. - local concat = vim.F.if_nil(opts.concat, true) local metadata = opts.metadata or {} if metadata.text then return metadata.text elseif type(source) == 'number' then - return metadata.range and buf_range_get_text(source, metadata.range, concat) - or buf_range_get_text(source, { node:range() }, concat) - elseif type(source) == 'string' then - return source:sub(select(3, node:start()) + 1, select(3, node:end_())) + local range = M.get_range(node, source, metadata) + return buf_range_get_text(source, range) end + + ---@cast source string + return source:sub(select(3, node:start()) + 1, select(3, node:end_())) end ---@alias TSMatch table @@ -312,7 +312,7 @@ local predicate_handlers = { str = predicate[3] else -- (#eq? @aa @bb) - str = M.get_node_text(match[predicate[3]], source) --[[@as string]] + str = M.get_node_text(match[predicate[3]], source) end if node_text ~= str or str == nil then @@ -328,7 +328,7 @@ local predicate_handlers = { return true end local regex = predicate[3] - return string.find(M.get_node_text(node, source) --[[@as string]], regex) ~= nil + return string.find(M.get_node_text(node, source), regex) ~= nil end, ['match?'] = (function() @@ -366,7 +366,7 @@ local predicate_handlers = { if not node then return true end - local node_text = M.get_node_text(node, source) --[[@as string]] + local node_text = M.get_node_text(node, source) for i = 3, #predicate do if string.find(node_text, predicate[i], 1, true) then @@ -404,9 +404,9 @@ local predicate_handlers = { predicate_handlers['vim-match?'] = predicate_handlers['match?'] ---@class TSMetadata +---@field range Range ---@field [integer] TSMetadata ---@field [string] integer|string ----@field range Range4 ---@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata) @@ -465,13 +465,20 @@ local directive_handlers = { assert(#pred == 4) local id = pred[2] + assert(type(id) == 'number') + local node = match[id] local text = M.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' if not metadata[id] then metadata[id] = {} end - metadata[id].text = text:gsub(pred[3], pred[4]) + + local pattern, replacement = pred[3], pred[3] + assert(type(pattern) == 'string') + assert(type(replacement) == 'string') + + metadata[id].text = text:gsub(pattern, replacement) end, } diff --git a/scripts/lua2dox.lua b/scripts/lua2dox.lua index b99cd955f4..de9f2926f2 100644 --- a/scripts/lua2dox.lua +++ b/scripts/lua2dox.lua @@ -302,7 +302,7 @@ local types = { 'integer', 'number', 'string', 'table', 'list', 'boolean', 'func local tagged_types = { 'TSNode', 'LanguageTree' } -- Document these as 'table' -local alias_types = { 'Range4', 'Range6' } +local alias_types = { 'Range', 'Range4', 'Range6', 'TSMetadata' } -- Processes the file and writes filtered output to stdout. function TLua2DoX_filter.filter(this, AppStamp, Filename) diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua index 0f00fcfe0d..c6ca65f9a1 100644 --- a/test/functional/treesitter/parser_spec.lua +++ b/test/functional/treesitter/parser_spec.lua @@ -196,7 +196,7 @@ void ui_refresh(void) local manyruns = q(100) -- First run should be at least 400x slower than an 100 subsequent runs. - local factor = is_os('win') and 300 or 400 + local factor = is_os('win') and 200 or 400 assert(factor * manyruns < firstrun, ('firstrun: %f ms, manyruns: %f ms'):format(firstrun / 1e6, manyruns / 1e6)) end) @@ -277,13 +277,13 @@ void ui_refresh(void) eq('void', res2) end) - it('support getting text where start of node is past EOF', function() + it('support getting text where start of node is one past EOF', function() local text = [[ def run a = <<~E end]] insert(text) - local result = exec_lua([[ + eq('', exec_lua[[ local fake_node = {} function fake_node:start() return 3, 0, 23 @@ -291,12 +291,14 @@ end]] function fake_node:end_() return 3, 0, 23 end - function fake_node:range() + function fake_node:range(bytes) + if bytes then + return 3, 0, 23, 3, 0, 23 + end return 3, 0, 3, 0 end - return vim.treesitter.get_node_text(fake_node, 0) == nil + return vim.treesitter.get_node_text(fake_node, 0) ]]) - eq(true, result) end) it('support getting empty text if node range is zero width', function()