perf(treesitter): smarter languagetree invalidation

Problem:
  Treesitter injections are slow because all injected trees are invalidated on every change.

Solution:
    Implement smarter invalidation to avoid reparsing injected regions.

    - In on_bytes, try and update self._regions as best we can. This PR just offsets any regions after the change.
    - Add valid flags for each region in self._regions.
    - Call on_bytes recursively for all children.
       - We still need to run the query every time for the top level tree. I don't know how to avoid this. However, if the new injection ranges don't change, then we re-use the old trees and avoid reparsing children.

This should result in roughly a 2-3x reduction in tree parsing when the comment injections are enabled.
This commit is contained in:
Lewis Russell 2023-02-23 15:19:52 +00:00 committed by GitHub
parent 8680715743
commit 75e53341f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 326 additions and 88 deletions

View File

@ -0,0 +1,126 @@
local api = vim.api
local M = {}
---@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
---@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer}
---@private
---@param a_row integer
---@param a_col integer
---@param b_row integer
---@param b_col integer
---@return integer
--- 1: a > b
--- 0: a == b
--- -1: a < b
local function cmp_pos(a_row, a_col, b_row, b_col)
if a_row == b_row then
if a_col > b_col then
return 1
elseif a_col < b_col then
return -1
else
return 0
end
elseif a_row > b_row then
return 1
end
return -1
end
M.cmp_pos = {
lt = function(...)
return cmp_pos(...) == -1
end,
le = function(...)
return cmp_pos(...) ~= 1
end,
gt = function(...)
return cmp_pos(...) == 1
end,
ge = function(...)
return cmp_pos(...) ~= -1
end,
eq = function(...)
return cmp_pos(...) == 0
end,
ne = function(...)
return cmp_pos(...) ~= 0
end,
}
setmetatable(M.cmp_pos, { __call = cmp_pos })
---@private
---@param r1 Range4|Range6
---@param r2 Range4|Range6
---@return boolean
function M.intercepts(r1, r2)
local off_1 = #r1 == 6 and 1 or 0
local off_2 = #r1 == 6 and 1 or 0
local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
-- r1 is above r2
if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
return false
end
-- r1 is below r2
if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
return false
end
return true
end
---@private
---@param r1 Range4|Range6
---@param r2 Range4|Range6
---@return boolean whether r1 contains r2
function M.contains(r1, r2)
local off_1 = #r1 == 6 and 1 or 0
local off_2 = #r1 == 6 and 1 or 0
local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
-- start doesn't fit
if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
return false
end
-- end doesn't fit
if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
return false
end
return true
end
---@private
---@param source integer|string
---@param range Range4
---@return Range6
function M.add_bytes(source, range)
local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
local start_byte = 0
local end_byte = 0
-- TODO(vigoux): proper byte computation here, and account for EOL ?
if type(source) == 'number' then
-- Easy case, this is a buffer parser
start_byte = api.nvim_buf_get_offset(source, start_row) + start_col
end_byte = api.nvim_buf_get_offset(source, end_row) + end_col
elseif type(source) == 'string' then
-- string parser, single `\n` delimited string
start_byte = vim.fn.byteidx(source, start_col)
end_byte = vim.fn.byteidx(source, end_col)
end
return { start_row, start_col, start_byte, end_row, end_col, end_byte }
end
return M

View File

@ -1,9 +1,8 @@
local a = vim.api
local query = require('vim.treesitter.query')
local language = require('vim.treesitter.language')
local Range = require('vim.treesitter._range')
---@alias Range {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
--
---@alias TSCallbackName
---| 'changedtree'
---| 'bytes'
@ -24,11 +23,13 @@ local language = require('vim.treesitter.language')
---@field private _injection_query Query Queries defining injected languages
---@field private _opts table Options
---@field private _parser TSParser Parser for language
---@field private _regions Range[][] List of regions this tree should manage and parse
---@field private _regions Range6[][] List of regions this tree should manage and parse
---@field private _lang string Language name
---@field private _source (integer|string) Buffer or string to parse
---@field private _trees TSTree[] Reference to parsed tree (one for each language)
---@field private _valid boolean If the parsed tree is valid
---@field private _valid boolean|table<integer,true> If the parsed tree is valid
--- TODO(lewis6991): combine _regions, _valid and _trees
---@field private _is_child boolean
local LanguageTree = {}
---@class LanguageTreeOpts
@ -114,6 +115,9 @@ end
--- If the tree is invalid, call `parse()`.
--- This will return the updated tree.
function LanguageTree:is_valid()
if type(self._valid) == 'table' then
return #self._valid == #self._regions
end
return self._valid
end
@ -127,6 +131,16 @@ function LanguageTree:source()
return self._source
end
---@private
---This is only exposed so it can be wrapped for profiling
---@param old_tree TSTree
---@return TSTree, integer[]
function LanguageTree:_parse_tree(old_tree)
local tree, tree_changes = self._parser:parse(old_tree, self._source)
self:_do_callback('changedtree', tree_changes, tree)
return tree, tree_changes
end
--- Parses all defined regions using a treesitter parser
--- for the language this tree represents.
--- This will run the injection query for this language to
@ -135,35 +149,27 @@ end
---@return TSTree[]
---@return table|nil Change list
function LanguageTree:parse()
if self._valid then
if self:is_valid() then
return self._trees
end
local parser = self._parser
local changes = {}
local old_trees = self._trees
self._trees = {}
-- If there are no ranges, set to an empty list
-- so the included ranges in the parser are cleared.
if self._regions and #self._regions > 0 then
if #self._regions > 0 then
for i, ranges in ipairs(self._regions) do
local old_tree = old_trees[i]
parser:set_included_ranges(ranges)
local tree, tree_changes = parser:parse(old_tree, self._source)
self:_do_callback('changedtree', tree_changes, tree)
table.insert(self._trees, tree)
vim.list_extend(changes, tree_changes)
if not self._valid or not self._valid[i] then
self._parser:set_included_ranges(ranges)
local tree, tree_changes = self:_parse_tree(self._trees[i])
self._trees[i] = tree
vim.list_extend(changes, tree_changes)
end
end
else
local tree, tree_changes = parser:parse(old_trees[1], self._source)
self:_do_callback('changedtree', tree_changes, tree)
table.insert(self._trees, tree)
vim.list_extend(changes, tree_changes)
local tree, tree_changes = self:_parse_tree(self._trees[1])
self._trees = { tree }
changes = tree_changes
end
local injections_by_lang = self:_get_injections()
@ -249,6 +255,7 @@ function LanguageTree:add_child(lang)
end
self._children[lang] = LanguageTree.new(self._source, lang, self._opts)
self._children[lang]._is_child = true
self:invalidate()
self:_do_callback('child_added', self._children[lang])
@ -298,43 +305,35 @@ end
--- This allows for embedded languages to be parsed together across different
--- nodes, which is useful for templating languages like ERB and EJS.
---
--- Note: This call invalidates the tree and requires it to be parsed again.
---
---@private
---@param regions integer[][][] List of regions this tree should manage and parse.
---@param regions Range4[][] List of regions this tree should manage and parse.
function LanguageTree:set_included_regions(regions)
-- Transform the tables from 4 element long to 6 element long (with byte offset)
for _, region in ipairs(regions) do
for i, range in ipairs(region) do
if type(range) == 'table' and #range == 4 then
---@diagnostic disable-next-line:no-unknown
local start_row, start_col, end_row, end_col = unpack(range)
local start_byte = 0
local end_byte = 0
local source = self._source
-- TODO(vigoux): proper byte computation here, and account for EOL ?
if type(source) == 'number' then
-- Easy case, this is a buffer parser
start_byte = a.nvim_buf_get_offset(source, start_row) + start_col
end_byte = a.nvim_buf_get_offset(source, end_row) + end_col
elseif type(self._source) == 'string' then
-- string parser, single `\n` delimited string
start_byte = vim.fn.byteidx(self._source, start_col)
end_byte = vim.fn.byteidx(self._source, end_col)
end
region[i] = Range.add_bytes(self._source, range)
end
end
end
region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte }
if #self._regions ~= #regions then
self._trees = {}
self:invalidate()
elseif self._valid ~= false then
if self._valid == true then
self._valid = {}
end
for i = 1, #regions do
self._valid[i] = true
if not vim.deep_equal(self._regions[i], regions[i]) then
self._valid[i] = nil
self._trees[i] = nil
end
end
end
self._regions = regions
-- Trees are no longer valid now that we have changed regions.
-- TODO(vigoux,steelsojka): Look into doing this smarter so we can use some of the
-- old trees for incremental parsing. Currently, this only
-- affects injected languages.
self._trees = {}
self:invalidate()
end
--- Gets the set of included regions
@ -346,10 +345,10 @@ end
---@param node TSNode
---@param id integer
---@param metadata TSMetadata
---@return Range
---@return Range4
local function get_range_from_metadata(node, id, metadata)
if metadata[id] and metadata[id].range then
return metadata[id].range --[[@as Range]]
return metadata[id].range --[[@as Range4]]
end
return { node:range() }
end
@ -378,7 +377,7 @@ function LanguageTree:_get_injections()
self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1)
do
local lang = nil ---@type string
local ranges = {} ---@type Range[]
local ranges = {} ---@type Range4[]
local combined = metadata.combined ---@type boolean
-- Directives can configure how injections are captured as well as actual node captures.
@ -408,6 +407,7 @@ function LanguageTree:_get_injections()
-- Lang should override any other language tag
if name == 'language' and not lang then
---@diagnostic disable-next-line
lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'combined' then
combined = true
@ -426,6 +426,8 @@ function LanguageTree:_get_injections()
end
end
assert(type(lang) == 'string')
-- Each tree index should be isolated from the other nodes.
if not injections[tree_index] then
injections[tree_index] = {}
@ -446,7 +448,7 @@ function LanguageTree:_get_injections()
end
end
---@type table<string,Range[][]>
---@type table<string,Range4[][]>
local result = {}
-- Generate a map by lang of node lists.
@ -485,6 +487,45 @@ function LanguageTree:_do_callback(cb_name, ...)
end
end
---@private
---@param regions Range6[][]
---@param old_range Range6
---@param new_range Range6
---@return table<integer, true> region indices to invalidate
local function update_regions(regions, old_range, new_range)
---@type table<integer,true>
local valid = {}
for i, ranges in ipairs(regions or {}) do
valid[i] = true
for j, r in ipairs(ranges) do
if Range.intercepts(r, old_range) then
valid[i] = nil
break
end
-- Range after change. Adjust
if Range.cmp_pos.gt(r[1], r[2], old_range[4], old_range[5]) then
local byte_offset = new_range[6] - old_range[6]
local row_offset = new_range[4] - old_range[4]
-- Update the range to avoid invalidation in set_included_regions()
-- which will compare the regions against the parsed injection regions
ranges[j] = {
r[1] + row_offset,
r[2],
r[3] + byte_offset,
r[4] + row_offset,
r[5],
r[6] + byte_offset,
}
end
end
end
return valid
end
---@private
---@param bufnr integer
---@param changed_tick integer
@ -510,14 +551,53 @@ function LanguageTree:_on_bytes(
new_col,
new_byte
)
self:invalidate()
local old_end_col = old_col + ((old_row == 0) and start_col or 0)
local new_end_col = new_col + ((new_row == 0) and start_col or 0)
-- Edit all trees recursively, together BEFORE emitting a bytes callback.
-- In most cases this callback should only be called from the root tree.
self:for_each_tree(function(tree)
local old_range = {
start_row,
start_col,
start_byte,
start_row + old_row,
old_end_col,
start_byte + old_byte,
}
local new_range = {
start_row,
start_col,
start_byte,
start_row + new_row,
new_end_col,
start_byte + new_byte,
}
local valid_regions = update_regions(self._regions, old_range, new_range)
if #self._regions == 0 or #valid_regions == 0 then
self._valid = false
else
self._valid = valid_regions
end
for _, child in pairs(self._children) do
child:_on_bytes(
bufnr,
changed_tick,
start_row,
start_col,
start_byte,
old_row,
old_col,
old_byte,
new_row,
new_col,
new_byte
)
end
-- Edit trees together BEFORE emitting a bytes callback.
for _, tree in ipairs(self._trees) do
tree:edit(
start_byte,
start_byte + old_byte,
@ -529,22 +609,24 @@ function LanguageTree:_on_bytes(
start_row + new_row,
new_end_col
)
end)
end
self:_do_callback(
'bytes',
bufnr,
changed_tick,
start_row,
start_col,
start_byte,
old_row,
old_col,
old_byte,
new_row,
new_col,
new_byte
)
if not self._is_child then
self:_do_callback(
'bytes',
bufnr,
changed_tick,
start_row,
start_col,
start_byte,
old_row,
old_col,
old_byte,
new_row,
new_col,
new_byte
)
end
end
---@private
@ -595,19 +677,15 @@ end
---@private
---@param tree TSTree
---@param range Range
---@param range Range4
---@return boolean
local function tree_contains(tree, range)
local start_row, start_col, end_row, end_col = tree:root():range()
local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2])
local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4])
return start_fits and end_fits
return Range.contains({ tree:root():range() }, range)
end
--- Determines whether {range} is contained in the |LanguageTree|.
---
---@param range Range `{ start_line, start_col, end_line, end_col }`
---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@return boolean
function LanguageTree:contains(range)
for _, tree in pairs(self._trees) do
@ -621,7 +699,7 @@ end
--- Gets the tree that contains {range}.
---
---@param range Range `{ start_line, start_col, end_line, end_col }`
---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@param opts table|nil Optional keyword arguments:
--- - ignore_injections boolean Ignore injected languages (default true)
---@return TSTree|nil
@ -631,10 +709,9 @@ function LanguageTree:tree_for_range(range, opts)
if not ignore then
for _, child in pairs(self._children) do
for _, tree in pairs(child:trees()) do
if tree_contains(tree, range) then
return tree
end
local tree = child:tree_for_range(range, opts)
if tree then
return tree
end
end
end
@ -650,7 +727,7 @@ end
--- Gets the smallest named node that contains {range}.
---
---@param range Range `{ start_line, start_col, end_line, end_col }`
---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@param opts table|nil Optional keyword arguments:
--- - ignore_injections boolean Ignore injected languages (default true)
---@return TSNode | nil Found node
@ -663,7 +740,7 @@ end
--- Gets the appropriate language that contains {range}.
---
---@param range Range `{ start_line, start_col, end_line, end_col }`
---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@return LanguageTree Managing {range}
function LanguageTree:language_for_range(range)
for _, child in pairs(self._children) do

