Simple Emax model fit with Stan

Kenta Yoshida

2024-12-06

library(rstanemax)
library(dplyr)
library(ggplot2)
set.seed(12345)

This vignette provide an overview of the workflow of Emax model analysis using this package.

Typical workflow

Model run with stan_emax function

stan_emax() is the main function of this package to perform Emax model analysis on the data. This function requires minimum two input arguments - formula and data. In the formula argument, you will specify which columns of data will be used as exposure and response data, in a format similar to stats::lm() function, e.g. response ~ exposure.

data(exposure.response.sample)

fit.emax <- stan_emax(response ~ exposure,
  data = exposure.response.sample,
  # the next line is only to make the example go fast enough
  chains = 2, iter = 1000, seed = 12345
)
fit.emax
#> ---- Emax model fit with rstanemax ----
#> 
#>        mean se_mean    sd  2.5%   25%   50%   75%  97.5%  n_eff Rhat
#> emax  91.98    0.30  5.81 79.98 88.05 91.82 95.83 103.07 373.88 1.01
#> e0     5.60    0.24  4.62 -3.49  2.80  5.62  8.40  15.34 374.00 1.00
#> ec50  75.01    0.87 20.13 43.85 60.95 71.67 86.46 122.91 537.42 1.00
#> gamma  1.00     NaN  0.00  1.00  1.00  1.00  1.00   1.00    NaN  NaN
#> sigma 16.55    0.06  1.57 13.84 15.44 16.42 17.56  19.78 804.82 1.00
#> 
#> * Use `extract_stanfit()` function to extract raw stanfit object
#> * Use `extract_param()` function to extract posterior draws of key parameters
#> * Use `plot()` function to visualize model fit
#> * Use `posterior_predict()` or `posterior_predict_quantile()` function to get
#>   raw predictions or make predictions on new data
#> * Use `extract_obs_mod_frame()` function to extract raw data 
#>   in a processed format (useful for plotting)

plot() function shows the estimated Emax model curve with 95% credible intervals of parameters.

plot(fit.emax)

Output of plot() function (for stanemax object) is a ggplot object, so you can apply additional settings as you would do in ggplot2.
Here is an example of using log scale for x axis (note that exposure == 0 is hanging at the very left, making the curve a bit weird).

plot(fit.emax) + scale_x_log10() + expand_limits(x = 1)
#> Warning in scale_x_log10(): log-10 transformation introduced infinite values.
#> log-10 transformation introduced infinite values.
#> log-10 transformation introduced infinite values.

Raw output from rstan is stored in the output variable, and you can access it with extract_stanfit() function.

class(extract_stanfit(fit.emax))
#> [1] "stanfit"
#> attr(,"package")
#> [1] "rstan"

Prediction of response with new exposure data

posterior_predict() function allows users to predict the response using new exposure data. If newdata is not provided, the function returns the prediction on the exposures in original data. The default output is a matrix of posterior predictions, but you can also specify “dataframe” or “tibble” that contain posterior predictions in a long format. See help of rstanemax::posterior_predict() for the description of two predictions, respHat and response.

response.pred <- posterior_predict(fit.emax, newdata = c(0, 100, 1000), returnType = "tibble")

response.pred %>% select(mcmcid, exposure, respHat, response)
#> # A tibble: 3,000 × 4
#>    mcmcid exposure   respHat response
#>     <int>    <dbl> <dbl[1d]>    <dbl>
#>  1      1        0      7.85     9.64
#>  2      1      100     57.4     79.8 
#>  3      1     1000     90.8     85.0 
#>  4      2        0     11.0     -3.31
#>  5      2      100     62.4     58.8 
#>  6      2     1000     95.0     87.4 
#>  7      3        0      2.74   -15.5 
#>  8      3      100     56.7     56.8 
#>  9      3     1000     93.8     85.4 
#> 10      4        0      6.86    49.2 
#> # ℹ 2,990 more rows

You can also get quantiles of predictions with posterior_predict_quantile() function.

