
import sync
import sys
import thread
import time

from numpy import arange, shape, zeros

import remote_exec
import population


######
#I've got to clean up evaluate and initial in population so that
#the incorporation of the parallel stuff is smoother.
######


def array_round(x):
    y = zeros(shape(x))
    for i in range(len(x.flat)):
        y[i] = int(round(x[i]))

    return y

def divide_list(l,sections):
    Ntot = len(l)
    Nsec = float(sections)
    Neach = Ntot/Nsec
    div_points = array_round(arange(0,Ntot,Neach)).tolist()
    if div_points[-1] != Ntot: div_points.append(Ntot)
    sub_pops = []
    st = div_points[0]
    for end in div_points[1:]:
        sub_pops.append(l[st:end])
        st = end
    return sub_pops

class parallel_pop_initializer:
    def evaluate(self,pop,settings = None):
    #only send the individuals out that need evaluation
        if len(pop):
            Nserv = len(pop.server_list)
            groups = divide_list(pop,Nserv)
            sys.setcheckinterval(10)
            finished = sync.event()
            bar = sync.barrier(Nserv)
            print '************',len(groups), len(pop.server_list), len(pop)
            for i in range(len(groups)):
                inputs = {'sub_pop':groups[i],'settings':settings, 'initializer':pop.initializer}
                returns = ('sub_pop',)
                code = 'initializer.evaluate(sub_pop,settings)'
                data_pack = (inputs,returns,code)
                server = pop.server_list[i]
                thread.start_new_thread(remote_thread_init,(bar,finished,server,data_pack))
            finished.wait()
            sys.setcheckinterval(10)

#what is this?                  for ind in pop: ind.evaluate(force)


import cPickle

def plen(obj): return len(cPickle.dumps(obj,1))


class parallel_pop_evaluator:
    def evaluate(self,pop,force = 0):
        #import tree
        #print '1',tree.ref()
        #only send the individuals out that need evaluation
        if force:
            _eval_list = pop.data
        else:
            _eval_list = filter(lambda x: not x.evaluated,pop)
        #print '2',tree.ref()
        eval_list = pop.clone()
        #print '3',tree.ref()
        eval_list.data = _eval_list
        if len(eval_list):
            Nserv = len(pop.server_list)
            groups = divide_list(eval_list,Nserv)
            #print '4',tree.ref()
            sys.setcheckinterval(10)
            finished = sync.event()
            bar = sync.barrier(Nserv)
            #print "EVAL LENGTH!!!", plen(pop.evaluator)
            gr = groups[0]
            print "GROUP LENGTH!!!", plen(groups[0]), len(gr),
            #print "IND!!!", plen(gr[0]),plen(gr[0].root)
            #print '4.5',tree.ref()
            for i in range(len(groups)):
                inputs = {'sub_pop':groups[i], 'evaluator':pop.evaluator, 'force':force}
                returns = ('sub_pop',)
                code = 'evaluator.evaluate(sub_pop,force)'
                data_pack = (inputs,returns,code)
                server = pop.server_list[i]
                thread.start_new_thread(remote_thread_eval,(bar,finished,server,data_pack))
            #print '7',tree.ref()
            finished.wait()
            sys.setcheckinterval(10)
#what is this?                  for ind in pop: ind.evaluate(force)
    """
    def evaluate(self,pop,force = 0):
            #only send the individuals out that need evaluation
            _eval_list = filter(lambda x: not x.evaluated,pop)
            eval_list = pop.clone()
            eval_list.data = _eval_list
            if len(eval_list):
                    #finest grain possible
                    groups = divide_list(eval_list,len(eval_list))
                    finished = sync.event()
                    bar = sync.barrier(groups)

                    sys.setcheckinterval(10)
                    Nserv = len(pop.server_list)
                    idx = 0
                    while idx < len(groups):
                            inputs = {'sub_pop':groups[idx], 'evaluator':pop.evaluator}
                            returns = ('sub_pop',)
                            code = 'evaluator.evaluate(sub_pop)'
                            data_pack = (inputs,returns,code)
                            server = pop.server_list[i]
                            thread.start_new_thread(remote_thread_eval,(bar,finished,server,data_pack))
                    #for i in range(len(groups)):
                    #       inputs = {'sub_pop':groups[i], 'evaluator':pop.evaluator}
                    #       returns = ('sub_pop',)
                    #       code = 'evaluator.evaluate(sub_pop)'
                    #       data_pack = (inputs,returns,code)
                    #       server = pop.server_list[i]
                    #       thread.start_new_thread(remote_thread,(bar,finished,server,data_pack))
                    finished.wait()
                    sys.setcheckinterval(10)
#what is this?                  for ind in pop: ind.evaluate(force)
    """

def remote_thread_init(bar,finished,server,data_pack):
    try:
        remote = remote_exec.remote_exec(server[0],server[1],0,1)
        results = remote.run(data_pack)
        #assign the results from the returned data to the local individuals
        inputs = data_pack[0]
        old = inputs['sub_pop']
        new = results['sub_pop']
        for i in range(len(old)):
            old[i].__dict__.update(new[i].__dict__)
    except IndexError:
        print 'error in %s,%d' %  server
    bar.enter()
    finished.post()

def remote_thread_eval(bar,finished,server,data_pack):
    #import tree
    try:
        #print '5',tree.ref()
        remote = remote_exec.remote_exec(server[0],server[1],0,1)
        results = remote.run(data_pack)
        #print '6',tree.ref()
        #assign the results from the returned data to the local individuals
        inputs = data_pack[0]
        old = inputs['sub_pop']
        new = results['sub_pop']
        for gnm in new:
            gnm.root.delete_circulars()
            del gnm.root
        #print '6.25',tree.ref()
        for i in range(len(old)):
            old[i].__dict__.update(new[i].__dict__)

        #print '6.5',tree.ref()
    except IndexError:
        print 'error in %s,%d' %  server
    """
    import sys
    #r = new[0].root
    #print 'ref count',sys.getrefcount(r)
    #print '6.75',tree.ref()
    #Huh??? Why do I need to delete the new genomes
    #individually here?  Why aren't they garbage collected?
    indices = range(len(new))
    indices.reverse()
    for i in indices:
            del new[i]
    #print 'ref count',sys.getrefcount(r)
    #print '6.8',tree.ref()
    #r.delete_circulars()
    #print 'ref count',sys.getrefcount(r)
    #print '6.9',tree.ref()
    #del r
    #print '6.95',tree.ref()
    """
    bar.enter()
    finished.post()

class ga_parallel_pop(population.population):
    parallel_evaluator = parallel_pop_evaluator()
    parallel_initializer = parallel_pop_initializer()
    def __init__(self,genome,size=1,server_list=None):
        """Arguments:

           genome -- a genome object.
           size -- number.  The population size.  The genome will be
                   replicated size times to fill the population.
           server_list -- a list of tuple pairs with machine names and
                          ports listed for the available servers
                          ex: [(ee.duke.edu,8000),('elsie.ee.duke.edu',8000)]
        """
        population.population.__init__(self,genome,size)
        assert(server_list)
        self.server_list = server_list
    def initialize(self,settings = None):
        """This method **must** be called before a genetic algorithm
           begins evolving the population.  It takes care of initializing
           the individual genomes, evaluating them, and scaling the population.
           It also clears and intializes the statistics for the population.

           Arguments:

           settings -- dictionary of genetic algorithm parameters.  These
                       are passed on to the genomes for initialization.
        """
        self.stats = {'current':{},'initial':{},'overall':{}}
        self.stats['ind_evals'] = 0

        print "beigninning genome generation"
        b = time.clock()
        self.parallel_initializer.evaluate(self,settings)
        e = time.clock()
        print "finished generation: ", e-b
        self.touch();
        b = time.clock()
        self.evaluate()
        e = time.clock()
        print "evaluation time: ", e-b
        self.scale()
        self.update_stats()
        self.stats['initial']['avg'] = self.stats['current']['avg']
        self.stats['initial']['max'] = self.stats['current']['max']
        self.stats['initial']['min'] = self.stats['current']['min']
        self.stats['initial']['dev'] = self.stats['current']['dev']

    def evaluate(self, force = 0):
        """ call the parallel_evaluator instead of the evaluator directly
        """
        self.selector.clear()
        self.parallel_evaluator.evaluate(self,force)
        #self.post_evaluate()
        #all of the remaining should be put in post eval...
        self.sort()
        #this is a cluge to get eval count to work correctly
        preval = self.stats['ind_evals']
        for ind in self:
            self.stats['ind_evals'] = self.stats['ind_evals'] + ind.evals
            ind.evals = 0
        print 'evals: ', self.stats['ind_evals'] - preval
        self.touch()
        self.evaluated = 1

########################## test stuff ############################
#import genome
#import gene
#import time
#
#import socket
#
#class objective:
#    def __init__(self,wait=.01):
#        self.wait = wait
#    def evaluate(self,genome):
#        time.sleep(self.wait)
#        return sum(genome.array(),axis=0)
#
#def test_pop(server_list,size=100,wait=.01):
#    obj = objective(wait)
#    the_gene = gene.float_gene((0,2.5))
#    genome_ = genome.list_genome(the_gene.replicate(5))
#    genome_.evaluator = obj
#    pop = ga_parallel_pop(genome_,size,server_list)
#    print  '########### awaiting evaluation#############'
#    pop.initialize()
#    print ' evaluation done!'
#    print 'best:', pop.best()
#    print 'worst',pop.worst()
#
#
#def gen_pop():
#    genome.list_genome.evaluator = objective()
#    gene = gene.float_gene((0,2.5))
#    genome_ = genome.list_genome(gene.replicate(5))
#    pop = ga_parallel_pop(genome_,100,[(host,port),])
#    return pop
#
#import os
#
#import parallel_pop
#
#
#def test_pop2(server_list,size=100,wait=.01):
#    import hmm_gnm,os
#    genome = hmm_gnm.make_genome()
#    #pop = ga_parallel_pop(genome,4,server_list)
#    global galg
#    #genome.target = targets[0]
#    pop = ga_parallel_pop(genome,1,server_list)
#    galg = hmm_gnm.class_ga(pop)
#    galg.settings.update({ 'pop_size':6,'gens':2,'p_mutate':.03,
#                                'dbase':os.environ['HOME'] + '/all_lift3', 'p_cross':0.9, 'p_replace':.6,
#                                'p_deviation': -.001})
#    galg.evolve()
#
#    print  '########### awaiting evaluation#############'
#    pop.initialize()
#    print ' evaluation done!'
#    print 'best:', pop.best()
#    print 'worst',pop.worst()
#
#import thread
#def test():
#    host = socket.gethostname()
#    port = 8000
#    server_list = [(host,port),(host,port+1)]
#    for server in server_list:
#        host,port = server
#        thread.start_new_thread(remote_exec.server,(host,port))
#    thread.start_new_thread(test_pop2,(server_list,))
#
#def test2(machines=32,size=100,wait=.01):
#    import time
#    t1 = time.time()
#    #requires that servers are started on beowulf 1 and 2.
#    import beowulf
#    server_list = beowulf.beowulf.servers[:machines]
#    thread.start_new_thread(test_pop,(server_list,size,wait))
#    print 'total time:', time.time()-t1
