Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug in brms/emmeans integration #1654

Open
wlandau opened this issue May 15, 2024 · 6 comments
Open

Possible bug in brms/emmeans integration #1654

wlandau opened this issue May 15, 2024 · 6 comments

Comments

@wlandau
Copy link

wlandau commented May 15, 2024

Related: #1630, https://discourse.mc-stan.org/t/trouble-with-brms-emmeans-integration/34664. I am posting here because I think the issue might be a bug in brms, and the comment section in my Stan Discourse post has not been active.

brms integrates with emmeans for marginal mean calculations, but the results seem off. The reprex below uses the mmrm package's FEV1 dataset, a simulation of a clinical trial with treatment groups in ARMCD and discrete time points for repeated measures in AVISIT. The example compares 4 different methods of estimating marginal means for each combination of ARMCD and AVISIT:

  1. Data summaries: compute means and independent frequentist 95% confidence intervals on the raw data.
  2. lm() + emmeans: fit a model with lm() and get marginal means with emmeans.
  3. brms + custom: fit a model with brms and use a custom linear transformation to map model parameters to marginal means.
  4. brms + emmeans: use the native brms/emmeans integration to estimate marginal means from the fitted brms model.

There is reasonable agreement among approaches (1), (2), and (3), and approach (4) gives very different results from all the others. I ran the following on the current development version of brms in the master branch (298b947)

suppressPackageStartupMessages({
  library(brms)
  library(coda)
  library(emmeans)
  library(mmrm)
  library(posterior)
  library(tidyverse)
  library(zoo)
})
emm_options(sep = "|")

packageDescription("brms")$GithubSHA1
#> [1] "298b947fa9cfb914aeb7cb3aab7974aa179682b1"

# FEV data from the mmrm package, using LOCF and then LOCF reversed
# to impute responses. (For this discussion, it is helpful to avoid
# the topic of missingness.)
data(fev_data, package = "mmrm")
data <- fev_data %>%
  mutate(FEV1_CHG = FEV1 - FEV1_BL, USUBJID = as.character(USUBJID)) %>%
  select(-FEV1) %>%
  group_by(USUBJID) %>%
  complete(
    AVISIT,
    fill = as.list(.[1L, c("ARMCD", "FEV1_BL", "RACE", "SEX", "WEIGHT")])
  ) %>%
  ungroup() %>%
  arrange(USUBJID, AVISIT) %>%
  group_by(USUBJID) %>%
  mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE)) %>%
  mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE, fromLast = TRUE)) %>%
  ungroup() %>%
  filter(!is.na(FEV1_CHG))
summary_data <- data %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(
    source = "1_data",
    mean = mean(FEV1_CHG),
    lower = mean(FEV1_CHG) - qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
    upper = mean(FEV1_CHG) + qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
    .groups = "drop"
  )

# Formula shared by all the models
formula <- FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT +
  AVISIT + RACE + SEX + WEIGHT

# lm with emmeans
model_lm <- lm(formula = formula, data = data)
summary_lm_emmeans <- emmeans(
  object = model_lm,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
) %>%
  as.data.frame() %>%
  as_tibble() %>%
  select(ARMCD, AVISIT, emmean, lower.CL, upper.CL) %>%
  rename(mean = emmean, lower = lower.CL, upper = upper.CL) %>%
  mutate(source = "2_lm_emmeans")

# brms with emmeans
model_brms <- brm(data = data, formula = brmsformula(formula))
summary_brms_emmeans <- emmeans(
  object = model_brms,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
) %>%
  as.data.frame() %>%
  as_tibble() %>%
  select(ARMCD, AVISIT, emmean, lower.HPD, upper.HPD) %>%
  rename(mean = emmean, lower = lower.HPD, upper = upper.HPD) %>%
  mutate(source = "4_brms_emmeans")

# custom marginal means from brms draws using a custom mapping
# from brms model parameters to marginal means. I would expect the
# emmeans/brms integration to agree with the results below
# (within rounding error + MCMC error), based on what I find with lm()
# (c.f. https://github.com/openpharma/brms.mmrm/issues/53)
proportional_factors <- brmsformula(FEV1_CHG ~ 0 + SEX + RACE) %>%
  make_standata(data = data) %>%
  .subset2("X") %>%
  colMeans() %>%
  t()
grid <- data %>%
  mutate(FEV1_BL = mean(FEV1_BL), FEV1_CHG = 0, WEIGHT = mean(WEIGHT)) %>%
  distinct(ARMCD, AVISIT, FEV1_BL, WEIGHT, FEV1_CHG)
