# fmt: off

# flake8: noqa
import numpy as np
from numpy import linalg

from ase import units

# Three variables extracted from what used to be endless repetitions below.
Ax = np.array([[1, 0, 0, -1, 0, 0, 0, 0, 0],
               [0, 1, 0, 0, -1, 0, 0, 0, 0],
               [0, 0, 1, 0, 0, -1, 0, 0, 0],
               [0, 0, 0, -1, 0, 0, 1, 0, 0],
               [0, 0, 0, 0, -1, 0, 0, 1, 0],
               [0, 0, 0, 0, 0, -1, 0, 0, 1]])
Bx = np.array([[1, 0, 0, -1, 0, 0],
               [0, 1, 0, 0, -1, 0],
               [0, 0, 1, 0, 0, -1]])
Mx = Bx


class Morse:
    def __init__(self, atomi, atomj, D, alpha, r0):
        self.atomi = atomi
        self.atomj = atomj
        self.D = D
        self.alpha = alpha
        self.r0 = r0
        self.r = None


class Bond:
    def __init__(self, atomi, atomj, k, b0,
                 alpha=None, rref=None):
        self.atomi = atomi
        self.atomj = atomj
        self.k = k
        self.b0 = b0
        self.alpha = alpha
        self.rref = rref
        self.b = None


class Angle:
    def __init__(self, atomi, atomj, atomk, k, a0, cos=False,
                 alpha=None, rref=None):
        self.atomi = atomi
        self.atomj = atomj
        self.atomk = atomk
        self.k = k
        self.a0 = a0
        self.cos = cos
        self.alpha = alpha
        self.rref = rref
        self.a = None


class Dihedral:
    def __init__(self, atomi, atomj, atomk, atoml, k, d0=None, n=None,
                 alpha=None, rref=None):
        self.atomi = atomi
        self.atomj = atomj
        self.atomk = atomk
        self.atoml = atoml
        self.k = k
        self.d0 = d0
        self.n = n
        self.alpha = alpha
        self.rref = rref
        self.d = None


class VdW:
    def __init__(self, atomi, atomj, epsilonij=None, sigmaij=None, rminij=None,
                 Aij=None, Bij=None, epsiloni=None, epsilonj=None,
                 sigmai=None, sigmaj=None, rmini=None, rminj=None, scale=1.0):
        self.atomi = atomi
        self.atomj = atomj
        if epsilonij is not None:
            if sigmaij is not None:
                self.Aij = scale * 4.0 * epsilonij * sigmaij**12
                self.Bij = scale * 4.0 * epsilonij * sigmaij**6
            elif rminij is not None:
                self.Aij = scale * epsilonij * rminij**12
                self.Bij = scale * 2.0 * epsilonij * rminij**6
            else:
                raise NotImplementedError("not implemented combination"
                                          "of vdW parameters.")
        elif Aij is not None and Bij is not None:
            self.Aij = scale * Aij
            self.Bij = scale * Bij
        elif epsiloni is not None and epsilonj is not None:
            if sigmai is not None and sigmaj is not None:
                self.Aij = (scale * 4.0 * np.sqrt(epsiloni * epsilonj)
                            * ((sigmai + sigmaj) / 2.0)**12)
                self.Bij = (scale * 4.0 * np.sqrt(epsiloni * epsilonj)
                            * ((sigmai + sigmaj) / 2.0)**6)
            elif rmini is not None and rminj is not None:
                self.Aij = (scale * np.sqrt(epsiloni * epsilonj)
                            * ((rmini + rminj) / 2.0)**12)
                self.Bij = (scale * 2.0 * np.sqrt(epsiloni * epsilonj)
                            * ((rmini + rminj) / 2.0)**6)
        else:
            raise NotImplementedError("not implemented combination"
                                      "of vdW parameters.")
        self.r = None


