
#/*##########################################################################
# Copyright (C) 2001-2013 European Synchrotron Radiation Facility
#
#              PyHST2  
#  European Synchrotron Radiation Facility, Grenoble,France
#
# PyHST2 is  developed at
# the ESRF by the Scientific Software  staff.
# Principal author for PyHST2: Alessandro Mirone.
#
# This program is free software; you can redistribute it and/or modify it 
# under the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or (at your option) 
# any later version.
#
# PyHST2 is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# PyHST2; if not, write to the Free Software Foundation, Inc., 59 Temple Place,
# Suite 330, Boston, MA 02111-1307, USA.
#
# PyHST2 follows the dual licensing model of Trolltech's Qt and Riverbank's PyQt
# and cannot be used as a free plugin for a non-free program. 
#
# Please contact the ESRF industrial unit (industry@esrf.fr) if this license 
# is a problem for you.
#############################################################################*/



from __future__ import absolute_import
from __future__ import division
from __future__ import print_function



from . import getSysOutput
from . import getCpuSetRange
import os
import string

from .  import string_six


import mpi4py.MPI as MPI
import sys
import numpy
import multiprocessing

def setCpuSet(maxnargs=3):

    myrank = MPI.COMM_WORLD.Get_rank()
    nprocs = MPI.COMM_WORLD.Get_size()
    mypname = MPI.Get_processor_name()
    comm = MPI.COMM_WORLD

    # cpuset_string = getSysOutput.getSysOutput("taskset -cp $$ ")



    import subprocess as sub
    comando = 'taskset -cp %d'%(os.getpid())
    if (sys.version_info >= (3, 7)):    
        p = sub.Popen(args=comando.split( " ") ,stdout=sub.PIPE,stderr=sub.PIPE, text=True)
        
    elif (sys.version_info > (3, 0)):
      
        p = sub.Popen(args= comando.split( " ") ,stdout=sub.PIPE,stderr=sub.PIPE, universal_newlines=True)
        
    else:
        p = sub.Popen(args=comando.split( " ") ,stdout=sub.PIPE,stderr=sub.PIPE)
        
    cpuset_string, errors = p.communicate()
 
     # tmp_ncores4cpu,  myavailable_cores_on_host  = getCpuSetRange.getCoresOrdered()

    # previously the followin line applied
    myavailable_cores_on_host = getCpuSetRange.getCpuSetRange(str(cpuset_string).strip())


    print(" AVAILABLE ",myavailable_cores_on_host, cpuset_string ) 
    
    #
    ##########################################################################################
    Ntotal_cores_on_host = multiprocessing.cpu_count()

    comando = 'grep#physical id#/proc/cpuinfo'
    
    if (sys.version_info >= (3, 7)):        
        p1 = sub.Popen(args=comando.split( "#") ,stdout=sub.PIPE,stderr=sub.PIPE, text=True )
       
    elif (sys.version_info > (3, 0)):
      
        p1 = sub.Popen(args= comando.split( "#") ,stdout=sub.PIPE,stderr=sub.PIPE, universal_newlines=True)
        
    else:
        
        p1 = sub.Popen(args=comando.split( "#") ,stdout=sub.PIPE,stderr=sub.PIPE)
        
    comando = 'sort -u'
    if (sys.version_info >= (3, 7)):            
        p2 = sub.Popen(args=comando.split( " ") ,stdin=p1.stdout,stdout=sub.PIPE,stderr=sub.PIPE, text=True)

       
    elif (sys.version_info > (3, 0)):
      
        p2 = sub.Popen(args=comando.split( " ") ,stdin=p1.stdout,stdout=sub.PIPE,stderr=sub.PIPE, universal_newlines=True)
     

        
    else:
        p2 = sub.Popen(args=comando.split( " ") ,stdin=p1.stdout,stdout=sub.PIPE,stderr=sub.PIPE)
        
    comando = 'wc -l'
    if (sys.version_info >= (3, 7)):
        
        p3 = sub.Popen(args=comando.split( " ") ,stdin=p2.stdout,stdout=sub.PIPE,stderr=sub.PIPE, text=True)

    elif (sys.version_info > (3, 0)):  

        p3 = sub.Popen(args=comando.split( " ") ,stdin=p2.stdout,stdout=sub.PIPE,stderr=sub.PIPE, universal_newlines=True)
        
    else:
        p3 = sub.Popen(args=comando.split( " ") ,stdin=p2.stdout,stdout=sub.PIPE,stderr=sub.PIPE)

    nofprocessors, errors = p3.communicate()
    nofprocessors=string.atoi(nofprocessors)

    comando = "grep MemTotal /proc/meminfo"
    if (sys.version_info >= (3, 7)):            
        p1 = sub.Popen(args=comando.split( " ") ,stdout=sub.PIPE,stderr=sub.PIPE, text=True)

    elif (sys.version_info > (3, 0)):  

        p1 = sub.Popen(args=comando.split( " ") ,stdout=sub.PIPE,stderr=sub.PIPE, universal_newlines=True)


        
    else:
        p1 = sub.Popen(args=comando.split( " ") ,stdout=sub.PIPE,stderr=sub.PIPE)

        
    msg, errors = p1.communicate()
    MemTotal=msg.split()[1]
    
    MemTotal =string.atof(MemTotal)*1000
    
    MemPerProc = MemTotal/Ntotal_cores_on_host

    # MemPerProc = MemTotal/len( myavailable_cores_on_host    )

    if "kB" not in msg:
        raise Exception(" MemTotal in is not given in kB :" ,   msg)

    ## in un futuro in cui potessero esserci piu di due cpu i cores potrebbero essere
    ## sur intervalli non contigui. In questo caso non bastera avere il numero totale di core
    ## e il primo numero. Ci vorranno piu range , ognuno individuato dall'inizio 
    ## e dal numero di cores
    # taskset -cp 30199
    # pid 30199's current affinity list: 0,2,3




    
    # by Now : Example :
    # cpuset_string "pid 2928's current affinity list: 0-3\n"
    # ===>  firstCORE_launch, nCOREs_launch = 0,4

    cpuranges_perproc={}
    hostname_perproc={}

    for iproc in range(nprocs):
        if myrank == iproc:
            datacpus= {myrank:myavailable_cores_on_host}
            datanames={myrank: mypname  }
        else:
            datacpus = None
            datanames = None

        datacpus = comm.bcast(datacpus, root=iproc)
        datanames = comm.bcast(datanames, root=iproc)

        cpuranges_perproc.update(datacpus )
        hostname_perproc .update(datanames)

    comrade_procs=[]
    for key in hostname_perproc.keys():
        if hostname_perproc[key]== mypname:
            comrade_procs.append(key)

    assert(myrank in comrade_procs)
    
    for iproc in comrade_procs:
        assert cpuranges_perproc[iproc] ==  cpuranges_perproc[myrank], (  iproc, myrank ,cpuranges_perproc[iproc],  cpuranges_perproc[myrank]   )

    comrade_procs.sort()
    ncomrades=len(comrade_procs)
    mycomraderank= comrade_procs.index(myrank)

    assert( len( myavailable_cores_on_host    )%ncomrades == 0  )
    coresperproc =  len(myavailable_cores_on_host)//ncomrades
    
    cpusperproc=1

    
   

        
    ncores4cpu,   cores_ordered = getCpuSetRange.getCoresOrdered()
    cpusperproc  =  coresperproc//ncores4cpu
    print( " cpusperproc " , cpusperproc)



    

    my_cores = numpy.array(cores_ordered)[range(mycomraderank*coresperproc,  (mycomraderank+1) *  coresperproc  )]
    s=""
    for n in my_cores:
        s=s+str(n)+","
    s=s[:-1]
  
    comando = 'taskset -pc %s %d'%(  s,   os.getpid())
    
    if (sys.version_info >= (3, 7)):            
        p = sub.Popen(args=comando.split( " ") ,stdout=sub.PIPE,stderr=sub.PIPE, text=True)

    elif (sys.version_info > (3, 0)):  
        
        p = sub.Popen(args=comando.split( " ") ,stdout=sub.PIPE,stderr=sub.PIPE, universal_newlines=True)
        
        
    else:
        p = sub.Popen(args=comando.split( " ") ,stdout=sub.PIPE,stderr=sub.PIPE)

    
    cpuset_string, errors = p.communicate()

    comm.Barrier()


    if len(sys.argv)<maxnargs:
        mygpus=[]
    else:
        gpus=[]
        largv3=sys.argv[maxnargs-1].split(",")
        status=0
        for tok in largv3:
            if status==1:
                try:
                    gpuid = string.atoi(tok)
                    gpus.append(gpuid)
                except:
                    status=0
            if tok==mypname:
                status=1

        posincomrade = comrade_procs.index(myrank)
        pos1 = (len(gpus) *posincomrade)//len(comrade_procs)
        pos2 = (len(gpus) *(posincomrade+1))//len(comrade_procs)
        
        mygpus = gpus[pos1:pos2]

        
    return mygpus, MemPerProc, coresperproc, cpusperproc
