from __future__ import absolute_import
from __future__ import division
from __future__ import print_function



from pylab import *
from PyMca import EdfFile
import glob
import string

from . import string_six




import sys
from numpy import fft
import os
import pickle 
import numpy 
from  .congrid import congrid

try:
    import pycuda
    import pycuda.autoinit
    import pycuda.elementwise
    import pycuda.gpuarray as gpuarray
    from pycuda.compiler import SourceModule
    import scikits.cuda.fft as cu_fft

    usecuda=1


    modShifta = SourceModule("""
    #include <cuComplex.h>
      __global__ void shifta(cuDoubleComplex  *a, double K0aV, int Ny, int Nx)
      {
        int gidx = threadIdx.x + blockIdx.x*blockDim.x;
        int gidy = threadIdx.y + blockIdx.y*blockDim.y;
        int gid  = gidy*Nx+gidx;

        cuDoubleComplex c;
        if(gidx<Nx && gidy<Ny) {
            c = a[gid] ;
            a[gid] =   cuCmul(c, make_cuDoubleComplex(cos(K0aV*gidx), sin(K0aV*gid)));
        }
      }
      __global__ void propa(cuDoubleComplex  *a, double K0, double Vox, double p, int Ny, int Nx) {
          int gidx = threadIdx.x + blockIdx.x*blockDim.x;
          int gidy = threadIdx.y + blockIdx.y*blockDim.y;
          int gid  = gidy*Nx+gidx;
          cuDoubleComplex c;
          double kperp, kpar, phase;
          int Nx2 = Nx/2;

          if(gidx<Nx && gidy<Ny) {
              kperp = 2*M_PI/Vox/Nx * ( (gidx+Nx2 )%Nx -Nx2) ; 
              kpar   = sqrt( K0*K0 - kperp*kperp  ) ;
              phase =  kpar * p; 
              c = a[gid] ;
              // c = cuCmul(c, make_cuDoubleComplex(  exp( - kperp*kperp*Vox*Vox*1.4/8 )       , 0));
              a[gid] =   cuCmul(c, make_cuDoubleComplex(cos(phase), sin(phase)));
          }
       }
      __global__ void intensity(cuDoubleComplex  *a, double * inte, double w, int Ny, int Nx) {
          int gidx = threadIdx.x + blockIdx.x*blockDim.x;
          int gidy = threadIdx.y + blockIdx.y*blockDim.y;
          int gid  = gidy*Nx+gidx;
          cuDoubleComplex c;


          if(gidx<Nx && gidy<Ny) {
              c = a[gid] ;
              inte[gid] +=   w*( cuCreal(c)*cuCreal(c) + cuCimag(c)*cuCimag(c) ) ;
          }
       }
      __global__ void blur(cuDoubleComplex *a, double sigma, int Ny, int Nx) {
          int gidx = threadIdx.x + blockIdx.x*blockDim.x;
          int gidy = threadIdx.y + blockIdx.y*blockDim.y;
          int gid  = gidy*Nx+gidx;
          cuDoubleComplex c;
          double kperp;
          int Nx2 = Nx/2;

          if(gidx<Nx && gidy<Ny) {
              kperp = 2*M_PI/Nx * ( (gidx+Nx2 )%Nx -Nx2) ; 
              c = a[gid] ;
              c =  cuCmul(c, make_cuDoubleComplex(exp( - kperp*kperp*sigma*sigma/2.0 )  , 0)); 
              a[gid] =   c;
          }
       }
      """)



except:
    usecuda=0
    try:
        import pyfftw
        usepyfftw = 1
    except:
        usepyfftw=0
        import fftw3

##############################################################################################
# HOW MANY  GPUS?
##
try:
 args=["oarprint",  "host", "-P",  "host,gpu_num",  "-F",  "'% %'" ]
 p = sub.Popen(args=args ,stdout=sub.PIPE,stderr=sub.PIPE)
 resources, errors = p.communicate()
except:
    resources=""
