# multicore
library(parallel)

# Colors
library(RColorBrewer)

# get necessary constants
source('mcgurk_causal_inference_constants.R')

source('mcgurk_causal_inference_functions.R')

# functions for plotting ellipses
source("ellipse.R")

# Don't forget the trailing forward slash
FIG_DIR = "figures/"

pal = brewer.pal(9, "RdBu")
col.c2 = head(pal,1)
col.c1 = tail(pal, 1)
color_palette.pc1 = colorRampPalette(pal, space='Lab')
colors.syllables = c("darkgreen", "purple", "orange")

# SD multiplier for graphing, for 5% bounds
ELLIPSE_SD_MULTIPLIER = 2.453

#
# these helper functions ensure a consistent look+feel
setup.plot = function(mar = rep(0, 4)) {
	par(mar = mar)
	plot.clean(X_LIM, Y_LIM)
}

make.content_plot = function(name, exp, w = 1.15, h = 1.15, col = colors.syllables, draw.syllables=TRUE, test=FALSE) {
	as_pdf(paste0(FIG_DIR, name, ".pdf"), w, h, {
		setup.plot()
		draw.axes()
		eval(exp)
		if(draw.syllables)
			draw.syllables()
	}, TEST=test)
}

draw.axes = function() {
	abline(v = 0, col = "gray80")
	abline(h = 0, col = "gray80")
}

draw.syllables = function() {
	text(syllables.locations, labels=row.names(syllables.locations), col=colors.syllables, cex=0.85)
}

draw.av = function(xa, xv, drawUnisensory=FALSE, drawLines=TRUE, av.cex=0.85, uni.cex=1) {
	av = get_av_integration(xa, xv)
	points(t(av), pch=20, cex=av.cex)

	if(drawLines) {
		lines(rbind(xa, av), col='gray70')
		lines(rbind(xv, av), col='gray70')
	}
	if(drawUnisensory) {
		points(xa, pch='a', cex=uni.cex)
		points(xv, pch='v', cex=uni.cex)
	}
	return(t(av))
}

draw.A_V_AV = function(A, V, cols=rep('gray90', 3), border.cols=NA, lwds=rep(1,3)) {
    if(is.na(border.cols))
        border.cols = c(cols[1:2], 'black')

    av = get_av_integration(A,V)

    mapply(draw.ellipse, list(A, V, av), list(Sa, Sv, Sav), cols, border.cols, lwds)

    return (invisible(av))
}

# aid for drawing points in the representational space
# because there is so much overlap, we have a function and then just switch based
# on the kind of graph we're showing
draw_mcg_points = function(idx, cex=0.55) {
    offsets = c(-.1, 2, -1, -2, 1, -2)
    As = matrix(ba + offsets, nrow = 3, byrow = TRUE)
    Vs = matrix(ga + rev(offsets), nrow = 3, byrow = TRUE)
    AVs = t(sapply(1:3, function(ii) get_av_integration(As[ii,], Vs[ii,])))
    CIs = t(sapply(1:3, function(ii) get_ci_percept_location(list(a=As[ii,], v=Vs[ii,], av=AVs[ii,]))))

    pts = function(M) points(M, pch=paste(1:3))

    parcex = par("cex")
    par(cex = cex)
    switch (idx,
        {
            #1: show unisensory
            pts(As)
            pts(Vs)
        },
        # 2: show multisensory
        pts(AVs),

        # 3: show causal inference
        pts(CIs),

        #4: aonly
        pts(As),

        #5 vonly
        pts(Vs)
    )

    par(cex = parcex)
}


draw.image = function(x, y, z, cols, add=TRUE) {
	image(x, y, t(z)[,nrow(z):1], col=cols, useRaster=TRUE, add=add, axes=F, ylab='', xlab='')
}


#
draw.classification_overlay = function(add=TRUE, res=300*1.15, z=NA, render=TRUE, alpha=100) {
	if(length(z) < 10){
		xs = seq(X_LIM[1], X_LIM[2], length=res)
		ys = seq(Y_LIM[1], Y_LIM[2], length=res)

		cl = makeCluster(parallel::detectCores())
		clusterExport(cl, varlist=list('xs', 'ys'), envir=environment())

		z = parSapply(cl, xs, function(x){
		source('mcgurk_causal_inference_functions.R')
		sapply(rev(ys), function(y) {
			classify_sample(c(x,y), syllables.locations, category.Sigmas, syllables.priors)
			})
		})
	} else {
		xs = seq(X_LIM[1], X_LIM[2], length=nrow(z))
		ys = seq(Y_LIM[1], Y_LIM[2], length=ncol(z))
	}

	if(render)
		draw.image(xs, ys, z, col = getAlphaRGB(colors.syllables, alpha), add=add)

	invisible(z)
}


draw.pc1_overlay = function(add=TRUE, res=300*1.15, z=NA, render=TRUE) {
	if(length(z) < 10){
		xs = seq(X_LIM[1], X_LIM[2], length=res)
		ys = seq(Y_LIM[1], Y_LIM[2], length=res)

		cl = makeCluster(parallel::detectCores())
		clusterExport(cl, varlist=list('xs', 'ys'), envir=environment())

		z = parSapply(cl, xs, function(x){
			source('mcgurk_causal_inference_functions.R')
			sapply(rev(ys), function(y) {
				pC1_given_xav(c(x,y))
			})
		})

		stopCluster(cl)
	} else {
		xs = seq(X_LIM[1], X_LIM[2], length=nrow(z))
		ys = seq(Y_LIM[1], Y_LIM[2], length=ncol(z))
	}
	if(render)
		draw.image(xs, ys, z, col = color_palette.pc1(100), add=add)

	invisible(z)
}

draw.ellipse = function(m, S, col='gray90', border.col='black', lwd=1, alpha=0) {
	fill.ellipse(m, S, sd=ELLIPSE_SD_MULTIPLIER, alpha=alpha, lwd=lwd, col=col, border.col=border.col)
}

draw.ci_percept = function(a, v, av) {
	ciw = get_ci_percept_location(list(a = a, v = v, av = av))

	lines(x = c(av[1], ciw[1]), y = c(av[2], ciw[2]), col = "red")

	points(ciw, pch = 20, lwd = 1, col = "red", cex=0.85)
}

draw.ci.steps = function(a, v, add=TRUE, redrawAxes=TRUE) {
	draw.av(a, v, drawUni=TRUE)

	if(redrawAxes) draw.axes()

	draw.ci_percept(a, v, av=t(get_av_integration(a, v)))
}

make.barplot = function(name, y, sems=NA, cols=colors.syllables) {
	as_pdf(paste0(FIG_DIR, name), w=.77, h=1, {
		par(mar=rep(1/2,4))
		xpos = barplot(y, col = getAlphaRGB(cols, 100), border = NA, ylim = c(0, 1), axes = F)
		draw.axis(2, 0:2/2)
		if(!all(is.na(sems))) {
			ebars.y(xpos, y, sems, col=getAlphaRGB(cols,100))
		}
	})
}


make.pc1_colorbar = function() {
	tiff(paste0(FIG_DIR, 'pc1_colorbar', ".tiff"), w = 0.2, h = 1.25, units = "in", res = 300, compression = "lzw")
		par(mar = rep(0.2, 4))
		image(x = 1:2, y = 1:201, z = matrix(rep(-100:100, each = 2), nrow = 2), col = color_palette.pc1(101), axes = F, zlim=c(-100,100))
		box()
	dev.off()
}


#
# call a bunch of plotters at once
# this function plots the steps of the non-cims model based on a given A and V
plot.non_cims_model = function(prefx, A, V, n=1e4) {
    fname = function(name) paste0(prefx, '_', name)

    make.content_plot(fname('noisy_encoding'), {
        draw.ellipse(A, Sa)
        draw.ellipse(V, Sv)

        if(prefx == 'mcg') draw_mcg_points(1)
    })

    make.content_plot(fname('calculate_av'), {
        av = draw.A_V_AV(A, V, cols=rep('black', 3))
        if(prefx == 'mcg') draw_mcg_points(2)
    })

    make.content_plot(fname('categorize_representation'), {
        draw.classification_overlay(z=classification.img)
        draw.ellipse(get_av_integration(A,V), Sav, border.col='black')

        if(prefx == 'mcg') draw_mcg_points(2)
    })

    ptable = tabulate(apply(get_encoded_exemplar(A, V, samples=n)$av, 1, classify_sample), nbins=n.syllables) / n
    make.barplot(fname('without_ci_responses'), ptable)

    return (ptable)
}

# this function plots the steps of the cims model based on a given A and V
plot.cims_model = function(prefx, A, V, n=1e5) {
    fname = function(name) paste0(prefx, '_', name)

    AV = get_av_integration(A,V)

    make.content_plot(fname('c1_percept'), {
        draw.A_V_AV(A,V, cols=rep('black', 3))
        if(prefx == 'mcg') draw_mcg_points(2)
    })

    make.content_plot(fname('c2_percept'), {
        draw.ellipse(A, Sa)
        draw.ellipse(V, Sv)
        if(prefx == 'mcg') draw_mcg_points(4)
    })

    make.content_plot(fname('pc1_c2'), {
        draw.pc1_overlay(z=pc1.img)
        draw.ellipse(AV, Sav, lwd=1, border.col='black')
        if(prefx=='mcg') draw_mcg_points(2)
        draw.axes()
    }, draw.syllables = FALSE)

    mS = get_ci_mean_Sigma(A, V, n=n)
    make.content_plot(fname('combine_c1_c2'), {
        mapply(draw.ellipse, list(A, AV), list(Sa, Sav))
        draw.ellipse(mS$m, mS$S, border.col='orangered', lwd=3)
        if(prefx=='mcg') draw_mcg_points(3)
    }, draw.syllables=FALSE)

    make.content_plot(fname('classify_ci_representation'), {
        draw.classification_overlay(z=classification.img)
        draw.ellipse(mS$m, mS$S, border.col='orangered', lwd=3)
        if(prefx=='mcg') draw_mcg_points(3)
    })

    ptable = tabulate(apply(mS$xav, 1, classify_sample), nbins=3) / nrow(mS$xav)
    make.barplot(fname('cims_predictions'), ptable)

    return(ptable)
}