class Coulomb:
    def __init__(self, atomi, atomj, chargeij=None,
                 chargei=None, chargej=None, scale=1.0):
        self.atomi = atomi
        self.atomj = atomj
        if chargeij is not None:
            self.chargeij = (scale * chargeij * 8.9875517873681764e9
                             * units.m * units.J / units.C / units.C)
        elif chargei is not None and chargej is not None:
            self.chargeij = (scale * chargei * chargej * 8.9875517873681764e9
                             * units.m * units.J / units.C / units.C)
        else:
            raise NotImplementedError("not implemented combination"
                                      "of Coulomb parameters.")
        self.r = None


def get_morse_potential_eta(atoms, morse):
    i = morse.atomi
    j = morse.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)

    if dij > morse.r0:
        exp = np.exp(-morse.alpha * (dij - morse.r0))
        eta = 1.0 - (1.0 - exp)**2
    else:
        eta = 1.0

    return eta


def get_morse_potential_value(atoms, morse):
    i = morse.atomi
    j = morse.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)

    exp = np.exp(-morse.alpha * (dij - morse.r0))

    v = morse.D * (1.0 - exp)**2

    morse.r = dij

    return i, j, v


def get_morse_potential_gradient(atoms, morse):
    i = morse.atomi
    j = morse.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    exp = np.exp(-morse.alpha * (dij - morse.r0))

    gr = 2.0 * morse.D * morse.alpha * exp * (1.0 - exp) * eij

    gx = np.dot(Mx.T, gr)

    morse.r = dij

    return i, j, gx


def get_morse_potential_hessian(atoms, morse, spectral=False):
    i = morse.atomi
    j = morse.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    Pij = np.tensordot(eij, eij, axes=0)
    Qij = np.eye(3) - Pij

    exp = np.exp(-morse.alpha * (dij - morse.r0))

    Hr = (2.0 * morse.D * morse.alpha * exp * (morse.alpha * (2.0 * exp - 1.0) * Pij
                                               + (1.0 - exp) / dij * Qij))

    Hx = np.dot(Mx.T, np.dot(Hr, Mx))

    if spectral:
        eigvals, eigvecs = linalg.eigh(Hx)
        D = np.diag(np.abs(eigvals))
        U = eigvecs
        Hx = np.dot(U, np.dot(D, np.transpose(U)))

    morse.r = dij

    return i, j, Hx


def get_morse_potential_reduced_hessian(atoms, morse):
    i = morse.atomi
    j = morse.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    Pij = np.tensordot(eij, eij, axes=0)

    exp = np.exp(-morse.alpha * (dij - morse.r0))

    Hr = np.abs(2.0 * morse.D * morse.alpha**2 * exp * (2.0 * exp - 1.0)) * Pij

    Hx = np.dot(Mx.T, np.dot(Hr, Mx))

    morse.r = dij

    return i, j, Hx


def get_bond_potential_value(atoms, bond):
    i = bond.atomi
    j = bond.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)

    v = 0.5 * bond.k * (dij - bond.b0)**2

    bond.b = dij

    return i, j, v


def get_bond_potential_gradient(atoms, bond):
    i = bond.atomi
    j = bond.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    gr = bond.k * (dij - bond.b0) * eij

    gx = np.dot(Bx.T, gr)

    bond.b = dij

    return i, j, gx


def get_bond_potential_hessian(atoms, bond, morses=None, spectral=False):
    i = bond.atomi
    j = bond.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    Pij = np.tensordot(eij, eij, axes=0)
    Qij = np.eye(3) - Pij

    Hr = bond.k * Pij + bond.k * (dij - bond.b0) / dij * Qij

    if bond.alpha is not None:
        Hr *= np.exp(bond.alpha[0] * (bond.rref[0]**2 - dij**2))

    if morses is not None:
        for m in range(len(morses)):
            if (morses[m].atomi == i or
                    morses[m].atomi == j):
                Hr *= get_morse_potential_eta(atoms, morses[m])
            elif (morses[m].atomj == i or
                  morses[m].atomj == j):
                Hr *= get_morse_potential_eta(atoms, morses[m])

    Hx = np.dot(Bx.T, np.dot(Hr, Bx))

    if spectral:
        eigvals, eigvecs = linalg.eigh(Hx)
        D = np.diag(np.abs(eigvals))
        U = eigvecs
        Hx = np.dot(U, np.dot(D, np.transpose(U)))

    bond.b = dij

    return i, j, Hx


def get_bond_potential_reduced_hessian(atoms, bond, morses=None):
    i = bond.atomi
    j = bond.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    Pij = np.tensordot(eij, eij, axes=0)

    Hr = bond.k * Pij

    if bond.alpha is not None:
        Hr *= np.exp(bond.alpha[0] * (bond.rref[0]**2 - dij**2))

    if morses is not None:
        for m in range(len(morses)):
            if (morses[m].atomi == i or
                    morses[m].atomi == j):
                Hr *= get_morse_potential_eta(atoms, morses[m])
            elif (morses[m].atomj == i or
                  morses[m].atomj == j):
                Hr *= get_morse_potential_eta(atoms, morses[m])

    Hx = np.dot(Bx.T, np.dot(Hr, Bx))

    bond.b = dij

    return i, j, Hx


def get_bond_potential_reduced_hessian_test(atoms, bond):

    i, j, v = get_bond_potential_value(atoms, bond)
    i, j, gx = get_bond_potential_gradient(atoms, bond)

    Hx = np.tensordot(gx, gx, axes=0) / v / 2.0

    return i, j, Hx


def get_angle_potential_value(atoms, angle):

    i = angle.atomi
    j = angle.atomj
    k = angle.atomk

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij
    rkj = rel_pos_pbc(atoms, k, j)
    dkj = linalg.norm(rkj)
    ekj = rkj / dkj
    eijekj = np.dot(eij, ekj)
    if np.abs(eijekj) > 1.0:
        eijekj = np.sign(eijekj)

    a = np.arccos(eijekj)

    if angle.cos:
        da = np.cos(a) - np.cos(angle.a0)
    else:
        da = a - angle.a0
        da = da - np.around(da / np.pi) * np.pi

    v = 0.5 * angle.k * da**2

    angle.a = a

    return i, j, k, v


def get_angle_potential_gradient(atoms, angle):
    i = angle.atomi
    j = angle.atomj
    k = angle.atomk

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij
    rkj = rel_pos_pbc(atoms, k, j)
    dkj = linalg.norm(rkj)
    ekj = rkj / dkj
    eijekj = np.dot(eij, ekj)
    if np.abs(eijekj) > 1.0:
        eijekj = np.sign(eijekj)

    a = np.arccos(eijekj)
    if angle.cos:
        da = np.cos(a) - np.cos(angle.a0)
    else:
        da = a - angle.a0
        da = da - np.around(da / np.pi) * np.pi
        sina = np.sin(a)

    Pij = np.tensordot(eij, eij, axes=0)
    Qij = np.eye(3) - Pij
    Pkj = np.tensordot(ekj, ekj, axes=0)
    Qkj = np.eye(3) - Pkj

    gr = np.zeros(6)
    if angle.cos:
        gr[0:3] = angle.k * da / dij * np.dot(Qij, ekj)
        gr[3:6] = angle.k * da / dkj * np.dot(Qkj, eij)
    elif np.abs(sina) > 0.001:
        gr[0:3] = -angle.k * da / sina / dij * np.dot(Qij, ekj)
        gr[3:6] = -angle.k * da / sina / dkj * np.dot(Qkj, eij)

    gx = np.dot(Ax.T, gr)

    angle.a = a

    return i, j, k, gx


