neovim/runtime/lua/vim/snippet.lua
2023-10-21 08:51:26 +02:00

536 lines
16 KiB
Lua

local G = require('vim.lsp._snippet_grammar')
local snippet_group = vim.api.nvim_create_augroup('vim/snippet', {})
local snippet_ns = vim.api.nvim_create_namespace('vim/snippet')
--- Returns the 0-based cursor position.
---
--- @return integer, integer
local function cursor_pos()
local cursor = vim.api.nvim_win_get_cursor(0)
return cursor[1] - 1, cursor[2]
end
--- Resolves variables (like `$name` or `${name:default}`) as follows:
--- - When a variable is unknown (i.e.: its name is not recognized in any of the cases below), return `nil`.
--- - When a variable isn't set, return its default (if any) or an empty string.
---
--- Note that in some cases, the default is ignored since it's not clear how to distinguish an empty
--- value from an unset value (e.g.: `TM_CURRENT_LINE`).
---
--- @param var string
--- @param default string
--- @return string?
local function resolve_variable(var, default)
--- @param str string
--- @return string
local function expand_or_default(str)
local expansion = vim.fn.expand(str) --[[@as string]]
return expansion == '' and default or expansion
end
if var == 'TM_SELECTED_TEXT' then
-- Snippets are expanded in insert mode only, so there's no selection.
return default
elseif var == 'TM_CURRENT_LINE' then
return vim.api.nvim_get_current_line()
elseif var == 'TM_CURRENT_WORD' then
return expand_or_default('<cword>')
elseif var == 'TM_LINE_INDEX' then
return tostring(vim.fn.line('.') - 1)
elseif var == 'TM_LINE_NUMBER' then
return tostring(vim.fn.line('.'))
elseif var == 'TM_FILENAME' then
return expand_or_default('%:t')
elseif var == 'TM_FILENAME_BASE' then
-- Not using '%:t:r' since we want to remove all extensions.
local filename_base = expand_or_default('%:t'):gsub('%.[^%.]*$', '')
return filename_base
elseif var == 'TM_DIRECTORY' then
return expand_or_default('%:p:h:t')
elseif var == 'TM_FILEPATH' then
return expand_or_default('%:p')
end
-- Unknown variable.
return nil
end
--- Transforms the given text into an array of lines (so no line contains `\n`).
---
--- @param text string|string[]
--- @return string[]
local function text_to_lines(text)
text = type(text) == 'string' and { text } or text
--- @cast text string[]
return vim.split(table.concat(text), '\n', { plain = true })
end
--- Computes the 0-based position of a tabstop located at the end of `snippet` and spanning
--- `placeholder` (if given).
---
--- @param snippet string[]
--- @param placeholder string?
--- @return Range4
local function compute_tabstop_range(snippet, placeholder)
local cursor_row, cursor_col = cursor_pos()
local snippet_text = text_to_lines(snippet)
local placeholder_text = text_to_lines(placeholder or '')
local start_row = cursor_row + #snippet_text - 1
local start_col = #(snippet_text[#snippet_text] or '')
-- Add the cursor's column offset to the first line.
if start_row == cursor_row then
start_col = start_col + cursor_col
end
local end_row = start_row + #placeholder_text - 1
local end_col = (start_row == end_row and start_col or 0)
+ #(placeholder_text[#placeholder_text] or '')
return { start_row, start_col, end_row, end_col }
end
--- @class vim.snippet.Tabstop
--- @field extmark_id integer
--- @field index integer
--- @field bufnr integer
local Tabstop = {}
--- Creates a new tabstop.
---
--- @package
--- @param index integer
--- @param bufnr integer
--- @param range Range4
--- @return vim.snippet.Tabstop
function Tabstop.new(index, bufnr, range)
local extmark_id = vim.api.nvim_buf_set_extmark(bufnr, snippet_ns, range[1], range[2], {
right_gravity = false,
end_right_gravity = true,
end_line = range[3],
end_col = range[4],
hl_group = 'SnippetTabstop',
})
local self = setmetatable(
{ index = index, bufnr = bufnr, extmark_id = extmark_id },
{ __index = Tabstop }
)
return self
end
--- Returns the tabstop's range.
---
--- @package
--- @return Range4
function Tabstop:get_range()
local mark =
vim.api.nvim_buf_get_extmark_by_id(self.bufnr, snippet_ns, self.extmark_id, { details = true })
--- @diagnostic disable-next-line: undefined-field
return { mark[1], mark[2], mark[3].end_row, mark[3].end_col }
end
--- Returns the text spanned by the tabstop.
---
--- @package
--- @return string
function Tabstop:get_text()
local range = self:get_range()
return table.concat(
vim.api.nvim_buf_get_text(self.bufnr, range[1], range[2], range[3], range[4], {}),
'\n'
)
end
--- Sets the tabstop's text.
---
--- @package
--- @param text string
function Tabstop:set_text(text)
local range = self:get_range()
vim.api.nvim_buf_set_text(self.bufnr, range[1], range[2], range[3], range[4], text_to_lines(text))
end
--- @class vim.snippet.Session
--- @field bufnr integer
--- @field tabstops table<integer, vim.snippet.Tabstop[]>
--- @field current_tabstop vim.snippet.Tabstop
local Session = {}
--- Creates a new snippet session in the current buffer.
---
--- @package
--- @return vim.snippet.Session
function Session.new()
local bufnr = vim.api.nvim_get_current_buf()
local self = setmetatable({
bufnr = bufnr,
tabstops = {},
current_tabstop = Tabstop.new(0, bufnr, { 0, 0, 0, 0 }),
}, { __index = Session })
return self
end
--- Creates the session tabstops.
---
--- @package
--- @param tabstop_ranges table<integer, Range4[]>
function Session:set_tabstops(tabstop_ranges)
for index, ranges in pairs(tabstop_ranges) do
for _, range in ipairs(ranges) do
self.tabstops[index] = self.tabstops[index] or {}
table.insert(self.tabstops[index], Tabstop.new(index, self.bufnr, range))
end
end
end
--- Returns the destination tabstop index when jumping in the given direction.
---
--- @package
--- @param direction vim.snippet.Direction
--- @return integer?
function Session:get_dest_index(direction)
local tabstop_indexes = vim.tbl_keys(self.tabstops) --- @type integer[]
table.sort(tabstop_indexes)
for i, index in ipairs(tabstop_indexes) do
if index == self.current_tabstop.index then
local dest_index = tabstop_indexes[i + direction] --- @type integer?
-- When jumping forwards, $0 is the last tabstop.
if not dest_index and direction == 1 then
dest_index = 0
end
-- When jumping backwards, make sure we don't think that $0 is the first tabstop.
if dest_index == 0 and direction == -1 then
dest_index = nil
end
return dest_index
end
end
end
--- @class vim.snippet.Snippet
--- @field private _session? vim.snippet.Session
local M = { session = nil }
--- Select the given tabstop range.
---
--- @param tabstop vim.snippet.Tabstop
local function select_tabstop(tabstop)
--- @param keys string
local function feedkeys(keys)
keys = vim.api.nvim_replace_termcodes(keys, true, false, true)
vim.api.nvim_feedkeys(keys, 'n', true)
end
--- NOTE: We don't use `vim.api.nvim_win_set_cursor` here because it causes the cursor to end
--- at the end of the selection instead of the start.
---
--- @param row integer
--- @param col integer
local function move_cursor_to(row, col)
local line = vim.fn.getline(row) --[[ @as string ]]
col = math.max(vim.fn.strchars(line:sub(1, col)) - 1, 0)
feedkeys(string.format('%sG0%s', row, string.rep('<Right>', col)))
end
local range = tabstop:get_range()
local mode = vim.fn.mode()
-- Move the cursor to the start of the tabstop.
vim.api.nvim_win_set_cursor(0, { range[1] + 1, range[2] })
-- For empty and the final tabstop, start insert mode at the end of the range.
if tabstop.index == 0 or (range[1] == range[3] and range[2] == range[4]) then
if mode ~= 'i' then
if mode == 's' then
feedkeys('<Esc>')
end
vim.cmd.startinsert({ bang = range[4] >= #vim.api.nvim_get_current_line() })
end
else
-- Else, select the tabstop's text.
if mode ~= 'n' then
feedkeys('<Esc>')
end
move_cursor_to(range[1] + 1, range[2] + 1)
feedkeys('v')
move_cursor_to(range[3] + 1, range[4])
feedkeys('o<c-g>')
end
end
--- Sets up the necessary autocommands for snippet expansion.
---
--- @param bufnr integer
local function setup_autocmds(bufnr)
vim.api.nvim_create_autocmd({ 'CursorMoved', 'CursorMovedI' }, {
group = snippet_group,
desc = 'Update snippet state when the cursor moves',
buffer = bufnr,
callback = function()
-- Just update the tabstop in insert and select modes.
if not vim.fn.mode():match('^[isS]') then
return
end
-- Update the current tabstop to be the one containing the cursor.
local cursor_row, cursor_col = cursor_pos()
for tabstop_index, tabstops in pairs(M._session.tabstops) do
for _, tabstop in ipairs(tabstops) do
local range = tabstop:get_range()
if
(cursor_row > range[1] or (cursor_row == range[1] and cursor_col >= range[2]))
and (cursor_row < range[3] or (cursor_row == range[3] and cursor_col <= range[4]))
then
M._session.current_tabstop = tabstop
if tabstop_index ~= 0 then
return
end
end
end
end
-- The cursor is either not on a tabstop or we reached the end, so exit the session.
M.exit()
return true
end,
})
vim.api.nvim_create_autocmd({ 'TextChanged', 'TextChangedI' }, {
group = snippet_group,
desc = 'Update active tabstops when buffer text changes',
buffer = bufnr,
callback = function()
if not M.active() then
return true
end
-- Sync the tabstops in the current group.
local current_tabstop = M._session.current_tabstop
local current_text = current_tabstop:get_text()
for _, tabstop in ipairs(M._session.tabstops[current_tabstop.index]) do
if tabstop.extmark_id ~= current_tabstop.extmark_id then
tabstop:set_text(current_text)
end
end
end,
})
end
--- Expands the given snippet text.
--- Refer to https://microsoft.github.io/language-server-protocol/specification/#snippet_syntax
--- for the specification of valid input.
---
--- Tabstops are highlighted with hl-SnippetTabstop.
---
--- @param input string
function M.expand(input)
local snippet = G.parse(input)
local snippet_text = {}
M._session = Session.new()
-- Get the placeholders we should use for each tabstop index.
--- @type table<integer, string>
local placeholders = {}
for _, child in ipairs(snippet.data.children) do
local type, data = child.type, child.data
if type == G.NodeType.Placeholder then
--- @cast data vim.snippet.PlaceholderData
local tabstop, value = data.tabstop, tostring(data.value)
if placeholders[tabstop] and placeholders[tabstop] ~= value then
error('Snippet has multiple placeholders for tabstop $' .. tabstop)
end
placeholders[tabstop] = value
end
end
-- Keep track of tabstop nodes during expansion.
--- @type table<integer, Range4[]>
local tabstop_ranges = {}
--- @param index integer
--- @param placeholder string?
local function add_tabstop(index, placeholder)
tabstop_ranges[index] = tabstop_ranges[index] or {}
table.insert(tabstop_ranges[index], compute_tabstop_range(snippet_text, placeholder))
end
--- Appends the given text to the snippet, taking care of indentation.
---
--- @param text string|string[]
local function append_to_snippet(text)
-- Get the base indentation based on the current line and the last line of the snippet.
local base_indent = vim.api.nvim_get_current_line():match('^%s*') or ''
if #snippet_text > 0 then
base_indent = base_indent .. (snippet_text[#snippet_text]:match('^%s*') or '') --- @type string
end
local lines = vim.iter.map(function(i, line)
-- Replace tabs by spaces.
if vim.o.expandtab then
line = line:gsub('\t', (' '):rep(vim.fn.shiftwidth())) --- @type string
end
-- Add the base indentation.
if i > 1 then
line = base_indent .. line
end
return line
end, ipairs(text_to_lines(text)))
table.insert(snippet_text, table.concat(lines, '\n'))
end
for _, child in ipairs(snippet.data.children) do
local type, data = child.type, child.data
if type == G.NodeType.Tabstop then
--- @cast data vim.snippet.TabstopData
local placeholder = placeholders[data.tabstop]
add_tabstop(data.tabstop, placeholder)
if placeholder then
append_to_snippet(placeholder)
end
elseif type == G.NodeType.Placeholder then
--- @cast data vim.snippet.PlaceholderData
local value = placeholders[data.tabstop]
add_tabstop(data.tabstop, value)
append_to_snippet(value)
elseif type == G.NodeType.Choice then
--- @cast data vim.snippet.ChoiceData
append_to_snippet(data.values[1])
elseif type == G.NodeType.Variable then
--- @cast data vim.snippet.VariableData
-- Try to get the variable's value.
local value = resolve_variable(data.name, data.default and tostring(data.default) or '')
if not value then
-- Unknown variable, make this a tabstop and use the variable name as a placeholder.
value = data.name
local tabstop_indexes = vim.tbl_keys(tabstop_ranges)
local index = math.max(unpack((#tabstop_indexes == 0 and { 0 }) or tabstop_indexes)) + 1
add_tabstop(index, value)
end
append_to_snippet(value)
elseif type == G.NodeType.Text then
--- @cast data vim.snippet.TextData
append_to_snippet(data.text)
end
end
-- $0, which defaults to the end of the snippet, defines the final cursor position.
-- Make sure the snippet has exactly one of these.
if vim.tbl_contains(vim.tbl_keys(tabstop_ranges), 0) then
assert(#tabstop_ranges[0] == 1, 'Snippet has multiple $0 tabstops')
else
add_tabstop(0)
end
-- Insert the snippet text.
local cursor_row, cursor_col = cursor_pos()
vim.api.nvim_buf_set_text(
M._session.bufnr,
cursor_row,
cursor_col,
cursor_row,
cursor_col,
text_to_lines(snippet_text)
)
-- Create the tabstops.
M._session:set_tabstops(tabstop_ranges)
-- Jump to the first tabstop.
M.jump(1)
end
--- @alias vim.snippet.Direction -1 | 1
--- Returns `true` if there is an active snippet which can be jumped in the given direction.
--- You can use this function to navigate a snippet as follows:
---
--- ```lua
--- vim.keymap.set({ 'i', 's' }, '<Tab>', function()
--- if vim.snippet.jumpable(1) then
--- return '<cmd>lua vim.snippet.jump(1)<cr>'
--- else
--- return '<Tab>'
--- end
--- end, { expr = true })
--- ```
---
--- @param direction (vim.snippet.Direction) Navigation direction. -1 for previous, 1 for next.
--- @return boolean
function M.jumpable(direction)
if not M.active() then
return false
end
return M._session:get_dest_index(direction) ~= nil
end
--- Jumps within the active snippet in the given direction.
--- If the jump isn't possible, the function call does nothing.
---
--- You can use this function to navigate a snippet as follows:
---
--- ```lua
--- vim.keymap.set({ 'i', 's' }, '<Tab>', function()
--- if vim.snippet.jumpable(1) then
--- return '<cmd>lua vim.snippet.jump(1)<cr>'
--- else
--- return '<Tab>'
--- end
--- end, { expr = true })
--- ```
---
--- @param direction (vim.snippet.Direction) Navigation direction. -1 for previous, 1 for next.
function M.jump(direction)
-- Get the tabstop index to jump to.
local dest_index = M._session and M._session:get_dest_index(direction)
if not dest_index then
return
end
-- Find the tabstop with the lowest range.
local tabstops = M._session.tabstops[dest_index]
local dest = tabstops[1]
for _, tabstop in ipairs(tabstops) do
local dest_range, range = dest:get_range(), tabstop:get_range()
if (range[1] < dest_range[1]) or (range[1] == dest_range[1] and range[2] < dest_range[2]) then
dest = tabstop
end
end
-- Clear the autocommands so that we can move the cursor freely while selecting the tabstop.
vim.api.nvim_clear_autocmds({ group = snippet_group, buffer = M._session.bufnr })
M._session.current_tabstop = dest
select_tabstop(dest)
-- Restore the autocommands.
setup_autocmds(M._session.bufnr)
end
--- Returns `true` if there's an active snippet in the current buffer.
---
--- @return boolean
function M.active()
return M._session ~= nil and M._session.bufnr == vim.api.nvim_get_current_buf()
end
--- Exits the current snippet.
function M.exit()
if not M.active() then
return
end
vim.api.nvim_clear_autocmds({ group = snippet_group, buffer = M._session.bufnr })
vim.api.nvim_buf_clear_namespace(M._session.bufnr, snippet_ns, 0, -1)
M._session = nil
end
return M