draws_parameters <- model_brms %>%
  as_draws_df() %>%
  as_tibble() %>%
  select(starts_with("b_"), -starts_with("b_sigma"))
mapping <- brmsformula(
    FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT + AVISIT + WEIGHT
  ) %>%
  make_standata(data = grid) %>%
  .subset2("X") %>%
  bind_cols(proportional_factors) %>%
  setNames(paste0("b_", colnames(.)))
stopifnot(all(colnames(draws_parameters) %in% colnames(mapping)))
mapping <- as.matrix(mapping)[, colnames(draws_parameters)]
rownames(mapping) <- paste(grid$ARMCD, grid$AVISIT, sep = "|")
draws_custom <- as.matrix(draws_parameters) %*% t(mapping) %>%
  as.data.frame() %>%
  as_tibble()
summary_brms_custom <- draws_custom %>%
  pivot_longer(everything()) %>%
  separate("name", c("ARMCD", "AVISIT")) %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(
    source = "3_brms_custom",
    mean = mean(value),
    lower = quantile(value, 0.025),
    upper = quantile(value, 0.975),
    .groups = "drop"
  )

# Compare results
summary <- bind_rows(
  summary_data,
  summary_lm_emmeans,
  summary_brms_custom,
  summary_brms_emmeans
)
ggplot(summary) +
  geom_point(aes(x = source, y = mean, color = source)) +
  geom_errorbar(aes(x = source, ymin = lower, ymax = upper, color = source)) +
  facet_grid(ARMCD ~ AVISIT) +
  theme_gray(16) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) +
  ylab("FEV1_CHG")

Screenshot 2024-05-15 at 2 27 51 PM

sessionInfo()
#> R version 4.4.0 (2024-04-24)
#> Platform: aarch64-apple-darwin20
#> Running under: macOS Sonoma 14.5
#>
#> Matrix products: default
#> BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.0
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> time zone: America/Indiana/Indianapolis
#> tzcode source: internal
#>
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods
#> [7] base
#>
#> other attached packages:
#>  [1] zoo_1.8-12       lubridate_1.9.3  forcats_1.0.0
#>  [4] stringr_1.5.1    dplyr_1.1.4      purrr_1.0.2
#>  [7] readr_2.1.5      tidyr_1.3.1      tibble_3.2.1
#> [10] ggplot2_3.5.1    tidyverse_2.0.0  posterior_1.5.0
#> [13] mmrm_0.3.11      emmeans_1.10.1   coda_0.19-4.1
#> [16] brms_2.21.3      Rcpp_1.0.12      abind_1.4-5
#> [19] drake_7.13.10    testthat_3.2.1.1
#>
#> loaded via a namespace (and not attached):
#>   [1] Rdpack_2.6           txtq_0.2.4           gridExtra_2.3
#>   [4] remotes_2.5.0        inline_0.3.19        rlang_1.1.3
#>   [7] magrittr_2.0.3       matrixStats_1.3.0    compiler_4.4.0
#>  [10] loo_2.7.0            callr_3.7.6          vctrs_0.6.5
#>  [13] profvis_0.3.8        pkgconfig_2.0.3      crayon_1.5.2
#>  [16] fastmap_1.1.1        backports_1.4.1      ellipsis_0.3.2
#>  [19] utf8_1.2.4           promises_1.3.0       tzdb_0.4.0
#>  [22] sessioninfo_1.2.2    ps_1.7.6             waldo_0.5.2
#>  [25] cachem_1.0.8         jsonlite_1.8.8       progress_1.2.3
#>  [28] later_1.3.2          parallel_4.4.0       prettyunits_1.2.0
#>  [31] R6_2.5.1             StanHeaders_2.32.7   stringi_1.8.4
#>  [34] parallelly_1.37.1    pkgload_1.3.4        estimability_1.5
#>  [37] brio_1.1.5           bindr_0.1.1          rstan_2.32.6
#>  [40] usethis_2.2.3        bayesplot_1.11.1     httpuv_1.6.15
#>  [43] Matrix_1.7-0         igraph_2.0.3         timechange_0.3.0
#>  [46] tidyselect_1.2.1     rstudioapi_0.16.0    codetools_0.2-20
#>  [49] miniUI_0.1.1.1       curl_5.2.1           processx_3.8.4
#>  [52] listenv_0.9.1        pkgbuild_1.4.4       lattice_0.22-6
#>  [55] shiny_1.8.1.1        withr_3.0.0          bridgesampling_1.1-2
#>  [58] future_1.33.2        desc_1.4.3           RcppParallel_5.1.7
#>  [61] urlchecker_1.0.1     pillar_1.9.0         fstcore_0.9.18
#>  [64] filelock_1.0.3       tensorA_0.36.2.1     checkmate_2.3.1
#>  [67] renv_1.0.7           stats4_4.4.0         distributional_0.4.0
#>  [70] generics_0.1.3       rprojroot_2.0.4      hms_1.1.3
#>  [73] rstantools_2.4.0     munsell_0.5.1        scales_1.3.0
#>  [76] storr_1.2.5          globals_0.16.3       xtable_1.8-4
#>  [79] base64url_1.4        glue_1.7.0           tools_4.4.0
#>  [82] data.table_1.15.4    fs_1.6.4             mvtnorm_1.2-4
#>  [85] grid_4.4.0           rbibutils_2.2.16     QuickJSR_1.1.3
#>  [88] devtools_2.4.5       colorspace_2.1-0     nlme_3.1-164
#>  [91] cli_3.6.2            fst_0.9.8            fansi_1.0.6
#>  [94] Brobdingnag_1.2-9    V8_4.4.2             gtable_0.3.5
#>  [97] digest_0.6.35        htmlwidgets_1.6.4    memoise_2.0.1
#> [100] htmltools_0.5.8.1    lifecycle_1.0.4      mime_0.12
@paul-buerkner
Copy link
Owner