print( " resources gpu ", resources)
if(len(resources)):
    ## example  
    # gpu0102 1,0
    # gpu0101 1
    resources_L = resources.split( "\n"  )
    gpus_string=" "
    for l in resources_L:
        if len(l)==0:
            continue
        if l[0]==l[-1] and l[0] in ["'", '"']:
            l=l[1:-1]
        l=l.strip()
        node, gpus  =  l.split( " ")
        gpus_s=gpus.split( ",")
        if(len(gpus_s)>1):
            print( " WARNING : you have been allocated ", len(gpus_s)," gpus ")
            print( " but you will use ony one cpu! ")
        gpus_string = gpus

        gpu=string.atoi(gpus_string)
        import pycuda.autoinit
        import pycuda.tools
        if gpu!=-1:
            dev=pycuda.driver.Device(gpu)
            dev.make_context()
        else:
            pycuda.tools.make_default_context() 

# else:
#     pass
#     # la linea gpu-string sara presa dalla linea di comando
#     gpus_string=""
#     if len(sys.argv)==3:
#         gpus_string=" "+sys.argv[2]
#     del  sys.argv[2]

print( "Ma carte est ", pycuda.driver.Context.get_device().name())

parinput = open(sys.argv[1],"r")
for line in parinput:
    if "NPROJ" in line:
        exec(line)
    if "PROJ_OUTPUT_FILE" in line:
        items = line.split()
        path = items[2]
        numberpos = path.find("%")
        postfixpos = path.rfind(".")
        path = path[:numberpos]+"*"+path[postfixpos:]

    if "PROPA_OUTPUT_FILE" in line:
        items = line.split()
        PROPAPROJ_OUTPUT_FILE = items[2]

    if "dbratio" in line or "plength" in line or "OVERSAMPLE" in line or "SIGMAPIXEL" in line or "IMAGE_PIXEL_SIZE_1" in line  or  "SSOURCE_RAD" in line  or  "SENE" in line or  "NENE" in  line  :
        exec(line)
    if "ekev" in line  :
        exec(line)
        

print( "\n* THE INPUT FILE SAYS YOU HAVE GENERATED ", NPROJ, " PROJECTION ")
print( "                 AND PROJECTIONS ARE SEARCHED AS PATTERN ", path)

fl=glob.glob(path)
fl.sort()

fl=fl[ :NPROJ]
print( "\n\n* FIRST PROJECTION ", fl[0])
print(     "  LAST             ", fl[-1])
print( "\n")
data = EdfFile.EdfFile(fl[0]).GetData(0)

sino  = zeros([NPROJ, data.shape[1]],"d")

voxelsize = IMAGE_PIXEL_SIZE_1*1.0
ekev = ekev*1.0
dbratio=dbratio*1.0

print( "\n* THE INPUT FILE SAYS YOU VOXEL IS  ", voxelsize, " MICRON ")
print(    "                      DELTA/BETA   ", dbratio)
print(    "                      PROPAGATION  ", plength , " METERS ")
print(    "                      ENERGY       ", ekev , " KeV ")
print( "\n")
print(    "* CONCERNING COHERENCE:")
print(    "         SSOURCE_RAD               ", SSOURCE_RAD  , " Radians ")
print(    "         SENE                      ", SENE     , " Kev ")
print(    "         NENE                      ", NENE     , " Points ")
print(    "         SIGMAPIXEL                ", SIGMAPIXEL)
print(    "         OVERSAMPLE                ", OVERSAMPLE)


##  OVERSAMPLE = 10

voxelsize_A = voxelsize *10000.0
plength_A   = plength   *1.0e10


print( " leggo .... " )
for i,nome in enumerate(fl):
    sino[i] =  EdfFile.EdfFile(nome).GetData(0)


print( " congrid " )
voxelsize_A = voxelsize_A/OVERSAMPLE
sinonew = zeros([sino.shape[0], data.shape[1]*OVERSAMPLE],"d")
for i in range(0,sino.shape[0],2):
    print( i)
    sinonew[i:i+2,:] = congrid(sino[i:i+2,:],[1,sino.shape[1]*OVERSAMPLE ] , minusone=False, method="linear"  )
print( " OK " )

sinonew[:,-1] = sino[:,-1]

sinoold=sino
NSTEP=sino.shape[0]/OVERSAMPLE

for NSTART in range(0,sinonew.shape[0], NSTEP):
    sino=sinonew[NSTART:min(NSTART+NSTEP,sinonew.shape[0])]


    sigma_pixel =  SSOURCE_RAD*plength/voxelsize_A*1.0e6
    sigma_pixel = numpy.sqrt(  sigma_pixel*sigma_pixel+SIGMAPIXEL*SIGMAPIXEL  )


    if usecuda:
        aux_a = numpy.zeros([ sino.shape[0], sino.shape[1]*2  ], dtype= 'D')
        aux_b = numpy.zeros([ sino.shape[0], sino.shape[1]*2  ], dtype= 'D')

        arrayin  = gpuarray.empty([ sino.shape[0], sino.shape[1]*2 ], numpy.complex128)

        print( sino.shape[0]*sino.shape[1]*2)
        print( sino.shape[0],sino.shape[1]*2)

        arrayout = gpuarray.empty([ sino.shape[0], sino.shape[1]*2 ], numpy.complex128)


        array_signal = gpuarray.empty([ sino.shape[0], sino.shape[1]*2 ], numpy.complex128)
        array_convoluted = gpuarray.zeros([ sino.shape[0], sino.shape[1]*2 ], numpy.float64)

        data=numpy.zeros([sino.shape[0], sino.shape[1]*2 ],numpy.complex128)
        fplan       =cu_fft.Plan([sino.shape[1]*2 ],numpy.complex128, numpy.complex128 , batch=sino.shape[0])
        # fplan_float =cu_fft.Plan([sino.shape[1]*2 ],numpy.float64, numpy.float64 , batch=sino.shape[0])
        fact4norm = 1.0/aux_a.shape[1]

    # if usecuda:
        aux_a[:]=0
        aux_a[:,:sino.shape[1]].real=  sino
        aux_a[:,:].imag=  0
        arrayin.set(aux_a)
        cu_fft.fft( arrayin , arrayout, fplan)

        Ny,Nx = aux_a.shape[0], aux_a.shape[1]
        modShifta.get_function("blur")( arrayout , numpy.float64(OVERSAMPLE/3.0) ,
                                        numpy.int32(Ny),numpy.int32(Nx) , block = (32, 8, 1), grid=( Nx/32+1,Ny/8+1)   )

        cu_fft.ifft(arrayout, arrayin , fplan)
        sino = arrayin.get().real[:,:sino.shape[1]]* fact4norm

    # else:    
    #     fft_object_a_float()
    #     convoluted_fft *= blurring
    #     fft_object_b_float()
    #     ressino=convoluted[:,:sino.shape[1]] * fact4norm  






    elif not usepyfftw:
        try:
            if os.path.exists("fftw3_wisdom"):
                fftw3.import_wisdom_from_file("fftw3_wisdom")
        except:
            pass

        aux_a = fftw3.create_aligned_array([ sino.shape[0], sino.shape[1]*2  ], dtype= 'D')
        aux_b = fftw3.create_aligned_array([ sino.shape[0], sino.shape[1]*2  ], dtype= 'D')


        convoluted     = fftw3.create_aligned_array(aux_a.shape,"d")
        convoluted_fft = fftw3.create_aligned_array(aux_a.shape,"d")
        convoluted[:]=0.0


        fft_object_a = fftw3.Plan(aux_a,aux_b, direction="forward",  flags=('measure', ) )  
        fft_object_b = fftw3.Plan(aux_b,aux_a, direction="backward", flags=('measure', ) )  

        fft_object_a_float = fftw3.Plan(convoluted     , convoluted_fft, direction="forward",  flags=('measure', ) )  
        fft_object_b_float = fftw3.Plan(convoluted_fft , convoluted    , direction="backward", flags=('measure', ) )  

        fftw3.export_wisdom_to_file("fftw3_wisdom")

        fact4norm = 1.0/ aux_a.shape[1]/aux_a.shape[0]

    else:
        try:
            if os.path.exists("pyfftw_wisdom"):
                wisdom = pickle.load(open("pyfftw_wisdom","r") )
                pyfftw.import_wisdom(wisdom)
        except:
            pass

        aux_a = pyfftw.n_byte_align_empty([ sino.shape[0], sino.shape[1]*2  ],16,'complex128')
        aux_b = pyfftw.n_byte_align_empty([ sino.shape[0], sino.shape[1]*2  ],16,'complex128')

        convoluted     = pyfftw.n_byte_align_empty(aux_a.shape,"d")
        convoluted_fft = pyfftw.n_byte_align_empty(aux_a.shape,"d")
        convoluted[:]=0.0

        fft_object_a = pyfftw.FFTW(aux_a,aux_b, direction="FFTW_FORWARD",  flags=('FFTW_MEASURE', ),axes=(-1,),threads=4 )    #  
        fft_object_b = pyfftw.FFTW(aux_b,aux_a, direction="FFTW_BACKWARD", flags=('FFTW_MEASURE', ),axes=(-1,),threads=4 )    #  

        fft_object_a_float = pyfftw.FFTW(convoluted     , convoluted_fft,  direction="FFTW_FORWARD",  flags=('FFTW_MEASURE', ),axes=(-1,),threads=4 )    #  
        fft_object_b_float = pyfftw.FFTW(convoluted_fft , convoluted    ,  direction="FFTW_BACKWARD", flags=('FFTW_MEASURE', ),axes=(-1,),threads=4 )    #  


        wisdom =  pyfftw.export_wisdom()
        pickle.dump(wisdom, open("pyfftw_wisdom","w") )

        fact4norm = 1.0


    absorption = sino/2.0
    phase = dbratio * absorption
    complex_signal = exp( -absorption - phase*(0+1.0j) )



    if NENE>1:
        enes = numpy.linspace(-1*SENE, 1*SENE, num=NENE, endpoint=True, retstep=False)
        enew = numpy.exp( -enes*enes/SENE/SENE/2.0  )
        enew = enew/ enew.sum()
    else:
        enes=[0.0]
        enew=[1.0]


    print( aux_a.shape)

    freqs  = fft.fftfreq(aux_a.shape[1], voxelsize_A)*2*pi
    freqs_pixel  = fft.fftfreq(aux_a.shape[1], voxelsize_A)*2*pi
    blurring = exp( -freqs_pixel*freqs_pixel*sigma_pixel*sigma_pixel/2.0 )


    if usecuda:   
        aux_a[:,:sino.shape[1]]=complex_signal
        aux_a[:,sino.shape[1]:]= 1.0+0.0j
        array_signal.set(aux_a)


    kperp = fft.fftfreq(aux_a.shape[1], voxelsize_A)*2*pi

    for ed, ew in zip( enes, enew):

        K0 = 2*pi*(ekev+ed)/12.39842              



        if 1:
            print(  "   ed, ew  " , ed, ew)
            if usecuda:
                Ny,Nx = aux_a.shape[0], aux_a.shape[1]
                pycuda.driver.memcpy_dtod( arrayin.gpudata,  array_signal.gpudata, array_signal.nbytes)
                cu_fft.fft(arrayin, arrayout, fplan)
                modShifta.get_function("propa")( arrayout, numpy.float64(K0) , numpy.float64(voxelsize_A)
                                                 , numpy.float64(plength_A), 
                                                 numpy.int32(Ny),numpy.int32(Nx) , block = (32, 8, 1), grid=( Nx/32+1,Ny/8+1)   )


            else:
                aux_a[:,:sino.shape[1]]=complex_signal
                aux_a[:,sino.shape[1]:]= 1.0+0.0j
                fft_object_a()

                pp = sqrt(K0*K0 - kperp*kperp)
                pp_forw = exp( 1.0j  * pp  * plength_A)
                pp_back = exp( -1.0j  * pp  * plength_A)
                aux_b[:] = aux_b * pp_forw

            if usecuda:
                cu_fft.ifft(arrayout, arrayin, fplan)
                modShifta.get_function("intensity")( arrayin, array_convoluted, numpy.float64(ew* fact4norm*fact4norm), 
                                                     numpy.int32(Ny),numpy.int32(Nx) , block = (32, 8, 1), grid=( Nx/32+1,Ny/8+1)   )
            else:
                fft_object_b()
                result = array( aux_a[:,:] * fact4norm   ) 
                ressino = abs(result)
                ressino=ressino*ressino
                convoluted+=ew* ressino

    if usecuda:


        aux_a[:,:]= array_convoluted.get()
        arrayin.set(aux_a)

        cu_fft.fft( arrayin , arrayout, fplan)

        modShifta.get_function("blur")( arrayout , numpy.float64(sigma_pixel) ,
                                        numpy.int32(Ny),numpy.int32(Nx) , block = (32, 8, 1), grid=( Nx/32+1,Ny/8+1)   )

        cu_fft.ifft(arrayout, arrayin , fplan)

        ressino = arrayin.get().real[:,:sino.shape[1]]* fact4norm
        ressino.shape = [ressino.shape[0], ressino.shape[1]/OVERSAMPLE,  OVERSAMPLE]
        ressino=ressino.sum(axis=-1)/OVERSAMPLE
    else:    

        fft_object_a_float()
        convoluted_fft *= blurring
        fft_object_b_float()
        ressino=convoluted[:,:sino.shape[1]] * fact4norm  
        
    if( NSTART==0) :
        fig = plt.figure()

        ax = fig.add_subplot(111)
        ax.imshow(sino, cmap=cm.jet, interpolation='nearest')
        numrows, numcols = sino.shape
        def format_coord(x, y):
            col = int(x+0.5)
            row = int(y+0.5)
            if col>=0 and col<numcols and row>=0 and row<numrows:
                z = sino[row,col]
                return 'x=%1.4f, y=%1.4f, z=%1.4f'%(x, y, z)
            else:
                return 'x=%1.4f, y=%1.4f'%(x, y)
        ax.format_coord = format_coord

        fig2 = plt.figure()

        ax = fig2.add_subplot(111)
        ax.imshow(ressino, cmap=cm.jet, interpolation='nearest')
        numrows, numcols = ressino.shape
        def format_coord(x, y):
            col = int(x+0.5)
            row = int(y+0.5)
            if col>=0 and col<numcols and row>=0 and row<numrows:
                z = ressino[row,col]
                return 'x=%1.4f, y=%1.4f, z=%1.4f'%(x, y, z)
            else:
                return 'x=%1.4f, y=%1.4f'%(x, y)
        ax.format_coord = format_coord

        plt.show()

    print( "PROJECTIONS FROM " , NSTART, "  TO " , NSTART+ressino.shape[0])
    for i, line in enumerate(ressino):
        line=array(line,"f")
        edf=EdfFile.EdfFile(PROPAPROJ_OUTPUT_FILE%(i+NSTART),"w+")
        edf.WriteImage({},line)


