Skip to content

Optimizing a variable loss function #233

Answered by jcmgray
wkretschmer asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @wkretschmer. No I think sadly this kind of 'non-optimized variable' is not supported at the moment. It would be a nice feature to have, especially for jax, which indeed has very slow compilation.

My suggestions would be:

  • use a different autodiff_backend or try turning jit_fn=False - might be slow however.
  • optimize the expression outside of TNOptimizer - see e.g. https://quimb.readthedocs.io/en/latest/examples/ex_quimb_within_jax_flax_optax.html - maybe with the qtn.pack / qtn.unpack functions
  • try and add the feature! I think its probably not too difficult (especially if only supporting the jax backend), if a bit fiddly. Basically one would want to pass something like 'loss_variables'…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by wkretschmer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants