Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose keep_mu option to users to make mu explicit in code and summary #1610

Closed
wants to merge 20 commits into from

Conversation

venpopov
Copy link
Contributor

@venpopov venpopov commented Mar 5, 2024

Summary

This PR exposes the option keep_mu in check_prefix() to the user-facing functions brm(), stancode(), standata(), validate_formula(), and default_prior(). The default is keep_mu = FALSE, in which case everything is the same as before. When keep_mu = TRUE, the mu parameter is labeled explicitely in the stancode, standata and the resulting summary object. Thus, mu is treated like any other parameter. I hope you like the idea - the behavior is completely optional but ti will be useful for our team.

Example

zinb <- read.csv("https://paul-buerkner.github.io/data/fish.csv")

# replicating https://paul-buerkner.github.io/brms/articles/brms_distreg.html#zero-inflated-models with keep_mu = TRUE
fit_mu <- brm(bf(count ~ persons + child + camper, zi ~ child),  data = zinb,
              family = zero_inflated_poisson(), backend = 'cmdstanr', keep_mu = TRUE)

summary(fit_mu)
 Family: zero_inflated_poisson 
  Links: mu = log; zi = logit 
Formula: count ~ persons + child + camper 
         zi ~ child
   Data: zinb (Number of observations: 250) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
mu_Intercept    -1.09      0.18    -1.44    -0.74 1.00     3123     2705
zi_Intercept    -0.96      0.26    -1.51    -0.49 1.00     3218     2168
mu_persons       0.90      0.05     0.81     0.99 1.00     3123     2986
mu_child        -1.18      0.10    -1.37    -0.99 1.00     3344     2254
mu_camper        0.77      0.10     0.59     0.96 1.00     4121     2786
zi_child         1.22      0.27     0.70     1.78 1.00     3089     2835

Details

This is a follow up on this issue. It does not allow arbitrary parameters to be the main, but it makes the mu parameter transparent, and handled much more like other parameters, which is a small step towards that. It is completely optional behavior, and I have tested it extensively to make sure that when the option is false (as by default), the current brms behavior is not affected.

I thought this will be a quick feature so I tried to implement directly and see if you like it. It turned out to be trickier than I thought and took much longer than I expected, but on the plus side I now I understand much better how brms works internally. I saw that check_prefix has an option keep_mu, but a lot of changes were necessary to make sure that nothing breaks when this option is exposed and set to TRUE.

The basic implementation is:

  • The keep_mu argument is passed explicitely only to validate_formula. Within validate_formula, a new attribute of formula$formula called "keep_mu" is created. This attribute gets also automatically copied to bterms$mu$formula
  • Similarly to the stan_center_X() function in stan_predictors.R, a new utility function stan_keep_mu() can be used to flexibly determine when this attribute affects behavior
  • The main addition is to check_prefix and combine_prefix(), which now have a default argument keep_mu = NULL (FALSE previously). If the keep_mu argument is NULL, they use stan_keep_mu() to determine whether to keep the mu prefix or not. This was necessary because otherwise "_mu" is sometimes added to the wrong variables such as the response Y.
  • a bunch of small changes to various functions in stan-likelihood.R were needed to ensure correct behavior for all models

Tests

  • To ensure that all models work properly, I copied the test files test.stancode and test.standata to the tests/local folder. Then I rewrote all the tests in the copied file with keep_mu = TRUE, and reworked the logic until all tests passed. For every stancode I also ran locally the corresponding model to ensure that they sample succesfully and produce expected output.
  • I put those tests in local, since there is a lot of duplication with the existing tests for keep_mu = FALSE and you probably need to decide which of the new tests you might want to keep, if you accep the pull requests

