Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

gather_nd vs. take #21195

Open
mureva opened this issue May 1, 2023 · 0 comments
Open

gather_nd vs. take #21195

mureva opened this issue May 1, 2023 · 0 comments

Comments

@mureva
Copy link

mureva commented May 1, 2023

I have an observation and I'm hoping someone can advise.

I have a scenario where I maintain a large table of vectors, so a basic (n,m) array of n vectors of size m. I then have some system that generates indices into this table, and I want to pull rows out of the table at those indices. (more background? A hash table version of NeRF is what we're making).

So, I have a set of indices, and I want to gather the rows out of the table to make use of elsewhere. There's two operators in mxnet that will do the job - gather_nd and take

I could have more than 100k, even 1000k indices:

  • at 100k take will do a forward pass in less than 1 ms, but a backward pass will take about 45 ms. Meanwhile, gather_nd will do a forward pass in about 16 ms, and a backward pass in under 1 ms.
  • At 1000k indices, take is 4 ms fwd, 400 ms bwd, gather is 170 ms fwd, 1ms bwd. .

So... obvious question... is there a way to get the best of both worlds here? The fast forward pass of take, the fast backward pass of gather_nd?

Is there a better operator for gathering rows from the table? I also tried Embedding - on my test it looked like the best of both worlds, but on the real app, was slow on the backward pass.

@mureva mureva closed this as completed May 1, 2023
@mureva mureva reopened this May 1, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant