From 99e0facf3a001608287ec6db69b01c77443c7b9d Mon Sep 17 00:00:00 2001 From: Christian Clason Date: Sun, 15 Sep 2024 14:19:08 +0200 Subject: [PATCH] feat(treesitter)!: use return values in `language.add()` Problem: No clear way to check whether parsers are available for a given language. Solution: Make `language.add()` return `true` if a parser was successfully added and `nil` otherwise. Use explicit `assert` instead of relying on thrown errors. --- runtime/doc/news.txt | 4 +- runtime/doc/treesitter.txt | 11 ++++- runtime/lua/vim/treesitter/language.lua | 45 ++++++++++++++------ runtime/lua/vim/treesitter/languagetree.lua | 4 +- runtime/lua/vim/treesitter/query.lua | 3 +- test/functional/treesitter/language_spec.lua | 14 ++---- 6 files changed, 49 insertions(+), 32 deletions(-) diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt index 77345e6b19..6a29efdd9c 100644 --- a/runtime/doc/news.txt +++ b/runtime/doc/news.txt @@ -97,12 +97,12 @@ TREESITTER capture IDs to a list of nodes that need to be iterated over. For backwards compatibility, an option `all=false` (only return the last matching node) is provided that will be removed in a future release. - • |vim.treesitter.language.get_filetypes()| always includes the {language} argument in addition to explicitly registered filetypes. - • |vim.treesitter.language.get_lang()| falls back to the {filetype} argument if no languages are explicitly registered. +• |vim.treesitter.language.add()| returns `true` if a parser was loaded + successfully and `nil,errmsg` otherwise instead of throwing an error. TUI diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt index e0f7536712..a83b80d9ad 100644 --- a/runtime/doc/treesitter.txt +++ b/runtime/doc/treesitter.txt @@ -964,7 +964,12 @@ add({lang}, {opts}) *vim.treesitter.language.add()* Load parser with name {lang} Parsers are searched in the `parser` runtime directory, or the provided - {path} + {path}. Can be used to check for available parsers before enabling + treesitter features, e.g., >lua + if vim.treesitter.language.add('markdown') then + vim.treesitter.start(bufnr, 'markdown') + end +< Parameters: ~ • {lang} (`string`) Name of the parser (alphanumerical and `_` only) @@ -973,6 +978,10 @@ add({lang}, {opts}) *vim.treesitter.language.add()* • {symbol_name}? (`string`) Internal symbol name for the language to load + Return (multiple): ~ + (`boolean?`) True if parser is loaded + (`string?`) Error if parser cannot be loaded + get_filetypes({lang}) *vim.treesitter.language.get_filetypes()* Returns the filetypes for which a parser named {lang} is used. diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua index e4e1c2fff0..9f7807e036 100644 --- a/runtime/lua/vim/treesitter/language.lua +++ b/runtime/lua/vim/treesitter/language.lua @@ -62,8 +62,22 @@ function M.require_language(lang, path, silent, symbol_name) return installed end - M.add(lang, opts) - return true + return M.add(lang, opts) +end + +--- Load wasm or native parser (wrapper) +--- todo(clason): move to C +--- +---@param path string Path of parser library +---@param lang string Language name +---@param symbol_name? string Internal symbol name for the language to load (default lang) +---@return boolean? True if parser is loaded +local function loadparser(path, lang, symbol_name) + if vim.endswith(path, '.wasm') then + return vim._ts_add_language_from_wasm and vim._ts_add_language_from_wasm(path, lang) + else + return vim._ts_add_language_from_object(path, lang, symbol_name) + end end ---@class vim.treesitter.language.add.Opts @@ -77,10 +91,18 @@ end --- Load parser with name {lang} --- ---- Parsers are searched in the `parser` runtime directory, or the provided {path} +--- Parsers are searched in the `parser` runtime directory, or the provided {path}. +--- Can be used to check for available parsers before enabling treesitter features, e.g., +--- ```lua +--- if vim.treesitter.language.add('markdown') then +--- vim.treesitter.start(bufnr, 'markdown') +--- end +--- ``` --- ---@param lang string Name of the parser (alphanumerical and `_` only) ---@param opts? vim.treesitter.language.add.Opts Options: +---@return boolean? True if parser is loaded +---@return string? Error if parser cannot be loaded function M.add(lang, opts) opts = opts or {} local path = opts.path @@ -96,30 +118,25 @@ function M.add(lang, opts) lang = lang:lower() if vim._ts_has_language(lang) then - return + return true end if path == nil then + -- allow only safe language names when looking for libraries to load if not (lang and lang:match('[%w_]+') == lang) then - error("'" .. lang .. "' is not a valid language name") + return nil, string.format('Invalid language name "%s"', lang) end local fname = 'parser/' .. lang .. '.*' local paths = api.nvim_get_runtime_file(fname, false) if #paths == 0 then - error("no parser for '" .. lang .. "' language, see :help treesitter-parsers") + return nil, string.format('No parser for language "%s"', lang) end path = paths[1] end - if vim.endswith(path, '.wasm') then - if not vim._ts_add_language_from_wasm then - error(string.format("Unable to load wasm parser '%s': not built with ENABLE_WASMTIME ", path)) - end - vim._ts_add_language_from_wasm(path, lang) - else - vim._ts_add_language_from_object(path, lang, symbol_name) - end + return loadparser(path, lang, symbol_name) or nil, + string.format('Cannot load parser %s for language "%s"', path, lang) end --- @param x string|string[] diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index cc9ffeaa29..fd68c2b910 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -108,7 +108,7 @@ LanguageTree.__index = LanguageTree ---@param opts vim.treesitter.LanguageTree.new.Opts? ---@return vim.treesitter.LanguageTree parser object function LanguageTree.new(source, lang, opts) - language.add(lang) + assert(language.add(lang)) opts = opts or {} if source == 0 then @@ -734,7 +734,7 @@ local function add_injection(t, tree_index, pattern, lang, combined, ranges) table.insert(t[tree_index][lang][pattern].regions, ranges) end --- TODO(clason): replace by refactored `ts.has_parser` API (without registering) +-- TODO(clason): replace by refactored `ts.has_parser` API (without side effects) --- The result of this function is cached to prevent nvim_get_runtime_file from being --- called too often --- @param lang string parser name diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 135250578e..4614967799 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -247,8 +247,7 @@ end) --- ---@see [vim.treesitter.query.get()] M.parse = memoize('concat-2', function(lang, query) - language.add(lang) - + assert(language.add(lang)) local ts_query = vim._ts_parse_query(lang, query) return Query.new(lang, ts_query) end) diff --git a/test/functional/treesitter/language_spec.lua b/test/functional/treesitter/language_spec.lua index 3947ab23b2..e1e34fcecc 100644 --- a/test/functional/treesitter/language_spec.lua +++ b/test/functional/treesitter/language_spec.lua @@ -31,29 +31,21 @@ describe('treesitter language API', function() ) ) - eq(false, exec_lua("return pcall(vim.treesitter.language.add, 'borklang')")) + eq(NIL, exec_lua("return vim.treesitter.language.add('borklang')")) eq( false, exec_lua("return pcall(vim.treesitter.language.add, 'borklang', { path = 'borkbork.so' })") ) - eq( - ".../language.lua:0: no parser for 'borklang' language, see :help treesitter-parsers", - pcall_err(exec_lua, "parser = vim.treesitter.language.inspect('borklang')") - ) - matches( 'Failed to load parser: uv_dlsym: .+', pcall_err(exec_lua, 'vim.treesitter.language.add("c", { symbol_name = "borklang" })') ) end) - it('shows error for invalid language name', function() - eq( - ".../language.lua:0: '/foo/' is not a valid language name", - pcall_err(exec_lua, 'vim.treesitter.language.add("/foo/")') - ) + it('does not load parser for invalid language name', function() + eq(NIL, exec_lua('vim.treesitter.language.add("/foo/")')) end) it('inspects language', function()