feat(func): allow manual cache invalidation for _memoize

This commit also adds some tests for the existing memoization
functionality.
This commit is contained in:
Riley Bruins 2024-09-01 16:54:30 -07:00
parent b4599acbf8
commit 4ca7626462
4 changed files with 191 additions and 25 deletions

View File

@ -3,9 +3,6 @@ local M = {}
-- TODO(lewis6991): Private for now until:
-- - There are other places in the codebase that could benefit from this
-- (e.g. LSP), but might require other changes to accommodate.
-- - Invalidation of the cache needs to be controllable. Using weak tables
-- is an acceptable invalidation policy, but it shouldn't be the only
-- one.
-- - I don't think the story around `hash` is completely thought out. We
-- may be able to have a good default hash by hashing each argument,
-- so basically a better 'concat'.
@ -17,6 +14,10 @@ local M = {}
--- Internally uses a |lua-weaktable| to cache the results of {fn} meaning the
--- cache will be invalidated whenever Lua does garbage collection.
---
--- The cache can also be manually invalidated by calling `:clear()` on the returned object.
--- Calling this function with no arguments clears the entire cache; otherwise, the arguments will
--- be interpreted as function inputs, and only the cache entry at their hash will be cleared.
---
--- The memoized function returns shared references so be wary about
--- mutating return values.
---
@ -32,11 +33,12 @@ local M = {}
--- first n arguments passed to {fn}.
---
--- @param fn F Function to memoize.
--- @param strong? boolean Do not use a weak table
--- @param weak? boolean Use a weak table (default `true`)
--- @return F # Memoized version of {fn}
--- @nodoc
function M._memoize(hash, fn, strong)
return require('vim.func._memoize')(hash, fn, strong)
function M._memoize(hash, fn, weak)
-- this is wrapped in a function to lazily require the module
return require('vim.func._memoize')(hash, fn, weak)
end
return M

View File

@ -1,5 +1,7 @@
--- Module for private utility functions
--- @alias vim.func.MemoObj { _hash: (fun(...): any), _weak: boolean?, _cache: table<any> }
--- @param argc integer?
--- @return fun(...): any
local function concat_hash(argc)
@ -33,29 +35,49 @@ local function resolve_hash(hash)
return hash
end
--- @param weak boolean?
--- @return table
local create_cache = function(weak)
return setmetatable({}, {
__mode = weak ~= false and 'kv',
})
end
--- @generic F: function
--- @param hash integer|string|fun(...): any
--- @param fn F
--- @param strong? boolean
--- @param weak? boolean
--- @return F
return function(hash, fn, strong)
return function(hash, fn, weak)
vim.validate('hash', hash, { 'number', 'string', 'function' })
vim.validate('fn', fn, 'function')
vim.validate('weak', weak, 'boolean', true)
---@type table<any,table<any,any>>
local cache = {}
if not strong then
setmetatable(cache, { __mode = 'kv' })
end
--- @type vim.func.MemoObj
local obj = {
_cache = create_cache(weak),
_hash = resolve_hash(hash),
_weak = weak,
--- @param self vim.func.MemoObj
clear = function(self, ...)
if select('#', ...) == 0 then
self._cache = create_cache(self._weak)
return
end
local key = self._hash(...)
self._cache[key] = nil
end,
}
hash = resolve_hash(hash)
return function(...)
local key = hash(...)
if cache[key] == nil then
cache[key] = vim.F.pack_len(fn(...))
end
return vim.F.unpack_len(cache[key])
end
return setmetatable(obj, {
--- @param self vim.func.MemoObj
__call = function(self, ...)
local key = self._hash(...)
local cache = self._cache
if cache[key] == nil then
cache[key] = vim.F.pack_len(fn(...))
end
return vim.F.unpack_len(cache[key])
end,
})
end

View File

@ -860,8 +860,8 @@ function Query:iter_captures(node, source, start, stop)
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
local apply_directives = memoize(match_id_hash, self.apply_directives, true)
local match_preds = memoize(match_id_hash, self.match_preds, true)
local apply_directives = memoize(match_id_hash, self.apply_directives, false)
local match_preds = memoize(match_id_hash, self.match_preds, false)
local function iter(end_line)
local capture, captured_node, match = cursor:next_capture()

View File

@ -0,0 +1,142 @@
local t = require('test.testutil')
local n = require('test.functional.testnvim')()
local clear = n.clear
local exec_lua = n.exec_lua
local eq = t.eq
describe('vim.func._memoize', function()
before_each(clear)
it('caches function results based on their parameters', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize('concat', function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end)
collectgarbage('stop')
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder(3, -4)
collectgarbage('restart')
]])
eq(1, exec_lua([[return _G.count]]))
end)
it('caches function results using a weak table by default', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize('concat-2', function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end)
adder(3, -4)
collectgarbage()
adder(3, -4)
collectgarbage()
adder(3, -4)
]])
eq(3, exec_lua([[return _G.count]]))
end)
it('can cache using a strong table', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize('concat-2', function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end, false)
adder(3, -4)
collectgarbage()
adder(3, -4)
collectgarbage()
adder(3, -4)
]])
eq(1, exec_lua([[return _G.count]]))
end)
it('can clear a single cache entry', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize(function(arg1, arg2)
return tostring(arg1) .. '%%' .. tostring(arg2)
end, function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end)
collectgarbage('stop')
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder:clear(3, -4)
adder(3, -4)
collectgarbage('restart')
]])
eq(2, exec_lua([[return _G.count]]))
end)
it('can clear the entire cache', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize(function(arg1, arg2)
return tostring(arg1) .. '%%' .. tostring(arg2)
end, function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end)
collectgarbage('stop')
adder(1, 2)
adder(3, -4)
adder(1, 2)
adder(3, -4)
adder(1, 2)
adder(3, -4)
adder:clear()
adder(1, 2)
adder(3, -4)
collectgarbage('restart')
]])
eq(4, exec_lua([[return _G.count]]))
end)
it('can cache functions that return nil', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize('concat', function(arg1, arg2)
_G.count = _G.count + 1
return nil
end)
collectgarbage('stop')
adder(1, 2)
adder(1, 2)
adder(1, 2)
adder(1, 2)
adder:clear()
adder(1, 2)
collectgarbage('restart')
]])
eq(2, exec_lua([[return _G.count]]))
end)
end)