from __future__ import annotations

import argparse
import sys

import mpi4py

import h5py
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import numpy as np
from matplotlib.ticker import LogLocator, MaxNLocator, NullFormatter


def visualize_elasticity():
    f = h5py.File("bench_edge_output.hdf5", "r", driver="mpio", comm=mpi4py.MPI.COMM_WORLD)
    iterations = f.get("its")[:]
    dofs = f.get("num_dofs")[:]
    slaves = np.sum(f.get("num_slaves")[:], axis=1)
    solver = f.get("solve_time").attrs["solver"].decode("utf-8")
    ct = f.get("solve_time").attrs["ct"].decode("utf-8")
    degree = f.get("solve_time").attrs["degree"].decode("utf-8")
    f.close()
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)

    plt.plot(dofs, iterations, "-ro", label="MPC", markersize=12)

    f_ref = h5py.File("elasticity_ref.hdf5", "r", driver="mpio", comm=mpi4py.MPI.COMM_WORLD)
    iterations_ref = f_ref.get("its")[:]
    dofs_ref = f_ref.get("num_dofs")[:]
    f_ref.close()

    plt.plot(dofs_ref, iterations_ref, "-bs", label="Unconstrained")

    ax.tick_params(axis="both", which="major", labelsize=20)
    ax.set_xscale("log")
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    plt.xlabel("# DOFS", fontsize=20)
    plt.ylabel("# Iterations", fontsize=20)
    trans_offset = mtransforms.offset_copy(ax.transData, fig=fig, x=0.025, y=-0.1, units="inches")
    for i in range(len(iterations)):
        plt.text(dofs[i], iterations[i], slaves[i], transform=trans_offset, fontsize=20)
    plt.title("Linear elasticity (CG{0:s}, {1:s}) with {2:s}".format(degree, ct, solver), fontsize=25)
    plt.legend(fontsize=15)
    ax.minorticks_on()
    ax.set_ylim([0, max([max(iterations), max(iterations_ref)]) + 1])
    ax.set_xlim([1e2, 1e8])
    locmax = LogLocator(base=10.0, numticks=8)
    ax.xaxis.set_major_locator(locmax)
    locmin = LogLocator(base=10.0, subs=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), numticks=9)
    ax.xaxis.set_minor_locator(locmin)
    ax.xaxis.set_minor_formatter(NullFormatter())
    plt.grid(True, which="both", axis="both")
    plt.savefig("elasticity_iterations_CG{0:s}_{1:s}.png".format(degree, ct), bbox_inches="tight")


def visualize_periodic():
    f = h5py.File("periodic_output.hdf5", "r", driver="mpio", comm=mpi4py.MPI.COMM_WORLD)
    iterations = f.get("its")[:]
    dofs = f.get("num_dofs")[:]
    slaves = np.sum(f.get("num_slaves")[:], axis=1)

    solver = f.get("solve_time").attrs["solver"].decode("utf-8")
    ct = f.get("solve_time").attrs["ct"].decode("utf-8")
    degree = f.get("solve_time").attrs["degree"].decode("utf-8")
    f.close()

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)

    plt.plot(dofs, iterations, "-ro", label="MPC", markersize=12)

    f_ref = h5py.File("periodic_ref_output.hdf5", "r", driver="mpio", comm=mpi4py.MPI.COMM_WORLD)
    iterations_ref = f_ref.get("its")[:]
    dofs_ref = f_ref.get("num_dofs")[:]
    f_ref.close()

    plt.plot(dofs_ref, iterations_ref, "-bs", label="Unconstrained")

    ax.tick_params(axis="both", which="major", labelsize=20)
    ax.set_xscale("log")
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.set_ylim([0, max(iterations) + 1])
    ax.set_xlim([1e2, max(dofs) + 1])
    plt.xlabel("# DOFS", fontsize=20)
    plt.ylabel("# Iterations", fontsize=20)

    trans_offset = mtransforms.offset_copy(ax.transData, fig=fig, x=0.025, y=0.025, units="inches")
    for i in range(len(iterations)):
        plt.text(dofs[i], iterations[i], slaves[i], transform=trans_offset, fontsize=20)
    plt.title("Periodic Poisson (CG {0:s}, {1:s}) with {2:s}".format(degree, ct, solver), fontsize=25)
    plt.legend(fontsize=15)
    ax.minorticks_on()
    ax.set_ylim([0, max(iterations) + 1])
    ax.set_xlim([1e2, 1e8])
    locmax = LogLocator(base=10.0, numticks=8)
    ax.xaxis.set_major_locator(locmax)
    locmin = LogLocator(base=10.0, subs=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), numticks=9)
    ax.xaxis.set_minor_locator(locmin)
    ax.xaxis.set_minor_formatter(NullFormatter())
    plt.grid(True, which="both", axis="both")
    plt.savefig("periodic_iterations_CG{0:s}_{1:s}.png".format(degree, ct), bbox_inches="tight")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--elasticity",
        action="store_true",
        dest="elasticity",
        default=False,
        help="Visualize iterations for elasticity",
    )
    parser.add_argument(
        "--periodic",
        action="store_true",
        dest="periodic",
        default=False,
        help="Visualize iterations for periodic",
    )
    args = parser.parse_args()
    thismodule = sys.modules[__name__]
    periodic = elasticity = None
    for key in vars(args):
        setattr(thismodule, key, getattr(args, key))
    if elasticity:
        visualize_elasticity()

    if periodic:
        visualize_periodic()
