import re, string, sys, os
from math import *
import muscle, translate

codon_list = ['ATT', 'ATC', 'ATA', \
	      'CTT', 'CTC', 'CTA', 'CTG', 'TTA', 'TTG', \
	      'GTT', 'GTC', 'GTA', 'GTG', \
	      'TTT', 'TTC', \
	      'ATG', \
	      'TGT', 'TGC', \
	      'GCT', 'GCC', 'GCA', 'GCG', \
	      'GGT', 'GGC', 'GGA', 'GGG', \
	      'CCT', 'CCC', 'CCA', 'CCG', \
	      'ACT', 'ACC', 'ACA', 'ACG', \
	      'TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC', \
	      'TAT', 'TAC', \
	      'TGG', \
	      'CAA', 'CAG', \
	      'AAT', 'AAC', \
	      'CAT', 'CAC', \
	      'GAA', 'GAG', \
	      'GAT', 'GAC', \
	      'AAA', 'AAG', \
	      'CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG']

codon_aa_dict = {'ATT': 'I', 'ATC': 'I', 'ATA': 'I', \
	      'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', 'TTA': 'L', 'TTG': 'L', \
	      'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', \
	      'TTT': 'F', 'TTC': 'F', \
	      'ATG': 'M', \
	      'TGT': 'C', 'TGC': 'C', \
	      'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', \
	      'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G', \
	      'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', \
	      'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', \
	      'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', 'AGT': 'S', 'AGC': 'S', \
	      'TAT': 'Y', 'TAC': 'Y', \
	      'TGG': 'W', \
	      'CAA': 'Q', 'CAG': 'Q', \
	      'AAT': 'N', 'AAC': 'N', \
	      'CAT': 'H', 'CAC': 'H', \
	      'GAA': 'E', 'GAG': 'E', \
	      'GAT': 'D', 'GAC': 'D', \
	      'AAA': 'K', 'AAG': 'K', \
	      'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', 'AGA': 'R', 'AGG': 'R', \
	      'TAA': '*', 'TAG': '*', 'TGA': '*', \
	      '---': '-'}

species_opt_codon_dict = {'yeast': ['ATT','ATC','TTG','GTT','GTC','TTC','TGT','GCT','GCC','GGT','CCA','ACT','ACC','TCT','TCC','TAC','CAA','AAC','CAC','GAA','GAC','AAG','AGA'], \
                          'worm':  ['CCA','TTC','AAG','ATC','ACC','GGA','CGC','AAC','CGT','CTC','GCC','TAC','GTC','TGC','CAC','TCC','GAG','CTT','GAC','GCT','TCT','CAA','GTT','ACT', 'AGA'], \
                          'fly':   ['AGC','GCT','GAC','GGT','CTC','GGC','GTC','GTG','CCC','TCC','CAC','CGT','ACC','TGC','AAC','CGC','CAG','GAG','GCC','TAC','CTG','ATC','TTC','AAG']}

base_list = ['A', 'C', 'G', 'T']

def main(species, index_file, ofile):
	opt_codon_list = species_opt_codon_dict[species]
	codon_ns_dict = {}
	for codon in codon_list:
		codon_ns_dict[codon] = getCodonSynSiteNum(codon, opt_codon_list)
	codon_index_dict = {}
	index = 0
	for i in range(4):
		for j in range(4):
			for k in range(4):
				codon = base_list[i] + base_list[j] + base_list[k]
				if codon not in ['TAA', 'TAG', 'TGA']:
					codon_index_dict[codon] = index
					index += 1
	regrec = re.compile(r'(\S*)\s*(\S*)\s*(\S*)\s*(\S*)\s*(\S*)\s*(\S*)\s*(\S*)\s*(\S*)\s*(\S*)\s*(\S*)\s*')
	for line in index_file:
		m = regrec.match(line)
		if m:
			if species == 'yeast':
				print line[:-1]
				orf = m.groups()[0]
				ortho = 'ORFN:' + m.groups()[1]
				spe = m.groups()[3]
				source = m.groups()[4]
				orf_seq = getOrfSeq_yeast(orf)
				ortho_seq = getOrthoSeq_yeast(ortho, spe, source)
			
			if species == 'ecoli':
				print line[:-1]
				orf = m.groups()[0]
				ortho = m.groups()[6]
				if ortho != 'NA':
					orf_seq = getSeq_ecoli(orf, 'ecol')
					ortho_seq = getSeq_ecoli(ortho, 'styp')
				else:
					orf_seq = None
					ortho_seq = None
			
			if species == 'worm':
				print line[:-1]
				orf = m.groups()[0]
				ortho = m.groups()[1]
				if ortho != 'NA':
					orf_seq = getOrfSeq_worm(orf, 'cele')
					ortho_seq = getOrfSeq_worm(ortho, 'cbri')
				else:
					orf_seq = None
					ortho_seq = None
			
			if species == 'fly':
				print line[:-1]
				orf = m.groups()[1]
				ortho = m.groups()[2]
				if ortho != 'NA':
					orf_seq = getSeq_fly(orf, 'dmel')
					ortho_seq = getSeq_fly(ortho, 'dyak')
				else:
					orf_seq = None
					ortho_seq = None
			
			if species == 'mouse':
				print line[:-1]
				orf = m.groups()[1]
				ortho = m.groups()[4]
				if ortho != 'NA':
					orf_seq = orf_seq_dict[orf]
					ortho_seq = ortho_seq_dict[ortho]
				else:
					orf_seq = None
					ortho_seq = None
			
			if orf_seq == None or ortho_seq == None:
				continue
			protein1 = translate.Translate(orf_seq)
			protein2 = translate.Translate(ortho_seq)
			if protein1 == None or protein2 == None:
				continue
			length = len(orf_seq) / 3
			(align_p1, align_p2) = muscle.alignSequences([protein1, protein2])
			align_per = calcAlignPer(align_p1, align_p2)
			if align_per <= 0.8:
				continue
			align_orf_seq = muscle.alignGeneFromProtein(orf_seq, align_p1)
			align_ortho_seq = muscle.alignGeneFromProtein(ortho_seq, align_p2)
			align_length = writeHyphyInputFile(align_orf_seq, align_ortho_seq)
			os.system("/home/tz666/HyPhy/HYPHY_Source/HYPHY ./synAlphaWPsiModelP.bf")
			(likelihood, w, Nd, Sd, psi, SRd, SCd) = procML(align_length)
			(N, S, SR, SC) = getSiteNum(align_orf_seq, align_ortho_seq, codon_ns_dict)
			dn = float(Nd) / float(N)
			ds = float(Sd) / float(S)
			dsr = float(SRd) / float(SR)
			dsc = float(SCd) / float(SC)
			ofile.write("%s\t%s\t%i\t%.4f\t%.5f\t%.2f\t%.2f\t%.2f\t%.2f\t%.5f\t%.5f\t%.5f\t%.2f\t%.2f\t%.2f\t%.2f\t%.5f\t%.5f\t%.5f\n" % (orf, ortho, length, align_per, likelihood, Nd, Sd, N, S, w, dn, ds, SRd, SCd, SR, SC, psi, dsr, dsc))
			ofile.flush()

def writeHyphyInputFile(seq1, seq2):
	(seq1, seq2) = rmGap(seq1, seq2)
	tmp_file = open("./tmp_p.input", 'w')
	tmp_file.write(">a\n")
	tmp_file.write(seq1 + '\n')
	tmp_file.write(">b\n")
	tmp_file.write(seq2 + '\n')
	tmp_file.flush()
	tmp_file.close()
	return len(seq1) / 3

def getSiteNum(seq1, seq2, codon_ns_dict):
	(seq1, seq2) = rmGap(seq1, seq2)
	n = len(seq1) / 3
	(N, S, Sc, Sr) = (0.0, 0.0, 0.0, 0.0)
	for i in range(n):
		codon1 = seq1[(i*3):(i*3+3)]
		if codon1 != '---':
			(nn, ns, nsc, nsr) = codon_ns_dict[codon1]
			N += nn
			S += ns
			Sc += nsc
			Sr += nsr
	for i in range(n):
		codon2 = seq2[(i*3):(i*3+3)]
		if codon2 != '---':
			(nn, ns, nsc, nsr) = codon_ns_dict[codon2]
			N += nn
			S += ns
			Sc += nsc
			Sr += nsr
	return (N/2, S/2, Sr/2, Sc/2)

def getCodonSynSiteNum(codon, opt_codon_list):
	base_list = ['A', 'T', 'G', 'C']
	aa = codon_aa_dict[codon]
	if codon in opt_codon_list:
		is_opt = True
	else:
		is_opt = False
	ns = 0
	nsc = 0
	for base in base_list:
		if base != codon[0]:
			new_codon = base + codon[1:]
			if codon_aa_dict[new_codon] == aa:
				ns += 1
				if (new_codon in opt_codon_list) == is_opt:
					nsc += 1
	for base in base_list:
		if base != codon[1]:
			new_codon = codon[0] + base + codon[2]
			if codon_aa_dict[new_codon] == aa:
				ns += 1
				if (new_codon in opt_codon_list) == is_opt:
					nsc += 1
	for base in base_list:
		if base != codon[2]:
			new_codon = codon[0:2] + base
			if codon_aa_dict[new_codon] == aa:
				ns += 1
				if (new_codon in opt_codon_list) == is_opt:
					nsc += 1
	ns = float(ns) / 3
	nn = 3 - ns
	nsc = float(nsc) / 3
	nsr = ns - nsc
	return (nn, ns, nsc, nsr)

