# libraries used
library(parallel)
library(MASS)
library(mvtnorm)

# constants needed
source('mcgurk_causal_inference_constants.R')

#
# Function definitions
#

llr_to_pc1 = function(llr) {
	1 / (1 + exp(-llr))
}


llr_to_pc2 = function(llr) {
	1 / (1 + exp(llr))
}

llr_given_xav = function(xav) {

	noisy.sigmas = lapply(category.Sigmas, function(cs) {
		# cs + solve(solve(Sa) + solve(Sv))
		cs + (1 / (1/Sa + 1/Sv))
	})
	dnormC1 = mean(sapply(seq_len(n.syllables), function(ii) dmvnorm(xav, syllables.locations[ii, ], noisy.sigmas[[ii]])))

	pairs = c(1, 2, 1, 3, 2, 3)
	pairs = matrix(c(pairs, rev(pairs)), ncol = 2, byrow = T)

	dnormC2 = mean(apply(pairs, 1, function(ij) {
		c2_SigA = category.Sigmas[[ij[1]]] + Sa
		c2_SigV = category.Sigmas[[ij[2]]] + Sv
		c2_Sav = 1/(1/(c2_SigA) + 1/(c2_SigV))

		muAV = c2_Sav %*% t(syllables.locations[ij[1], ] %*% solve(c2_SigA) + syllables.locations[ij[2], ] %*% solve(c2_SigV))

		dmvnorm(xav, mean = muAV, c2_Sav)
	}))

	llr = log(dnormC1) - log(dnormC2)

	return(llr)
}

logdnormC1 = function(xav) {
	noisy.sigmas = lapply(category.Sigmas, function(cs) {
		cs + (1 / (1/Sa + 1/Sv))
	})
	dnormC1 = mean(sapply(seq_len(n.syllables), function(ii) dmvnorm(xav, syllables.locations[ii, ], noisy.sigmas[[ii]])))

	return(log(dnormC1))
}

logdnormC2 = function(xav) {
	pairs = c(1, 2, 1, 3, 2, 3)
	pairs = matrix(c(pairs, rev(pairs)), ncol = 2, byrow = T)

	dnormC2 = mean(apply(pairs, 1, function(ij) {
		c2_SigA = category.Sigmas[[ij[1]]] + Sa
		c2_SigV = category.Sigmas[[ij[2]]] + Sv
		c2_Sav = 1/(1/(c2_SigA) + 1/(c2_SigV))

		muAV = c2_Sav %*% t(syllables.locations[ij[1], ] %*% solve(c2_SigA) + syllables.locations[ij[2], ] %*% solve(c2_SigV))

		dmvnorm(xav, mean = muAV, c2_Sav)
	}))

	return(log(dnormC2))
}



# just call both functions, add the results, and crank into probability
pC1_given_xav = function(xav, priorC1 = Pcommon) {
	d = log(priorC1 / (1 - priorC1)) + llr_given_xav(xav)
	llr_to_pc1(d)
}

#
get_ci_percept_location = function(A_V_AV) {
	pc1 = pC1_given_xav(A_V_AV$av)

	c1w = A_V_AV$av
	c2w = A_V_AV$a

	return (pc1 * c1w + (1-pc1) * c2w)
}

#
get_exemplar = function(word.i = NA, n = 1) {
	if (is.na(word.i))
		word.i = sample(seq_len(n.syllables), 1, prob = syllables.priors)

	mvrnorm(n, syllables.locations[word.i, ], category.Sigmas[[word.i]])
}

sample_from_exemplar = function(exemplar, encodingSigma, samples = 1) {
	matrix(mvrnorm(samples, exemplar, encodingSigma), nrow = samples)
}

get_encoded_exemplar = function(Aexemplar, Vexemplar = Aexemplar, samples = 1) {
	result = list(a = sample_from_exemplar(Aexemplar, Sa, samples), v = sample_from_exemplar(Vexemplar, Sv, samples))

	result$av = t(sapply(1:samples, function(ii) {
		get_av_integration(result$a[ii, ], result$v[ii, ])
	}))

	return(result)
}


get_av_integration = function(a, v) {
	c(Sav %*% t(a %*% invSa + v %*% invSv))
}


# Take into account all noise sources when calculating probabilities
classify_sample = function(x, syllables=syllables.locations, Sigmas=category.Sigmas, priors=syllables.priors, noiseVar=Sav) {
	ds = sapply(1:nrow(syllables), function(ii) dmvnorm(x, syllables[ii,], Sigmas[[ii]] + noiseVar, log=TRUE))

	return(which.max(ds))
}


# helper to create ci locations for a given AV syllable.
get_ci_mean_Sigma = function(s1, s2, n=10000) {
	k = parallel::detectCores()
	cl = makeCluster(k)
    clusterExport(cl, varlist=list('s1', 's2', 'n', 'k'), envir = environment())
	ci.locs = parLapply(cl, 1:k, function(ii) {
		source('mcgurk_causal_inference_functions.R')
		av = get_encoded_exemplar(s1, s2, samples=as.integer(n/k))
		ci.locs = array(NA, dim=dim(av$av))
		for(ii in 1:nrow(ci.locs))
			ci.locs[ii,] = get_ci_percept_location(list(a= av$a[ii,], v= av$v[ii,], av= av$av[ii,]))
		return (ci.locs)
	})
	ci.locs = do.call(rbind, ci.locs)
    stopCluster(cl)
	return (list(m=colMeans(ci.locs), S = cov(ci.locs), xav=ci.locs))
}