Thank you for reporting this issue. I am no emmeans expert so for me it's hard to tell what is going on. @rvlenth do you happen to have an idea perhaps?

@rvlenth
Copy link
Contributor

rvlenth commented May 16, 2024

I have no clue.

I am bothered by the fact that there are two (very) different objects named model in this code.

As for the "custom" code, I disagree that it is what emmeans should be doing, simply because whatever all that stuff is, it shouldn't be that complex.

My suggestion for finding out more is to try this, using the second version of model, the one that was produced by brm().

emm_itself <- emmeans(
  object = model,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
)

summary(emm_itself)

So far, we are now seeing directly what emmeans is producing. Are the estimates the same as those in the plot? Do the annotations below the summary provide additional information that was never seen because it was swept away by all the "tidy" post-processing? Because the estimate in the summary is the median of the posterior, how about the results in summary(emm_itself, point.est = mean)?

If you still see the serious discrepancies, do this:

newdata <-emmeans::ref_grid(model)@grid

This gives you the grid of all fixed-effects factors, which is the basis for all emmeans calculations.
Then use brms functions/methods to obtain predictions from model, with newdata as new data. Average those results together over all but the two primary factors, using appropriate weights. That's what emmeans should be doing.

@wlandau
Copy link
Author

wlandau commented May 17, 2024

I am bothered by the fact that there are two (very) different objects named model in this code.

Edited #1654 (comment) to use model_lm and model_brms.

As for the "custom" code, I disagree that it is what emmeans should be doing, simply because whatever all that stuff is, it shouldn't be that complex.

Edited #1654 (comment) to clarify that comment.

So far, we are now seeing directly what emmeans is producing. Are the estimates the same as those in the plot?

Yes:

emm_itself <- emmeans(
  object = model_brms,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
)
summary(emm_itself)
#>  ARMCD AVISIT emmean lower.HPD upper.HPD
#>  PBO   VIS1   -18.08    -22.29   -13.617
#>  TRT   VIS1   -14.81    -18.28   -11.236
#>  PBO   VIS2   -16.08    -19.48   -12.372
#>  TRT   VIS2   -12.68    -16.50    -9.329
#>  PBO   VIS3   -12.53    -15.84    -8.909
#>  TRT   VIS3    -9.71    -13.31    -6.333
#>  PBO   VIS4    -7.93    -11.54    -4.363
#>  TRT   VIS4    -3.46     -7.11     0.102
#> 
#> Results are averaged over the levels of: 2 nuisance factors 
#> Point estimate displayed: median 
#> HPD interval probability: 0.95 
as.data.frame(summary_brms_emmeans)
#>   ARMCD AVISIT       mean      lower       upper         source
#> 1   PBO   VIS1 -18.083219 -22.287808 -13.6172246 4_brms_emmeans
#> 2   TRT   VIS1 -14.812490 -18.276953 -11.2362895 4_brms_emmeans
#> 3   PBO   VIS2 -16.079485 -19.477840 -12.3717137 4_brms_emmeans
#> 4   TRT   VIS2 -12.679113 -16.503203  -9.3292318 4_brms_emmeans
#> 5   PBO   VIS3 -12.527884 -15.841525  -8.9088424 4_brms_emmeans
#> 6   TRT   VIS3  -9.709981 -13.307893  -6.3334955 4_brms_emmeans
#> 7   PBO   VIS4  -7.928348 -11.537075  -4.3630501 4_brms_emmeans
#> 8   TRT   VIS4  -3.462008  -7.109919   0.1019503 4_brms_emmeans
summary(emm_itself)$emmean - summary_brms_emmeans$mean
#> [1] 0 0 0 0 0 0 0 0
summary(emm_itself)$lower.HPD - summary_brms_emmeans$lower
#> [1] 0 0 0 0 0 0 0 0
summary(emm_itself)$upper.HPD - summary_brms_emmeans$upper
#> [1] 0 0 0 0 0 0 0 0

Do the annotations below the summary provide additional information that was never seen because it was swept away by all the "tidy" post-processing?

The summary says the results are averaged over two nuisance variables, whereas the code supplies three. I am not sure why, or if it matters here. This makes sense because there are no fixed effects for USUBJID.

Because the estimate in the summary is the median of the posterior, how about the results in summary(emm_itself, point.est = mean)?

Only slight differences:

summary_emmeans <- summary(emm_itself, point.est = mean)
max(abs(summary_emmeans$emmean - summary_brms_emmeans$mean))
#> [1] 0.0202332

If you still see the serious discrepancies, do this:

newdata <-emmeans::ref_grid(model)@grid

This gives you the grid of all fixed-effects factors, which is the basis for all emmeans calculations.
Then use brms functions/methods to obtain predictions from model, with newdata as new data. Average those results together over all but the two primary factors, using appropriate weights.

When I do that, I see close enough agreement with the native lm()/emmeans integration, but strong disagreement between the brms/emmeans integration.

# Predictions
new_data <- emmeans::ref_grid(model_brms)@grid
predictions <- predict(model_brms, newdata = new_data)
grid <- mutate(new_data, estimate = predictions[, "Estimate"])

# Proportional weights
weighted_grid <- grid %>%
  left_join(y = count(data, RACE, SEX), by = c("RACE", "SEX")) %>%
  rename(.wgt. = n)

# Marginal means
custom <- weighted_grid %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(mean = sum(estimate * .wgt.) / sum(.wgt.)) %>%
  arrange(AVISIT, ARMCD)
custom
#> # A tibble: 8 × 3
#> # Groups:   ARMCD [2]
#>   ARMCD AVISIT   mean
#>   <fct> <fct>   <dbl>
#> 1 PBO   VIS1   -4.67 
#> 2 TRT   VIS1   -1.24 
#> 3 PBO   VIS2   -2.47 
#> 4 TRT   VIS2    0.957
#> 5 PBO   VIS3    1.00 
#> 6 TRT   VIS3    3.78 
#> 7 PBO   VIS4    5.57 
#> 8 TRT   VIS4   10.1  

# Good enough agreement with lm marginal means
summary_lm_emmeans
#> # A tibble: 8 × 6
#>   ARMCD AVISIT   mean  lower  upper source      
#>   <fct> <fct>   <dbl>  <dbl>  <dbl> <chr>       
#> 1 PBO   VIS1   -4.60  -5.98  -3.22  2_lm_emmeans
#> 2 TRT   VIS1   -1.29  -2.76   0.185 2_lm_emmeans
#> 3 PBO   VIS2   -2.54  -3.92  -1.17  2_lm_emmeans
#> 4 TRT   VIS2    0.847 -0.625  2.32  2_lm_emmeans
#> 5 PBO   VIS3    0.984 -0.393  2.36  2_lm_emmeans
#> 6 TRT   VIS3    3.80   2.33   5.27  2_lm_emmeans
#> 7 PBO   VIS4    5.60   4.22   6.98  2_lm_emmeans
#> 8 TRT   VIS4   10.1    8.58  11.5   2_lm_emmeans

max(abs(custom$mean - summary_lm_emmeans$mean))
#> [1] 0.1104108

# Disagreement with the native emmeans/brms integration
max(abs(custom$mean - summary_brms_emmeans$mean))
#> [1] 13.63619

@wlandau
Copy link
Author

wlandau commented May 20, 2024

Also, thanks for explaining the role of emmeans::ref_grid(model_brms)@grid in the weighting technique. This object is basically an expand.grid() over the unique levels of all the factors in the fixed effects, including nuisance factors, with continuous variables set at their observed grand means. Each row in the grid is given a weight, and I guess these weights are used to estimate marginal means as weighted averages over rows of predicted responses in the grid. This is the most direct and edifying explanation I have seen about how exactly the reference grid works and what exactly we mean by a "weight" in emmeans. (I read the help files, https://www.jstatsoft.org/article/view/v069i01, and all the vignettes, but I still missed these concepts.) Very helpful.

But whether we take the emmeans the two-step approach of predict() + weighting, or we use my reprex's one-step linear transformation from model coefficients to marginal means, the results appear to agree on the frequentist model.

# Create the reference grid.
new_data <- emmeans::ref_grid(model_lm)@grid
grid <- mutate(new_data, estimate = predict(model_lm, newdata = new_data))

# Apply proportional weights.
weighted_grid <- grid %>%
  left_join(y = count(data, RACE, SEX), by = c("RACE", "SEX")) %>%
  mutate(.wgt. = n)

# Compute marginal means using the weighted grid.
summary_lm_emmeans_using_grid <- weighted_grid %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(mean = sum(estimate * .wgt.) / sum(.wgt.)) %>%
  arrange(AVISIT, ARMCD)

# Both approaches agree:
max(abs(summary_lm_emmeans_using_grid$mean - summary_lm_emmeans$mean))
#> [1] 5.329071e-15

@rvlenth
Copy link
Contributor

rvlenth commented May 23, 2024

We can go all over the place looking at examples and trying to guess what is done, but it shouldn't be too difficult to tell by looking at the code.

The emmeans package provides the infrastructure, but what it does to actually estimate things depends on the emm_basis method for that model class, and in this case that method is part of the package code for brms. Here is that code, copied here for convenience:

> brms:::emm_basis.brmsfit

function (object, trms, xlev, grid, vcov., resp = NULL, dpar = NULL, 
    nlpar = NULL, re_formula = NA, epred = FALSE, ...) 
{
    if (is_equal(dpar, "mean")) {
        warning2("dpar = 'mean' is deprecated. Please use epred = TRUE instead.")
        epred <- TRUE
        dpar <- NULL
    }
    epred <- as_one_logical(epred)
    bterms <- .extract_par_terms(object, resp = resp, dpar = dpar, 
        nlpar = nlpar, re_formula = re_formula, epred = epred)
    if (epred) {
        post.beta <- posterior_epred(object, newdata = grid, 
            re_formula = re_formula, resp = resp, incl_autocor = FALSE, 
            ...)
    }
    else {
        req_vars <- all_vars(bterms$allvars)
        post.beta <- posterior_linpred(object, newdata = grid, 
            re_formula = re_formula, resp = resp, dpar = dpar, 
            nlpar = nlpar, incl_autocor = FALSE, req_vars = req_vars, 
            transform = FALSE, offset = FALSE, ...)
    }
    if (anyNA(post.beta)) {
        stop2("emm_basis.brmsfit created NAs. Please check your reference grid.")
    }
    misc <- bterms$.misc
    if (length(dim(post.beta)) == 3L) {
        ynames <- dimnames(post.beta)[[3]]
        if (is.null(ynames)) {
            ynames <- as.character(seq_len(dim(post.beta)[3]))
        }
        dims <- dim(post.beta)
        post.beta <- matrix(post.beta, ncol = prod(dims[2:3]))
        misc$ylevs = list(rep.meas = ynames)
    }
    attr(post.beta, "n.chains") <- object$fit@sim$chains
    X <- diag(ncol(post.beta))
    bhat <- apply(post.beta, 2, mean)
    V <- cov(post.beta)
    nbasis <- matrix(NA)
    dfargs <- list()
    dffun <- function(k, dfargs) Inf
    environment(dffun) <- baseenv()
    nlist(X, bhat, nbasis, V, dffun, dfargs, misc, post.beta)
}

In the arguments, object is the model object, trms is a terms component, and grid is a data frame with the factor combinations in the reference grid. The function is supposed to set us up to produce predictions with grid as new data. In the returned list, bhat is the regression coefficients and X is the matrix of linear functions such that X %*% bhat obtains the predictions. (More important post.beta is the posterior sample for bhat.) This particular function has a few optional brmsfit-specific arguments resp,dpar,nlpar,re_formula, epred which - as this isn't my package - I am in no position to explain, but they affect how things get set up. Some of them are mentioned in the help for predict.brmsfit.

This is not a very complex function (seems simpler than a lot of the code in this issue), and I suggets trying to understand what it does. For example, maybe what you need to do is add the argument epred = TRUE?

@rvlenth
Copy link
Contributor

rvlenth commented May 23, 2024

@wlandau PS -- of course, you should also look at ? emm_basis.brmsfit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants