#include <R.h>
#include <Rinternals.h>
#include <stdlib.h>
#include <string.h>
#include <zlib.h>
#include <stdio.h>
#include "wzmisc.h"
#include "wzbed.h"
#include "cfile.h"
#include "snames.h"
#include "kstring.h"

static int usage(void) {
  REprintf("\n");
  REprintf("Usage: yame summary [options] <query.cm>\n");
  REprintf("Query should be of format 0,1,2,3, can be a multi-sample set.\n");
  REprintf("\n");
  REprintf("Options:\n");
  REprintf("    -m        Mask feature (.cx) file, can be multi-sample.\n");
  REprintf("              If '-', the whole sample will bed kept in memory, same as -M.\n");
  REprintf("    -M        All masks will be loaded to memory. This save disk IO.\n");
  REprintf("    -u        Optional universe set as a .cx file. If given, the masks and queries are both subset.\n");
  REprintf("    -H        Suppress header printing.\n");
  REprintf("    -q        The backup query file name if the query file name is '-'.\n");
  REprintf("    -F        Use full feature/query file name instead of base name.\n");
  REprintf("    -T        State features always show section names.\n");
  REprintf("    -s        Sample list provided to override the query index file. Only applies to the first query.\n");
  REprintf("    -h        This help.\n");
  REprintf("\n");

  return 1;
}

typedef struct stats_t {
  uint64_t sum_depth;           // sum of depth
  double sum_beta;
  uint64_t n_u;                 // universe
  uint64_t n_q;                 // query
  uint64_t n_m;                 // mask
  uint64_t n_o;                 // overlap
  char* sm;                     // mask name
  char* sq;                     // query name
} stats_t;

typedef struct config_t {
  int full_name;
  int section_name;
  int in_memory;
  int no_header;
  char *fname_mask;
  char *fname_snames;
  char *fname_qry_stdin;
} config_t;

static stats_t* summarize1_queryfmt0(
  cdata_t *c, cdata_t *c_mask, uint64_t *n_st, char *sm, char *sq, config_t *config) {

  stats_t *st = NULL;
  if (c_mask->n == 0) {          // no mask
    
    *n_st = 1;
    st = calloc(1, sizeof(stats_t));
    st[0].n_u = c->n;
    st[0].n_q = bit_count(c[0]);
    st[0].sm = strdup(sm);
    st[0].sq = strdup(sq);
    
  } else if (c_mask->fmt <= '1') { // binary mask

    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }
    
    *n_st = 1;
    st = calloc(1, sizeof(stats_t));
    st[0].n_u = c->n;
    st[0].n_q = bit_count(c[0]);
    st[0].n_m = bit_count(c_mask[0]);
    cdata_t tmp = {0};
    tmp.s = malloc((c->n>>3)+1); tmp.n = c->n;
    memcpy(tmp.s, c->s, (c->n>>3)+1);
    for (uint64_t i=0; i<(tmp.n>>3)+1; ++i) tmp.s[i] &= c_mask->s[i];
    st[0].n_o = bit_count(tmp);
    free(tmp.s);
    st[0].sm = strdup(sm);
    st[0].sq = strdup(sq);

  } else if (c_mask->fmt == '6') { // binary mask with universe

    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }
    
    *n_st = 1;
    stats_t st1 = {0};
    for (uint64_t i=0; i<c->n; ++i) {
      if (FMT6_IN_UNI(*c_mask, i)) {
        st1.n_u++;
        int in_q = FMT0_IN_SET(*c, i);
        int in_m = FMT6_IN_SET(*c_mask, i);
        if (in_q) st1.n_q++;
        if (in_m) st1.n_m++;
        if (in_q && in_m) st1.n_o++;
      }
    }
    st = calloc(1, sizeof(stats_t));
    st[0] = st1;
    st[0].sm = strdup(sm);
    st[0].sq = strdup(sq);

  } else if (c_mask->fmt == '2') { // state mask

    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }
    if (!c_mask->aux) fmt2_set_aux(c_mask);
    f2_aux_t *aux = (f2_aux_t*) c_mask->aux;
    *n_st = aux->nk;
    st = calloc((*n_st), sizeof(stats_t));
    uint64_t nq=0;
    for (uint64_t i=0; i<c->n; ++i) {
      uint64_t index = f2_get_uint64(c_mask, i);
      if (index >= (*n_st)) {
        REprintf("[%s:%d] State data is corrupted.\n", __func__, __LINE__);
        error("Abort.");
      }
      if (FMT0_IN_SET(*c, i)) {
        st[index].n_o++;
        nq++;
      }
      st[index].n_m++;
    }
    for (uint64_t k=0; k < (*n_st); ++k) {
      st[k].n_q = nq;
      st[k].n_u = c->n;
      if (config->section_name) {
        kstring_t tmp = {0};
        ksprintf(&tmp, "%s-%s", sm, aux->keys[k]);
        st[k].sm = tmp.s;
      } else {
        st[k].sm = strdup(aux->keys[k]);
      }
      st[k].sq = strdup(sq);
    }
    
  } else {                      // other masks
    REprintf("[%s:%d] Mask format %c unsupported.\n", __func__, __LINE__, c_mask->fmt);
    error("Abort.");
  }
  return st;
}

static stats_t* summarize1_queryfmt2(
  cdata_t *c, cdata_t *c_mask, uint64_t *n_st, char *sm, char *sq, config_t *config) {

  stats_t *st = NULL;
  if (c_mask->n == 0) {          // no mask
    
    if (!c->aux) fmt2_set_aux(c);
    f2_aux_t *aux = (f2_aux_t*) c->aux;
    *n_st = aux->nk;
    uint64_t *cnts = calloc(aux->nk, sizeof(uint64_t));
    for (uint64_t i=0; i<c->n; ++i) cnts[f2_get_uint64(c, i)]++;
    st = calloc(aux->nk, sizeof(stats_t));
    for (uint64_t k=0; k<aux->nk; ++k) {
      st[k].n_u = c->n;
      st[k].n_q = cnts[k];
      st[k].n_m = 0;
      st[k].n_o = 0;
      st[k].sm = strdup(sm);
      if (config->section_name) {
        kstring_t tmp = {0};
        ksprintf(&tmp, "%s-%s", sq, aux->keys[k]);
        st[k].sq = tmp.s;
      } else {
        st[k].sq = strdup(aux->keys[k]);
      }
    }
    free(cnts);
    
  } else if (c_mask->fmt <= '1') { // binary mask

    if (!c->aux) fmt2_set_aux(c);
    f2_aux_t *aux = (f2_aux_t*) c->aux;
    *n_st = aux->nk;
    uint64_t *cnts = calloc(aux->nk, sizeof(uint64_t));
    uint64_t *cnts_q = calloc(aux->nk, sizeof(uint64_t));
    uint64_t n_m = 0;
    for (uint64_t i=0; i<c->n; ++i) {
      if (FMT0_IN_SET(*c_mask, i)) {
        n_m++;
        cnts[f2_get_uint64(c, i)]++;
      }
      cnts_q[f2_get_uint64(c, i)]++;
    }
    st = calloc(aux->nk, sizeof(stats_t));
    for (uint64_t k=0; k<aux->nk; ++k) {
      st[k].n_u = c->n;
      st[k].n_q = cnts_q[k];
      st[k].n_o = cnts[k];
      st[k].n_m = n_m;
      st[k].sm = strdup(sm);
      kstring_t tmp = {0};
      ksprintf(&tmp, "%s-%s", sq, aux->keys[k]);
      st[k].sq = tmp.s;
    }
    free(cnts);

  } else if (c_mask->fmt == '6') { // binary mask with universe

    if (!c->aux) fmt2_set_aux(c);
    f2_aux_t *aux = (f2_aux_t*) c->aux;
    *n_st = aux->nk;
    uint64_t *cnts = calloc(aux->nk, sizeof(uint64_t));
    uint64_t *cnts_q = calloc(aux->nk, sizeof(uint64_t));
    uint64_t n_m = 0;
    for (uint64_t i=0; i<c->n; ++i) {
      if (FMT6_IN_UNI(*c_mask,i) && FMT6_IN_SET(*c_mask, i)) {
        n_m++;
        cnts[f2_get_uint64(c, i)]++;
      }
      cnts_q[f2_get_uint64(c, i)]++;
    }
    st = calloc(aux->nk, sizeof(stats_t));
    for (uint64_t k=0; k<aux->nk; ++k) {
      st[k].n_u = c->n;
      st[k].n_q = cnts_q[k];
      st[k].n_o = cnts[k];
      st[k].n_m = n_m;
      st[k].sm = strdup(sm);
      kstring_t tmp = {0};
      ksprintf(&tmp, "%s-%s", sq, aux->keys[k]);
      st[k].sq = tmp.s;
    }
    free(cnts);
    
  } else if (c_mask->fmt == '2') { // state mask

    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }

    if (!c_mask->aux) fmt2_set_aux(c_mask);
    f2_aux_t *aux_m = (f2_aux_t*) c_mask->aux;

    if (!c->aux) fmt2_set_aux(c);
    f2_aux_t *aux_q = (f2_aux_t*) c->aux;

    *n_st = aux_m->nk * aux_q->nk;
    st = calloc((*n_st), sizeof(stats_t));
    uint64_t *nq = calloc(aux_q->nk, sizeof(uint64_t));
    uint64_t *nm = calloc(aux_m->nk, sizeof(uint64_t));
    for (uint64_t i=0; i<c->n; ++i) {
      uint64_t im = f2_get_uint64(c_mask, i);
      uint64_t iq = f2_get_uint64(c, i);
      st[im * aux_q->nk + iq].n_o++;
      nq[iq]++;
      nm[im]++;
    }
    for (uint64_t im=0; im<aux_m->nk; ++im) {
      for (uint64_t iq=0; iq<aux_q->nk; ++iq) {
        stats_t *st1 = &st[im * aux_q->nk + iq];
        st1->n_o++;
        st1->n_u = c->n;
        st1->n_q = nq[iq];
        st1->n_m = nm[im];
        if (config->section_name) {
          kstring_t tmp = {0};
          ksprintf(&tmp, "%s-%s", sm, aux_m->keys[im]);
          st1->sm = tmp.s;
        } else {
          st1->sm = strdup(aux_m->keys[im]);
        }
        if (config->section_name) {
          kstring_t tmp = {0};
          ksprintf(&tmp, "%s-%s", sq, aux_q->keys[iq]);
          st1->sq = tmp.s;
        } else {
          st1->sq = strdup(aux_q->keys[iq]);
        }
      }
    }
    free(nq); free(nm);
    
  } else {                      // other masks
    REprintf("[%s:%d] Mask format %c unsupported.\n", __func__, __LINE__, c_mask->fmt);
    error("Abort.");
  }
  return st;
}

