
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function



import numpy as np
import math
import gc


def Get_Sparse_LT(P, finegrid=False):
    num_bins = P.NUM_IMAGE_1

    LT_REBINNING = P.LT_REBINNING
    DIA_square =  P.END_VOXEL_1 - P.START_VOXEL_1 -2* P.LT_MARGIN  +1
    
    if finegrid:
        coarse =  P.LT_FINE
    else:
        coarse = P.LT_COARSE

    NC = int(DIA_square/2/coarse)+1+6 

    R_old = -NC*coarse

    # Center_a = (DIA_a-1.0)/2

    all_CXYS = []

    NFIRSTBLOCK=None
    ilevel=0
    Sigmas = []

    while (not finegrid and  (2*NC*coarse)<2*P.LT_DIAMETER ) or ( finegrid and   coarse ==   P.LT_FINE ):

        Sigmas.append( coarse/0.65)

        Center_a = 0

        N   =   (2*NC+1)
        NN   =   N * N
        CXYS  = np.array([ [ (-NC +(i%N)  )*coarse     , (-NC  +(i // N)  )*coarse, ilevel  ]   for i in range(NN) ])

        print( R_old)
        print(  np.abs(CXYS[:,0]).shape)
        print(  np.abs(CXYS[:,1]).shape)


        print( " CXYS.shape ", CXYS.shape )# debug
        
        mask = ( np.less(    R_old, np.abs(CXYS[:,0])   ) *  np.less(    R_old, np.abs(CXYS[:,1])   )    )

        
        print( " mask.shape  ", mask.shape )

        CXYS = CXYS [   mask     ]

        print( " CXYS  ", CXYS )
        print( len(CXYS))
        
        R_old = NC*coarse

        if coarse*2<= P.LT_MAX_STEP:
            coarse = coarse*2
            ilevel=ilevel+1
        else :
            NC=NC*2

        print( "NFIRSTBLOCK ", NFIRSTBLOCK )
        if NFIRSTBLOCK is None:
            NFIRSTBLOCK = len(CXYS)
            print( "NFIRSTBLOCK ======", NFIRSTBLOCK )
        all_CXYS.append(CXYS)

    print( all_CXYS)
    #print((Sigmas))
        
    CXYS = np.concatenate( all_CXYS, axis = 0) 

    """
    P.axis_corrections
    P.ROTATION_AXIS_POSITION 
    P.numpjs_span
    """

    LT_MARGE = 3*P.LT_MAX_STEP

    try: # FIXME: rustine, en attendant mieux
                print((P.numpjs_span))
    except:
        P.numpjs_span = P.numpjs


    # Re-binning
    num_bins_rebin = num_bins // LT_REBINNING
    LT_MARGE_REBIN = LT_MARGE // LT_REBINNING # ?
    numpjs_span_rebin = P.numpjs_span // LT_REBINNING
    LT_MAX_STEP_REBIN = P.LT_MAX_STEP // LT_REBINNING

    Isparse=[]
    Jsparse=[]
    Csparse=[]
    for iproj in range(  P.numpjs_span//LT_REBINNING   ) :
        print( iproj)
        if isinstance( P.angles, np.ndarray):
            angolo = P.angles[iproj]
        else:
            angolo = (iproj *LT_REBINNING)* P.ANGLE_BETWEEN_PROJECTIONS
            
        X = CXYS[:,0] * math.cos( angolo)
        X = X - CXYS[:,1]*math.sin(angolo)
        X = X + P.ROTATION_AXIS_POSITION  
        X = X + P.axis_corrections[iproj*LT_REBINNING]

        X /= LT_REBINNING


        # LT_DEBUG
        if iproj == 0:
            print(("min = %f ; max = %f" % (CXYS[:,0].min(), CXYS[:,0].max())))
        #

        mask = np.less( -LT_MARGE_REBIN-1 , np.floor(X)   )*  np.less(  np.ceil(X)  , num_bins_rebin + LT_MARGE_REBIN  )
        CXYSm = CXYS[mask]
        Xm = X[mask]
        Jfrom = np.arange(len(X))[mask]

        for x,c,J in zip( Xm,CXYSm, Jfrom): # CXYSm = CXYS[mask] where mask is binned and CXYS is not...
            ifloor =  int(math.floor( x ))+ LT_MARGE_REBIN
            for floor_ceil in [0,1]:
                ibin = ifloor + floor_ceil
                iother = ifloor + (1-floor_ceil)
                Jsparse.append(J)
                Isparse.append(    (iproj+c[2]*numpjs_span_rebin) * (num_bins_rebin +6*LT_MAX_STEP_REBIN) + ibin    )
                Csparse.append(   (iother-x)*(1-2*floor_ceil)    )

    Isparse_2slice=[]
    Jsparse_2slice=[]
    Csparse_2slice=[]
    square_indexes = np.arange( DIA_square*DIA_square   )
    square_indexes.shape = [DIA_square,DIA_square] 

    Y,X = np.mgrid[0:DIA_square, 0:DIA_square]
    
    # Y=Y-(DIA_square-1.0)/2.0
    # X=X-(DIA_square-1.0)/2.0
    # R2 = Y*Y+X*X
    S = Sigmas[0]
    # GG = np.exp( -R2/(2.0*S*S) ) /( S*S * 2*np.pi  )

    
    # normalizzata a 1 norma L1 sulla slice
    # Quando si proietta pensare al fattore pi/(2*num_projs) che c 'e' anche nel kernel cuda
    #
    #  1/( sigma  sqrt( 2 pi ) ) exp( - x2/(2 sigma**2)   ) e' L1=1
    #  Quando si fa x*y il prefattore va al quadrato
    #

    # normalizzata a 1 norma L2 sulla slice
    # Quando si proietta pensare al fattore pi/(2*num_projs) che c 'e' anche nel kernel cuda
    #
    #  1/( sigma  sqrt( 2 pi ) ) exp( - x2/(2 sigma**2)   ) e' L1=1
    #  sqrt(1/( sigma   sqrt(pi)    )) exp( - x2/(2 sigma**2)   )  e' L2 = 1
    #  Quando si fa x*y il prefattore va al quadrato
    #
    print( "NFIRSTBLOCK " , NFIRSTBLOCK  )
    #storeall=[]
    for J,c in enumerate( CXYS[:NFIRSTBLOCK]):
        
        print( c  )
        Xm = c[0]+ (DIA_square-1.0)/2.0
        Ym = c[1]+ (DIA_square-1.0)/2.0

        ## limiti assoluti
        Xstart = max(0            , int(round( Xm-3*S)))
        Xend   = min(DIA_square   , int(round( Xm+3*S))+1 )
        Ystart = max(0            , int(round( Ym-3*S)))
        Yend   = min(DIA_square   , int(round( Ym+3*S))+1 )
        
        print( "  Xstart, Xend,  Ystart, Yend   ", Xstart, Xend,  Ystart, Yend)
        if Xstart>=Xend or Ystart>=Yend:
            continue
        
        
        
        
        indexes = np.array(( square_indexes[ Ystart:Yend, Xstart:Xend   ]    ).flatten())


        print((indexes.size))
        print(("J = %d" % J))


        
        YY  = Y[Ystart:Yend, Xstart:Xend]-Ym
        XX  = X[Ystart:Yend, Xstart:Xend]-Xm
        R2 = YY*YY+XX*XX
        

        GG = np.exp( -R2/(2.0*S*S) ) /( S*S * 2*np.pi  )
        coeffs = GG
        
        # limiti relativi al centro della proiezione
        # x1 = Xstart-c[0]
        # x2 = Xend-c[0]
        # y1 = Ystart- c[1]
        # y2 = Yend  - c[1]
        # coeffs  =   GG[y1:y2,x1:x2] 


        # print( " x1,x2,y1,y2  ", x1,x2,y1,y2  , GG.shape, square_indexes.shape)
        
        #assert(len(indexes)==len(coeffs))
        
        #storeall.append([  np.array(indexes.flatten()),  np.array(coeffs.flatten()), np.array([J]*len(indexes.flatten()))  ])
        #Isparse_2slice.extend( indexes.flatten()  )
        #Csparse_2slice.extend( coeffs.flatten()  )
        #Jsparse_2slice.extend( [J]*len(indexes.flatten())   )
        # if (J % 100 == 0): gc.collect()

        Isparse_2slice.append( indexes.flatten() ) 
        Csparse_2slice.append( coeffs.flatten()  )
        Jsparse_2slice.append( np.array([J]*len(indexes.flatten())) )


    Isparse_2slice = np.concatenate(Isparse_2slice)
    Csparse_2slice = np.concatenate(Csparse_2slice)
    Jsparse_2slice = np.concatenate(Jsparse_2slice)
    

    SINO_leading_dimension =  numpjs_span_rebin * (num_bins_rebin +6*LT_MAX_STEP_REBIN)

    res= {
    "Csparse" :np.array(Csparse,"f"), 
    "Isparse" :np.array(Isparse,"i"),
    "Jsparse" :np.array(Jsparse,"i"),
    "Csparse_2slice" :np.array(Csparse_2slice,"f"),
    "Isparse_2slice" :np.array(Isparse_2slice,"i"),
    "Jsparse_2slice" :np.array(Jsparse_2slice,"i"), 
    "Sigmas" :np.array(Sigmas,"f"), 
    "SINO_leading_dimension" :SINO_leading_dimension
    }



    print( " SHAPES 2slice " )
    print(  np.array(Csparse_2slice,"f").shape)
    print(  np.array(Isparse_2slice,"f").shape)
    print(  np.array(Jsparse_2slice,"f").shape)
    print(  np.array(Csparse,"f").shape)
    print(  np.array(Isparse,"f").shape)
    print(  np.array(Jsparse,"f").shape)

    
    res = type('MyObject_for_LTSparse',(object,),res)()

    return  res
