Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permit the user to pass a function to ArviZ to compute log likelihood on demand for memory-intensive models #2197

Open
pjheslin opened this issue Jan 22, 2023 · 3 comments

Comments

@pjheslin
Copy link

When my Stan model computes and saves the log likelihood of my data, the resulting files are huge. When I try to read these files in with arviz.from_cmdstan(), I run out of memory.

The R implementation of LOO-PSIS seems to have a provision for such cases. The documentation says that, instead of passing the log likelihood as an array or matrix, you can pass an R function to loo() to calculate and return the log likelihood of each data point separately, based on the data and the draws from the posterior: https://mc-stan.org/loo/reference/loo.html#methods-by-class-

There is a vignette which explains that in these cases, you should not use the Generated Quantities block of your Stan program to compute the log likelihood. Instead, you write an R function to calculate it that loo() will call repeatedly:
https://mc-stan.org/loo/articles/loo2-large-data.html

This way, the size of your Stan output does not explode, and you can just calculate the log likelihood for each point as loo() requires it without running out of memory.

I have looked for equivalent functionality in ArviZ, but cannot find it. Does it exist? It would be great not to have to switch to R for this.

Thanks!

@OriolAbril
Copy link
Member

I am not sure it is possible as of now, in general we use a slightly different approach for this which is using Dask to handle the chunking and computation graph organization. It integrates very well with xarray and in general is more efficient than looping over each sample as it works block-wise (loads blocks into memory) and parallelizes computations on a few blocks to make it possible to operate on data that doesn't fit on RAM.

I have used ess and rhat on data that doesn't fit in memory but I am not sure loo has been updated to allow this. And I know from_cmdstan would need to be updated (or alternatively don't compute the pointwise log likelihood in stan and do so later in python with dask).

Is this something you think you could help out? Or test from a PR? I think nobody on the team is currently using this so it has kind of been on the back burner for a bit.

@pjheslin
Copy link
Author

Yes, exactly -- the idea is that in these cases you refrain from computing the pointwise log likelihood in Stan. You compute it later in Python by passing ArviZ a function that gets called for each point. I don't know enough about ArviZ to say whether this should be an option specifically for loo as seems to be the case in R, or whether it should be integrated with from_cmdstan

I'd be happy to help with testing, but I don't know anything about xarray or Dask.

@OriolAbril
Copy link
Member

I'll try to use a linear regression to illustrate the steps. Say you have 20_000_000 observations and you run a regression on them multiple predictors so you generate 4000 posterior samples of the 5 slopes, the intercept and the sigma. There are no generated quantities so from_cmdstan will work properly.

You can then use python+xarray+dask to compute the pointwise log likelihood values without loading them into memory.

We start from the constant data+observed_data and the posterior groups. We will assume the posterior has 3 variables, 2 are chain, draw and the slopes are chain, draw, predictor so it is similar to the "centered_eight" example data. In constant data we have 1 variable with dims predictor, obs_id, observed_data also one variable with dims obs_id.

Our pointwise log likelihood will be of chain, draw, obs_id dimensions that is (4, 1000, 20_000_000). As float64, that array would take 640 Gb of ram (8bytes per number x 4000*20_000_000 numbers times 1e-9 to convert Gb). Depending on the RAM available, we could use chunks of 4, 1000, 5000 -> 160Mb or of 4, 1000, 8000 -> 250Mb, dask's docs have more info on chunking.

That would be something like:

from scipy.stats import norm # or any distribution
from xarray_einstats.stats import XrContinuousRV

const_data = idata.constant_data.chunk(obs_id=5000)
obs_data = idata.observed_data.chunk(obs_id=5000)
# optional, save chunks in inferencedata: idata.constant_data = const_data
post = idata.posterior
log_lik = XrContinuousRV(norm).logpdf(
    obs_data["y"],
    post["intercept"] + (post["slope"] * const_data["X"]).sum("predictor"),
    post["sigma"],
    apply_kwargs={"dask": "allowed"}  # you might need to try both "allowed" and "parallelized"
)
idata.add_groups(log_likelihood=log_lik.to_dataset(name="y")

If possible, dask="allowed" is preferred, but it doesn't always work, in which case you might need to fall back to "parallelized".

If you are using a normal you can also write down the operations for the log likelihood with python operators straight away, but I shared the "extra" step of using the einstats wrappers of scipy which allows computing that for any distribution even for dask arrays.

Eventually, you'd need to call az.loo using the dask integration as shown in https://python.arviz.org/en/stable/user_guide/Dask.html, but it isn't implemented yet. I'll try to send a PR next week as it is not very long to do. I figured I'd share this overview before to see if it made sense to you. It might also be a good idea for you to already start generating the log likelihood data as a dask array adapting the steps above to your case and then calling rhat on the log likelihood group. Even if it doesn't make sense, it will serve to check everything is working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants