diff --git a/src/nvim/lua/stdlib.c b/src/nvim/lua/stdlib.c index 2d969357b4..788819ab03 100644 --- a/src/nvim/lua/stdlib.c +++ b/src/nvim/lua/stdlib.c @@ -216,6 +216,45 @@ static int nlua_str_utf_pos(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL return 1; } +/// Return the offset from the 1-indexed byte position to the first byte of the +/// current character. +/// +/// Expects a string and an int. +/// +/// Returns the byte offset to the first byte of the current character +/// pointed into by the offset. +static int nlua_str_utf_start(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL +{ + size_t s1_len; + const char *s1 = luaL_checklstring(lstate, 1, &s1_len); + long offset = luaL_checkinteger(lstate, 2); + if (offset < 0 || offset > (intptr_t)s1_len) { + return luaL_error(lstate, "index out of range"); + } + int tail_offset = mb_head_off((char_u *)s1, (char_u *)s1 + (char_u)offset - 1); + lua_pushinteger(lstate, tail_offset); + return 1; +} + +/// Return the offset from the 1-indexed byte position to the last +/// byte of the current character. +/// +/// Expects a string and an int. +/// +/// Returns the byte offset to the last byte of the current character +/// pointed into by the offset. +static int nlua_str_utf_end(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL +{ + size_t s1_len; + const char *s1 = luaL_checklstring(lstate, 1, &s1_len); + long offset = luaL_checkinteger(lstate, 2); + if (offset < 0 || offset > (intptr_t)s1_len) { + return luaL_error(lstate, "index out of range"); + } + int tail_offset = mb_tail_off((char_u *)s1, (char_u *)s1 + (char_u)offset - 1); + lua_pushinteger(lstate, tail_offset); + return 1; +} /// convert UTF-32 or UTF-16 indices to byte index. /// @@ -439,6 +478,12 @@ void nlua_state_add_stdlib(lua_State *const lstate) // str_utf_pos lua_pushcfunction(lstate, &nlua_str_utf_pos); lua_setfield(lstate, -2, "str_utf_pos"); + // str_utf_start + lua_pushcfunction(lstate, &nlua_str_utf_start); + lua_setfield(lstate, -2, "str_utf_start"); + // str_utf_end + lua_pushcfunction(lstate, &nlua_str_utf_end); + lua_setfield(lstate, -2, "str_utf_end"); // regex lua_pushcfunction(lstate, &nlua_regex); lua_setfield(lstate, -2, "regex"); diff --git a/src/nvim/mbyte.c b/src/nvim/mbyte.c index bd680330ca..7ce4e2b4f5 100644 --- a/src/nvim/mbyte.c +++ b/src/nvim/mbyte.c @@ -1883,6 +1883,40 @@ int mb_tail_off(char_u *base, char_u *p) return i; } + +/// Return the offset from "p" to the first byte of the character it points +/// into. Can start anywhere in a stream of bytes. +/// +/// @param[in] base Pointer to start of string +/// @param[in] p Pointer to byte for which to return the offset to the previous codepoint +// +/// @return 0 if invalid sequence, else offset to previous codepoint +int mb_head_off(char_u *base, char_u *p) +{ + int i; + int j; + + if (*p == NUL) { + return 0; + } + + // Find the first character that is not 10xx.xxxx + for (i = 0; p - i > base; i--) { + if ((p[i] & 0xc0) != 0x80) { + break; + } + } + + // Find the last character that is 10xx.xxxx + for (j = 0; (p[j + 1] & 0xc0) == 0x80; j++) {} + + // Check for illegal sequence. + if (utf8len_tab[p[i]] == 1) { + return 0; + } + return i; +} + /* * Find the next illegal byte sequence. */ diff --git a/test/functional/lua/vim_spec.lua b/test/functional/lua/vim_spec.lua index 8c691b1121..5f903e7d5f 100644 --- a/test/functional/lua/vim_spec.lua +++ b/test/functional/lua/vim_spec.lua @@ -179,6 +179,31 @@ describe('lua stdlib', function() end end) + it("vim.str_utf_start", function() + exec_lua([[_G.test_text = "xy åäö ɧ 汉语 ↥ 🤦x🦄 å بِيَّ"]]) + local expected_positions = {0,0,0,0,-1,0,-1,0,-1,0,0,-1,0,0,-1,-2,0,-1,-2,0,0,-1,-2,0,0,-1,-2,-3,0,0,-1,-2,-3,0,0,0,-1,0,0,-1,0,-1,0,-1,0,-1,0,-1} + eq(expected_positions, exec_lua([[ + local start_codepoint_positions = {} + for idx = 1, #_G.test_text do + table.insert(start_codepoint_positions, vim.str_utf_start(_G.test_text, idx)) + end + return start_codepoint_positions + ]])) + end) + + it("vim.str_utf_end", function() + exec_lua([[_G.test_text = "xy åäö ɧ 汉语 ↥ 🤦x🦄 å بِيَّ"]]) + local expected_positions = {0,0,0,1,0,1,0,1,0,0,1,0,0,2,1,0,2,1,0,0,2,1,0,0,3,2,1,0,0,3,2,1,0,0,0,1,0,0,1,0,1,0,1,0,1,0,1,0 } + eq(expected_positions, exec_lua([[ + local end_codepoint_positions = {} + for idx = 1, #_G.test_text do + table.insert(end_codepoint_positions, vim.str_utf_end(_G.test_text, idx)) + end + return end_codepoint_positions + ]])) + end) + + it("vim.str_utf_pos", function() exec_lua([[_G.test_text = "xy åäö ɧ 汉语 ↥ 🤦x🦄 å بِيَّ"]]) local expected_positions = { 1,2,3,4,6,8,10,11,13,14,17,20,21,24,25,29,30,34,35,36,38,39,41,43,45,47 }