Datas = array(ressino)


# ## INIZIO 
# absorption = -log(ressino)/2.0
# phase = dbratio * absorption
# complex_signal = exp( -absorption - phase*(0+1.0j) )

# # complex_signal  = result

# # for i in range(12):
# #
#     # WAVE SYNTHESIS AT THE DETECTOR
#     aux_a[:,:sino.shape[1]]=complex_signal
#     bordervalue = (complex_signal[:,0]+complex_signal[:,-1])/2.0
#     aux_a[:,sino.shape[1]:]= bordervalue[:, newaxis]

#     ## BACKWARD
#     fft_object_a()
#     aux_b[:] = aux_b * pp_back
#     fft_object_b()
#     result = array( aux_a[:,:sino.shape[1]] * fact4norm   ) 


#     ## SIMULATED SIGNAL AT THE SAMPLE
#     ressino = abs(result)
#     ressino=ressino*ressino
#     absorption = -log(ressino)/2.0

#        # A PRIORI PHASE
#     phase = dbratio * absorption
#     complex_signal = exp( -absorption - phase*(0+1.0j) )


#     print( complex_signal[0,0])
#     print( complex_signal[0,-1])


#     # WAVE SYNTHESIS AT THE SAMPLE
#     aux_a[:,:sino.shape[1]]=complex_signal
#     bordervalue = (complex_signal[:,0]+complex_signal[:,-1])/2.0
#     aux_a[:,sino.shape[1]:]= bordervalue[:, newaxis]

#     ## FORWARD
#     fft_object_a()
#     aux_b[:] = aux_b * pp_forw
#     fft_object_b()
#     result = array( aux_a[:,:sino.shape[1]] * fact4norm   ) 

#     ## SIMULATED SIGNAL AT THE DETECTOR
#     simphase      = log( result  ).imag  # simulated phase
#     complex_signal = exp(  + simphase*(0+1.0j) ) * sqrt(Datas)

# fig3 = plt.figure()
# ax = fig3.add_subplot(111)
# ax.imshow(ressino, cmap=cm.jet, interpolation='nearest')
# numrows, numcols = ressino.shape
# def format_coord(x, y):
#     col = int(x+0.5)
#     row = int(y+0.5)
#     if col>=0 and col<numcols and row>=0 and row<numrows:
#         z = ressino[row,col]
#         return 'x=%1.4f, y=%1.4f, z=%1.4f'%(x, y, z)
#     else:
#         return 'x=%1.4f, y=%1.4f'%(x, y)
# ax.format_coord = format_coord

# plt.show()