def get_angle_potential_hessian(atoms, angle, morses=None, spectral=False):
    i = angle.atomi
    j = angle.atomj
    k = angle.atomk

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    dij2 = dij * dij
    eij = rij / dij
    rkj = rel_pos_pbc(atoms, k, j)
    dkj = linalg.norm(rkj)
    dkj2 = dkj * dkj
    ekj = rkj / dkj
    dijdkj = dij * dkj
    eijekj = np.dot(eij, ekj)
    if np.abs(eijekj) > 1.0:
        eijekj = np.sign(eijekj)

    a = np.arccos(eijekj)
    if angle.cos:
        da = np.cos(a) - np.cos(angle.a0)
        cosa0 = np.cos(angle.a0)
    else:
        da = a - angle.a0
        da = da - np.around(da / np.pi) * np.pi
    sina = np.sin(a)
    cosa = np.cos(a)
    ctga = cosa / sina

    Pij = np.tensordot(eij, eij, axes=0)
    Qij = np.eye(3) - Pij
    Pkj = np.tensordot(ekj, ekj, axes=0)
    Qkj = np.eye(3) - Pkj
    Pik = np.tensordot(eij, ekj, axes=0)
    Pki = np.tensordot(ekj, eij, axes=0)
    P = np.eye(3) * eijekj

    QijPkjQij = np.dot(Qij, np.dot(Pkj, Qij))
    QijPkiQkj = np.dot(Qij, np.dot(Pki, Qkj))
    QkjPijQkj = np.dot(Qkj, np.dot(Pij, Qkj))

    Hr = np.zeros((6, 6))
    if angle.cos and np.abs(sina) > 0.001:
        factor = 1.0 - 2.0 * cosa * cosa + cosa * cosa0
        Hr[0:3, 0:3] = (angle.k * (factor * QijPkjQij / sina
                                   - sina * da * (-ctga * QijPkjQij / sina + np.dot(Qij, Pki)
                                                  - np.dot(Pij, Pki) * 2.0 + (Pik + P))) / sina / dij2)
        Hr[0:3, 3:6] = (angle.k * (factor * QijPkiQkj / sina
                                   - sina * da * (-ctga * QijPkiQkj / sina
                                                  - np.dot(Qij, Qkj))) / sina / dijdkj)
        Hr[3:6, 0:3] = Hr[0:3, 3:6].T
        Hr[3:6, 3:6] = (angle.k * (factor * QkjPijQkj / sina
                                   - sina * da * (-ctga * QkjPijQkj / sina
                                                  + np.dot(Qkj, Pik) -
                                                  np.dot(Pkj, Pik)
                                                  * 2.0 + (Pki + P))) / sina / dkj2)
    elif np.abs(sina) > 0.001:
        Hr[0:3, 0:3] = (angle.k * (QijPkjQij / sina
                                   + da * (-ctga * QijPkjQij / sina + np.dot(Qij, Pki)
                                           - np.dot(Pij, Pki) * 2.0 + (Pik + P))) / sina / dij2)
        Hr[0:3, 3:6] = (angle.k * (QijPkiQkj / sina
                                   + da * (-ctga * QijPkiQkj / sina
                                           - np.dot(Qij, Qkj))) / sina / dijdkj)
        Hr[3:6, 0:3] = Hr[0:3, 3:6].T
        Hr[3:6, 3:6] = (angle.k * (QkjPijQkj / sina
                                   + da * (-ctga * QkjPijQkj / sina
                                           + np.dot(Qkj, Pik) - np.dot(Pkj, Pik)
                                           * 2.0 + (Pki + P))) / sina / dkj2)

    if angle.alpha is not None:
        Hr *= (np.exp(angle.alpha[0] * (angle.rref[0]**2 - dij**2))
               * np.exp(angle.alpha[1] * (angle.rref[1]**2 - dkj**2)))

    if morses is not None:
        for m in range(len(morses)):
            if (morses[m].atomi == i or
                morses[m].atomi == j or
                    morses[m].atomi == k):
                Hr *= get_morse_potential_eta(atoms, morses[m])
            elif (morses[m].atomj == i or
                  morses[m].atomj == j or
                  morses[m].atomj == k):
                Hr *= get_morse_potential_eta(atoms, morses[m])

    Hx = np.dot(Ax.T, np.dot(Hr, Ax))

    if spectral:
        eigvals, eigvecs = linalg.eigh(Hx)
        D = np.diag(np.abs(eigvals))
        U = eigvecs
        Hx = np.dot(U, np.dot(D, np.transpose(U)))

    angle.a = a

    return i, j, k, Hx


def get_angle_potential_reduced_hessian(atoms, angle, morses=None):
    i = angle.atomi
    j = angle.atomj
    k = angle.atomk

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    dij2 = dij * dij
    eij = rij / dij
    rkj = rel_pos_pbc(atoms, k, j)
    dkj = linalg.norm(rkj)
    dkj2 = dkj * dkj
    ekj = rkj / dkj
    dijdkj = dij * dkj
    eijekj = np.dot(eij, ekj)
    if np.abs(eijekj) > 1.0:
        eijekj = np.sign(eijekj)

    a = np.arccos(eijekj)
    sina = np.sin(a)
    sina2 = sina * sina

    Pij = np.tensordot(eij, eij, axes=0)
    Qij = np.eye(3) - Pij
    Pkj = np.tensordot(ekj, ekj, axes=0)
    Qkj = np.eye(3) - Pkj
    Pki = np.tensordot(ekj, eij, axes=0)

    Hr = np.zeros((6, 6))
    if np.abs(sina) > 0.001:
        Hr[0:3, 0:3] = np.dot(Qij, np.dot(Pkj, Qij)) / dij2
        Hr[0:3, 3:6] = np.dot(Qij, np.dot(Pki, Qkj)) / dijdkj
        Hr[3:6, 0:3] = Hr[0:3, 3:6].T
        Hr[3:6, 3:6] = np.dot(Qkj, np.dot(Pij, Qkj)) / dkj2

    if angle.cos and np.abs(sina) > 0.001:
        cosa = np.cos(a)
        cosa0 = np.cos(angle.a0)
        factor = np.abs(1.0 - 2.0 * cosa * cosa + cosa * cosa0)
        Hr = Hr * factor * angle.k / sina2
    elif np.abs(sina) > 0.001:
        Hr = Hr * angle.k / sina2

    if angle.alpha is not None:
        Hr *= (np.exp(angle.alpha[0] * (angle.rref[0]**2 - dij**2))
               * np.exp(angle.alpha[1] * (angle.rref[1]**2 - dkj**2)))

    if morses is not None:
        for m in range(len(morses)):
            if (morses[m].atomi == i or
                morses[m].atomi == j or
                    morses[m].atomi == k):
                Hr *= get_morse_potential_eta(atoms, morses[m])
            elif (morses[m].atomj == i or
                  morses[m].atomj == j or
                  morses[m].atomj == k):
                Hr *= get_morse_potential_eta(atoms, morses[m])

    Hx = np.dot(Ax.T, np.dot(Hr, Ax))

    angle.a = a

    return i, j, k, Hx


def get_angle_potential_reduced_hessian_test(atoms, angle):
    i, j, k, v = get_angle_potential_value(atoms, angle)
    i, j, k, gx = get_angle_potential_gradient(atoms, angle)

    Hx = np.tensordot(gx, gx, axes=0) / v / 2.0

    return i, j, k, Hx


def get_dihedral_potential_value(atoms, dihedral):
    i = dihedral.atomi
    j = dihedral.atomj
    k = dihedral.atomk
    l = dihedral.atoml

    rij = rel_pos_pbc(atoms, i, j)
    rkj = rel_pos_pbc(atoms, k, j)
    rkl = rel_pos_pbc(atoms, k, l)

    rmj = np.cross(rij, rkj)
    dmj = linalg.norm(rmj)
    emj = rmj / dmj
    rnk = np.cross(rkj, rkl)
    dnk = linalg.norm(rnk)
    enk = rnk / dnk
    emjenk = np.dot(emj, enk)
    if np.abs(emjenk) > 1.0:
        emjenk = np.sign(emjenk)

    d = np.sign(np.dot(rkj, np.cross(rmj, rnk))) * np.arccos(emjenk)

    if dihedral.d0 is None:
        v = 0.5 * dihedral.k * (1.0 - np.cos(2.0 * d))
    else:
        dd = d - dihedral.d0
        dd = dd - np.around(dd / np.pi / 2.0) * np.pi * 2.0
        if dihedral.n is None:
            v = 0.5 * dihedral.k * dd**2
        else:
            v = dihedral.k * (1.0 + np.cos(dihedral.n * d - dihedral.d0))

    dihedral.d = d

    return i, j, k, l, v


def get_dihedral_potential_gradient(atoms, dihedral):
    i = dihedral.atomi
    j = dihedral.atomj
    k = dihedral.atomk
    l = dihedral.atoml

    rij = rel_pos_pbc(atoms, i, j)
    rkj = rel_pos_pbc(atoms, k, j)
    dkj = linalg.norm(rkj)
    dkj2 = dkj * dkj
    rkl = rel_pos_pbc(atoms, k, l)

    rijrkj = np.dot(rij, rkj)
    rkjrkl = np.dot(rkj, rkl)

    rmj = np.cross(rij, rkj)
    dmj = linalg.norm(rmj)
    dmj2 = dmj * dmj
    emj = rmj / dmj
    rnk = np.cross(rkj, rkl)
    dnk = linalg.norm(rnk)
    dnk2 = dnk * dnk
    enk = rnk / dnk
    emjenk = np.dot(emj, enk)
    if np.abs(emjenk) > 1.0:
        emjenk = np.sign(emjenk)

    dddri = dkj / dmj2 * rmj
    dddrl = -dkj / dnk2 * rnk

    gx = np.zeros(12)

    gx[0:3] = dddri
    gx[3:6] = (rijrkj / dkj2 - 1.0) * dddri - rkjrkl / dkj2 * dddrl
    gx[6:9] = (rkjrkl / dkj2 - 1.0) * dddrl - rijrkj / dkj2 * dddri
    gx[9:12] = dddrl

    d = np.sign(np.dot(rkj, np.cross(rmj, rnk))) * np.arccos(emjenk)

    if dihedral.d0 is None:
        gx *= dihedral.k * np.sin(2.0 * d)
    else:
        dd = d - dihedral.d0
        dd = dd - np.around(dd / np.pi / 2.0) * np.pi * 2.0
        if dihedral.n is None:
            gx *= dihedral.k * dd
        else:
            gx *= -dihedral.k * dihedral.n * \
                np.sin(dihedral.n * d - dihedral.d0)

    dihedral.d = d

    return i, j, k, l, gx


def get_dihedral_potential_hessian(atoms, dihedral, morses=None,
                                   spectral=False):
    eps = 0.000001

    i, j, k, l, g = get_dihedral_potential_gradient(atoms, dihedral)

    Hx = np.zeros((12, 12))

    dihedral_eps = Dihedral(dihedral.atomi, dihedral.atomj,
                            dihedral.atomk, dihedral.atoml,
                            dihedral.k, dihedral.d0, dihedral.n)
    indx = [3 * i, 3 * i + 1, 3 * i + 2,
            3 * j, 3 * j + 1, 3 * j + 2,
            3 * k, 3 * k + 1, 3 * k + 2,
            3 * l, 3 * l + 1, 3 * l + 2]
    for x in range(12):
        a = atoms.copy()
        positions = np.reshape(a.get_positions(), -1)
        positions[indx[x]] += eps
        a.set_positions(np.reshape(positions, (len(a), 3)))
        i, j, k, l, geps = get_dihedral_potential_gradient(a, dihedral_eps)
        for y in range(12):
            Hx[x, y] += 0.5 * (geps[y] - g[y]) / eps
            Hx[y, x] += 0.5 * (geps[y] - g[y]) / eps

    if dihedral.alpha is not None:
        rij = rel_pos_pbc(atoms, i, j)
        dij = linalg.norm(rij)
        rkj = rel_pos_pbc(atoms, k, j)
        dkj = linalg.norm(rkj)
        rkl = rel_pos_pbc(atoms, k, l)
        dkl = linalg.norm(rkl)
        Hx *= (np.exp(dihedral.alpha[0] * (dihedral.rref[0]**2 - dij**2))
               * np.exp(dihedral.alpha[1] * (dihedral.rref[1]**2 - dkj**2))
               * np.exp(dihedral.alpha[2] * (dihedral.rref[2]**2 - dkl**2)))

    if morses is not None:
        for m in range(len(morses)):
            if (morses[m].atomi == i or
                morses[m].atomi == j or
                morses[m].atomi == k or
                    morses[m].atomi == l):
                Hx *= get_morse_potential_eta(atoms, morses[m])
            elif (morses[m].atomj == i or
                  morses[m].atomj == j or
                  morses[m].atomj == k or
                  morses[m].atomj == l):
                Hx *= get_morse_potential_eta(atoms, morses[m])

    if spectral:
        eigvals, eigvecs = linalg.eigh(Hx)
        D = np.diag(np.abs(eigvals))
        U = eigvecs
        Hx = np.dot(U, np.dot(D, np.transpose(U)))

    return i, j, k, l, Hx


def get_dihedral_potential_reduced_hessian(atoms, dihedral, morses=None):
    i = dihedral.atomi
    j = dihedral.atomj
    k = dihedral.atomk
    l = dihedral.atoml

    rij = rel_pos_pbc(atoms, i, j)
    rkj = rel_pos_pbc(atoms, k, j)
    dkj = linalg.norm(rkj)
    dkj2 = dkj * dkj
    rkl = rel_pos_pbc(atoms, k, l)

    rijrkj = np.dot(rij, rkj)
    rkjrkl = np.dot(rkj, rkl)

    rmj = np.cross(rij, rkj)
    dmj = linalg.norm(rmj)
    dmj2 = dmj * dmj
    emj = rmj / dmj
    rnk = np.cross(rkj, rkl)
    dnk = linalg.norm(rnk)
    dnk2 = dnk * dnk
    enk = rnk / dnk
    emjenk = np.dot(emj, enk)
    if np.abs(emjenk) > 1.0:
        emjenk = np.sign(emjenk)

    d = np.sign(np.dot(rkj, np.cross(rmj, rnk))) * np.arccos(emjenk)

    dddri = dkj / dmj2 * rmj
    dddrl = -dkj / dnk2 * rnk

    gx = np.zeros(12)

    gx[0:3] = dddri
    gx[3:6] = (rijrkj / dkj2 - 1.0) * dddri - rkjrkl / dkj2 * dddrl
    gx[6:9] = (rkjrkl / dkj2 - 1.0) * dddrl - rijrkj / dkj2 * dddri
    gx[9:12] = dddrl

    if dihedral.d0 is None:
        Hx = np.abs(2.0 * dihedral.k * np.cos(2.0 * d)) * \
            np.tensordot(gx, gx, axes=0)
    if dihedral.n is None:
        Hx = dihedral.k * np.tensordot(gx, gx, axes=0)
    else:
        Hx = (np.abs(-dihedral.k * dihedral.n**2
                     * np.cos(dihedral.n * d - dihedral.d0)) * np.tensordot(gx, gx, axes=0))

    if dihedral.alpha is not None:
        rij = rel_pos_pbc(atoms, i, j)
        dij = linalg.norm(rij)
        rkj = rel_pos_pbc(atoms, k, j)
        dkj = linalg.norm(rkj)
        rkl = rel_pos_pbc(atoms, k, l)
        dkl = linalg.norm(rkl)
        Hx *= (np.exp(dihedral.alpha[0] * (dihedral.rref[0]**2 - dij**2))
               * np.exp(dihedral.alpha[1] * (dihedral.rref[1]**2 - dkj**2))
               * np.exp(dihedral.alpha[2] * (dihedral.rref[2]**2 - dkl**2)))

    if morses is not None:
        for m in range(len(morses)):
            if (morses[m].atomi == i or
                morses[m].atomi == j or
                morses[m].atomi == k or
                    morses[m].atomi == l):
                Hx *= get_morse_potential_eta(atoms, morses[m])
            elif (morses[m].atomj == i or
                  morses[m].atomj == j or
                  morses[m].atomj == k or
                  morses[m].atomj == l):
                Hx *= get_morse_potential_eta(atoms, morses[m])

    dihedral.d = d

    return i, j, k, l, Hx


def get_dihedral_potential_reduced_hessian_test(atoms, dihedral):
    i, j, k, l, gx = get_dihedral_potential_gradient(atoms, dihedral)

    if dihedral.n is None:
        i, j, k, l, v = get_dihedral_potential_value(atoms, dihedral)
        Hx = np.tensordot(gx, gx, axes=0) / v / 2.0
    else:
        arg = dihedral.n * dihedral.d - dihedral.d0
        Hx = (np.tensordot(gx, gx, axes=0) / dihedral.k / np.sin(arg) / np.sin(arg)
              * np.cos(arg))

    return i, j, k, l, Hx


def get_vdw_potential_value(atoms, vdw):
    i = vdw.atomi
    j = vdw.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)

    v = vdw.Aij / dij**12 - vdw.Bij / dij**6

    vdw.r = dij

    return i, j, v


def get_vdw_potential_gradient(atoms, vdw):
    i = vdw.atomi
    j = vdw.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    gr = (-12.0 * vdw.Aij / dij**13 + 6.0 * vdw.Bij / dij**7) * eij

    gx = np.dot(Bx.T, gr)

    vdw.r = dij

    return i, j, gx


def get_vdw_potential_hessian(atoms, vdw, spectral=False):
    i = vdw.atomi
    j = vdw.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    Pij = np.tensordot(eij, eij, axes=0)
    Qij = np.eye(3) - Pij

    Hr = ((156.0 * vdw.Aij / dij**14 - 42.0 * vdw.Bij / dij**8) * Pij
          + (-12.0 * vdw.Aij / dij**13 + 6.0 * vdw.Bij / dij**7) / dij * Qij)

    Hx = np.dot(Bx.T, np.dot(Hr, Bx))

    if spectral:
        eigvals, eigvecs = linalg.eigh(Hx)
        D = np.diag(np.abs(eigvals))
        U = eigvecs
        Hx = np.dot(U, np.dot(D, np.transpose(U)))

    vdw.r = dij

    return i, j, Hx


def get_coulomb_potential_value(atoms, coulomb):
    i = coulomb.atomi
    j = coulomb.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)

    v = coulomb.chargeij / dij

    coulomb.r = dij

    return i, j, v


def get_coulomb_potential_gradient(atoms, coulomb):
    i = coulomb.atomi
    j = coulomb.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    gr = -coulomb.chargeij / dij / dij * eij

    gx = np.dot(Bx.T, gr)

    coulomb.r = dij

    return i, j, gx


def get_coulomb_potential_hessian(atoms, coulomb, spectral=False):
    i = coulomb.atomi
    j = coulomb.atomj

    rij = rel_pos_pbc(atoms, i, j)
    dij = linalg.norm(rij)
    eij = rij / dij

    Pij = np.tensordot(eij, eij, axes=0)
    Qij = np.eye(3) - Pij

    Hr = (2.0 * coulomb.chargeij / dij**3) * Pij + \
        (-coulomb.chargeij / dij / dij) / dij * Qij

    Hx = np.dot(Bx.T, np.dot(Hr, Bx))

    if spectral:
        eigvals, eigvecs = linalg.eigh(Hx)
        D = np.diag(np.abs(eigvals))
        U = eigvecs
        Hx = np.dot(U, np.dot(D, np.transpose(U)))

    coulomb.r = dij

    return i, j, Hx


def rel_pos_pbc(atoms, i, j):
    """
    Return difference between two atomic positions, 
    correcting for jumps across PBC
    """
    d = atoms.get_positions()[i, :] - atoms.get_positions()[j, :]
    g = linalg.inv(atoms.get_cell().T)
    f = np.floor(np.dot(g, d.T) + 0.5)
    d -= np.dot(atoms.get_cell().T, f).T
    return d