resp.pred.quantile <- posterior_predict_quantile(fit.emax, newdata = seq(0, 5000, by = 100))
resp.pred.quantile
#> # A tibble: 51 × 11
#>    exposure covemax covec50 cove0 Covariates ci_low ci_med ci_high pi_low pi_med
#>       <dbl> <fct>   <fct>   <fct> <chr>       <dbl>  <dbl>   <dbl>  <dbl>  <dbl>
#>  1        0 1       1       1     ""          -1.77   5.62    13.8  -23.7   5.61
#>  2      100 1       1       1     ""          52.8   58.9     64.6   30.2  58.6 
#>  3      200 1       1       1     ""          67.7   72.9     77.7   46.7  73.4 
#>  4      300 1       1       1     ""          74.7   79.3     84.0   52.5  78.9 
#>  5      400 1       1       1     ""          78.5   83.2     87.8   55.3  84.0 
#>  6      500 1       1       1     ""          80.8   85.6     90.4   59.5  86.8 
#>  7      600 1       1       1     ""          82.3   87.5     92.4   61.1  87.6 
#>  8      700 1       1       1     ""          83.4   88.8     93.9   61.2  88.6 
#>  9      800 1       1       1     ""          84.3   89.7     95.1   62.0  88.9 
#> 10      900 1       1       1     ""          85.0   90.6     96.0   63.2  91.1 
#> # ℹ 41 more rows
#> # ℹ 1 more variable: pi_high <dbl>

Input data can be obtained in a same format with extract_obs_mod_frame() function.

obs.formatted <- extract_obs_mod_frame(fit.emax)

These are particularly useful when you want to plot the estimated Emax curve.

ggplot(resp.pred.quantile, aes(exposure, ci_med)) +
  geom_line() +
  geom_ribbon(aes(ymin = ci_low, ymax = ci_high), alpha = .5) +
  geom_ribbon(aes(ymin = pi_low, ymax = pi_high), alpha = .2) +
  geom_point(
    data = obs.formatted,
    aes(y = response)
  ) +
  labs(y = "response")

Posterior draws of Emax model parameters can be extracted with extract_param() function.

posterior.fit.emax <- extract_param(fit.emax)
posterior.fit.emax
#> # A tibble: 1,000 × 6
#>    mcmcid  emax    e0  ec50     gamma     sigma
#>     <int> <dbl> <dbl> <dbl> <dbl[1d]> <dbl[1d]>
#>  1      1  89.7  7.85  81.2         1      13.8
#>  2      2  90.3 11.0   75.9         1      17.3
#>  3      3  98.6  2.74  82.7         1      17.7
#>  4      4  88.5  6.86  80.5         1      19.4
#>  5      5  93.3  3.21  62.7         1      18.2
#>  6      6  97.2  5.13  74.0         1      15.6
#>  7      7  91.7 10.7  115.          1      18.1
#>  8      8  91.3  5.96  64.5         1      17.9
#>  9      9  94.2  7.48  98.7         1      16.5
#> 10     10  96.3 10.4  124.          1      17.1
#> # ℹ 990 more rows

Fix parameter values in Emax model

You can fix parameter values in Emax model for Emax, E0 and/or gamma (Hill coefficient). See help of stan_emax() for the details. The default is to fix gamma at 1 and to estimate Emax and E0 from data.

Below is the example of estimating gamma from data.

data(exposure.response.sample)

fit.emax.sigmoidal <- stan_emax(response ~ exposure,
  data = exposure.response.sample,
  gamma.fix = NULL,
  # the next line is only to make the example go fast enough
  chains = 2, iter = 1000, seed = 12345
)
fit.emax.sigmoidal
#> ---- Emax model fit with rstanemax ----
#> 
#>        mean se_mean    sd  2.5%   25%   50%   75%  97.5%  n_eff Rhat
#> emax  90.23    0.70 10.96 72.34 83.01 89.08 95.38 117.81 244.83 1.01
#> e0     6.68    0.24  5.08 -4.46  3.42  7.06  9.91  16.04 458.54 1.00
#> ec50  78.22    2.94 35.63 41.83 59.63 70.96 87.60 153.98 147.35 1.01
#> gamma  1.16    0.02  0.36  0.58  0.93  1.11  1.34   2.02 461.44 1.00
#> sigma 16.76    0.06  1.68 13.97 15.56 16.64 17.76  20.28 725.79 1.01
#> 
#> * Use `extract_stanfit()` function to extract raw stanfit object
#> * Use `extract_param()` function to extract posterior draws of key parameters
#> * Use `plot()` function to visualize model fit
#> * Use `posterior_predict()` or `posterior_predict_quantile()` function to get
#>   raw predictions or make predictions on new data
#> * Use `extract_obs_mod_frame()` function to extract raw data 
#>   in a processed format (useful for plotting)

You can compare the difference of posterior predictions between two models (in this case they are very close to each other):

exposure_pred <- seq(min(exposure.response.sample$exposure),
  max(exposure.response.sample$exposure),
  length.out = 100
)

pred1 <-
  posterior_predict_quantile(fit.emax, exposure_pred) %>%
  mutate(model = "Emax")
pred2 <-
  posterior_predict_quantile(fit.emax.sigmoidal, exposure_pred) %>%
  mutate(model = "Sigmoidal Emax")

pred <- bind_rows(pred1, pred2)


ggplot(pred, aes(exposure, ci_med, color = model, fill = model)) +
  geom_line() +
  geom_ribbon(aes(ymin = ci_low, ymax = ci_high), alpha = .3) +
  geom_ribbon(aes(ymin = pi_low, ymax = pi_high), alpha = .1, color = NA) +
  geom_point(
    data = exposure.response.sample, aes(exposure, response),
    color = "black", fill = NA, size = 2
  ) +
  labs(y = "response")

Set covariates

You can specify categorical covariates for Emax, EC50, and E0. See help of stan_emax() for the details.

data(exposure.response.sample.with.cov)

test.data <-
  mutate(exposure.response.sample.with.cov,
    SEX = ifelse(cov2 == "B0", "MALE", "FEMALE")
  )

fit.cov <- stan_emax(
  formula = resp ~ conc, data = test.data,
  param.cov = list(emax = "SEX"),
  # the next line is only to make the example go fast enough
  chains = 2, iter = 1000, seed = 12345
)
fit.cov
#> ---- Emax model fit with rstanemax ----
#> 
#>                mean se_mean    sd  2.5%   25%    50%    75%  97.5%  n_eff Rhat
#> emax[FEMALE]  81.37    0.16  4.10 73.53 78.51  81.22  83.96  89.68 626.21 1.00
#> emax[MALE]    87.74    0.20  4.94 77.52 84.52  87.77  91.13  97.34 637.35 1.00
#> e0            15.05    0.10  2.37 10.38 13.48  15.09  16.71  19.37 584.32 1.00
#> ec50         106.74    1.01 22.61 66.76 91.48 104.90 119.40 156.34 503.80 1.01
#> gamma          1.00     NaN  0.00  1.00  1.00   1.00   1.00   1.00    NaN  NaN
#> sigma         10.54    0.04  0.97  8.86  9.84  10.48  11.18  12.65 631.19 1.00
#> 
#> * Use `extract_stanfit()` function to extract raw stanfit object
#> * Use `extract_param()` function to extract posterior draws of key parameters
#> * Use `plot()` function to visualize model fit
#> * Use `posterior_predict()` or `posterior_predict_quantile()` function to get
#>   raw predictions or make predictions on new data
#> * Use `extract_obs_mod_frame()` function to extract raw data 
#>   in a processed format (useful for plotting)
plot(fit.cov)

You can extract MCMC samples from raw stanfit and evaluate differences between groups.

fit.cov.posterior <-
  extract_param(fit.cov)

emax.posterior <-
  fit.cov.posterior %>%
  select(mcmcid, SEX, emax) %>%
  tidyr::pivot_wider(names_from = SEX, values_from = emax) %>%
  mutate(delta = FEMALE - MALE)

ggplot2::qplot(delta, data = emax.posterior, bins = 30) +
  ggplot2::labs(x = "emax[FEMALE] - emax[MALE]")
#> Warning: `qplot()` was deprecated in ggplot2 3.4.0.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.

# Credible interval of delta
quantile(emax.posterior$delta, probs = c(0.025, 0.05, 0.5, 0.95, 0.975))
#>       2.5%         5%        50%        95%      97.5% 
#> -15.253614 -14.060222  -6.515276   1.249568   2.790219

# Posterior probability of emax[FEMALE] < emax[MALE]
sum(emax.posterior$delta < 0) / nrow(emax.posterior)
#> [1] 0.921