# confusion matrices

# color scale for cmats
col.cmat = colorRampPalette(c('white', 'skyblue', 'steelblue', 'midnightblue'))

# make cong and incong conf mats
make.confusion_mat = function(mat, fname=NULL, w=1.4, h=2.58, cong.w=1.4, cong.h=1.4) {
    idx = matrix(1:9, byrow=T, nrow=3)

    cong = c(1, 5, 9)
    inc = (1:9)[-cong]

    # write out the congruents
    as_pdf(file=paste0(FIG_DIR, fname, '_cong'), w=cong.w, h=cong.h, TEST=is.null(fname), {
        par(mar=rep(0.5,4))
        image(1:3,1:3,z=t(mat[cong,])[,3:1], col=col.cmat(4), asp=1, axes=F, zlim=0:1)

        abline(v=0:3 + 0.5)
        abline(h=0:7 + 0.5)
    })
    # write out the incongruents
    as_pdf(file=paste0(FIG_DIR, fname, '_inc'), w=w, h=h, TEST=is.null(fname), {
        par(mar=rep(.5,4))
        image(1:3,1:6,z=t(mat[inc,])[,6:1], col=col.cmat(4), asp=1, axes=F, zlim=0:1)
        #box()
        abline(v=0:3 + 0.5)
        abline(h=0:7 + 0.5)
    })
}

# color bar for confusion matrices
make.cmat_colorbar = function() {
    tiff(paste0(FIG_DIR, 'av_cmat_colorbar', ".tiff"), w = 0.15, h = 0.5, units = "in", res = 300, compression = "lzw")
    par(mar = rep(0.2, 4))
    image(x = 1:2, y = 1:201, z = matrix(rep(-100:100, each = 2), nrow = 2), col = col.cmat(4), axes = F, zlim=c(-100,100))
    box(lwd=0.5)
    dev.off()
}


#
# low level plot functions
#


# pdf wrapper that evaluates an arbitrary expression into a pdf and invisibly returns the result
as_pdf = function(file, w, h, expr, TEST=FALSE, bg='white') {
    if(! TEST) {
        if(!grepl("\\.pdf$", file)) {
            file = paste0(file, ".pdf")
        }
        pdf(file, width=w, height=h, useDingbats=FALSE, bg=bg)
        res = eval(expr)
        dev.off()
    } else {
        res = eval(expr)
    }
    return (invisible(res))
}

# helper to draw axes with labels closer to ticks
draw.axis = function(side, at, tcl=-0.3, labels=at, padj=0.5, adj=1, yline=0.65, xline=0.2, ...) {
    if(length(side) > 1) {
        return (invisible(sapply(side, draw.axis,
            at=at, tcl=tcl, labels=labels, padj=padj, adj=adj, yline=yline, xline=xline, ...)
        ))
    }
    axis(side, at=at, labels=F, tcl=tcl, ...)

    if(side%%2 == 1)	mtext(labels, side=side, at=at, padj=padj, las=1, line=xline)
    if(side%%2 == 0)	mtext(labels, side=side, at=at, adj=adj, las=1, line=yline)

    invisible(at)
}

# error bar function
ebars.y = function(x, y, sem, length = 0.05, up = T, down = T, code = 2, ...) {
    if (up) {
        arrows(x0 = x, y0 = as.numeric(y), y1 = as.numeric(y + sem), angle = 90, code = code, length = length, ...)
    }
    if (down) {
        arrows(x0 = x, y0 = as.numeric(y), y1 = as.numeric(y - sem), angle = 90, code = code, length = length, ...)
    }
}


# specify a color name and an alpha level
getAlphaRGB = function(colname, alpha, max=255) {
    c = col2rgb(colname)
    rgb(t(c), alpha = alpha, maxColorValue = max)
}

# create a plotting area with bounds but no decorations
plot.clean = function(xlim, ylim, x = 1, y = 1, type = "n", xlab="", ylab="", ...) {
    plot(x, y, type = type, axes = F, ylab = ylab, xlab = xlab, xlim = range(xlim), ylim = range(ylim), ...)
}