Additional features and changes

  • To ensure that priors work properly, set_prior no longer supresses "mu" when dpar="mu" (introduced originally in Feature request: prior() doesn't understand dpar = "mu" #1368). To keep backwards compatibility, this is supressed in validate_prior if keep_mu = FALSE (the default)
  • When keep_mu = TRUE, priors need to be specified with explicit 'dpar=mu'. To allow specifying priors without dpar=mu when keep_mu is TRUE, I changed the check for whether the prior is valid. A new internal function "repair_prior" checks if the prior would become valid if the invalid parts are given 'dpar=mu'. I added tests with positive and negative examples to ensure correct behavior
  • The keep_mu behavior can also be set via a new global option: options(brms.keep_mu = TRUE)
  • stancode and standata can be applied with option "keep_mu" to an existing fitted model from before this change. When this is done, they will regenerate the code, the stan data and the prior as if the model was run with keep_mu = TRUE (and vice versa). Potentially [have not done that] these can be combined into the restructure fuction to allow converting a model fitted with keep_mu = FALSE to keep_mu = TRUE

Below is an example code file with the new options:

Full example code of different functionality

Replicating https://paul-buerkner.github.io/brms/articles/brms_distreg.html#zero-inflated-models with keep_mu = TRUE

options(mc.cores = 4)
zinb <- read.csv("https://paul-buerkner.github.io/data/fish.csv")

fit_mu <- brm(bf(count ~ persons + child + camper, zi ~ child),  data = zinb,
              family = zero_inflated_poisson(), backend = 'cmdstanr', keep_mu = TRUE)

summary(fit_mu)
plot(fit_mu, variable = c("b_mu_persons", "b_zi_child"))
plot(conditional_effects(fit_mu, effects = "child", dpar = c("mu")), ask = FALSE)

prior no longer removes the mu parameter:

prior <- c(prior(normal(0, 1), class = "b", dpar = "mu"),
           prior(normal(0, 1), class = "b", dpar = "zi"))
prior

#>        prior class coef group resp dpar nlpar   lb   ub source
#> normal(0, 1)     b                   mu       <NA> <NA>   user
#> normal(0, 1)     b                   zi       <NA> <NA>   user

but this will be removed internally if keep_mu = FALSE (default) for backwards compatibility

validate_prior(prior, bf(count ~ persons + child + camper, zi ~ child),
               data = zinb,
               family = zero_inflated_poisson())
               
#>                   prior     class    coef group resp dpar nlpar lb ub       source
#>            normal(0, 1)         b                                             user
#>            normal(0, 1)         b  camper                             (vectorized)
#>            normal(0, 1)         b   child                             (vectorized)
#>            normal(0, 1)         b persons                             (vectorized)
#>            normal(0, 1)         b                      zi                     user
#>            normal(0, 1)         b   child              zi             (vectorized)
#> student_t(3, -2.3, 2.5) Intercept                                          default
#>          logistic(0, 1) Intercept                      zi                  default

and kept if keep_mu = TRUE. keep_mu can be set globally via options as well

options(brms.keep_mu = TRUE)
validate_prior(prior, bf(count ~ persons + child + camper, zi ~ child),
               data = zinb,
               family = zero_inflated_poisson())
           
#>                   prior     class    coef group resp dpar nlpar lb ub       source
#>            normal(0, 1)         b                      mu                     user
#>            normal(0, 1)         b  camper              mu             (vectorized)
#>            normal(0, 1)         b   child              mu             (vectorized)
#>            normal(0, 1)         b persons              mu             (vectorized)
#>            normal(0, 1)         b                      zi                     user
#>            normal(0, 1)         b   child              zi             (vectorized)
#> student_t(3, -2.3, 2.5) Intercept                      mu                  default
#>          logistic(0, 1) Intercept                      zi                  default

options(brms.keep_mu = TRUE)
validate_prior(prior, bf(count ~ persons + child + camper, zi ~ child),
               data = zinb,
               family = zero_inflated_poisson())
options(brms.keep_mu = FALSE)

stancode and standata also have the keep_mu argument

scode_mu <- stancode(bf(count ~ persons + child + camper, zi ~ child),  data = zinb,
                  family = zero_inflated_poisson(), keep_mu = TRUE)
scode_nomu <- stancode(bf(count ~ persons + child + camper, zi ~ child),  data = zinb,
                  family = zero_inflated_poisson(), keep_mu = FALSE)

Differences in the two codes:

Reduce(setdiff, strsplit(c(scode_mu, scode_nomu), split = "\n"))

 [1] "  int<lower=1> K_mu;  // number of population-level effects"                 
 [2] "  matrix[N, K_mu] X_mu;  // population-level design matrix"                  
 [3] "  int<lower=1> Kc_mu;  // number of population-level effects after centering"
 [4] "  matrix[N, Kc_mu] Xc_mu;  // centered version of X_mu without an intercept" 
 [5] "  vector[Kc_mu] means_X_mu;  // column means of X_mu before centering"       
 [6] "  for (i in 2:K_mu) {"                                                       
 [7] "    means_X_mu[i - 1] = mean(X_mu[, i]);"                                    
 [8] "    Xc_mu[, i - 1] = X_mu[, i] - means_X_mu[i - 1];"                         
 [9] "  vector[Kc_mu] b_mu;  // regression coefficients"                           
[10] "  real Intercept_mu;  // temporary intercept for centered predictors"        
[11] "  lprior += student_t_lpdf(Intercept_mu | 3, -2.3, 2.5);"                    
[12] "    mu += Intercept_mu + Xc_mu * b_mu;"                                      
[13] "  real b_mu_Intercept = Intercept_mu - dot_product(means_X_mu, b_mu);"  

Reduce(setdiff, strsplit(c(scode_nomu, scode_mu), split = "\n"))

 [1] "  int<lower=1> K;  // number of population-level effects"                 
 [2] "  matrix[N, K] X;  // population-level design matrix"                     
 [3] "  int<lower=1> Kc;  // number of population-level effects after centering"
 [4] "  matrix[N, Kc] Xc;  // centered version of X without an intercept"       
 [5] "  vector[Kc] means_X;  // column means of X before centering"             
 [6] "  for (i in 2:K) {"                                                       
 [7] "    means_X[i - 1] = mean(X[, i]);"                                       
 [8] "    Xc[, i - 1] = X[, i] - means_X[i - 1];"                               
 [9] "  vector[Kc] b;  // regression coefficients"                              
[10] "  real Intercept;  // temporary intercept for centered predictors"        
[11] "  lprior += student_t_lpdf(Intercept | 3, -2.3, 2.5);"                    
[12] "    mu += Intercept + Xc * b;"                                            
[13] "  real b_Intercept = Intercept - dot_product(means_X, b);"  

stancode and standata can regenerate the stan code and data from older fits with keep_mu = TRUE

old_model <- rename_pars(brmsfit_example3)
scode_mu2 <- stancode(old_model, keep_mu = TRUE)
Reduce(setdiff, strsplit(c(scode_mu2, old_model$model), split = "\n"))

 [1] "// generated with brms 2.20.16"                                                                                                                                                                                                                                               
 [2] "  int<lower=1> K_mu;  // number of population-level effects"                                                                                                                                                                                                                  
 [3] "  matrix[N, K_mu] X_mu;  // population-level design matrix"                                                                                                                                                                                                                   
 [4] "  int<lower=1> Kc_mu;  // number of population-level effects after centering"                                                                                                                                                                                                 
 [5] "  int<lower=1> Ksp_mu;  // number of special effects terms"                                                                                                                                                                                                                   
 [6] "  vector[N] Csp_mu_1;"                                                                                                                                                                                                                                                        
 [7] "  vector[N] Z_1_mu_1_1;"                                                                                                                                                                                                                                                      
 [8] "  vector[N] Z_1_mu_1_2;"                                                                                                                                                                                                                                                      
 [9] "  vector[N] Z_1_mu_2_1;"                                                                                                                                                                                                                                                      
[10] "  vector[N] Z_1_mu_2_2;"                                                                                                                                                                                                                                                      
[11] "  matrix[N, Kc_mu] Xc_mu;  // centered version of X_mu without an intercept"                                                                                                                                                                                                  
[12] "  vector[Kc_mu] means_X_mu;  // column means of X_mu before centering"                                                                                                                                                                                                        
[13] "  for (i in 2:K_mu) {"                                                                                                                                                                                                                                                        
[14] "    means_X_mu[i - 1] = mean(X_mu[, i]);"                                                                                                                                                                                                                                     
[15] "    Xc_mu[, i - 1] = X_mu[, i] - means_X_mu[i - 1];"                                                                                                                                                                                                                          
[16] "  vector[Kc_mu] b_mu;  // regression coefficients"                                                                                                                                                                                                                            
[17] "  real Intercept_mu;  // temporary intercept for centered predictors"                                                                                                                                                                                                         
[18] "  vector[Ksp_mu] bsp_mu;  // special effects coefficients"                                                                                                                                                                                                                    
[19] "  vector[N_1] r_1_mu_1;"                                                                                                                                                                                                                                                      
[20] "  vector[N_1] r_1_mu_2;"                                                                                                                                                                                                                                                      
[21] "  r_1_mu_1 = r_1[, 1];"                                                                                                                                                                                                                                                       
[22] "  r_1_mu_2 = r_1[, 2];"                                                                                                                                                                                                                                                       
[23] "  lprior += normal_lpdf(b_mu | 0, 10);"                                                                                                                                                                                                                                       
[24] "  lprior += student_t_lpdf(Intercept_mu | 3, 19.5, 4.4);"                                                                                                                                                                                                                     
[25] "  lprior += normal_lpdf(bsp_mu | 0, 10);"                                                                                                                                                                                                                                     
[26] "    mu += Intercept_mu;"                                                                                                                                                                                                                                                      
[27] "      mu[n] += (bsp_mu[1]) * Xme_1[n] + (bsp_mu[2]) * Xme_1[n] * Csp_mu_1[n] + W_1_1[n] * r_1_mu_1[J_1_1[n]] * Z_1_mu_1_1[n] + W_1_2[n] * r_1_mu_1[J_1_2[n]] * Z_1_mu_1_2[n] + W_1_1[n] * r_1_mu_2[J_1_1[n]] * Z_1_mu_2_1[n] + W_1_2[n] * r_1_mu_2[J_1_2[n]] * Z_1_mu_2_2[n];"
[28] "    target += normal_id_glm_lpdf(Y | Xc_mu, mu, b_mu, sigma);"                                                                                                                                                                                                                
[29] "  real b_mu_Intercept = Intercept_mu - dot_product(means_X_mu, b_mu);" 

sdata_mu2 <- standata(old_model, keep_mu = TRUE)
sdata_old <- standata(old_model)
setdiff(names(sdata_mu2), names(sdata_old))

[1] "K_mu"       "Kc_mu"      "X_mu"       "Ksp_mu"     "Csp_mu_1"   "Z_1_mu_1_1" "Z_1_mu_1_2" "Z_1_mu_2_1" "Z_1_mu_2_2"

@venpopov venpopov changed the title Expose keep_mu option to users Expose keep_mu option to users to make mu explicit in code and summary Mar 5, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 90.58824% with 8 lines in your changes are missing coverage. Please review.

Project coverage is 82.08%. Comparing base (9e7d825) to head (52d2e62).

Files Patch % Lines
R/stan-likelihood.R 84.37% 5 Missing ⚠️
R/priors.R 95.23% 1 Missing ⚠️
R/stancode.R 83.33% 1 Missing ⚠️
R/standata.R 75.00% 1 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1610      +/-   ##
==========================================
+ Coverage   82.06%   82.08%   +0.01%     
==========================================
  Files          70       70              
  Lines       19886    19928      +42     
==========================================
+ Hits        16320    16358      +38     
- Misses       3566     3570       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@paul-buerkner
Copy link
Owner

Thank you a lot!

The PR is a bit overwhelming though. I wish you would have discussed with me first before doing this major work. The challenge I have now is that it is hard for me to tell whether the approach you choose is a good approach to maintain in the future.

Or, to put it differently, perhaps it would make sense to abandon the "mu requirement" altogether in some more general sense (not clear on the implications yet). I any case, I will need some time to look into your PR and determine whether this is a good way forward.

@paul-buerkner paul-buerkner added this to the 2.21.0 milestone Mar 5, 2024
@venpopov
Copy link
Contributor Author

venpopov commented Mar 5, 2024

I completely agree, I should have opened an issue to discuss this. I thought it would be a much smaller change but then got absorbed by making it work and by the time I realized I should have open a discussion I was almost done. So I decided to post the PR and use that as a discussion. It's on me, so I don't want to put any pressure on you - if you decide this is not a good approach that's totally fine!

@venpopov
Copy link
Contributor Author

Closing this as we've found a way to circumvent the issue in our own package, and as we discussed this would be up for a more general restructuring of how mu is treated in brms 3.0 #1660

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants