-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CP function #470
Comments
+1. We will probably heavily tinker with the CP function with the dimension tree thing, so we should probably discuss how to proceed when we reach the implementation stage. |
I ran a quick test just to provide more precise quantification of how much slower the function has become. I decomposed ten 100x100x100 tensors using parafac with rank 10, default options except tol=0 and n_iter_max=100. This does not work for some Tensorly versions, so when the code fails I used return_errors=True (I believe this is a fixed bug). I tried all tensorly version from the current dev to 0.3.0; here are the factual results: Running latest tensorly version Notice how 0.4.3 is by far faster: this is because it is the only version (except the current latest) where we actually do not compute the error when tol=0. For the other versions, it is not as clear for me why more recent ones are slower. With this in mind we should run some profiling on the most recent version since it should not do anything more than what is done in 0.4.3 (the error is not computed if tol=0) but it is the slowest... My take on it is that it is related to #442 but I did not investigate this seriously. Version 0.4.3 uses naive MTTKRP while recent ones use the less memory-intensive but costly MTTKRP. The bash script and python script are on my github if you want to cross check. EDIT: I ran the profiling on the latest Tensorly version (current main), no surprises:
The most time consuming thing is MTTKRP, lots of reshapes and inefficient matrix-matrix products. |
Thanks, great to have numbers. I think the last profiling result is a little misleading, we see that the algo is significantly slower for newer version. It might be more informative to just look at the line-by-line profiling ( I do think we should skip error calculations if we don't need (fixed number of iters) it as it slows things down quite a bit (that's why I had added option for |
Not sure I understand your comment on the profiling being misleading, the profiling I showed is only for the last version anyway. About error computation, after manual inspection it turns out we compute the error even if tol=0 in all versions>0.3 except:
In a bunch of versions (=<0.7, >0.4.5), tol=0 actually raises an error because we don't want to compute the loss but the loss appears in some part of the code anyway. So indeed I think 0.4.3 is faster than say 0.4.2 because of error not being computed if tol=0, but it does not explain why the latest version is so bad (in fact it makes things even worse!). On the other hand, 0.4.4 is much slower than 0.4.3, and the big change between these two was to factorize the MTTKRP in the code. Version 0.4.3 had explicit computations with the matrix-matrix products inside parafac, and 0.4.4 has the pseudo_inverse = tl.tensor(np.ones((rank, rank)), **tl.context(tensor))
for i, factor in enumerate(factors):
if i != mode:
pseudo_inverse = pseudo_inverse*tl.dot(tl.transpose(factor), factor)
factor = tl.dot(unfold(tensor, mode), khatri_rao(factors, skip_matrix=mode))
factor = tl.transpose(tl.solve(tl.transpose(pseudo_inverse), tl.transpose(factor)))
factors[mode] = factor but this is MTTKRP in mttkrp_parts = []
_, rank = _validate_kruskal_tensor(kruskal_tensor)
weights, factors = kruskal_tensor
for r in range(rank):
component = multi_mode_dot(tensor, [f[:, r] for f in factors], skip=mode)
mttkrp_parts.append(component)
if weights is None:
return T.stack(mttkrp_parts, axis=1)
else:
return T.stack(mttkrp_parts, axis=1)*T.reshape(weights, (1, -1)) And we have already discussed that this second code is much less efficient, see the great issue #442 by @yngvem. So again, yes the CP function is hard to maintain and we should make it less cluttered, but speed is mostly an issue because of our bad MTTKRP computation (speed-wise, it is not so bad memory-wise). |
I agree with your points @cohenjer. As a side note, an easy way to speed mttkrp is to use the einsum tenalg backend. For the core one, as I mentioned in the issue you linked to, the change was made so TensorLy could be sparse safe. One solution is to redefine it in a sparse specific tensor backend. We also probably need to revisit how we handle sparse tensors, not sure how much the sparse backend is currently used. Re profiling, the timings are great, I was saying that line-profiling is probably more informative/easier to read than this report. For error computation, I fixed it in the commit I mentioned in the original comment. It's still needed in a few cases as you mentioned. |
You are right @JeanKossaifi, in fact it is quite a bit of work to test all backends/tenalg-backend on all versions (but probably necessary to have a better picture). Also 100% agree with current sparse backend needing a lot of improvements. About einsum backend: the thing is last time I checked einsum tenalg backend was horrible with numpy. Here are a few solutions to the speed issue of CP (cluttered code is a different story):
I would rather go with 2. but I will compute some tests when I find the time to see which is simpler/brings better improvements. Any other ideas? |
I agree, I think 2 the way to go, ideally coordinated with some work on the sparse backend. |
Just to update, I have not forgotten this issue; I am working on a quite large PR that will include this change (bot probably not coordinated with the sparse back-end) |
CP via ALS is probably the most used function in TensorLy and comes with lots of options. One issue is that due to these successive additions, bugs (see e.g. this commit) and undue complexity are slowly creeping in while the code is becoming increasingly hard to read.
Another thing is efficiency: previously it was possible to have a fast version by setting the tolerance to 0 (i.e. no convergence test) but now the decomposition has become increasingly slow.
It might be good to have a review of where we're at, what features are actually needed and simplify the code a little.
The text was updated successfully, but these errors were encountered: