Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient computation tutorial #42

Open
yoavram opened this issue Dec 28, 2022 · 1 comment
Open

Gradient computation tutorial #42

yoavram opened this issue Dec 28, 2022 · 1 comment

Comments

@yoavram
Copy link

yoavram commented Dec 28, 2022

Hi
I think the docs are missing an example of how we can use sunode to compute gradients wrt to model parameters using the adjoint solver.
I'm not even sure what gradients are computed - of the solution ? Sundials talks about a "derived function", which I take to be , for example, a loss function. But not sure how this applies to sunode.
I understand I should use solve_backward; but the function is not documents. I'm especially confused about what the "grads" argument is, and what the output variables are - "grad_out" vs "lamda_out".
If you explain this I would be happy to share a notebook that uses sunode to fit an ODE - this could be useful for the documentation.
Thanks!

@aseyboldt
Copy link
Member

Hi :-)

You are right, we could really use some documentation on that. If you can share a notebook that we can add to the docs, that would be great!

If you use solve_ivp from the as_pytensor (the old as_aesara) module, the adjoint solver will be used by default controlled by the derivatives: str = 'adjoint' argument).

About your questions:

The adjoint solver corresponds to the backward step in reverse mode autodiff, or the pullback from differential geometry.

We assume that we want to compute the gradient of some large function $h: \mathbb{R}^n \to \mathbb{R}$. In an application that could for instance be a posterior log probability function, that maps parameter values to their unnormalized density. We split this large function into smaller parts, and one of those parts would be the function $f$ that solves the ODE, so $h(x) = g(f(x))$, where $f$ is a function $f: \mathbb{R}^n \to \mathbb{R}^m$. The $\mathbb{R}^n$ contains all parameters, initial conditions, and time points where we want to evaluate the solution. $\mathbb{R}^m$ contains the solution of the ode at those points. And the other part $g$ is the function that maps the solution of the ode to a log prob value (ie a likelihood). When we compute the gradient of $h$ we can isolate the contribution of the function $f$ using the chain rule, and basically ask: If we already know the gradient of the later parts of $h$, namely $g$, what would then be the gradient of $h$? So we define a function that takes those gradients of $g$ as input, and returns the gradients of $h$. This is exactly what happens in solve_backward. The gradients of $g$ are called grads in the code. The final gradients of $h$ are split in two parts: grads_out, for the gradients with respect to the parameters, and lambda_out for the gradients with respect to the initial conditions (-lambda_out actually, that's how this was defined by sundials for some reason...).

This is essentially also how sundials does things internally, only that it generalizes it a bit more. The idea is that the way we think about "the function that solves the ODE $f$" isn't as general as it could be. Instead of just asking what the solution will be at certain points, we could also say that $f$ should return the solution function. Which means that $f$ is a function from parameters and initial conditions to the solution function of the ODE. And correspondingly $g$ would then be a function that takes a function as input and returns a scalar. This allows some things that sunode doesn't support
currently, like computing the gradient of an integral over the solution. So for instance you could have a loss function that compares the solution function to a target solution.

In what context are you using sunode? If you don't use the pytensor wrappers, you'll have to apply the chain rule yourself to get gradients of the composite function.

I hope this explanation is helping at least a bit, feel free to ask for clarification if something is not clear, this isn't the easiest subject to write about. :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants