#!/usr/bin/env python
# Author: Samuel Ponc\'e
# Date: 24/04/2013
# Script to compute the ZPR
import sys
import os
from systeme import system
try:
  import numpy as N
except ImportError:
  import warnings
  warnings.warn("The numpy module is missing!")
  raise
try:
  import netCDF4 as nc
except ImportError:
  import warnings
  warnings.warn("The netCDF4 module is missing!")
  raise


#############
# Constants #
#############
# If you want to hardcode the weight of the k-points you can do it here:
# wtq = [0.00, 0.125, 0.5,0.375]

tol6 = 1E-6
tol8 = 1E-8
Ha2eV = 27.21138386
kb_HaK = 3.1668154267112283e-06

# Interaction with the user
print '\n############################'
print '# Temperature Corrections #'
print '###########################'
print '\nThis script compute the zero-point motion and the temperature dependance \n\
of eigenenergies due to electron-phonon interaction. This script can \n\
only compute Q-points with the same weight for the moment.\n\
WARNING: The first Q-point MUST be the Gamma point\n'

# Type of calculation the user want to perform
user_input = raw_input('Define the type of calculation you want to perform. Type:\n\
                      1 if you want to run a static AHC calculation\n \
                      2 if you want to run a dynamic AHC calculation\n \
                      3 if you want to run a dynamic AHC calculation with correction terms\n')
type = N.int(user_input)

# Define the output file name
user_input = raw_input('Enter name of the output file\n')
output = user_input

# Get the path of the DDB files from user
user_input = raw_input('Enter value of the smearing parameter (in eV)\n')
smearing = N.float(user_input)
smearing = smearing/Ha2eV

# Temperature dependence analysis?
user_input = raw_input('Do you want to compute the change of eigenergies with temperature? [y/n]\n')
temperature =user_input.split()[0]
if temperature == 'y':
  user_input = raw_input('Introduce the max temperature and the temperature steps. e.g. 2000 50\n')
  temp_info = user_input.split()

# Get the nb of random Q-points from user 
user_input = raw_input('Enter the number of random Q-points you have\n')
try:
  nbQ = int(user_input)
except ValueError:
  raise Exception('The value you enter is not an integer!')

# Get the path of the DDB files from user
user_input = raw_input('Enter the name of the %s DDB files separated by a space\n' %nbQ)
if len(user_input.split()) != nbQ:
  raise Exception("You sould provide %s DDB files" %nbQ)
else:
  DDB_files = user_input.split()

# Test if the first file is at the Gamma point
DDBtmp = system(directory='.',filename=DDB_files[0])
if N.allclose(DDBtmp.iqpt,[0.0,0.0,0.0]) == False:
  raise Exception('The first Q-point is not Gamma!')

# Take the EIG at Gamma
user_input = raw_input('Enter the name of the unperturbed EIG.nc file at Gamma\n')
if len(user_input.split()) != 1:
  raise Exception("You sould only provide 1 file")
else:
  eig0 = system(directory='.',filename=user_input)

# Find the degenerate eigenstates
DDB = system(directory='.',filename=DDB_files[0])
degen =  N.zeros((DDB.nkpt,DDB.nband),dtype=int)
for ikpt in N.arange(DDB.nkpt):
  count = 0
  for iband in N.arange(DDB.nband):
    if iband != DDB.nband-1:
      if N.allclose(eig0.EIG[0,ikpt,iband+1], eig0.EIG[0,ikpt,iband]):
        degen[ikpt,iband] = count
      else:
        degen[ikpt,iband] = count
        count += 1
        continue
    else:
      if N.allclose(eig0.EIG[0,ikpt,iband-1], eig0.EIG[0,ikpt,iband]):
        degen[ikpt,iband] = count   
    if iband != 0:
      if N.allclose(eig0.EIG[0,ikpt,iband-1], eig0.EIG[0,ikpt,iband]):
        degen[ikpt,iband] = count
    else:
      if N.allclose(eig0.EIG[0,ikpt,iband+1], eig0.EIG[0,ikpt,iband]):
        degen[ikpt,iband] = count

