"""

script for analyzing dynasor output from a diatomic chain calculation


"""

import pickle
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt


def function_Ckw(w, w0, gamma, A):
    """
    Current correlation function in angular frequency domain form as
    C(k,w) = A * w**2 * 2 * G * w0**2 /( (w**2 - w0**2)**2 + gamma**2 * w**2)
    """
    return A * w**2 * 2 * gamma * w0**2 / ((w**2 - w0**2)**2 + gamma**2 * w**2)


def function_Ckw_double(w, w01, gamma1, A1, w02, gamma2, A2):
    """
    Current correlation double peak function
    """
    return function_Ckw(w, w01, gamma1, A1) + function_Ckw(w, w02, gamma2, A2)


def function_Ckt(t, w0, gamma, A):
    """
    Current correlation function in time domain as
    C(k,t) = A * np.exp(-gamma/2.0 * np.abs(t)) * (np.cos(w_e*t) + \
             gamma/(2*w_e) * np.sin(w_e*np.abs(t))
    with w_e = np.sqrt(w0**2 - gamma**2/4.0)
    """
    if w0**2 - gamma**2/4.0 < 0.0:
        w_e = 1e-10
    else:
        w_e = np.sqrt(w0**2 - gamma**2/4.0)
    return A * np.exp(-gamma/2.0 * np.abs(t)) * \
        (np.cos(w_e*t) + gamma/(2*w_e) * np.sin(w_e*np.abs(t)))


def function_Ckt_double(w, w01, gamma1, A1, w02, gamma2, A2):
    """
    Current correlation double peak function
    """
    return function_Ckt(t, w01, gamma1, A1) + function_Ckt(t, w02, gamma2, A2)


# Constants
invfs2mev = 658.2119        # Angular frequency in 1/fs to meV
eV2J = 1.60217662e-19       # eV to Joule
amu2kg = 1.66053904e-27   # amu (grams/mole) to kg

# Parameters
m1 = 100 * amu2kg  # kg
m2 = 200 * amu2kg  # kg
kspring = 0.2 * eV2J * 1e20    # spring constant J/m**2
a = 0.2   # lattice parameter nm

# Read dynasor output file
dynasor_data = pickle.load(open('outputs/dynasor_1Dchain.short.pickle', 'rb'))
data = {}
for item in dynasor_data:
    data[item[1]] = item[0]

k = data['k']  # 2pi / nm
w = data['w']  # 2pi /fs
t = data['t']  # fs


# Analytical solution for diatomic chain [Kittel p97], 1e-15 from 1/s to 1/fs
w_theory1 = 1e-15 * np.sqrt(kspring * (m1 + m2) / (m1*m2) + kspring * np.sqrt(
        (m1 + m2)**2 / (m1*m2)**2 - 2 * (1-np.cos(k*a))/(m1*m2)))
w_theory2 = 1e-15 * np.sqrt(kspring * (m1 + m2) / (m1*m2) - kspring * np.sqrt(
        (m1 + m2)**2 / (m1*m2)**2 - 2 * (1-np.cos(k*a))/(m1*m2)))


# Fit data and plot
plot_indices = [4, 8, 10, 18, 19]  # which k-points to plot

fig = plt.figure(figsize=(14, 13), dpi=100, facecolor='w', edgecolor='k')
props = dict(boxstyle='round', facecolor='wheat', alpha=0.9)
LW = 1.5
blue = '#1F77B4'
red = '#D62728'
w_max = 1.4 * np.max(w_theory1)
t_max = 5.0 * np.pi / np.min(w_theory2[1:-1])
t_max = np.max(t)

fitted_frequencies = []
for i, k_i in enumerate(k):
    if i == 0 or i == len(k)-1:  # skip gamma points
        continue
    ckt = data['Cl_k_t_0_0'][:, i] + data['Cl_k_t_0_1'][:, i] \
        + data['Cl_k_t_1_1'][:, i]
    ckw = np.abs(data['Cl_k_w_0_0'][:, i] + data['Cl_k_w_0_1'][:, i] +
                 data['Cl_k_w_1_1'][:, i])

    # use the analytical frequencies as initial guess for fit ()
    guess_time = [w_theory1[i], 1e-10, 5e-6, w_theory2[i], 1e-10, 1e-5]
    guess_freq = [w_theory1[i], 1e-3, 1, w_theory2[i], 1e-5, 1]

    time_fit = curve_fit(function_Ckt_double, t, ckt, p0=guess_time)
    time_params = time_fit[0]

    # freq fit might fail sometimes because its badly conditioned
    freq_fit = curve_fit(function_Ckw_double, w, ckw, p0=guess_freq,
                         maxfev=10000)
    freq_params = freq_fit[0]

    fitted_frequencies.append([time_params[0], time_params[3]])

    if i in plot_indices:
        ax_t = fig.add_subplot(len(plot_indices), 2, 2*plot_indices.index(i)+1)
        ax_w = fig.add_subplot(len(plot_indices), 2, 2*plot_indices.index(i)+2)

        ax_t.plot(t, function_Ckt_double(t, *time_params), color=blue,
                  linewidth=LW)
        ax_t.plot(t, ckt, 'o', color=red)
        ax_t.set_xlim([0.0, t_max])
        ax_t.tick_params(axis='y', which='both', left='off', right='off',
                         labelleft='off')

        ax_w.plot(w, function_Ckw_double(w, *freq_params), color=blue,
                  linewidth=LW, label='fit')
        ax_w.plot(w, ckw, 'o', color=red, label='correlation data')
        ax_w.set_xlim([0.0, w_max])
        ax_w.tick_params(axis='y', which='both', left='off', right='off',
                         labelleft='off')
        plt.tight_layout()

fitted_frequencies = np.array(fitted_frequencies)
ax_t.set_xlabel('Time (fs)')
ax_w.set_xlabel('Angular Frequency (2pi/fs)')
plt.legend(numpoints=1, loc='best')

# ------------------------

fig = plt.figure(figsize=(7, 5), dpi=100, facecolor='w', edgecolor='k')
plt.plot(k * a / (2 * np.pi), w_theory1, color=blue, linewidth=LW,
         label='analytical')
plt.plot(k * a / (2 * np.pi), w_theory2, color=blue, linewidth=LW)
plt.plot(k[1:-1] * a / (2 * np.pi), fitted_frequencies[:, 0], 'o', color=red,
         label='correlation fit values')
plt.plot(k[1:-1] * a / (2 * np.pi), fitted_frequencies[:, 1], 'o', color=red)
plt.xlabel('k * a / 2pi')
plt.ylabel('Angular Frequency (2pi/fs)')
plt.xlim([0.0, 1.0])
plt.legend(numpoints=1, loc='best')

plt.show()
