Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CP function #470

Open
JeanKossaifi opened this issue Dec 29, 2022 · 8 comments
Open

CP function #470

JeanKossaifi opened this issue Dec 29, 2022 · 8 comments

Comments

@JeanKossaifi
Copy link
Member

CP via ALS is probably the most used function in TensorLy and comes with lots of options. One issue is that due to these successive additions, bugs (see e.g. this commit) and undue complexity are slowly creeping in while the code is becoming increasingly hard to read.

Another thing is efficiency: previously it was possible to have a fast version by setting the tolerance to 0 (i.e. no convergence test) but now the decomposition has become increasingly slow.

It might be good to have a review of where we're at, what features are actually needed and simplify the code a little.

@cohenjer
Copy link
Contributor

cohenjer commented Jan 3, 2023

+1. We will probably heavily tinker with the CP function with the dimension tree thing, so we should probably discuss how to proceed when we reach the implementation stage.

@cohenjer
Copy link
Contributor

cohenjer commented Jan 4, 2023

I ran a quick test just to provide more precise quantification of how much slower the function has become.

I decomposed ten 100x100x100 tensors using parafac with rank 10, default options except tol=0 and n_iter_max=100. This does not work for some Tensorly versions, so when the code fails I used return_errors=True (I believe this is a fixed bug).
The total time is reported using time.perf_counter. Numpy backend and default contractions.

I tried all tensorly version from the current dev to 0.3.0; here are the factual results:

Running latest tensorly version
Runtime for 10 CPs is 16.66701964100048
Running tensorly version 0.7.0
Runtime for 10 CPs is 14.145441272000426
Running tensorly version 0.6.0
Runtime for 10 CPs is 15.706550804000472
Running tensorly version 0.5.1
Runtime for 10 CPs is 13.24064584900043
Running tensorly version 0.5.0
Runtime for 10 CPs is 15.428937661000418
Running tensorly version 0.4.5
Runtime for 10 CPs is 13.758153636000316
Running tensorly version 0.4.4
Runtime for 10 CPs is 16.10983427699921
Running tensorly version 0.4.3
Runtime for 10 CPs is 5.178578480999931
Running tensorly version 0.4.2
Runtime for 10 CPs is 8.262633802999517
Running tensorly version 0.4.0
Runtime for 10 CPs is 9.182569708000301
Running tensorly version 0.3.0
Runtime for 10 CPs is 8.927185896999617

Notice how 0.4.3 is by far faster: this is because it is the only version (except the current latest) where we actually do not compute the error when tol=0. For the other versions, it is not as clear for me why more recent ones are slower.

With this in mind we should run some profiling on the most recent version since it should not do anything more than what is done in 0.4.3 (the error is not computed if tol=0) but it is the slowest... My take on it is that it is related to #442 but I did not investigate this seriously. Version 0.4.3 uses naive MTTKRP while recent ones use the less memory-intensive but costly MTTKRP.

The bash script and python script are on my github if you want to cross check.

EDIT: I ran the profiling on the latest Tensorly version (current main), no surprises:

ncalls tottime percall cumtime percall filename:lineno(function)
66000 0.074 0.000 6.372 0.000 <array_function internals>:177(dot)
60030 0.063 0.000 1.216 0.000 <array_function internals>:177(moveaxis)
129030 0.127 0.000 8.115 0.000 <array_function internals>:177(reshape)
30 0.000 0.000 4.192 0.140 <array_function internals>:177(svd)
53 0.003 0.000 1.318 0.025 init.py:1()
567650/48650 1.022 0.000 22.925 0.000 init.py:202(wrapped_backend_method)
10 0.147 0.015 23.106 2.311 _cp.py:230(parafac)
10 0.001 0.000 4.222 0.422 _cp.py:26(initialize_cp)
60030 0.124 0.000 9.311 0.000 base.py:39(unfold)
129030 0.128 0.000 7.827 0.000 fromnumeric.py:198(reshape)
144060 0.119 0.000 7.722 0.000 fromnumeric.py:51(_wrapfunc)
30 4.191 0.140 4.192 0.140 linalg.py:1477(svd)
3000 0.134 0.000 18.107 0.006 mttkrp.py:7(unfolding_dot_khatri_rao)
60000 0.357 0.000 16.946 0.000 n_mode_product.py:5(mode_dot)
30000 0.216 0.000 17.393 0.001 n_mode_product.py:81(multi_mode_dot)
60030 0.334 0.000 1.056 0.000 numeric.py:1410(moveaxis)
30 0.001 0.000 4.193 0.140 svd.py:202(truncated_svd)
30 0.000 0.000 4.204 0.140 svd.py:357(svd_interface)
1 0.013 0.013 23.734 23.734 test_zerotol.py:1()
590/1 0.004 0.000 23.734 23.734 {built-in method builtins.exec}
285402/282402 6.578 0.000 20.052 0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
129030 7.558 0.000 7.558 0.000 {method 'reshape' of 'numpy.ndarray' objects}

