## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(ggplot2)
library(dplyr)
library(tibble)
library(purrr)
library(patchwork)
library(masc)

## -----------------------------------------------------------------------------
# Function to simulate overall value effects
simulate_overall_value_effects <- function(n_trials = 200,
                                           n_subjects = 100,
                                           thresholds = c(0.001, 0.1, 0.2)) {

  # Function to generate attribute values with controlled overall value
  generate_controlled_values <- function(n_options = 5) {
    # Generate random values initially
    values <- rnorm(n_options)
    # Calculate and adjust overall value while preserving relative differences
    overall <- mean(values)
    # Randomly shift overall value between -2 and 2
    shift <- runif(1, -4, 4)
    values + shift
  }

  # Simulate trials for each threshold value
  results <- map_dfr(thresholds, function(thresh) {
    map_dfr(1:n_subjects, function(s) {
      trials <- map_dfr(1:n_trials, function(t) {
        # Generate values with controlled overall value
        values <- generate_controlled_values()

        # Create properly formatted data frame for rMASC
        trial_data <- data.frame(
          matrix(values, nrow = 1, ncol = 5)
        )
        names(trial_data) <- paste0("opt", 1:5, "_att1")

        # Run MASC model
        trial_result <- rMASC(
          data = trial_data,
          n_options = 5,
          n_attributes = 1,
          theta = thresh,
          w = 1  # single attribute weight
        )$raw[[1]]

        # Calculate overall value and value difference
        overall_value <- mean(values)
        value_diff <- max(values) - mean(values[-which.max(values)])

        tibble(
          subject = s,
          trial = t,
          threshold = thresh,
          overall_value = overall_value,
          value_diff = value_diff,
          n_fixations = trial_result$rt,
          correct = trial_result$correct
        )
      })
    })
  })

  results
}

# Run simulation
set.seed(2025)
results <- simulate_overall_value_effects(n_trials = 200, n_subjects = 20)

## -----------------------------------------------------------------------------
# Add threshold labels as a factor
results <- results %>%
  mutate(threshold_label = factor(threshold,
                                  levels = c(0.001, 0.1, 0.2),
                                  labels = c("Conservative (θ = 0.001)",
                                             "Moderate (θ = 0.1)",
                                             "Liberal (θ = 0.2)")))

# Add binned overall value and calculate mean consistency per bin
results_binned <- results %>%
  mutate(value_diff = round(value_diff)) %>%
  filter(value_diff == 1) %>%
  mutate(
    correct = as.numeric(correct),  # Convert logical to numeric
    # Create bins for overall value (e.g., 20 bins)
    value_bin = cut(overall_value, breaks = 40)
  ) %>%
  group_by(threshold_label, value_bin) %>%
  summarize(
    mean_consistency = mean(correct),
    mean_value = mean(overall_value),
    n = n(),
    se = sqrt((mean_consistency * (1 - mean_consistency)) / n),
    .groups = "drop"
  )

# Calculate regression statistics for each threshold level
regression_stats <- results %>%
  group_by(threshold, threshold_label) %>%
  summarize(
    # Fit model controlling for value difference
    model = list(lm(n_fixations ~ overall_value + value_diff)),
    # Extract coefficient and p-value
    beta = coef(first(model))["overall_value"],
    p_value = summary(first(model))$coefficients["overall_value", "Pr(>|t|)"],
    # Get x and y positions for text
    x_pos = min(overall_value),
    y_pos = max(n_fixations),
    # Create formatted label
    stat_label = sprintf("β = %.3f\np = %.3e", beta, p_value),
    .groups = "drop"
  )

# Calculate quadratic regression statistics
quad_stats <- results %>%
  mutate(correct = as.numeric(correct)) %>%  # Convert logical to numeric
  group_by(threshold_label) %>%
  summarize(
    # Fit quadratic model
    model = list(lm(correct ~ overall_value + I(overall_value^2) + value_diff)),
    # Extract coefficients and p-value
    beta_linear = coef(first(model))["overall_value"],
    beta_quad = coef(first(model))["I(overall_value^2)"],
    p_value = summary(first(model))$coefficients["overall_value", "Pr(>|t|)"],
    # Get x and y positions for text
    x_pos = min(overall_value),
    y_pos = 0.95,
    # Create formatted label
    stat_label = sprintf("β1 = %.3f\nβ2 = %.3f\np = %.3e",
                         beta_linear, beta_quad, p_value),
    .groups = "drop"
  )

## ----fig.width=14, fig.height=10, out.width="100%"----------------------------
p1 <- ggplot(results, aes(x = overall_value, y = n_fixations)) +
  geom_point(alpha = 0.1, size = 1) +
  geom_smooth(method = "lm",
              formula = y ~ x,
              color = "green",
              se = TRUE) +
  geom_text(data = regression_stats,
            aes(x = x_pos, y = y_pos,
                label = stat_label),
            hjust = 0, vjust = 1,
            size = 3) +
  facet_wrap(~threshold_label) +
  labs(x = "Overall Value",
       y = "Predicted Number of Fixations") +
  theme_classic() +
  theme(
    legend.position = "none",
    panel.grid.minor = element_blank(),
    strip.text = element_text(face = "bold")
  )

p2 <- ggplot(results_binned, aes(x = mean_value, y = mean_consistency)) +
  geom_errorbar(aes(ymin = mean_consistency - se,
                    ymax = mean_consistency + se),
                width = 0.2, alpha = 0.5) +
  geom_point(size = 2) +
  # Add quadratic fit
  geom_smooth(data = results %>% mutate(correct = as.numeric(correct)),
              aes(x = overall_value, y = correct),
              method = "lm",
              formula = y ~ poly(x, 2),
              color = "red",
              se = TRUE) +
  geom_smooth(data = results %>% mutate(correct = as.numeric(correct)),
              aes(x = overall_value, y = correct),
              method = "lm",
              formula = y ~ x,
              color = "blue",
              se = TRUE) +
  geom_text(data = quad_stats,
            aes(x = x_pos, y = y_pos,
                label = stat_label),
            hjust = 0, vjust = 1,
            size = 3) +
  facet_wrap(~threshold_label) +
  labs(x = "Overall Value",
       y = "Choice Consistency",
       caption = "value difference fixed at 1") +
  scale_y_continuous(limits = c(0, 1)) +
  geom_vline(xintercept = 0, linetype = "dashed", color = "grey")+
  theme_classic() +
  theme(
    legend.position = "none",
    panel.grid.minor = element_blank(),
    strip.text = element_text(face = "bold")
  )

# Combine plots
combined_plot <- p1 / p2 +
  plot_layout(heights = c(1, 1)) +
  plot_annotation(
    title = "Overall Value Effects on Fixations and Choice Consistency",
    theme = theme(plot.title = element_text(size = 14, face = "bold"))
  )

print(combined_plot)