# Create the random Q-integration (wtq=1/nqpt):
wtq = N.ones((nbQ+1))
wtq[0]=0
wtq = wtq*(1.0/nbQ)
  
# Initialize the total correction
total_corr =  N.zeros((DDBtmp.nkpt,DDBtmp.nband))
if temperature == 'y':
  total_corrT =  N.zeros((N.float(temp_info[0])/N.float(temp_info[1]),DDBtmp.nkpt,DDBtmp.nband))

# Compute phonon freq. and eigenvector for each Q-point 
# from each DDB (1 qpt per DDB file)
iiqpt = 0
vkpt = 0
vband = 0
tkpt = N.zeros((DDBtmp.nkpt,3))
DDB = system()
FANterm = system()
EIGR2D = system()
eigq = system()
for ii in DDB_files:
  DDB.__init__(directory='.',filename=ii)

# Calcul of gprimd from rprimd
  rprimd = DDB.rprim*DDB.acell
  gprimd = N.linalg.inv(N.matrix(rprimd))

# Transform from 2nd-order matrix (non-cartesian coordinates, 
# masses not included, asr not included ) from DDB to
# dynamical matrix, in cartesian coordinates, asr not imposed.
  IFC_cart = N.zeros((3,DDB.natom,3,DDB.natom),dtype=complex)
  for ii in N.arange(DDB.natom):
    for jj in N.arange(DDB.natom):
      for dir1 in N.arange(3):
        for dir2 in N.arange(3):
          for dir3 in N.arange(3):
            for dir4 in N.arange(3):
              IFC_cart[dir1,ii,dir2,jj] += gprimd[dir1,dir3]*DDB.IFC[dir3,ii,dir4,jj] \
            *gprimd[dir2,dir4]

# Reduce the 4 dimensional IFC_cart matrice to 2 dimensional Dynamical matrice.
  ipert1 = 0
  Dyn_mat = N.zeros((3*DDB.natom,3*DDB.natom),dtype=complex)
  while ipert1 < 3*DDB.natom:
    for ii in N.arange(DDB.natom):
      for dir1 in N.arange(3):
        ipert2 = 0
        while ipert2 < 3*DDB.natom:
          for jj in N.arange(DDB.natom):
            for dir2 in N.arange(3):
              Dyn_mat[ipert1,ipert2] = IFC_cart[dir1,ii,dir2,jj]
              ipert2 += 1
        ipert1 += 1

# Hermitianize the dynamical matrix
  dynmat = N.matrix(Dyn_mat)
  dynmat = 0.5*(dynmat + dynmat.transpose().conjugate())

# Solve the eigenvalue problem with linear algebra (Diagonalize the matrix)
  [eigval,eigvect]=N.linalg.eigh(Dyn_mat)

# Orthonormality relation 
  eigvect = (eigvect)*N.sqrt(5.4857990965007152E-4/float(DDB.amu[0]))
#END
# Phonon frequency (5.4857990946E-4 = 1 au of electron mass)
  omega = N.sqrt((eigval*5.4857990965007152E-4)/float(DDB.amu[0]))
  print 'The phonon frequency of the %s Q-point are:' %DDB.iqpt
  k = 0
  for ii in omega[:].real:
    print ' %e Ha' %ii
    print 'with eigenvector:'
    for kk in eigvect[:,k]:
      print '   %e %e i' % (kk.real, kk.imag)
    k += 1

# Now read the EIGq, EIGR2D and FAN
  user_input = raw_input('Enter the name of the the _EIG.nc that contain\n\
 the %s Q-points. \n' %DDB.iqpt)
  if len(user_input.split()) != 1:
    raise Exception("You sould provide only 1 ***_EIG.nc file" )
  else:
    eigq.__init__(directory='.',filename=user_input.split()[0])

  user_input = raw_input('Enter the name of the the EIGR2D that contain\n\
 the %s Q-points. \n' %DDB.iqpt)
  if len(user_input.split()) != 1:
    raise Exception("You sould provide only 1 ***_EIGR2D file" )
  else:
    EIGR2D.__init__(directory='.',filename=user_input.split()[0])

  user_input = raw_input('Enter the name of the the _FAN that contain\n\
 the %s Q-points. \n' %DDB.iqpt)
  if len(user_input.split()) != 1:
    raise Exception("You sould provide only 1 ***_FAN file" )
  else:
    FANterm.__init__(directory='.',filename=user_input.split()[0])

