
from pylab import *
from PyMca import EdfFile
import glob
import string
import sys
from numpy import fft
import numpy

import os
import subprocess as sub
import sys

from . import string_six


##############################################################################################
# HOW MANY  GPUS?
##
try:
 args=["oarprint",  "host", "-P",  "host,gpu_num",  "-F",  "'% %'" ]
 p = sub.Popen(args=args ,stdout=sub.PIPE,stderr=sub.PIPE)
 resources, errors = p.communicate()
except:
    resources=""
print( " resources gpu ", resources)
if(len(resources)):
    ## example  
    # gpu0102 1,0
    # gpu0101 1
    resources_L = resources.split( "\n"  )
    gpus_string=" "
    for l in resources_L:
        if len(l)==0:
            continue
        if l[0]==l[-1] and l[0] in ["'", '"']:
            l=l[1:-1]
        l=l.strip()
        node, gpus  =  l.split( " ")
        gpus_s=gpus.split( ",")
        if(len(gpus_s)>1):
            print( " WARNING : you have been allocated ", len(gpus_s)," gpus ")
            print( " but you will use ony one cpu! ")
        gpus_string = gpus

        gpu=string.atoi(gpus_string)
        import pycuda.autoinit
        import pycuda.tools
        if gpu!=-1:
            dev=pycuda.driver.Device(gpu)
            dev.make_context()
        else:
            pycuda.tools.make_default_context() 


# else:
#     pass
#     # la linea gpu-string sara presa dalla linea di comando
#     gpus_string=""
#     if len(sys.argv)==3:
#         gpus_string=" "+sys.argv[2]
#     del  sys.argv[2]




print ("Ma carte est ", pycuda.driver.Context.get_device().name())


try:
    import pycuda
    import pycuda.autoinit
    import pycuda.elementwise
    import pycuda.gpuarray as gpuarray
    from pycuda.compiler import SourceModule
    import scikits.cuda.fft as cu_fft

    usecuda=1
except:
    usecuda=0
    try:
        import pyfftw
        usepyfftw = 1
    except:
        usepyfftw=0
        import fftw3

print( " USECUDA" , usecuda)

parinput = open(sys.argv[1],"r")
for line in parinput:
    if "NPROJ" in line:
        exec(line)
    if "PROJ_OUTPUT_FILE" in line:
        items = line.split()
        path = items[2]
        numberpos = path.find("%")
        postfixpos = path.rfind(".")
        path = path[:numberpos]+"*"+path[postfixpos:]

    if "PROPA_OUTPUT_FILE" in line:
        items = line.split()
        PROPAPROJ_OUTPUT_FILE = items[2]

    if "dbratio" in line or "plength" in line or "IMAGE_PIXEL_SIZE_1" in line  or  "SSOURCE_RAD" in line or  "NSOURCE" in line or  "SENE" in line or  "NENE" in  line  :
        exec(line)
    if "ekev" in line  :
        exec(line)
        

print( "\n* THE INPUT FILE SAYS YOU HAVE GENERATED ", NPROJ, " PROJECTION ")
print( "                 AND PROJECTIONS ARE SEARCHED AS PATTERN ", path)

fl=glob.glob(path)
fl.sort()

fl=fl[ :NPROJ]
print( "\n\n* FIRST PROJECTION ", fl[0])
print(     "  LAST             ", fl[-1])
print( "\n")
data = EdfFile.EdfFile(fl[0]).GetData(0)
datashape = data.shape

sino  = zeros([NPROJ, datashape[1]],"d")
Csino = zeros([NPROJ, datashape[1]],"D")

voxelsize = IMAGE_PIXEL_SIZE_1*1.0
ekev = ekev*1.0
dbratio=dbratio*1.0

print( "\n* THE INPUT FILE SAYS YOU VOXEL IS  ", voxelsize, " MICRON ")
print(    "                      DELTA/BETA   ", dbratio)
print(    "                      PROPAGATION  ", plength , " METERS ")
print(    "                      ENERGY       ", ekev , " KeV ")
print( "\n")
print(    "* CONCERNING COHERENCE:")
print(    "         SSOURCE_RAD               ", SSOURCE_RAD  , " Radians ")
print(    "         NSOURCE                   ", NSOURCE  , " Points ")
print(    "         SENE                      ", SENE     , " Kev ")
print(    "         NENE                      ", NENE     , " Points ")


voxelsize_A = voxelsize *10000.0
plength_A   = plength   *1.0e10

for i,nome in enumerate(fl):
    sino[i] =  EdfFile.EdfFile(nome).GetData(0)


if usecuda:
    aux_a = numpy.zeros([ sino.shape[0], sino.shape[1]*2  ], dtype= 'D')
    aux_b = numpy.zeros([ sino.shape[0], sino.shape[1]*2  ], dtype= 'D')

    arrayin  = gpuarray.empty([ sino.shape[0], sino.shape[1]*2 ], numpy.complex128)
    arrayout = gpuarray.empty([ sino.shape[0], sino.shape[1]*2 ], numpy.complex128)
    array_signal = gpuarray.empty([ sino.shape[0], sino.shape[1]*2 ], numpy.complex128)
    array_convoluted = gpuarray.zeros([ sino.shape[0], sino.shape[1]*2 ], numpy.float64)
    data=numpy.zeros([sino.shape[0], sino.shape[1]*2 ],numpy.complex128)
    fplan=cu_fft.Plan([sino.shape[1]*2 ],numpy.complex128, numpy.complex128 , batch=sino.shape[0])
    fact4norm = 1.0/aux_a.shape[1]


    modShifta = SourceModule("""
#include <cuComplex.h>
  __global__ void shifta(cuDoubleComplex  *a, double K0aV, int Ny, int Nx)
  {
    int gidx = threadIdx.x + blockIdx.x*blockDim.x;
    int gidy = threadIdx.y + blockIdx.y*blockDim.y;
    int gid  = gidy*Nx+gidx;

    cuDoubleComplex c;
    if(gidx<Nx && gidy<Ny) {
        c = a[gid] ;
        a[gid] =   cuCmul(c, make_cuDoubleComplex(cos(K0aV*gidx), sin(K0aV*gid)));
    }
  }
  __global__ void propa(cuDoubleComplex  *a, double K0, double Vox, double p, int Ny, int Nx) {
      int gidx = threadIdx.x + blockIdx.x*blockDim.x;
      int gidy = threadIdx.y + blockIdx.y*blockDim.y;
      int gid  = gidy*Nx+gidx;
      cuDoubleComplex c;
      double kperp, kpar, phase;
      int Nx2 = Nx/2;

      if(gidx<Nx && gidy<Ny) {
          kperp = 2*M_PI/Vox/Nx * ( (gidx+Nx2 )%Nx -Nx2) ; 
          kpar   = sqrt( K0*K0 - kperp*kperp  ) ;
          phase =  kpar * p; 
          c = a[gid] ;
          // c = cuCmul(c, make_cuDoubleComplex(  exp( - kperp*kperp*Vox*Vox*1.4/8 )       , 0));
          a[gid] =   cuCmul(c, make_cuDoubleComplex(cos(phase), sin(phase)));
      }
   }
  __global__ void intensity(cuDoubleComplex  *a, double * inte, double w, int Ny, int Nx) {
      int gidx = threadIdx.x + blockIdx.x*blockDim.x;
      int gidy = threadIdx.y + blockIdx.y*blockDim.y;
      int gid  = gidy*Nx+gidx;
      cuDoubleComplex c;


      if(gidx<Nx && gidy<Ny) {
          c = a[gid] ;
          inte[gid] +=   w*( cuCreal(c)*cuCreal(c) + cuCimag(c)*cuCimag(c) ) ;
      }
   }
  """)


elif not usepyfftw:
    try:
        if os.path.exists("fftw3_wisdom"):
            fftw3.import_wisdom_from_file("fftw3_wisdom")
    except:
        pass
    
    aux_a = fftw3.create_aligned_array([ sino.shape[0], sino.shape[1]*2  ], dtype= 'D')
    aux_b = fftw3.create_aligned_array([ sino.shape[0], sino.shape[1]*2  ], dtype= 'D')

    fft_object_a = fftw3.Plan(aux_a,aux_b, direction="forward",  flags=('measure', ) )  
    fft_object_b = fftw3.Plan(aux_b,aux_a, direction="backward", flags=('measure', ) )  

    fftw3.export_wisdom_to_file("fftw3_wisdom")

    fact4norm = 1.0/ aux_a.shape[1]/aux_a.shape[0]

else:
    try:
        if os.path.exists("pyfftw_wisdom"):
            wisdom = pickle.load(open("pyfftw_wisdom","r") )
            pyfftw.import_wisdom(wisdom)
    except:
        pass

    
    aux_a = pyfftw.n_byte_align_empty([ sino.shape[0], sino.shape[1]*2  ],16,'complex128')
    aux_b = pyfftw.n_byte_align_empty([ sino.shape[0], sino.shape[1]*2  ],16,'complex128')
 
    fft_object_a = pyfftw.FFTW(aux_a,aux_b, direction="FFTW_FORWARD",  flags=('FFTW_MEASURE', ),axes=(-1,),threads=4 )    #  
    fft_object_b = pyfftw.FFTW(aux_b,aux_a, direction="FFTW_BACKWARD", flags=('FFTW_MEASURE', ),axes=(-1,),threads=4 )    #  

    wisdom =  pyfftw.export_wisdom()
    pickle.dump(wisdom, open("pyfftw_wisdom","w") )

    fact4norm = 1.0





absorption = sino/2.0
phase = dbratio * absorption
complex_signal = exp( -absorption - phase*(0+1.0j) )


convoluted = numpy.zeros(aux_a.shape,"d")

if NENE>1:
    enes = numpy.linspace(-1*SENE, 1*SENE, num=NENE, endpoint=True, retstep=False)
    enew = numpy.exp( -enes*enes/SENE/SENE/2.0  )
    enew = enew/ enew.sum()
else:
    enes=[0.0]
    enew=[1.0]


print( aux_a.shape)

freqs  = fft.fftfreq(aux_a.shape[1], voxelsize_A)*2*pi
# blurring = exp( -freqs*freqs*voxelsize_A*voxelsize_A*1.4/8.0  )
blurring=1


if usecuda:   
    aux_a[:,:sino.shape[1]]=complex_signal
    aux_a[:,sino.shape[1]:]= 1.0+0.0j
    array_signal.set(aux_a)

for ed, ew in zip( enes, enew):

    K0 = 2*pi*(ekev+ed)/12.39842              


    if NSOURCE>1:
        angles = numpy.linspace(-5*SSOURCE_RAD, 5*SSOURCE_RAD, num=NSOURCE, endpoint=True, retstep=False)
        # angles = angles   *K0*aux_a.shape[1]*voxelsize_A/2.0/numpy.pi
        # angles = angles.round()*2.0*numpy.pi/(K0*aux_a.shape[1]*voxelsize_A)
        
        anglew = numpy.exp( -angles*angles/SSOURCE_RAD/SSOURCE_RAD/2.0  )
        anglew = anglew/ anglew.sum()
    else:
        angles=[0.0]
        anglew=[1.0]


        
    for a, w in zip(angles, anglew):
        print( " a,w ", a,w , "   ed, ew  " , ed, ew)
        if usecuda:
            Ny,Nx = aux_a.shape[0], aux_a.shape[1]
            pycuda.driver.memcpy_dtod( arrayin.gpudata,  array_signal.gpudata, array_signal.nbytes)
            modShifta.get_function("shifta")( arrayin, numpy.float64(K0*a*voxelsize_A),
                                              numpy.int32(Ny),numpy.int32(Nx) , block = (32, 8, 1), grid=( Nx/32+1,Ny/8+1)   )
            cu_fft.fft(arrayin, arrayout, fplan)
            modShifta.get_function("propa")( arrayout, numpy.float64(K0) , numpy.float64(voxelsize_A)
                                             , numpy.float64(plength_A), 
                                             numpy.int32(Ny),numpy.int32(Nx) , block = (32, 8, 1), grid=( Nx/32+1,Ny/8+1)   )


        else:
            aux_a[:,:sino.shape[1]]=complex_signal
            aux_a[:,sino.shape[1]:]= 1.0+0.0j
            aux_a[:,:]  =   aux_a * numpy.exp( 1.0j * K0 * a *numpy.arange(aux_a.shape[1])*voxelsize_A )
            fft_object_a()

            kperp = fft.fftfreq(aux_a.shape[1], voxelsize_A)*2*pi
            pp = sqrt(K0*K0 - kperp*kperp)
            pp_forw = exp( 1.0j  * pp  * plength_A)
            pp_back = exp( -1.0j  * pp  * plength_A)
            aux_b[:] = aux_b * pp_forw
            # * blurring
        
        if usecuda:
            cu_fft.ifft(arrayout, arrayin, fplan)
            modShifta.get_function("intensity")( arrayin, array_convoluted, numpy.float64(ew*w * fact4norm*fact4norm), 
                                                 numpy.int32(Ny),numpy.int32(Nx) , block = (32, 8, 1), grid=( Nx/32+1,Ny/8+1)   )
        else:
            fft_object_b()
            result = array( aux_a[:,:] * fact4norm   ) 
            ressino = abs(result)
            ressino=ressino*ressino
            convoluted+=ew*w* ressino

if usecuda:
    ressino = array_convoluted.get()[:,:sino.shape[1]]
    pass
else:    
    ressino=convoluted[:,:sino.shape[1]]

fig = plt.figure()

ax = fig.add_subplot(111)
ax.imshow(sino, cmap=cm.jet, interpolation='nearest')
numrows, numcols = sino.shape
def format_coord(x, y):
    col = int(x+0.5)
    row = int(y+0.5)
    if col>=0 and col<numcols and row>=0 and row<numrows:
        z = sino[row,col]
        return 'x=%1.4f, y=%1.4f, z=%1.4f'%(x, y, z)
    else:
        return 'x=%1.4f, y=%1.4f'%(x, y)
ax.format_coord = format_coord

fig2 = plt.figure()

ax = fig2.add_subplot(111)
ax.imshow(ressino, cmap=cm.jet, interpolation='nearest')
numrows, numcols = ressino.shape
def format_coord(x, y):
    col = int(x+0.5)
    row = int(y+0.5)
    if col>=0 and col<numcols and row>=0 and row<numrows:
        z = ressino[row,col]
        return 'x=%1.4f, y=%1.4f, z=%1.4f'%(x, y, z)
    else:
        return 'x=%1.4f, y=%1.4f'%(x, y)
ax.format_coord = format_coord

plt.show()


for i, line in enumerate(ressino):
    line=array(line,"f")
    edf=EdfFile.EdfFile(PROPAPROJ_OUTPUT_FILE%i,"w+")
    edf.WriteImage({},line)


Datas = array(ressino)


# ## INIZIO 
# absorption = -log(ressino)/2.0
# phase = dbratio * absorption
# complex_signal = exp( -absorption - phase*(0+1.0j) )

# # complex_signal  = result

# # for i in range(12):
# #
#     # WAVE SYNTHESIS AT THE DETECTOR
#     aux_a[:,:sino.shape[1]]=complex_signal
#     bordervalue = (complex_signal[:,0]+complex_signal[:,-1])/2.0
#     aux_a[:,sino.shape[1]:]= bordervalue[:, newaxis]

#     ## BACKWARD
#     fft_object_a()
#     aux_b[:] = aux_b * pp_back
#     fft_object_b()
#     result = array( aux_a[:,:sino.shape[1]] * fact4norm   ) 


#     ## SIMULATED SIGNAL AT THE SAMPLE
#     ressino = abs(result)
#     ressino=ressino*ressino
#     absorption = -log(ressino)/2.0

#        # A PRIORI PHASE
#     phase = dbratio * absorption
#     complex_signal = exp( -absorption - phase*(0+1.0j) )


#     print( complex_signal[0,0])
#     print( complex_signal[0,-1])


#     # WAVE SYNTHESIS AT THE SAMPLE
#     aux_a[:,:sino.shape[1]]=complex_signal
#     bordervalue = (complex_signal[:,0]+complex_signal[:,-1])/2.0
#     aux_a[:,sino.shape[1]:]= bordervalue[:, newaxis]

#     ## FORWARD
#     fft_object_a()
#     aux_b[:] = aux_b * pp_forw
#     fft_object_b()
#     result = array( aux_a[:,:sino.shape[1]] * fact4norm   ) 

#     ## SIMULATED SIGNAL AT THE DETECTOR
#     simphase      = log( result  ).imag  # simulated phase
#     complex_signal = exp(  + simphase*(0+1.0j) ) * sqrt(Datas)

# fig3 = plt.figure()
# ax = fig3.add_subplot(111)
# ax.imshow(ressino, cmap=cm.jet, interpolation='nearest')
# numrows, numcols = ressino.shape
# def format_coord(x, y):
#     col = int(x+0.5)
#     row = int(y+0.5)
#     if col>=0 and col<numcols and row>=0 and row<numrows:
#         z = ressino[row,col]
#         return 'x=%1.4f, y=%1.4f, z=%1.4f'%(x, y, z)
#     else:
#         return 'x=%1.4f, y=%1.4f'%(x, y)
# ax.format_coord = format_coord

# plt.show()
