source('basic_helper.R')
require(parallel)

# function composition operator
# use this when you can/are supposed to only pass in one function,
# but really want to pass in >1 without writing a wrapper function
`%o%` <- function(f1, f2) {
    function(...) f1(f2(...))
}

get_which <- function(ll, bools) {
    ll[which(bools)]
}


do_sim <- function(sim_pars, chunk_size=10, n.rep=1, cl=NA) {
    if(n.rep>1) {
        res_ <- replicate(n.rep, do_sim(sim_pars, chunk_size), simplify = FALSE) %>% rbind_list
    } else {
        cat('.')
        pop_df <- get_fusion_df()

        if(any(is.na(cl))){
            cl <- makeForkCluster(detectCores())
            # this must be done to ensure the streams aren't using the same seed
            # else they produce the "same" random sequence
            clusterSetRNGStream(cl, as.integer(as.POSIXct(Sys.time())))

            #whether by crash or normal exit, do our best to cleanup
            on.exit(stopCluster(cl))
        }


        res_ <- t(parReplicate(chunk_size, cl, run_one_sim, sim_pars, pop_df))
    }

    # set the pars as an attribute so that it doesn't affect math on sim_pars
    attr(res_, 'pars') <- sim_pars

    return(res_)
}


# extract the parameters into a more usable form
par_vec <- function(sim_pars) {
    with(sim_pars,
        c(N, n_stimuli, n_trials,
            ifelse(group_type == 'paired', 1, 0),
            delta(0))
    )
}

# summary statistics from the result of do_sim
analyze_sim_results <- function(res, pop_delta) {
    sim_pars <- attr(res, 'pars')
    pvals <- res[,'pval']
    est <- res[,2]

    stopifnot(!identical(pvals, est))

    mean_non_sig <- abs_eff_est <- mean_est <- eff_too_big <-
        non_sig_too_small <- eff_wrong_sign <- eff_inflation <- NA

    if(length(which(pvals < 0.05)) > 0) {
        eff_estimates <- est[which(pvals < 0.05)]
        eff_inflation <- mean(abs(eff_estimates/pop_delta))
        eff_wrong_sign <- mean(sign(pop_delta) != sign(eff_estimates))
        eff_too_big <- mean(eff_estimates > pop_delta)
        mean_est <- mean(eff_estimates)
        abs_eff_est <- mean(abs(eff_estimates))
    }

    if(length(which(pvals >= 0.05)) > 0) {
        ns_eff <- est[which(pvals >= 0.05)]
        mean_non_sig <- mean(ns_eff)
        non_sig_too_small <- mean(ns_eff < pop_delta)
    }

    pow <- mean(pvals < 0.05, na.rm=TRUE)

    n.rep <- nrow(res)

    res <- matrix(c(n.rep, par_vec(sim_pars), pop_delta, pow, mean_est, abs_eff_est, eff_inflation, eff_wrong_sign, eff_too_big, mean_non_sig, non_sig_too_small),
                    nrow=1)
    res %<>% set_colnames(c('n.rep', 'N', 'n_stim', 'n_trials', 'is_paired', 'delta', 'pop_eff',
        'power', 'mean Eff Est', 'abs Eff Est', 'mean_abs_ratio', 'p(wrong_sign)', 'p(sig > pop_eff)', 'mean non sig est', 'p(non-sig < pop_eff)'))

    return(res)
}


do_run_summary <- function(df) {
    fm <- function(x) round(100*x)

    print(mod <- df %$% t.test(pF_hat ~ group, var.equal=TRUE))

    cat('---\n')
    print(df %>% do_aggregate(pF_hat ~ group, fm %o% m_se))

    est_eff <- as.numeric(mod$estimate[2] - mod$estimate[1])

    cat(paste('\nEffect Est B-A:', fm(est_eff), '+/-', fm(diff(mod$conf.int)/2)), '95% CI\n')

    invisible(df)
}


run_one_sim <- function(sim_pars, pop_df, summary=FALSE, analyze=TRUE) {
    df <- with(sim_pars, {
        pop_df %>%
            pick_stimuli(n_stimuli) %>%
            pick_subjects(N) %>%
            make_groups(group_type) %>%
            change_pF(delta, limit) %>%
            sample_pF(n_trials)
    })

    if(summary) return (do_run_summary(df))

    if(analyze) return (analyze(df))

    return (df)
}




parLapply_chunk <- function(cl, X, fun, ..., N_CHUNKS=10) {
    mk_fname <- function(ii) paste0('.__CHUNK__', ii, '.RDS')

    lx <- length(X)

    N_CHUNKS <- min(N_CHUNKS, lx)

    result <- vector('list', lx)

    sz <- floor(lx / N_CHUNKS)

    # handle 1 to N_CHUNKS - 1
    ind1 <- matrix(1:(sz* (N_CHUNKS-1)), nrow=N_CHUNKS-1, byrow = TRUE)

    row_apply_ii(ind1, function(ind, ii) {
        cat('starting chunk ', ii, 'of', N_CHUNKS, '\n')
        res <- parLapply(cl, X[ind], fun)

        saveRDS(res, file=mk_fname(ii), compress=FALSE)
    })

    # process the last chunk (no need to write this one out)
    cat('starting chunk ', N_CHUNKS, 'of', N_CHUNKS, '\n')
    final_ind <- (max(ind1)+1):lx
    result[final_ind] <- parLapply(cl, X[final_ind], fun)

    # read in the other chunks
    for(ii in seq_len(nrow(ind1))) {
        result[ind1[ii,]] = readRDS(file=mk_fname(ii))
        file.remove(mk_fname(ii))
    }

    return (result)
}

len_uniq <- length %o% unique

# take n items from the support of x, by default with replacement
sample_unique <- function(x, n, replace=TRUE) sample(unique(x), n, replace=replace)


# make this a function so we don't put more clutter in the global environment
get_fusion_df <- function() {

    # load n=165 data
    # n=66 with 14 stimuli, n=165 with 9 stimuli (some inbetweeners)
    population <- read.csv(file='us_mcg_only.csv', row.names=NULL)

    # to simplify, grab only the 9 stimuli that have all n=165
    keep_stim <- population %>% column_apply(all %o% not %o% is.na) %>% which
    population <- population[,keep_stim]

    total_subj <- nrow(population)
    total_stimuli <- ncol(population)

    stimuli <- names(population) %>%
        str_replace_all('_', '') %>%
        str_replace_all('.avi', '') %>% abbreviate(5)

    .pop <- data.frame(
        sid=rep(1:total_subj, total_stimuli),
        mcg=population %>% unlist,
        stimulus=rep(stimuli, each=total_subj),
        stringsAsFactors=FALSE
    )

    return (.pop)
}

# Basic steps
#   0. Set N, n_trials, n_stimuli, and delta (note that delta is a FUNCTION of pF, to allow shift/scale/uncertainty etc)
#   1. Pick N subjects w/ replacement
#   2. Pick p stimului w/ replacement
#   3. Obtain p(F) for each subject_i at each stimulus_j => pF_ij
#   4. Assign group [1, N/2] => Group A; [N/2 + 1, N] => Group B        (for paired data, Group B is a copy of Group A)
#   5a. Shift pF for Group B at each stimulus by delta
#   5b. Enforce pF \in (0.025, 0.975)
#   6. For each subject/stimulus pair, sample k trials from binomial centered at pF_ij => pF_hat
#   7. Conduct hypothesis test Group A vs. Group B => pval
#   8. Repeat 6-7 'n' times => pval_vec
#   9. Summarize distribution of pval_vec
#   10. Repeat 0-9 with particular N, p, k, and delta

# 1
pick_subjects <- function(df, N) {

    make_sid <- function(ii) 'subj_' %&% formatC(ii, width=3, flag='0')

    keep <- sample_unique(df$sid, N)

    # this is somewhat involved to handle duplicate subjects being selceted
    sapply_ii(keep, function(sid, ii) {
        a <- df[df$sid == sid,]
        a$sid = make_sid(ii)
        return(a)
    }, simplify = FALSE) %>% rbind_list
}

# 2
pick_stimuli <- function(df, n_stimuli) {
    make_stim_id <- function(ii) 'stim_' %&% formatC(ii, width=3, flag='0')

    # let the user specify a particular stimulus if they want
    if(class(n_stimuli) == 'character') {
        keep <- n_stimuli

        if(all(keep != unique(df$stimulus))) stop("stimulus not found: " %&% n_stimuli)

    } else {
        keep <- sample_unique(df$stimulus, n_stimuli)
    }

    # this is somewhat involved to handle duplicate stimuli being selceted
    sapply_ii(keep, function(stim, ii) {
        a <- df[df$stimulus == stim,]
        a$stimulus = make_stim_id(ii)
        return(a)
    }, simplify = FALSE) %>% rbind_list
}

# 3/4
make_groups <- function(df, group_type=c('paired', 'unpaired')) {
    group_type <- match.arg(group_type)
    switch(group_type,
        unpaired=make_groups.unpaired(df),
        paired=make_groups.paired(df)
    )
}

make_groups.unpaired <- function(df) {
    df$group = 'A'

    ids <- unique(df$sid)
    gB <- sample(ids, length(ids)/2, replace=FALSE)

    df$group[df$sid %in% gB] = 'B'

    class(df) = c('unpaired', class(df))

    return(df)
}

make_groups.paired <- function(df) {
    df2 <- df

    df$group = 'A'
    df2$group = 'B'

    df <- rbind(df, df2)

    class(df) = c('paired', class(df))

    return(df)
}

# 5
change_pF <- function(df, delta, limit) {
    df$pF <- df$mcg

    df$pF[df$group == 'B'] %<>% delta

    df$pF %<>% limit

    return(df)
}

# 6
sample_pF <- function(df, n_trials) {
    df$pF_hat <- rbinom(length(df$pF), n_trials, df$pF) / n_trials

    return(df)
}

# 7
lmer.pval <- function(mod) {
    lmerTest::anova(as(mod, 'merModLmerTest')) %>% extract2('Pr(>F)')
}

analyze <- function(df) UseMethod('analyze')

fix_constancy <- function(x, eps=0.01) {
    if(identical(x, rep(x[1], length(x)))) {
        # cat('.')
        return (runif(length(x), -eps, eps) + x)
    }

    return (x)
}

analyze.unpaired <- function(df) {
    # check for constancy
    df$pF_hat[df$group == 'A'] %<>% fix_constancy
    df$pF_hat[df$group == 'B'] %<>% fix_constancy

    if(length(unique(df$stimulus))>1) {
        # for speed reasons, just collapse across stimuli for now, lmer doesn't
        # change things b/c we're generating data with fixed effects (subject to truncation)
        df <- df %>% do_aggregate(pF_hat ~ group + sid, mean)
    }

    # t.test is faster than lm
    # with() is slightly faster than data=df
    mod <- with(df, t.test(pF_hat ~ group, var.equal=TRUE))

    eff <- (mod$estimate[2] - mod$estimate[1])
    pval <- mod$p.value

    return (c('pval'=pval, 'eff' = eff))
}

analyze.paired <- function(df) {
    stop('not implemented')
}


# sometimes the computer needs a break
do_sleep <- function(seconds){
    cat('\nsleep.')
    Sys.sleep(1)
    while(seconds <- seconds-1) {
        cat('.')
        Sys.sleep(1)
    }
    cat('wake!\n')
}