The most time consuming thing is MTTKRP, lots of reshapes and inefficient matrix-matrix products.
Bottom line is, yes the CP code has become hard to maintain and is bug-prone; but the slow speed comes from MTTKRP and lots of reshapes.

@JeanKossaifi
Copy link
Member Author

Thanks, great to have numbers. I think the last profiling result is a little misleading, we see that the algo is significantly slower for newer version. It might be more informative to just look at the line-by-line profiling (%lprun in a notebook).

I do think we should skip error calculations if we don't need (fixed number of iters) it as it slows things down quite a bit (that's why I had added option for tol=0).

@cohenjer
Copy link
Contributor

cohenjer commented Jan 12, 2023

Not sure I understand your comment on the profiling being misleading, the profiling I showed is only for the last version anyway.

About error computation, after manual inspection it turns out we compute the error even if tol=0 in all versions>0.3 except:

  • the latest one (main branch) !! (I don't remember which PR changed this) (we still compute error if line search is on or if return_errors=True)
  • version 0.4.5
  • version 0.4.4
  • version 0.4.3

In a bunch of versions (=<0.7, >0.4.5), tol=0 actually raises an error because we don't want to compute the loss but the loss appears in some part of the code anyway.

So indeed I think 0.4.3 is faster than say 0.4.2 because of error not being computed if tol=0, but it does not explain why the latest version is so bad (in fact it makes things even worse!).

On the other hand, 0.4.4 is much slower than 0.4.3, and the big change between these two was to factorize the MTTKRP in the code. Version 0.4.3 had explicit computations with the matrix-matrix products inside parafac, and 0.4.4 has the unfolded_dot_khatri_rao function in kruskal_tensor; but we also changed the way MTTKRP is computed. This is the code from parafac in 0.4.3:

pseudo_inverse = tl.tensor(np.ones((rank, rank)), **tl.context(tensor))
for i, factor in enumerate(factors):
     if i != mode:
         pseudo_inverse = pseudo_inverse*tl.dot(tl.transpose(factor), factor)
    factor = tl.dot(unfold(tensor, mode), khatri_rao(factors, skip_matrix=mode))
    factor = tl.transpose(tl.solve(tl.transpose(pseudo_inverse), tl.transpose(factor)))
    factors[mode] = factor

but this is MTTKRP in kruskal_tensor in 0.4.4

    mttkrp_parts = []
    _, rank = _validate_kruskal_tensor(kruskal_tensor)
    weights, factors = kruskal_tensor
    for r in range(rank):
        component = multi_mode_dot(tensor, [f[:, r] for f in factors], skip=mode)
        mttkrp_parts.append(component)

    if weights is None:
        return T.stack(mttkrp_parts, axis=1)
    else:
        return T.stack(mttkrp_parts, axis=1)*T.reshape(weights, (1, -1))

And we have already discussed that this second code is much less efficient, see the great issue #442 by @yngvem.

So again, yes the CP function is hard to maintain and we should make it less cluttered, but speed is mostly an issue because of our bad MTTKRP computation (speed-wise, it is not so bad memory-wise).

@JeanKossaifi
Copy link
Member Author

I agree with your points @cohenjer. As a side note, an easy way to speed mttkrp is to use the einsum tenalg backend. For the core one, as I mentioned in the issue you linked to, the change was made so TensorLy could be sparse safe. One solution is to redefine it in a sparse specific tensor backend. We also probably need to revisit how we handle sparse tensors, not sure how much the sparse backend is currently used.

Re profiling, the timings are great, I was saying that line-profiling is probably more informative/easier to read than this report.

For error computation, I fixed it in the commit I mentioned in the original comment. It's still needed in a few cases as you mentioned.

@cohenjer
Copy link
Contributor

You are right @JeanKossaifi, in fact it is quite a bit of work to test all backends/tenalg-backend on all versions (but probably necessary to have a better picture).

Also 100% agree with current sparse backend needing a lot of improvements.

About einsum backend: the thing is last time I checked einsum tenalg backend was horrible with numpy.

Here are a few solutions to the speed issue of CP (cluttered code is a different story):

  1. use einsum backend as default --> not sure how it would actually improve, in particular for numpy.
  2. revert back to a faster default MTTKRP computation --> requires to move the current version to sparse backend in some way.

I would rather go with 2. but I will compute some tests when I find the time to see which is simpler/brings better improvements.

Any other ideas?

@JeanKossaifi
Copy link
Member Author

I agree, I think 2 the way to go, ideally coordinated with some work on the sparse backend.

@cohenjer
Copy link
Contributor

Just to update, I have not forgotten this issue; I am working on a quite large PR that will include this change (bot probably not coordinated with the sparse back-end)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants