Skip to content

Commit

Permalink
perf(treesitter): rewrite has-ancestor? in C
Browse files Browse the repository at this point in the history
  • Loading branch information
vanaigr authored and clason committed May 16, 2024
1 parent 2294f5f commit 2f89f59
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
13 changes: 2 additions & 11 deletions runtime/lua/vim/treesitter/query.lua
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,8 @@ local predicate_handlers = {
end

for _, node in ipairs(nodes) do
local ancestor_types = {} --- @type table<string, boolean>
for _, type in ipairs({ unpack(predicate, 3) }) do
ancestor_types[type] = true
end

local cur = node:tree():root()
while cur do
if ancestor_types[cur:type()] then
return true
end
cur = cur:child_containing_descendant(node)
if node:__has_ancestor(predicate) then
return true
end
end
return false
Expand Down
35 changes: 35 additions & 0 deletions src/nvim/lua/treesitter.c
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ static struct luaL_Reg node_meta[] = {
{ "descendant_for_range", node_descendant_for_range },
{ "named_descendant_for_range", node_named_descendant_for_range },
{ "parent", node_parent },
{ "__has_ancestor", __has_ancestor },
{ "child_containing_descendant", node_child_containing_descendant },
{ "iter_children", node_iter_children },
{ "next_sibling", node_next_sibling },
Expand Down Expand Up @@ -1053,6 +1054,40 @@ static int node_parent(lua_State *L)
return 1;
}

static int __has_ancestor(lua_State *L)
{
TSNode descendant = node_check(L, 1);
if (lua_type(L, 2) != LUA_TTABLE) {
lua_pushboolean(L, false);
return 1;
}
int const pred_len = (int)lua_objlen(L, 2);

TSNode node = ts_tree_root_node(descendant.tree);
while (!ts_node_is_null(node)) {
char const *node_type = ts_node_type(node);
size_t node_type_len = strlen(node_type);

for (int i = 3; i <= pred_len; i++) {
lua_rawgeti(L, 2, i);
if (lua_type(L, -1) == LUA_TSTRING) {
size_t check_len;
char const *check_str = lua_tolstring(L, -1, &check_len);
if (node_type_len == check_len && memcmp(node_type, check_str, check_len) == 0) {
lua_pushboolean(L, true);
return 1;
}
}
lua_pop(L, 1);
}

node = ts_node_child_containing_descendant(node, descendant);
}

lua_pushboolean(L, false);
return 1;
}

static int node_child_containing_descendant(lua_State *L)
{
TSNode node = node_check(L, 1);
Expand Down

0 comments on commit 2f89f59

Please sign in to comment.