static stats_t* summarize1_queryfmt3(
  cdata_t *c, cdata_t *c_mask, uint64_t *n_st, char *sm, char *sq, config_t *config) {

  stats_t *st = NULL;
  if (c_mask->n == 0) {            // no mask
    
    *n_st = 1;
    st = calloc(1, sizeof(stats_t));
    st[0].n_u = c->n;
    for (uint64_t i=0; i<c->n; ++i) {
      uint64_t mu = f3_get_mu(c, i);
      if (mu) {
        st[0].sum_depth += MU2cov(mu);
        st[0].sum_beta += MU2beta(mu);
        st[0].n_o++;
        st[0].n_q++;
      }}
    st[0].sm = strdup(sm);
    st[0].sq = strdup(sq);
    
  } else if (c_mask->fmt <= '1') { // binary mask
    
    *n_st = 1;
    st = calloc(1, sizeof(stats_t));
    st[0].n_u = c->n;
    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }
    for (uint64_t i=0; i<c->n; ++i) {
      uint64_t mu = f3_get_mu(c, i);
      if (mu) st[0].n_q++;
      if (FMT0_IN_SET(*c_mask, i)) {
        st[0].n_m++;
        if (mu) {
          st[0].sum_depth += MU2cov(mu);
          st[0].sum_beta += MU2beta(mu);
          st[0].n_o++;
        }}}
    st[0].sm = strdup(sm);
    st[0].sq = strdup(sq);

  } else if (c_mask->fmt == '6') { // binary mask with universe
    
    *n_st = 1;
    st = calloc(1, sizeof(stats_t));
    st[0].n_u = c->n;
    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }
    for (uint64_t i=0; i<c->n; ++i) {
      uint64_t mu = f3_get_mu(c, i);
      if (mu) st[0].n_q++;
      if (FMT6_IN_UNI(*c_mask, i) && FMT6_IN_SET(*c_mask, i)) {
        st[0].n_m++;
        if (mu) {
          st[0].sum_depth += MU2cov(mu);
          st[0].sum_beta += MU2beta(mu);
          st[0].n_o++;
        }}}
    st[0].sm = strdup(sm);
    st[0].sq = strdup(sq);
    
  } else if (c_mask->fmt == '2') { // state mask
    
    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }
    if (!c_mask->aux) fmt2_set_aux(c_mask);
    f2_aux_t *aux = (f2_aux_t*) c_mask->aux;
    *n_st = aux->nk;
    st = calloc((*n_st), sizeof(stats_t));
    uint64_t nq=0;
    for (uint64_t i=0; i<c->n; ++i) {
      uint64_t index = f2_get_uint64(c_mask, i);
      uint64_t mu = f3_get_mu(c, i);
      if (index >= (*n_st)) {
        REprintf("[%s:%d] State data is corrupted.\n", __func__, __LINE__);
        error("Abort.");
      }
      if (mu) {
        st[index].sum_depth += MU2cov(mu);
        st[index].sum_beta += MU2beta(mu);
        st[index].n_o++;
        nq++;
      }
      st[index].n_m++;
    }
    for (uint64_t k=0; k < (*n_st); ++k) {
      st[k].n_q = nq;
      st[k].n_u = c->n;
      if (config->section_name) {
        kstring_t tmp = {0};
        ksprintf(&tmp, "%s-%s", sm, aux->keys[k]);
        st[k].sm = tmp.s;
      } else {
        st[k].sm = strdup(aux->keys[k]);
      }
      st[k].sq = strdup(sq);
    }
    
  } else {                      // other masks
    REprintf("[%s:%d] Mask format %c unsupported.\n", __func__, __LINE__, c_mask->fmt);
    error("Abort.");
  }
  return st;
}