def procML(align_length):
	file_name = './AlphaWPsi_tmp_p.result'
	(frequency_list, q_matrix, w, psi, likelihood) = readMLOutput(file_name)
	(rho_ns, rho_syn, rho_syn_r, rho_syn_c, t) = calcRhoRate(frequency_list, q_matrix)
	if t != 0:
		Nd = align_length * t * rho_ns
		Sd = align_length * t * rho_syn
		SRd = align_length * t * rho_syn_r
		SCd = align_length * t * rho_syn_c
	else:
		Nd = 0
		Sd = 0
		SRd = 0
		SCd = 0
	return (likelihood, w, Nd, Sd, psi, SRd, SCd)

def readMLOutput(file_name):
	ifile = open(file_name, 'r')
	frequency_flag = False
	matrix_flag = False
	parameter_flag = False
	for line in ifile:
		if line.find('Equilibrium codon frequency') >= 0:
			frequency_flag = True
			frequency_list = []
			continue
		if line.find('Q matrix') >= 0:
			matrix_flag = True
			frequency_flag = False
			q_matrix = []
			continue
		if line.find('Parameters') >= 0:
			parameter_flag = True
			matrix_flag = False
			parameter_dict = {}
			continue
		if frequency_flag:
			frequency_list.append(float(line[:-1]))
			continue
		if matrix_flag:
			q_matrix.append(line.split())
			continue
		if parameter_flag:
			if line.split()[0] == 'Log':
				parameter_dict['likelihood'] = float(line.split()[3][:-1])
				break
			else:
				parameter_dict[line.split()[0]] = float(line.split()[2])
	ifile.close()
	return (frequency_list, q_matrix, parameter_dict['w'], parameter_dict['psi'], parameter_dict['likelihood'])

def calcRhoRate(frequency_list, q_matrix):
	n = len(frequency_list)
	sum_diag = 0.0
	sum_all = 0.0
	sum_syn = 0.0
	sum_syn_r = 0.0
	for i in range(n):
		for j in range(n):
			if i == j:
				sum_diag += frequency_list[i] * float(q_matrix[i][j]) * frequency_list[j]
				continue
			else:
				codon_i = index_codon_dict[i]
				codon_j = index_codon_dict[j]
				sum_all += frequency_list[i] * float(q_matrix[i][j]) * frequency_list[j]
				if codon_aa_dict[codon_i] == codon_aa_dict[codon_j]:
					sum_syn += frequency_list[i] * float(q_matrix[i][j]) * frequency_list[j]
					if (codon_i in species_opt_codon_dict[species]) != (codon_j in species_opt_codon_dict[species]):
						sum_syn_r += frequency_list[i] * float(q_matrix[i][j]) * frequency_list[j]
	if sum_all != 0:
		rho_ns = (sum_all - sum_syn) / sum_all
		rho_syn = sum_syn / sum_all
		rho_syn_r = sum_syn_r / sum_all
		rho_syn_c = (sum_syn - sum_syn_r) / sum_all
	else:
		rho_ns = 0
		rho_syn = 0
		rho_syn_r = 0
		rho_syn_c = 0
	return (rho_ns, rho_syn, rho_syn_r, rho_syn_c, sum_all / 2)

def get_codon_index():
	codon_index_dict = {}
	index_codon_dict = {}
	index = 0
	for i in range(4):
		for j in range(4):
			for k in range(4):
				codon = base_list[i] + base_list[j] + base_list[k]
				if codon not in ['TAA', 'TAG', 'TGA']:
					codon_index_dict[codon] = index
					index_codon_dict[index] = codon
					index += 1
	return (codon_index_dict, index_codon_dict)

def rmGap(seq1, seq2):
	n = len(seq1) / 3
	new_seq1 = ''
	new_seq2 = ''
	for i in range(n):
		codon1 = seq1[(i*3):(i*3+3)]
		codon2 = seq2[(i*3):(i*3+3)]
		if codon1 != '---' and codon2 != '---':
			new_seq1 += codon1
			new_seq2 += codon2
	return (new_seq1, new_seq2)

def getOrfSeq_yeast(orf):
	orf_file = open("/home/tz666/functional_site/data/genome/yeast/orf_scer.fasta", 'r')
	seq = []
	flag = 0
	cDNA = ''
	for line in orf_file:
		if line.find(orf) == 1 and line[len(orf)+1] != '-':
			flag = 1
			continue
		if line[0] == ">" and flag == 1:
			cDNA = ''.join(seq)
			break
		if flag == 1:
			cDNA = line.rstrip().upper()
			seq.append(cDNA)
	orf_file.close()
	if cDNA == '':
		return
	return cDNA

def getOrthoSeq_yeast(ortho, spe, source):
	file_name = ('orf_s' + spe[0:3] + '_' + source[0] + '.fasta').lower()
	ortho_file = open("/home/tz666/functional_site/data/genome/yeast/" + file_name, 'r')
	seq = []
	flag = 0
	cDNA = ''
	for line in ortho_file:
		if line.find(ortho) == 1 and line[len(ortho)+1] == ' ':
			flag = 1
			continue
		if line[0] == ">" and flag == 1:
			cDNA = ''.join(seq)
			break
		if flag == 1:
			cDNA = line.rstrip().upper()
			seq.append(cDNA)
	ortho_file.close()
	if cDNA == '':
		return
	return cDNA

def getSeq_ecoli(ortho, spe):
	file_name = ('orf_' + spe + '.fasta').lower()
	ortho_file = open("/home/tz666/functional_site/data/genome/ecoli/" + file_name, 'r')
	seq = []
	flag = 0
	cDNA = ''
	for line in ortho_file:
		if line.find(ortho) == 1 and line[len(ortho)+1] == ' ':
			flag = 1
			continue
		if line[0] == ">" and flag == 1:
			cDNA = ''.join(seq)
			break
		if flag == 1:
			cDNA = line.rstrip().upper()
			seq.append(cDNA)
	ortho_file.close()
	if cDNA == '':
		return
	return cDNA

def getSeq_fly(ortho, spe):
	file_name = (spe + '.cds.fasta').lower()
	ortho_file = open("/home/tz666/functional_site/data/genome/fly/" + file_name, 'r')
	seq = []
	flag = 0
	cDNA = None
	for line in ortho_file:
		if line.find(ortho) == 1:
			flag = 1
			continue
		if line[0] == ">" and flag == 1:
			cDNA = ''.join(seq)
			break
		if flag == 1:
			cDNA = line.rstrip().upper()
			seq.append(cDNA)
	ortho_file.close()
	return cDNA

def getSeqDict_mouse(spe):
	regrec = re.compile(r'>(\S*)\|(\S*)')
	file_name = (spe + '.txt').lower()
	orf_file = open("/home/tz666/functional_site/data/genome/mouse/" + file_name, 'r')
	seq_dict = {}
	first = True
	for line in orf_file:
		if line[0] == ">" and first:
			last_orf = regrec.match(line).groups()[1]
			seq = []
			cDNA = None
			first = False
			continue
		if line[0] == ">":
			cDNA = ''.join(seq)
			try:
				seq_dict[last_orf] = cDNA
			except:
				seq_dict[last_orf] = None
			if len(line) < 30:
				last_orf = line[1:19]
				seq = []
				cDNA = None
				continue
			last_orf = regrec.match(line).groups()[1]
			seq = []
			cDNA = None
			continue
		cDNA = line.rstrip().upper()
		seq.append(cDNA)
	orf_file.close()
	return seq_dict

def getOrfSeq_worm(orf, spe):
	orf_file = open("/home/tz666/functional_site/data/genome/worm/" + spe + ".fasta")
	cDNA_list = []
	seq = []
	flag = 0
	cDNA = ''
	for line in orf_file:
		if line.find(orf) >= 1:
			flag = 1
			continue
		if line[0] == ">" and flag == 1:
			cDNA = ''.join(seq)
			cDNA_list.append(cDNA)
			flag = 0
			seq = []
			cDNA = ''
			continue
		if flag == 1:
			cDNA = line.rstrip().upper()
			seq.append(cDNA)
	orf_file.close()
	if len(cDNA_list) == 0:
		return
	if len(cDNA_list) == 1:
		if cDNA_list[0] == '' or cDNA_list[0] == 'SEQUENCE UNAVAILABLE':
			return
		else:
			return cDNA_list[0]
	max_len = 0
	for r in cDNA_list:
		if len(r) > max_len:
			max_len = len(r)
			max_cDNA = r
	if max_cDNA == '' or max_cDNA == 'SEQUENCE UNAVAILABLE':
		return
	return max_cDNA

def calcAlignPer(s1, s2):
	n = len(s1)
	aligned_count = 0
	aa_length = 0
	for i in range(n):
		res1 = s1[i]
		res2 = s2[i]
		if res1 != '-' and res2 != '-':
			aligned_count += 1
		if not (res1 == '-' and res2 == '-'):
			aa_length += 1
	per = float(aligned_count) / float(aa_length)
	return per

(codon_index_dict, index_codon_dict) = get_codon_index()

index_file = open("/home/tz666/HyPhy/data/ortho_index/yeast_ortho_sbay_MIT_part1.dat", 'r')
ofile = open("/home/tz666/HyPhy/result/yeast_part1.dat", 'w')
species = 'yeast'
main(species, index_file, ofile)
index_file.close()
ofile.close()
