Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse consecutive Elemwise subgraphs with multiple clients #1242

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Contributor

@ricardoV94 ricardoV94 commented Oct 6, 2022

Closes #1237

Todo

  • Figure out what was the deal with the empty array test values
  • Refactor the specialized add_mul_fusion
  • Do not split fusion ops depending on c-code presence in non-c backends
    • Similarly, get rid of elemwise_max_input_fct
  • Get better profile metrics
  • Check if debugprint looks reasonable for multiple outputs
  • Profile expected performance gains!
  • Test multiple outputs in NUMBA / JAX backends (does JAX need it?)

@ricardoV94 ricardoV94 force-pushed the multiple_output_composite branch 2 times, most recently from 074f125 to fff0d3c Compare October 6, 2022 17:41
@ricardoV94 ricardoV94 changed the title Fuse consecutive Elemwise subgraphs with multiple outputs Fuse consecutive Elemwise subgraphs with multiple clients Oct 6, 2022
@ricardoV94 ricardoV94 force-pushed the multiple_output_composite branch 3 times, most recently from c84e827 to fb6abe0 Compare October 6, 2022 19:16
@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Oct 6, 2022

The TestFusion.test_big_fusion seems to run considerably slower, suggesting my first attempt may be too expensive for large graphs.

Before:
image

After:
image

Edit: With the latest changes it is now more reasonable:
image

@ricardoV94 ricardoV94 force-pushed the multiple_output_composite branch 6 times, most recently from d4301b1 to 0246fce Compare October 8, 2022 00:11
@brandonwillard
Copy link
Member

brandonwillard commented Oct 8, 2022

The TestFusion.test_big_fusion seems to run considerably slower, suggesting my first attempt may be too expensive for large graphs.

It's hard to tell what's going on using those numbers alone. For example, the extra time could be spent in compilation, and the run-time could be significantly reduced. Regardless, the difference is alarming.

Situations like this are another reason we should get #718 in place sooner than later.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Oct 10, 2022

The logic for inplacing will have to be rethought, as some inplaced outputs could overwrite inputs that are still needed for other outputs. Basically we will need something that reasons about the inner graph like we do for the general function.

Edit: For now I just restricted inplace to single-output Composites

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Oct 11, 2022

Another more interesting issue I am finding is some FunctionGraph.replace_all_validate that are probably leading to cyclical dependencies (optimization just hangs as does dprinting the graph). This happens with tensor.test_basic.test_tile.

Edit: It was a bug in the subgraph algorithm. Fixed!

@ricardoV94 ricardoV94 force-pushed the multiple_output_composite branch 2 times, most recently from 76b30b4 to f000323 Compare October 11, 2022 16:12
Ricardo Vieira added 3 commits October 11, 2022 18:12
@brandonwillard brandonwillard changed the title Fuse consecutive Elemwise subgraphs with multiple clients Fuse consecutive Elemwise subgraphs with multiple clients Oct 11, 2022
@ricardoV94 ricardoV94 force-pushed the multiple_output_composite branch 3 times, most recently from 762061b to e70ea5f Compare October 18, 2022 08:32
@ricardoV94 ricardoV94 force-pushed the multiple_output_composite branch 3 times, most recently from 233f68b to 1bfbf24 Compare October 18, 2022 12:42
@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Oct 18, 2022

This seems to be now working (more often than not) on the C-backend. It provides less speedups than I was expecting:

import aesara
import aesara.tensor as at
import numpy as np

x = at.dvector("x")
mu = at.dvector("mu")
logp = (- ((x - mu) **2) / 2)
grad = at.grad(logp.sum(), x)
func = aesara.function([mu, x], [logp, grad])
func.trust_input = True
aesara.dprint(func)

rng = np.random.default_rng(123)
size = 100_000
xv = rng.normal(size=size)
muv = rng.normal(size=size)

%timeit func(xv, muv)

The speedup depends on the size.

  • size=1,000: the old method is 1.11x slower
  • size=10,000: the old method is 1.18x slower
  • size=100,000: the old method is 1.34x slower
  • size=1,000,000: the old method is 1.45x slower

I couldn't test the effects on the Numba backend, because mulit-output Elemwises are disabled (we could test https://numba.pydata.org/numba-doc/latest/user/vectorize.html#the-guvectorize-decorator).

The JAX backend also errors out but I didn't investigate why yet.

@brandonwillard do you know of an easy way to retrieve the c_code generated by the whole function? Would help to see what it is trying to do inside the Elemwise

@ricardoV94 ricardoV94 force-pushed the multiple_output_composite branch 3 times, most recently from ddc83a4 to e00c125 Compare October 18, 2022 15:02
@brandonwillard
Copy link
Member

@brandonwillard do you know of an easy way to retrieve the c_code generated by the whole function? Would help to see what it is trying to do inside the Elemwise

Which function exactly? All the C code generated during an aesara.function call? I believe one can get the cache/compiled module paths from the _CThunk objects in the Function returned by aesara.function, and the C source files are in those directories.

@brandonwillard
Copy link
Member

This seems to be now working (more often than not) on the C-backend. It provides less speedups than I was expecting:

It's possible that this new feature has to sometimes trade off between the benefits of "merging"/CSE and fusion. Your example in #1237 illustrates this possibility with the exp node; that's why we should first clarify how we expect fusion to work in these scenarios (#1237 (comment)).

@ricardoV94
Copy link
Contributor Author

@brandonwillard I extended the motivation behind this PR in the original issue: #1237 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow merging graphs with multiple clients in FusionOptimizer
2 participants