static stats_t* summarize1_queryfmt6(
  cdata_t *c, cdata_t *c_mask, uint64_t *n_st, char *sm, char *sq, config_t *config) {

  stats_t *st = NULL;
  if (c_mask->n == 0) {          // no mask
    
    *n_st = 1;
    st = calloc(1, sizeof(stats_t));
    for (uint64_t i=0; i<c->n; ++i) {
      if (FMT6_IN_UNI(*c,i)) {
        st[0].n_u++;
        if (FMT6_IN_SET(*c,i)) {
          st[0].n_q++;
        }
      }
    }
    st[0].sm = strdup(sm);
    st[0].sq = strdup(sq);
    
  } else if (c_mask->fmt <= '1') { // binary mask

    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }
    
    *n_st = 1;
    st = calloc(1, sizeof(stats_t));
    for (size_t i=0; i<c->n; ++i) {
      if (FMT6_IN_UNI(*c,i)) {
        st[0].n_u++;
        int in_q = FMT6_IN_SET(*c,i);
        int in_m = FMT0_IN_SET(*c_mask,i);
        if (in_q) st[0].n_q++;
        if (in_m) st[0].n_m++;
        if (in_q && in_m) st[0].n_o++;
      }
    }
    st[0].sm = strdup(sm);
    st[0].sq = strdup(sq);

  } else if (c_mask->fmt == '6') { // binary mask with universe

    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }
    
    *n_st = 1;
    st = calloc(1, sizeof(stats_t));
    for (size_t i=0; i<c->n; ++i) {
      if (FMT6_IN_UNI(*c,i) && FMT6_IN_UNI(*c_mask, i)) {
        st[0].n_u++;
        int in_q = FMT6_IN_SET(*c,i);
        int in_m = FMT6_IN_SET(*c_mask,i);
        if (in_q) st[0].n_q++;
        if (in_m) st[0].n_m++;
        if (in_q && in_m) st[0].n_o++;
      }
    }
    st[0].sm = strdup(sm);
    st[0].sq = strdup(sq);

  } else if (c_mask->fmt == '2') { // state mask

    if (c_mask->n != c->n) {
      REprintf("[%s:%d] mask (N=%"PRIu64") and query (N=%"PRIu64") are of different lengths.\n", __func__, __LINE__, c_mask->n, c->n);
      error("Abort.");
    }
    if (!c_mask->aux) fmt2_set_aux(c_mask);
    f2_aux_t *aux = (f2_aux_t*) c_mask->aux;
    *n_st = aux->nk;
    st = calloc((*n_st), sizeof(stats_t));
    uint64_t nq = 0, nu = 0;
    for (uint64_t i=0; i<c->n; ++i) {
      uint64_t index = f2_get_uint64(c_mask, i);
      if (index >= (*n_st)) {
        REprintf("[%s:%d] State data is corrupted.\n", __func__, __LINE__);
        error("Abort.");
      }
      if (FMT6_IN_UNI(*c,i)) {
        nu++;
        if (FMT6_IN_SET(*c,i)) {
          nq++;
          st[index].n_o++;
        }
        st[index].n_m++;
      }
    }
    for (uint64_t k=0; k < (*n_st); ++k) {
      st[k].n_q = nq;
      st[k].n_u = nu;
      if (config->section_name) {
        kstring_t tmp = {0};
        ksprintf(&tmp, "%s-%s", sm, aux->keys[k]);
        st[k].sm = tmp.s;
      } else {
        st[k].sm = strdup(aux->keys[k]);
      }
      st[k].sq = strdup(sq);
    }
    
  } else {                      // other masks
    REprintf("[%s:%d] Mask format %c unsupported.\n", __func__, __LINE__, c_mask->fmt);
    error("Abort.");
  }
  return st;
}

static stats_t* summarize1(cdata_t *c, cdata_t *c_mask, uint64_t *n_st, char *sm, char *sq, config_t *config) {
  if (c->fmt == '3') {
    return summarize1_queryfmt3(c, c_mask, n_st, sm, sq, config);
  } else if (c->fmt == '2') {
    return summarize1_queryfmt2(c, c_mask, n_st, sm, sq, config);
  } else if (c->fmt == '0' || c->fmt == '1') {
    return summarize1_queryfmt0(c, c_mask, n_st, sm, sq, config);
  } else if (c->fmt == '6') {
    return summarize1_queryfmt6(c, c_mask, n_st, sm, sq, config);
  } else {
    REprintf("[%s:%d] Query format %c unsupported.\n", __func__, __LINE__, c->fmt);
    error("Abort.");
  }
}

static void format_stats_and_clean(
    stats_t *st, uint64_t n_st, const char *fname_qry, config_t *config, FILE *out_fp) {
  
  const char *fmask = "NA";
  if (!config->full_name) fname_qry = get_basename(fname_qry);
  
  for (uint64_t i = 0; i < n_st; ++i) {
    stats_t s = st[i];
    char odds_ratio[20] = "NA";
    
    if (config->fname_mask) {
      double n_mm = s.n_u - s.n_q - s.n_m + s.n_o;
      double n_mp = s.n_q - s.n_o;
      double n_pm = s.n_m - s.n_o;
      
      if (n_mp * n_pm > 0) {
        snprintf(odds_ratio, sizeof(odds_ratio), "%1.2f", log2(n_mm * s.n_o / (n_mp * n_pm)));
      }
      
      fmask = config->full_name ? config->fname_mask : get_basename(config->fname_mask);
    }
    
    fprintf(out_fp, "%s\t%s\t%s\t%s\t%" PRIu64 "\t%" PRIu64 "\t%" PRIu64 "\t%" PRIu64 "\t%s",
            fname_qry, s.sq, fmask, s.sm, s.n_u, s.n_q, s.n_m, s.n_o, odds_ratio);
    
    if (s.n_o) {
      fprintf(out_fp, "\t%1.3f", s.sum_beta / s.n_o);
    } else {
      fprintf(out_fp, "\tNA");
    }
    
    if (s.sum_depth) {
      fprintf(out_fp, "\t%1.1f", (double)s.sum_depth / (s.n_m ? s.n_m : s.n_u));
    } else {
      fprintf(out_fp, "\tNA");
    }
    
    fprintf(out_fp, "\n");
  }
  
  for (uint64_t i = 0; i < n_st; ++i) {
    free(st[i].sm);
    free(st[i].sq);
  }
  free(st);
}

  static void prepare_mask(cdata_t *c) {
  if (c->fmt < '2') {
    convertToFmt0(c);
  } else {
    decompress2(c);
  }
}

void main_summary1(
    char *fname_qry, config_t config, cfile_t cf_mask, cdata_t *c_masks, 
    uint64_t c_masks_n, snames_t snames_mask, FILE *out_fp) {
  
  // Open and validate query file
  cfile_t cf_qry = open_cfile(fname_qry);
  if (!cf_qry.fh) {
    REprintf("Error: Failed to open query file: %s\n", fname_qry);
    return;
  }
  
  // Load sample names
  snames_t snames_qry = {0};
  if (config.fname_snames) {
    snames_qry = loadSampleNames(config.fname_snames, 1);
  } else {
    snames_qry = loadSampleNamesFromIndex(fname_qry);
  }
  
  for (uint64_t kq = 0; ; ++kq) {
    cdata_t c_qry = read_cdata1(&cf_qry);
    if (c_qry.n == 0) break;
    
    if (c_qry.fmt == '7') { // Skip format 7
      free_cdata(&c_qry); 
      c_qry.s = NULL;
      continue;
    }
    
    kstring_t sq = {0};
    if (snames_qry.n) kputs(snames_qry.s[kq], &sq);
    else ksprintf(&sq, "%"PRIu64"", kq + 1);
    
    prepare_mask(&c_qry);
    
    // Ensure mask file exists before processing
    if (config.fname_mask) {   
      
      if (c_masks_n) {  // Mask is loaded in memory
        for (uint64_t km = 0; km < c_masks_n; ++km) {
          cdata_t c_mask = c_masks[km];
          
          kstring_t sm = {0};
          if (snames_mask.n) kputs(snames_mask.s[km], &sm);
          else ksprintf(&sm, "%"PRIu64"", km + 1);
          
          uint64_t n_st = 0;
          stats_t *st = summarize1(&c_qry, &c_mask, &n_st, sm.s, sq.s, &config);
          
          format_stats_and_clean(st, n_st, fname_qry, &config, out_fp);
          free(sm.s);
        }
      } else {  // Seekable mask file
        if (bgzf_seek(cf_mask.fh, 0, SEEK_SET) != 0) {
          REprintf("Error: Cannot seek mask file: %s\n", config.fname_mask);
          return;
        }
        
        for (uint64_t km = 0;; ++km) {
          cdata_t c_mask = read_cdata1(&cf_mask);
          if (c_mask.n == 0) break;
          
          prepare_mask(&c_mask);
          
          kstring_t sm = {0};
          if (snames_mask.n) kputs(snames_mask.s[km], &sm);
          else ksprintf(&sm, "%"PRIu64"", km + 1);
          
          uint64_t n_st = 0;
          stats_t *st = summarize1(&c_qry, &c_mask, &n_st, sm.s, sq.s, &config);
          
          format_stats_and_clean(st, n_st, fname_qry, &config, out_fp);
          free(sm.s);
          free_cdata(&c_mask);
        }
      }
    } else {  // No mask file provided
      
      kstring_t sm = {0}; 
      cdata_t c_mask = {0};
      kputs("global", &sm);
      
      uint64_t n_st = 0;
      stats_t *st = summarize1(&c_qry, &c_mask, &n_st, sm.s, sq.s, &config);
      
      format_stats_and_clean(st, n_st, fname_qry, &config, out_fp);
      free(sm.s);
    }
    
    free(sq.s);
    free_cdata(&c_qry); 
    c_qry.s = NULL;
  }
  
  if (c_masks_n) {
    for (uint64_t i = 0; i < c_masks_n; ++i) free_cdata(&c_masks[i]);
    free(c_masks);
  }
  
  bgzf_close(cf_qry.fh);
  cleanSampleNames2(snames_qry);
}



/* The design, first 10 bytes are uint64_t (length) + uint16_t (0=vec; 1=rle) */
int main_summary(int argc, char *argv[]) {
  int c;
  config_t config = {0};
  char *output_file = NULL;
  // FILE *out_fp = stdout; // Default to stdout
  FILE *out_fp = NULL;
  while ((c = getopt(argc, argv, "m:u:MHFTs:q:o:h")) >= 0) {
    switch (c) {
    case 'm': config.fname_mask = strdup(optarg); break;
    case 'M': config.in_memory = 1; break;
    case 'H': config.no_header = 1; break;
    case 'F': config.full_name = 1; break;
    case 'T': config.section_name = 1; break;
    case 's': config.fname_snames = strdup(optarg); break;
    case 'q': config.fname_qry_stdin = optarg; break;
    case 'o': output_file = strdup(optarg); break;  // Output file option
    case 'h': return usage();
    default:
      usage();
    wzfatal("Unrecognized option: %c.\n", c);
    }
  }
  
  if (optind + 1 > argc) { 
    usage(); 
    wzfatal("Please supply input file.\n"); 
  }
  
  if (output_file) {
    out_fp = fopen(output_file, "w");
    if (!out_fp) {
      wzfatal("Error: Cannot open output file %s for writing.\n", output_file);
    }
  }
  
  if (output_file) {
    out_fp = fopen(output_file, "w");
    if (!out_fp) wzfatal("Error: Cannot open output file %s for writing.\n", output_file);
      } else {
    #ifndef BUILDING_R
    out_fp = stdout;
    #else
    wzfatal("Output to stdout is disabled in the R build. Please supply -o <file>.");
    #endif
  }

  cfile_t cf_mask;
  int unseekable = 0;
  snames_t snames_mask = {0};
  cdata_t *c_masks = NULL;
  uint64_t c_masks_n = 0;
  
  if (config.fname_mask) {
    cf_mask = open_cfile(config.fname_mask);
    unseekable = bgzf_seek(cf_mask.fh, 0, SEEK_SET);
    snames_mask = loadSampleNamesFromIndex(config.fname_mask);
  }
  
  if (config.in_memory || unseekable) { /* Load masks into memory */
c_masks = calloc(1, sizeof(cdata_t));
    c_masks_n = 0;
    for (;; ++c_masks_n) {
      cdata_t c_mask = read_cdata1(&cf_mask);
      if (c_mask.n == 0) break;
      prepare_mask(&c_mask);
      c_masks = realloc(c_masks, (c_masks_n + 1) * sizeof(cdata_t));
      c_masks[c_masks_n] = c_mask;
    }
  }
  
  /* Write header if needed */
  if (!config.no_header) {
    fprintf(out_fp, "QFile\tQuery\tMFile\tMask\tN_univ\tN_query\tN_mask\tN_overlap\tLog2OddsRatio\tBeta\tDepth\n");
  }
  
  for (int j = optind; j < argc; ++j) {
    main_summary1(argv[j], config, cf_mask, c_masks, c_masks_n, snames_mask, out_fp);
  }
  
  /* Clean up memory */
  if (config.fname_snames) free(config.fname_snames);
  if (config.fname_mask) {
    bgzf_close(cf_mask.fh);
    free(config.fname_mask);
  }
  cleanSampleNames2(snames_mask);
  
  /* Free mask memory */
  if (c_masks_n) {
    for (uint64_t i = 0; i < c_masks_n; ++i) free_cdata(&c_masks[i]);
    free(c_masks);
  }
  
  /* Close output file if it was used */
  if (output_file) {
    fclose(out_fp);
    free(output_file);
  }
  
  return 0;
}