# Compute the displacement = eigenvectors of the DDB. 
# Due to metric problem in reduce coordinate we have to work in cartesian
# but then go back to reduce because our EIGR2D matrix elements are in reduced coord.
  displ_FAN =  N.zeros((3,3),dtype=complex)
  displ_DDW =  N.zeros((3,3),dtype=complex)
  if N.allclose(EIGR2D.iqpt,[0.0,0.0,0.0]):
    ddw_save = N.zeros((EIGR2D.nkpt,EIGR2D.nband,3,EIGR2D.natom,3,EIGR2D.natom),dtype=complex)
  eigen_corr =  N.zeros((EIGR2D.nkpt,EIGR2D.nband),dtype=complex)
  fan_corr =  N.zeros((EIGR2D.nkpt,EIGR2D.nband),dtype=complex)
  fan_add = N.zeros((EIGR2D.nkpt,EIGR2D.nband),dtype=complex)
  ddw_corr = N.zeros((EIGR2D.nkpt,EIGR2D.nband),dtype=complex)
  if temperature == 'y':
    eigen_corrT =  N.zeros((N.float(temp_info[0])/N.float(temp_info[1]),EIGR2D.nkpt,EIGR2D.nband),dtype=complex)
    fan_corrT =  N.zeros((N.float(temp_info[0])/N.float(temp_info[1]),EIGR2D.nkpt,EIGR2D.nband),dtype=complex)
    fan_addT = N.zeros((N.float(temp_info[0])/N.float(temp_info[1]),EIGR2D.nkpt,EIGR2D.nband),dtype=complex)
    ddw_corrT = N.zeros((N.float(temp_info[0])/N.float(temp_info[1]),EIGR2D.nkpt,EIGR2D.nband),dtype=complex)

  for imode in N.arange(3*EIGR2D.natom): #Loop on perturbation (6 for 2 atoms)
    if omega[imode].real > tol6:
      fan_corrQ =  N.zeros((EIGR2D.nkpt,EIGR2D.nband),dtype=complex)
      ddw_corrQ = N.zeros((EIGR2D.nkpt,EIGR2D.nband),dtype=complex)
      for ikpt in N.arange(EIGR2D.nkpt):
        for iband in N.arange(EIGR2D.nband):
          for iatom1 in N.arange(EIGR2D.natom):
            for iatom2 in N.arange(EIGR2D.natom):
              for idir1 in N.arange(0,3):
                for idir2 in N.arange(0,3):
                  displ_FAN[idir1,idir2] = eigvect[3*iatom2+idir2,imode].conj()\
                      *eigvect[3*iatom1+idir1,imode]/(2.0*omega[imode].real)
                  displ_DDW[idir1,idir2] = (eigvect[3*iatom2+idir2,imode].conj()\
                     *eigvect[3*iatom2+idir1,imode]+eigvect[3*iatom1+idir2,imode].conj()\
                     *eigvect[3*iatom1+idir1,imode])/(4.0*omega[imode].real)
              # Now switch to reduced coordinates in 2 steps (more efficient)
              tmp_displ_FAN = N.zeros((3,3),dtype=complex)
              tmp_displ_DDW = N.zeros((3,3),dtype=complex)
              for idir1 in N.arange(3):
                for idir2 in N.arange(3):
                  tmp_displ_FAN[:,idir1] = tmp_displ_FAN[:,idir1]+displ_FAN[:,idir2]*gprimd[idir2,idir1]
                  tmp_displ_DDW[:,idir1] = tmp_displ_DDW[:,idir1]+displ_DDW[:,idir2]*gprimd[idir2,idir1]
              displ_red_FAN = N.zeros((3,3),dtype=complex)
              displ_red_DDW = N.zeros((3,3),dtype=complex)
              for idir1 in N.arange(3):
                for idir2 in N.arange(3):
                  displ_red_FAN[idir1,:] = displ_red_FAN[idir1,:] + tmp_displ_FAN[idir2,:]*gprimd[idir2,idir1]
                  displ_red_DDW[idir1,:] = displ_red_DDW[idir1,:] + tmp_displ_DDW[idir2,:]*gprimd[idir2,idir1]
              # Now compute the T=0 shift due to this q point
              for idir1 in N.arange(3):
                for idir2 in N.arange(3):
                  #print 'displ_red_FAN[idir1,idir2]',displ_red_FAN[idir1,idir2]
                  fan_corrQ[ikpt,iband] += EIGR2D.EIG2D[ikpt,iband,idir1,iatom1,idir2,iatom2]*\
                      displ_red_FAN[idir1,idir2]
                  # DDW matrix only computed at Gamma  
                  if N.allclose(EIGR2D.iqpt,[0.0,0.0,0.0]):
                    ddw_save[ikpt,iband,idir1,iatom1,idir2,iatom2] = EIGR2D.EIG2D[ikpt,iband,idir1,iatom1,idir2,iatom2]
                    ddw_corrQ[ikpt,iband] += ddw_save[ikpt,iband,idir1,iatom1,idir2,iatom2]*\
                        displ_red_DDW[idir1,idir2]
                  else:
                    ddw_corrQ[ikpt,iband] += ddw_save[ikpt,iband,idir1,iatom1,idir2,iatom2]*\
                        displ_red_DDW[idir1,idir2]
          if(type == 2):
            if temperature == 'y':
              for jband in N.arange(EIGR2D.nband):
                delta_E = eigq.EIG[0,ikpt,jband]-eig0.EIG[0,ikpt,iband] + smearing*1j
                tt = 0
                for T in N.arange(0,N.float(temp_info[0]),N.float(temp_info[1])):
                  if T < tol6:
                    bose = 0
                  else:
                    bose = 1.0/(N.exp(omega[imode].real/(kb_HaK*T))-1)

                  fan_addT[tt,ikpt,iband] += FANterm.FAN[ikpt,iband,imode,jband]*(\
                                         (bose+0.5)*(2*delta_E/(delta_E**2-(omega[imode].real)**2)) \
                                         - (1-EIGR2D.occ[iband])*(omega[imode].real/(delta_E**2-(omega[imode].real)**2))\
                                         -(bose+0.5)*2/delta_E)/(2.0*omega[imode].real)
                  tt += 1
            else:
              for jband in N.arange(EIGR2D.nband):
                delta_E = eigq.EIG[0,ikpt,jband]-eig0.EIG[0,ikpt,iband] + smearing*1j

                fan_add[ikpt,iband] += FANterm.FAN[ikpt,iband,imode,jband]*(\
                                  (0+0.5)*(2*delta_E/(delta_E**2-(omega[imode].real)**2)) \
                                - (1-EIGR2D.occ[iband])*(omega[imode].real/(delta_E**2-(omega[imode].real)**2))\
                                 -(0+0.5)*2/delta_E)/(2.0*omega[imode].real)
 

      if temperature == 'y':
        tt = 0 
        for T in N.arange(0,N.float(temp_info[0]),N.float(temp_info[1])):
          if T < tol6:
            bose = 0
          else:
            bose = 1.0/(N.exp(omega[imode].real/(kb_HaK*T))-1)
          fan_corrT[tt,:,:] += fan_corrQ[:,:]*(2*bose+1.0)
          ddw_corrT[tt,:,:] += ddw_corrQ[:,:]*(2*bose+1.0)
          tt += 1
      else:
        fan_corr[:,:] += fan_corrQ[:,:] 
        ddw_corr[:,:] += ddw_corrQ[:,:] 
  


 
  if temperature == 'y':
    if type == 1:
      eigen_corrT[:,:,:] = (fan_corrT[:,:,:]- ddw_corrT[:,:,:])*wtq[iiqpt]
    if type == 2:
      eigen_corrT[:,:,:] = (fan_corrT[:,:,:]+ fan_addT[:,:,:] - ddw_corrT[:,:,:])*wtq[iiqpt]
    if iiqpt != 0:
      total_corrT[:,:,:] += eigen_corrT[:,:,:].real
  else:
    if type == 1:
      eigen_corr[:,:] = (fan_corr[:,:]- ddw_corr[:,:])*wtq[iiqpt]
    if type ==2:
      eigen_corr[:,:] = (fan_corr[:,:]+ fan_add[:,:] - ddw_corr[:,:])*wtq[iiqpt]
    if iiqpt != 0:
      total_corr[:,:] += eigen_corr[:,:].real

  iiqpt +=1
  print 'Computation running: %s Q-points calculated' %iiqpt

  if temperature == 'n':
    print 'Fan term (eV) for the Q-point: %s' %DDB.iqpt
    for ikpt in N.arange(EIGR2D.nkpt):
      print 'Kpt', EIGR2D.kpt[ikpt,:]
      j = 1
      l = 1
      string = ' '
      for ii in fan_corr[ikpt,:].real*Ha2eV:
        if j == 8:
          print string[:-1]
          string = ' '
        elif l == DDB.nband:
          print string[:-1]
        else:
          string += str(' %e' %ii,)
        j += 1
        l += 1

    if (type ==2 or type ==3): 
      print 'Fan ADD term (eV) for the Q-point: %s' %DDB.iqpt
      for ikpt in N.arange(EIGR2D.nkpt):
        print 'Kpt', EIGR2D.kpt[ikpt,:]
        j = 1
        l = 1
        string = ' '
        for ii in fan_add[ikpt,:].real*Ha2eV:
          if j == 8:
            print string[:-1]
            string = ' '
          elif l == DDB.nband:
            print string[:-1]
          else:
            string += str(' %e' %ii,)
          j += 1
          l += 1

    print '--------------------------------------------'
    print 'DDW term (eV) for the Q-point: %s' %DDB.iqpt
    for ikpt in N.arange(EIGR2D.nkpt):
      print 'Kpt', EIGR2D.kpt[ikpt,:]
      j = 1
      l = 1
      string = ' '
      for ii in ddw_corr[ikpt,:].real*Ha2eV:
        if j == 8:
          print string[:-1]
          string = ' '
        elif l == DDB.nband:
          print string[:-1]
        else:
          string += str(' %e' %ii,)
        j += 1
        l += 1

    print '--------------------------------------------'
    print 'Fan+DDW term (eV) for the Q-point: %s' %DDB.iqpt
    for ikpt in N.arange(EIGR2D.nkpt):
      print 'Kpt', EIGR2D.kpt[ikpt,:]
      j = 1
      l = 1
      string = ' '
      for ii in eigen_corr[ikpt,:].real*Ha2eV:
        if j == 8:
          print string[:-1]
          string = ' '
        elif l == DDB.nband:
          print string[:-1]
        else:
          string += str(' %e' %ii,)
        j += 1
        l += 1

