# fmt: off

"""Infrared intensities"""

from math import sqrt
from sys import stdout

import numpy as np

import ase.units as units
from ase.parallel import paropen, parprint
from ase.vibrations.vibrations import Vibrations


class Infrared(Vibrations):
    """Class for calculating vibrational modes and infrared intensities
    using finite difference.

    The vibrational modes are calculated from a finite difference
    approximation of the Dynamical matrix and the IR intensities from
    a finite difference approximation of the gradient of the dipole
    moment. The method is described in:

      D. Porezag, M. R. Pederson:
      "Infrared intensities and Raman-scattering activities within
      density-functional theory",
      Phys. Rev. B 54, 7830 (1996)

    The calculator object (calc) linked to the Atoms object (atoms) must
    have the attribute:

    >>> calc.get_dipole_moment(atoms)

    In addition to the methods included in the ``Vibrations`` class
    the ``Infrared`` class introduces two new methods;
    *get_spectrum()* and *write_spectra()*. The *summary()*, *get_energies()*,
    *get_frequencies()*, *get_spectrum()* and *write_spectra()*
    methods all take an optional *method* keyword.  Use
    method='Frederiksen' to use the method described in:

      T. Frederiksen, M. Paulsson, M. Brandbyge, A. P. Jauho:
      "Inelastic transport theory from first-principles: methodology
      and applications for nanoscale devices",
      Phys. Rev. B 75, 205413 (2007)

    atoms: Atoms object
        The atoms to work on.
    indices: list of int
        List of indices of atoms to vibrate.  Default behavior is
        to vibrate all atoms.
    name: str
        Name to use for files.
    delta: float
        Magnitude of displacements.
    nfree: int
        Number of displacements per degree of freedom, 2 or 4 are
        supported. Default is 2 which will displace each atom +delta
        and -delta in each cartesian direction.
    directions: list of int
        Cartesian coordinates to calculate the gradient
        of the dipole moment in.
        For example directions = 2 only dipole moment in the z-direction will
        be considered, whereas for directions = [0, 1] only the dipole
        moment in the xy-plane will be considered. Default behavior is to
        use the dipole moment in all directions.

    Example:

    >>> from ase.io import read
    >>> from ase.calculators.vasp import Vasp
    >>> from ase.vibrations import Infrared
    >>> water = read('water.traj')  # read pre-relaxed structure of water
    >>> calc = Vasp(prec='Accurate',
    ...             ediff=1E-8,
    ...             isym=0,
    ...             idipol=4,       # calculate the total dipole moment
    ...             dipol=water.get_center_of_mass(scaled=True),
    ...             ldipol=True)
    >>> water.calc = calc
    >>> ir = Infrared(water)
    >>> ir.run()
    >>> ir.summary()
    -------------------------------------
    Mode    Frequency        Intensity
    #    meV     cm^-1   (D/Å)^2 amu^-1
    -------------------------------------
    0   16.9i    136.2i     1.6108
    1   10.5i     84.9i     2.1682
    2    5.1i     41.1i     1.7327
    3    0.3i      2.2i     0.0080
    4    2.4      19.0      0.1186
    5   15.3     123.5      1.4956
    6  195.5    1576.7      1.6437
    7  458.9    3701.3      0.0284
    8  473.0    3814.6      1.1812
    -------------------------------------
    Zero-point energy: 0.573 eV
    Static dipole moment: 1.833 D
    Maximum force on atom in `equilibrium`: 0.0026 eV/Å



    This interface now also works for calculator 'siesta',
    (added get_dipole_moment for siesta).

    Example:

    >>> #!/usr/bin/env python3

    >>> from ase.io import read
    >>> from ase.calculators.siesta import Siesta
    >>> from ase.vibrations import Infrared

    >>> bud = read('bud1.xyz')

    >>> calc = Siesta(label='bud',
    ...       meshcutoff=250 * Ry,
    ...       basis='DZP',
    ...       kpts=[1, 1, 1])

    >>> calc.set_fdf('DM.MixingWeight', 0.08)
    >>> calc.set_fdf('DM.NumberPulay', 3)
    >>> calc.set_fdf('DM.NumberKick', 20)
    >>> calc.set_fdf('DM.KickMixingWeight', 0.15)
    >>> calc.set_fdf('SolutionMethod',      'Diagon')
    >>> calc.set_fdf('MaxSCFIterations', 500)
    >>> calc.set_fdf('PAO.BasisType',  'split')
    >>> #50 meV = 0.003674931 * Ry
    >>> calc.set_fdf('PAO.EnergyShift', 0.003674931 * Ry )
    >>> calc.set_fdf('LatticeConstant', 1.000000 * Ang)
    >>> calc.set_fdf('WriteCoorXmol',       'T')

    >>> bud.calc = calc

    >>> ir = Infrared(bud)
    >>> ir.run()
    >>> ir.summary()

    """

    def __init__(self, atoms, indices=None, name='ir', delta=0.01,
                 nfree=2, directions=None):
        Vibrations.__init__(self, atoms, indices=indices, name=name,
                            delta=delta, nfree=nfree)
        if atoms.constraints:
            print('WARNING! \n Your Atoms object is constrained. '
                  'Some forces may be unintended set to zero. \n')
        if directions is None:
            self.directions = np.asarray([0, 1, 2])
        else:
            self.directions = np.asarray(directions)
        self.ir = True

    def read(self, method='standard', direction='central'):
        self.method = method.lower()
        self.direction = direction.lower()
        assert self.method in ['standard', 'frederiksen']

        if direction != 'central':
            raise NotImplementedError(
                'Only central difference is implemented at the moment.')

        disp = self._eq_disp()
        forces_zero = disp.forces()
        dipole_zero = disp.dipole()
        self.dipole_zero = (sum(dipole_zero**2)**0.5) / units.Debye
        self.force_zero = max(
            sum((forces_zero[j])**2)**0.5 for j in self.indices)

        ndof = 3 * len(self.indices)
        H = np.empty((ndof, ndof))
        dpdx = np.empty((ndof, 3))
        for r, (a, i) in enumerate(self._iter_ai()):
            disp_minus = self._disp(a, i, -1)
            disp_plus = self._disp(a, i, 1)

            fminus = disp_minus.forces()
            dminus = disp_minus.dipole()

            fplus = disp_plus.forces()
            dplus = disp_plus.dipole()

            if self.nfree == 4:
                disp_mm = self._disp(a, i, -2)
                disp_pp = self._disp(a, i, 2)
                fminusminus = disp_mm.forces()
                dminusminus = disp_mm.dipole()

                fplusplus = disp_pp.forces()
                dplusplus = disp_pp.dipole()
            if self.method == 'frederiksen':
                fminus[a] += -fminus.sum(0)
                fplus[a] += -fplus.sum(0)
                if self.nfree == 4:
                    fminusminus[a] += -fminus.sum(0)
                    fplusplus[a] += -fplus.sum(0)
            if self.nfree == 2:
                H[r] = (fminus - fplus)[self.indices].ravel() / 2.0
                dpdx[r] = (dminus - dplus)
            if self.nfree == 4:
                H[r] = (-fminusminus + 8 * fminus - 8 * fplus +
                        fplusplus)[self.indices].ravel() / 12.0
                dpdx[r] = (-dplusplus + 8 * dplus - 8 * dminus +
                           dminusminus) / 6.0
            H[r] /= 2 * self.delta
            dpdx[r] /= 2 * self.delta
            for n in range(3):
                if n not in self.directions:
                    dpdx[r][n] = 0
                    dpdx[r][n] = 0
        # Calculate eigenfrequencies and eigenvectors
        masses = self.atoms.get_masses()
        H += H.copy().T
        self.H = H

        self.im = np.repeat(masses[self.indices]**-0.5, 3)
        omega2, modes = np.linalg.eigh(self.im[:, None] * H * self.im)
        self.modes = modes.T.copy()

        # Calculate intensities
        dpdq = np.array([dpdx[j] / sqrt(masses[self.indices[j // 3]] *
                                        units._amu / units._me)
                         for j in range(ndof)])
        dpdQ = np.dot(dpdq.T, modes)
        dpdQ = dpdQ.T
        intensities = np.array([sum(dpdQ[j]**2) for j in range(ndof)])
        # Conversion factor:
        s = units._hbar * 1e10 / sqrt(units._e * units._amu)
        self.hnu = s * omega2.astype(complex)**0.5
        # Conversion factor from atomic units to (D/Angstrom)^2/amu.
        conv = (1.0 / units.Debye)**2 * units._amu / units._me
        self.intensities = intensities * conv

    def intensity_prefactor(self, intensity_unit):
        if intensity_unit == '(D/A)2/amu':
            return 1.0, '(D/Å)^2 amu^-1'
        elif intensity_unit == 'km/mol':
            # conversion factor from Porezag PRB 54 (1996) 7830
            return 42.255, 'km/mol'
        else:
            raise RuntimeError('Intensity unit >' + intensity_unit +
                               '< unknown.')

    def summary(self, method='standard', direction='central',
                intensity_unit='(D/A)2/amu', log=stdout):
        hnu = self.get_energies(method, direction)
        s = 0.01 * units._e / units._c / units._hplanck
        iu, iu_string = self.intensity_prefactor(intensity_unit)
        if intensity_unit == '(D/A)2/amu':
            iu_format = '%9.4f'
        elif intensity_unit == 'km/mol':
            iu_string = '   ' + iu_string
            iu_format = ' %7.1f'
        if isinstance(log, str):
            log = paropen(log, 'a')

        parprint('-------------------------------------', file=log)
        parprint(' Mode    Frequency        Intensity', file=log)
        parprint('  #    meV     cm^-1   ' + iu_string, file=log)
        parprint('-------------------------------------', file=log)
        for n, e in enumerate(hnu):
            if e.imag != 0:
                c = 'i'
                e = e.imag
            else:
                c = ' '
                e = e.real
            parprint(('%3d %6.1f%s  %7.1f%s  ' + iu_format) %
                     (n, 1000 * e, c, s * e, c, iu * self.intensities[n]),
                     file=log)
        parprint('-------------------------------------', file=log)
        parprint('Zero-point energy: %.3f eV' % self.get_zero_point_energy(),
                 file=log)
        parprint('Static dipole moment: %.3f D' % self.dipole_zero, file=log)
        parprint('Maximum force on atom in `equilibrium`: %.4f eV/Å' %
                 self.force_zero, file=log)
        parprint(file=log)

    def get_spectrum(self, start=800, end=4000, npts=None, width=4,
                     type='Gaussian', method='standard', direction='central',
                     intensity_unit='(D/A)2/amu', normalize=False):
        """Get infrared spectrum.

        The method returns wavenumbers in cm^-1 with corresponding
        absolute infrared intensity.
        Start and end point, and width of the Gaussian/Lorentzian should
        be given in cm^-1.
        normalize=True ensures the integral over the peaks to give the
        intensity.
        """
        frequencies = self.get_frequencies(method, direction).real
        intensities = self.intensities
        return self.fold(frequencies, intensities,
                         start, end, npts, width, type, normalize)

    def write_spectra(self, out='ir-spectra.dat', start=800, end=4000,
                      npts=None, width=10,
                      type='Gaussian', method='standard', direction='central',
                      intensity_unit='(D/A)2/amu', normalize=False):
        """Write out infrared spectrum to file.

        First column is the wavenumber in cm^-1, the second column the
        absolute infrared intensities, and
        the third column the absorbance scaled so that data runs
        from 1 to 0. Start and end
        point, and width of the Gaussian/Lorentzian should be given
        in cm^-1."""
        energies, spectrum = self.get_spectrum(start, end, npts, width,
                                               type, method, direction,
                                               normalize)

        # Write out spectrum in file. First column is absolute intensities.
        # Second column is absorbance scaled so that data runs from 1 to 0
        spectrum2 = 1. - spectrum / spectrum.max()
        outdata = np.empty([len(energies), 3])
        outdata.T[0] = energies
        outdata.T[1] = spectrum
        outdata.T[2] = spectrum2
        with open(out, 'w') as fd:
            fd.write(f'# {type.title()} folded, width={width:g} cm^-1\n')
            iu, iu_string = self.intensity_prefactor(intensity_unit)
            if normalize:
                iu_string = 'cm ' + iu_string
            fd.write('# [cm^-1] %14s\n' % ('[' + iu_string + ']'))
            for row in outdata:
                fd.write('%.3f  %15.5e  %15.5e \n' %
                         (row[0], iu * row[1], row[2]))