SEXP yame_summary_cfunc(SEXP str1, SEXP str2, SEXP str3) {
  const char *fname_qry = CHAR(STRING_ELT(str1, 0));
  const char *fname_mask = CHAR(STRING_ELT(str2, 0));
  const char *fname_output = CHAR(STRING_ELT(str3, 0));
  
  // Debugging info
  REprintf("Query file = %s\n", fname_qry);
  REprintf("Mask file = %s\n", fname_mask);
  
  // Open the output file for writing
  FILE *fp = fopen(fname_output, "w");
  if (fp == NULL) {
    REprintf("Error: Cannot open output file: %s\n", fname_output);
    return R_NilValue;
  }
  
  // Print header line
  fprintf(fp, "QFile\tQuery\tMFile\tMask\tN_univ\tN_query\tN_mask\tN_overlap\tLog2OddsRatio\tBeta\tDepth\n");
  
  // Initialize config
  config_t config = {0};
  config.fname_mask = strdup(fname_mask);
  
  // Load mask file if provided
  cfile_t cf_mask;
  snames_t snames_mask = {0};
  cdata_t *c_masks = NULL;
  uint64_t c_masks_n = 0;
  int unseekable = 0;
  
  if (config.fname_mask) {
    cf_mask = open_cfile(config.fname_mask);
    unseekable = bgzf_seek(cf_mask.fh, 0, SEEK_SET);
    snames_mask = loadSampleNamesFromIndex(config.fname_mask);
  }
  
  if (config.in_memory || unseekable) {
    c_masks = calloc(1, sizeof(cdata_t));
    c_masks_n = 0;
    for (;; ++c_masks_n) {
      cdata_t c_mask = read_cdata1(&cf_mask);
      if (c_mask.n == 0) break;
      prepare_mask(&c_mask);
      c_masks = realloc(c_masks, (c_masks_n + 1) * sizeof(cdata_t));
      c_masks[c_masks_n] = c_mask;
    }
  }
  
  // Call main_summary1() directly instead of `system("yame summary ...")`
  main_summary1((char *)fname_qry, config, cf_mask, c_masks, c_masks_n, snames_mask, fp);
  
  // Cleanup
  fclose(fp);
  if (config.fname_mask) free(config.fname_mask);
  if (config.fname_mask) bgzf_close(cf_mask.fh);
  cleanSampleNames2(snames_mask);
  if (c_masks_n) {
    for (uint64_t i = 0; i < c_masks_n; ++i) free_cdata(&c_masks[i]);
    free(c_masks);
  }
  return R_NilValue;
}