View File

@ -406,7 +406,7 @@ predicate_handlers['vim-match?'] = predicate_handlers['match?']
---@class TSMetadata
---@field [integer] TSMetadata
---@field [string] integer|string
---@field range Range
---@field range Range4
---@alias TSDirective fun(match: TSMatch, _, _, predicate: any[], metadata: TSMetadata)

View File

@ -291,7 +291,7 @@ local types = { 'integer', 'number', 'string', 'table', 'list', 'boolean', 'func
local tagged_types = { 'TSNode', 'LanguageTree' }
-- Document these as 'table'
local alias_types = { 'Range' }
local alias_types = { 'Range4', 'Range6' }
--! \brief run the filter
function TLua2DoX_filter.readfile(this, AppStamp, Filename)

View File

@ -639,6 +639,17 @@ int x = INT_MAX;
{1, 26, 1, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
{2, 29, 2, 68} -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
helpers.feed('ggo<esc>')
eq(5, exec_lua("return #parser:children().c:trees()"))
eq({
{0, 0, 8, 0}, -- root tree
{4, 14, 4, 17}, -- VALUE 123
{5, 15, 5, 18}, -- VALUE1 123
{6, 15, 6, 18}, -- VALUE2 123
{2, 26, 2, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
{3, 29, 3, 68} -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
end)
end)
@ -660,6 +671,18 @@ int x = INT_MAX;
{1, 26, 2, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
helpers.feed('ggo<esc>')
eq("table", exec_lua("return type(parser:children().c)"))
eq(2, exec_lua("return #parser:children().c:trees()"))
eq({
{0, 0, 8, 0}, -- root tree
{4, 14, 6, 18}, -- VALUE 123
-- VALUE1 123
-- VALUE2 123
{2, 26, 3, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
end)
end)
@ -688,6 +711,18 @@ int x = INT_MAX;
{1, 26, 2, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
helpers.feed('ggo<esc>')
eq("table", exec_lua("return type(parser:children().c)"))
eq(2, exec_lua("return #parser:children().c:trees()"))
eq({
{0, 0, 8, 0}, -- root tree
{4, 14, 6, 18}, -- VALUE 123
-- VALUE1 123
-- VALUE2 123
{2, 26, 3, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
end)
it("should not inject bad languages", function()