# Make a copy to free memory
  vkpt = DDB.nkpt 
  vband = DDB.nband
  tkpt = EIGR2D.kpt[:,:]

# Make the average on degenerate energy
if temperature == 'y':
  for ikpt in N.arange(vkpt):
    count = 0
    iband = 0
    while iband < vband:
      if iband < vband-2:
        if ((degen[ikpt,iband] == degen[ikpt,iband+1]) and (degen[ikpt,iband] == degen[ikpt,iband+2])):
          total_corrT[:,ikpt,iband] = (total_corrT[:,ikpt,iband]+total_corrT[:,ikpt,iband+1]+total_corrT[:,ikpt,iband+2])/3   
          total_corrT[:,ikpt,iband+1] = total_corrT[:,ikpt,iband]   
          total_corrT[:,ikpt,iband+2] = total_corrT[:,ikpt,iband]
          iband += 3
          continue
      if iband <  vband-1:
        if (degen[ikpt,iband] == degen[ikpt,iband+1]):
          total_corrT[:,ikpt,iband] = (total_corrT[:,ikpt,iband]+total_corrT[:,ikpt,iband+1])/2
          total_corrT[:,ikpt,iband+1]=total_corrT[:,ikpt,iband]
          iband +=2
          continue
      iband += 1 
else:
  for ikpt in N.arange(vkpt):
    count = 0
    iband = 0
    while iband < vband:
      if iband < vband-2:
        if ((degen[ikpt,iband] == degen[ikpt,iband+1]) and (degen[ikpt,iband] == degen[ikpt,iband+2])):
          total_corr[ikpt,iband] = (total_corr[ikpt,iband]+total_corr[ikpt,iband+1]+total_corr[ikpt,iband+2])/3
          total_corr[ikpt,iband+1] = total_corr[ikpt,iband]
          total_corr[ikpt,iband+2] = total_corr[ikpt,iband]
          iband += 3
          continue
      if iband <  vband-1:
        if (degen[ikpt,iband] == degen[ikpt,iband+1]):
          total_corr[ikpt,iband] = (total_corr[ikpt,iband]+total_corr[ikpt,iband+1])/2
          total_corr[ikpt,iband+1]=total_corr[ikpt,iband]
          iband +=2
          continue
      iband += 1

if temperature == 'n': 
  print '--------------------------------------------'
  print 'Total correction (eV) for %s Q-point.' %(nbQ-1)
  for ikpt in N.arange(vkpt):
    print 'Kpt', tkpt[ikpt,:]
    j = 1
    l = 1
    string = ' '
    for ii in ((total_corr[ikpt,:]*Ha2eV)):
      if j == 9:
        print string[:-1]
        string = ' '
      elif l == DDB.nband:
        print string[:-1]
      else:
        string += str(' %e' %ii,)
      j += 1
      l += 1
if temperature == 'y':
  print '--------------------------------------------'
  print 'Total correction (eV) for %s Q-point.' %(nbQ-1)
  for ikpt in N.arange(vkpt):
    print 'Kpt', tkpt[ikpt,:]
    j = 1
    l = 1
    string = ' '
    for ii in ((total_corrT[0,ikpt,:]*Ha2eV)):
      if j == 9:
        print string[:-1]
        string = ' '
      elif l == DDB.nband:
        print string[:-1]
      else:
        string += str(' %e' %ii,)
      j += 1
      l += 1
  
#sys.stdout.flush()

# Write the results into the output file
if temperature == 'y':
  with open(output,"w") as O:
    O.write("Total correction of the ZPM (eV) for "+str(nbQ-1)+" Q points\n")
    for ikpt in N.arange(vkpt):
      O.write('Kpt: '+str(tkpt[ikpt,:])+"\n")
      j = 1
      for ii in (total_corrT[0,ikpt,:]*Ha2eV):
  #     Create a new line every 6 values
        if (j%6 == 0 and j !=0):
          O.write(str(ii)+'\n')
          j += 1
        elif j == vband:
          O.write(str(ii)+'\n')
        else:
          O.write(str(ii)+' ')
          j += 1
    O.write("Temperature dependence at Gamma\n")
    for iband in N.arange(vband):     
      O.write('Band: '+str(iband)+"\n")
      tt = 0
      for T in N.arange(0,N.float(temp_info[0]),N.float(temp_info[1])):
        O.write(str(T)+" "+str(total_corrT[tt,0,iband]*Ha2eV)+"\n") 
	tt += 1
else:
  with open(output,"w") as O:
    O.write("Total correction of the ZPM (eV) for "+str(nbQ-1)+" Q points\n")
    for ikpt in N.arange(vkpt):
      O.write('Kpt: '+str(tkpt[ikpt,:])+"\n")
      j = 1
      for ii in (total_corr[ikpt,:]*Ha2eV):
  #     Create a new line every 6 values
        if (j%6 == 0 and j !=0):
          O.write(str(ii)+'\n')
          j += 1
        elif j == vband:
          O.write(str(ii)+'\n')
        else:
          O.write(str(ii)+' ')
          j += 1
              
