# -*- coding: utf-8 -*-
import os
import sys
from pyfold import *
import numpy as np
import tarfile as t
import random

targ='..(((((....(((((((.((((....)))).))))))).(((.(((((((..((((....(((....)))...))))))))))))))....)))))..'

def get_fit(shape,targ):
    #doesn't use hyperbolic anymore, can probably ditch
    if targ:
        fitness=100-(2.5*hamming(shape,targ))
        if fitness<0.0:
            fitness=0.0
    return fitness
    
def fit_dom(shp1,shp2,targ,h):
    '''
    replaces old seg fit of geometric mean. kept for relic's sake
    np.sqrt(getfit(self.shp1)*getfit(self.shp2,globs.add_targs[0]))
    '''
    fit1=get_fit(shp1,targ)
    fit2=get_fit(shp2,targ)
    #print fit1,fit2,h
    #print max((fit1,fit2))
    #print type(h)
    #print type(100-(min((fit1,fit2))))
    comb_fit=max((fit1,fit2))-h*(100-min((fit1,fit2)))
    if comb_fit<0.0:
        comb_fit=0
    return comb_fit
    
def recomb2(seq1,seq2,rs_val):
    '''
    the way I was doing recombination is kinda fucked up. redo as true crossover of variable length.
    not with a max bit as currently.
    '''
    #rec_seq1,r1=seq1 
    #rec_seq2,r2=seq2
    rec_seq1=seq1
    rec_seq2=seq2
    rands=np.random.random(len(seq1))
    for i in range(len(seq1)):
        if rands[i]<rs_val/99.0:
            rec_seq1=''.join([rec_seq1[:i],rec_seq2[i:]])
            rec_seq2=''.join([rec_seq2[:i],rec_seq1[i:]])

    return rec_seq1,rec_seq2
        
def seg(g1,g2,rs_val):
    """
    basic junk function for segregation, not tested yet due to commands from on high
    """
    #print gtype1,gtype2
    seg1=[]
    seg2=[]
    shps1=[]
    shps2=[]
    #g1,s1=gtype1
    #g2,s2=gtype2
    if np.random.random()<=rs_val:#min(s1,s2):
        for chrom in range(2):
            if np.random.random()<0.5:
                seg1.append(g1[chrom])
                seg2.append(g2[chrom])
                shps1.append(g1[chrom+2])
                shps2.append(g2[chrom+2])
            else:
                seg1.append(g2[chrom])
                seg2.append(g1[chrom])
                shps1.append(g2[chrom+2])
                shps2.append(g1[chrom+2])
        return tuple(seg1),tuple(seg2),tuple(shps1),tuple(shps2)  
    else:
        return g1,g2,tuple(g1[2:4]),tuple(g2[2:4])


def get_pop(popfile):
    '''
    returns list of genotypes
    '''
    f=open(popfile)
    dat=f.readlines()[4:]
    pop=[]
    for i in range(0,len(dat),2):
        gtype=dat[i].strip().split(',')
        gdat=dat[i+1].split(',')
        if len(gtype)==1:
            num_inds=int(gdat[3])
            fit=float(gdat[4])
            shp=gdat[0]
            for _ in range(num_inds):
                pop.append((gtype[0],shp,fit))
        else:
            shp1=gdat[0]
            shp2=gdat[1]
            num_inds=int(gdat[4])
            fit=float(gdat[5])
            for _ in range(num_inds):
                pop.append((gtype[0],gtype[1],shp1,shp2,fit))
    
    #will need to change for new population format in rec mutation files
    return pop
    
def mutlevels(maxfit,minfit,steps):
    d=(maxfit-minfit)/(steps-1)
    return [minfit+i*d for i in range(steps-1,-1,-1)]

expected_fit=mutlevels(.95,.01,15)
ulist=-np.log(expected_fit)
os.chdir('rec_seg_dat')
rs_val=.1
lin=1
f=open('epi_%i'%(sys.argv[2]),'w')
f.write('U,mfit_diff,vfit_diff,rep\n')
for i in range(1,25):
    try:
        tf=t.open('ml2_%s-%i.tar.gz'%(sys.argv[1],i))
    except IOError:
        continue
    for gen in [1000,2000,3000,4000,5000]:
        for mlev in range(1,16):
            try:
                cur_f='pymut_%i_run_%i'%(mlev,gen)
                tf.extract(cur_f)
            except KeyError:
                continue
            pop=get_pop(cur_f)
            if sys.argv[2]=='s':
                tot_fit=sum(gt[4] for gt in pop)
            else:
                tot_fit=sum(gt[2] for gt in pop)
            tot_fit/=100.0
            norm=0.0
            fit_diffs=[]
            for _ in range(1000):
                gdat1=random.choice(pop)
                gdat2=random.choice(pop)
                if sys.argv[2]=='s':
                    h=float(sys.argv[3])
                    ng1,ng2,ns1,ns2=seg(gdat1,gdat2,rs_val)
                    ng1_fit=fit_dom(ns1[0],ns1[1],targ,h)
                    ng2_fit=fit_dom(ns2[0],ns1[1],targ,h)
                    g1_fit_diff=(ng1_fit-gdat1[4])/gdat1[4] if gdat1[4] else 0
                    g2_fit_diff=(ng2_fit-gdat2[4])/gdat2[4] if gdat2[4] else 0
                    #print g1_fit_diff,g2_fit_diff
                    cnorm=(gdat1[4]/tot_fit)*(gdat2[4]/tot_fit)
                    fit_diffs.append(g1_fit_diff*cnorm)
                    fit_diffs.append(g2_fit_diff*cnorm)
                    norm+=cnorm*2
                else:
                    r1,r2=recomb2(gdat1[0],gdat2[0],rs_val)
                    cnorm=(gdat1[2]/tot_fit)*(gdat2[2]/tot_fit)
                    #print r1,gdat1[0]
                    if r1 != gdat1[0]:
                        shp1=pyfold(r1)
                        shp2=pyfold(r2)
                        f1=get_fit(shp1,targ)
                        f2=get_fit(shp2,targ)
                        #print f1,f2
                        g1_fit_diff=(f1-gdat1[2])/gdat1[2] if gdat1[2] else 0
                        g2_fit_diff=(f2-gdat2[2])/gdat2[2] if gdat2[2] else 0
                        #print g1_fit_diff,g2_fit_diff
                        fit_diffs.append(g1_fit_diff*cnorm)
                        fit_diffs.append(g2_fit_diff*cnorm)
                    else:
                        fit_diffs.extend((0,0))
                    norm+=cnorm*2
            fit_diffs=np.array(fit_diffs)*norm
            f.write('%i,%f,%f,%f,%i\n'%(lin,ulist[mlev-1],fit_diffs.mean(),fit_diffs.var(ddof=1),i))
            lin+=1
f.close()
os.chdir('../') 