# -*- coding: utf-8 -*-
# Copyright 2007-2023 The HyperSpy developers
#
# This file is part of RosettaSciIO.
#
# RosettaSciIO is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RosettaSciIO is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with RosettaSciIO. If not, see <https://www.gnu.org/licenses/#GPL>.

# The EMD format is a hdf5 standard proposed at Lawrence Berkeley
# National Lab (see https://emdatasets.com/ for more information).
# FEI later developed another EMD format, also based on the hdf5 standard. This
# reader first checked if the file have been saved by Velox (FEI EMD format)
# and use either the EMD class or the FEIEMDReader class to read the file.
# Writing file is only supported for EMD Berkeley file.


import logging
import math
import os
import re

import dask.array as da
import h5py
import numpy as np

from rsciio._hierarchical import get_signal_chunks
from rsciio.utils.tools import _UREG, DTBox

EMD_VERSION = "0.2"

_logger = logging.getLogger(__name__)


class EMD_NCEM:
    """Class for reading and writing the Berkeley variant of the electron
    microscopy datasets (EMD) file format. It reads files EMD NCEM, including
    files generated by the prismatic software.

    Attributes
    ----------
    dictionaries: list
        List of dictionaries which are passed to the file_reader.
    """

    def read_file(self, file, lazy=None, dataset_path=None, stack_group=None):
        """
        Read the data from an emd file

        Parameters
        ----------
        file : file handle
            Handle of the file to read the data from.
        lazy : bool, optional
            Load the data lazily. The default is False.
        dataset_path : None, str or list of str
            Path of the dataset. If None, load all supported datasets,
            otherwise the specified dataset. The default is None.
        stack_group : bool, optional
            Stack datasets of groups with common name. Relevant for emd file
            version >= 0.5 where groups can be named 'group0000', 'group0001',
            etc.
        """
        self.file = file
        self.lazy = lazy

        if isinstance(dataset_path, list):
            if stack_group:
                _logger.warning(
                    "The argument 'dataset_path' and "
                    "'stack_group' are not compatible."
                )
            stack_group = False
            dataset_path = dataset_path.copy()
        elif isinstance(dataset_path, str):
            dataset_path = [dataset_path]
        # if 'datasets' is not provided, we load all valid datasets
        elif dataset_path is None:
            dataset_path = self.find_dataset_paths(file)
            if stack_group is None:
                stack_group = True

        self.dictionaries = []

        while len(dataset_path) > 0:
            path = dataset_path.pop(0)
            group_paths = [os.path.dirname(path)]
            dataset_name = os.path.basename(path)

            if stack_group:
                # Find all the datasets in this group which are also listed
                # in dataset_path:
                # 1. add them to 'group_paths'
                # 2. remove them from 'dataset_path'
                group_basename = group_paths[0]
                if self._is_prismatic_file and "ppotential" not in path:
                    # In prismatic file, the group name have '0000' except
                    # for 'ppotential'
                    group_basename = group_basename[:-4]
                for _path in dataset_path[:]:
                    if path != _path and group_basename in _path:
                        group_paths.append(os.path.dirname(_path))
                        dataset_path.remove(_path)
                title = os.path.basename(group_basename)
            else:
                title = os.path.basename(group_paths[0])

            _logger.debug(f"Loading dataset: {path}")

            om = self._parse_original_metadata()
            data, axes = self._read_data_from_groups(
                group_paths, dataset_name, title, om
            )

            md = self._parse_metadata(group_paths[0], title=title)
            d = {
                "data": data,
                "axes": axes,
                "metadata": md,
                "original_metadata": om,
            }
            self.dictionaries.append(d)

    @classmethod
    def find_dataset_paths(cls, file, supported_dataset=True):
        """
        Find the paths of all groups containing valid EMD data.

        Parameters
        ----------
        file : hdf5 file handle
        supported_dataset : bool, optional
            If True (default), returns the paths of all supported datasets,
            otherwise returns the path of the non-supported other dataset.
            This is relevant for groups containing auxiliary dataset(s) which
            are not supported by HyperSpy or described in the EMD NCEM dataset
            specification.

        Returns
        -------
        datasets : list
            List of path to these group.

        """

        def print_dataset_only(item_name, item, dataset_only):
            if supported_dataset is os.path.basename(item_name).startswith(
                (
                    "data",
                    "counted_datacube",
                    "datacube",
                    "diffractionslice",
                    "realslice",
                    "pointlistarray",
                    "pointlist",
                )
            ):
                if isinstance(item, h5py.Dataset):
                    grp = file.get(os.path.dirname(item_name))
                    if cls._get_emd_group_type(grp):
                        dataset_path.append(item_name)

        def f(item_name, item):
            return print_dataset_only(item_name, item, supported_dataset)

        dataset_path = []
        file.visititems(f)

        return dataset_path

    @property
    def _is_prismatic_file(self):
        return True if "4DSTEM_simulation" in self.file.keys() else False

    @property
    def _is_py4DSTEM_file(self):
        return True if "4DSTEM_experiment" in self.file.keys() else False

    @staticmethod
    def _get_emd_group_type(group):
        """Return the value of the 'emd_group_type' attribute if it exist,
        otherwise returns False
        """
        return group.attrs.get("emd_group_type", False)

    @staticmethod
    def _read_dataset(dataset):
        """Read dataset and use the h5py AsStrWrapper when the dataset is of
        string type (h5py 3.0 and newer)
        """
        chunks = dataset.chunks
        if chunks is None:
            chunks = "auto"
        if h5py.check_string_dtype(dataset.dtype) and hasattr(dataset, "asstr"):
            # h5py 3.0 and newer
            # https://docs.h5py.org/en/3.0.0/strings.html
            data = dataset.asstr()[:]
        else:
            data = dataset[:]
        return data, chunks

    def _read_emd_version(self, group):
        """Return the group version if the group is an EMD group, otherwise
        return None.
        """
        if "version_major" in group.attrs.keys():
            version = [
                str(group.attrs.get(v)) for v in ["version_major", "version_minor"]
            ]
            version = ".".join(version)
            return version
        else:
            return None

    def _read_data_from_groups(
        self, group_path, dataset_name, stack_key=None, original_metadata={}
    ):
        axes = []
        transpose_required = True if dataset_name != "datacube" else False

        dataset_list = [self.file.get(f"{key}/{dataset_name}") for key in group_path]

        if None in dataset_list:
            raise IOError("Dataset can't be found.")

        if len(dataset_list) > 1:
            # Squeeze the data only when
            if self.lazy:
                data_list = [
                    da.from_array(*self._read_dataset(d)) for d in dataset_list
                ]
                if transpose_required:
                    data_list = [da.transpose(d) for d in data_list]
                data = da.stack(data_list)
                data = da.squeeze(data)
            else:
                data_list = [self._read_dataset(d)[0] for d in dataset_list]
                if transpose_required:
                    data_list = [np.transpose(d) for d in data_list]
                data = np.stack(data_list).squeeze()
        else:
            d = dataset_list[0]
            if self.lazy:
                data = da.from_array(*self._read_dataset(d))
            else:
                data = self._read_dataset(d)[0]
            if transpose_required:
                data = data.transpose()

        shape = data.shape

        if len(dataset_list) > 1:
            offset, scale, units = 0, 1, None
            if self._is_prismatic_file and "depth" in stack_key:
                simu_om = original_metadata.get("simulation_parameters", {})
                if "numSlices" in simu_om.keys():
                    scale = simu_om["numSlices"]
                    scale *= simu_om.get("sliceThickness", 1.0)
                if "zStart" in simu_om.keys():
                    offset = simu_om["zStart"]
                    # when zStart = 0, the first image is not at zero but
                    # the first output: numSlices * sliceThickness (=scale)
                    if offset == 0:
                        offset = scale
                units = "Å"
                total_thickness = (
                    simu_om.get("tile", 0)[2] * simu_om.get("cellDimension", 0)[0]
                )
                if not math.isclose(
                    total_thickness, len(dataset_list) * scale, rel_tol=1e-4
                ):
                    _logger.warning(
                        "Depth axis is non-uniform and its offset "
                        "and scale can't be set accurately."
                    )
                    # When non-uniform/non-linear axis are implemented, adjust
                    # the final depth to the "total_thickness"
                    offset, scale, units = 0, 1, None
            axes.append(
                {
                    "index_in_array": 0,
                    "name": stack_key if stack_key is not None else None,
                    "offset": offset,
                    "scale": scale,
                    "size": len(dataset_list),
                    "units": units,
                    "navigate": True,
                }
            )

            array_indices = np.arange(1, len(shape))
            dim_indices = array_indices
        else:
            array_indices = np.arange(0, len(shape))
            # dim indices start form 1
            dim_indices = array_indices + 1

        if transpose_required:
            dim_indices = dim_indices[::-1]

        for arr_index, dim_index in zip(array_indices, dim_indices):
            dim = self.file.get(f"{group_path[0]}/dim{dim_index}")
            offset, scale = self._parse_axis(dim)
            if self._is_prismatic_file:
                if dataset_name == "datacube":
                    # For datacube (4D STEM), the signal is detector coordinate
                    sig_dim = ["dim3", "dim4"]
                else:
                    sig_dim = ["dim1", "dim2"]

                navigate = dim.name.split("/")[-1] not in sig_dim

            else:
                navigate = False
            axes.append(
                {
                    "index_in_array": arr_index,
                    "name": self._parse_attribute(dim, "name"),
                    "units": self._parse_attribute(dim, "units"),
                    "size": shape[arr_index],
                    "offset": offset,
                    "scale": scale,
                    "navigate": navigate,
                }
            )
        return data, axes

    def _parse_attribute(self, obj, key):
        value = obj.attrs.get(key)
        if value is not None:
            if not isinstance(value, str):
                value = value.decode()
            if key == "units":
                # Get all the units
                units_list = re.findall(r"(\[.+?\])", value)
                units_list = [u[1:-1].replace("_", "") for u in units_list]
                value = " * ".join(units_list)
                try:
                    units = _UREG.parse_units(value)
                    value = f"{units:~}"
                except Exception:
                    # In case it fails parsing units
                    pass
        return value

    def _parse_metadata(self, group_basename, title=""):
        filename = self.file if isinstance(self.file, str) else self.file.filename
        md = {
            "General": {
                "title": title.replace("_depth", ""),
                "original_filename": os.path.split(filename)[1],
            },
            "Signal": {"signal_type": ""},
        }
        if "CBED" in group_basename:
            md["Signal"]["signal_type"] = "electron_diffraction"
        return md

    def _parse_original_metadata(self):
        f = self.file
        om = {"EMD_version": self._read_emd_version(self.file.get("/"))}
        for group_name in ["microscope", "sample", "user", "comments"]:
            group = f.get(group_name)
            if group is not None:
                om.update(
                    {group_name: {key: value for key, value in group.attrs.items()}}
                )

        if self._is_prismatic_file:
            md_mapping = {
                "i": "filenameAtoms",
                "a": "algorithm",
                "fx": "interpolationFactorX",
                "fy": "interpolationFactorY",
                "F": "numFP",
                "ns": "numSlices",
                "te": "includeThermalEffects",
                "oc": "includeOccupancy",
                "3D": "save3DOutput",
                "4D": "save3DOutput",
                "DPC": "saveDPC_CoM",
                "ps": "savePotentialSlices",
                "nqs": "nyquistSampling",
                "px": "realspacePixelSizeX",
                "py": "realspacePixelSizeY",
                "P": "potBound",
                "s": "sliceThickness",
                "zs": "zStart",
                "E": "E0",
                "A": "alphaBeamMax",
                "rx": "probeStepX",
                "ry": "probeStepY",
                "df": "probeDefocus",
                "sa": "probeSemiangle",
                "d": "detectorAngleStep",
                "tx": "probeXtilt",
                "ty": "probeYtilt",
                "c": "cellDimension",
                "t": "tile",
                "wx": "scanWindowX",
                "wy": "scanWindowY",
                "wxr": "scanWindowX_r",
                "wyr": "scanWindowY_r",
                "2D": "integrationAngle",
            }
            simu_md = f.get(
                "4DSTEM_simulation/metadata/metadata_0/original/simulation_parameters"
            )
            om["simulation_parameters"] = {
                md_mapping.get(k, k): v for k, v in simu_md.attrs.items()
            }

        return om

    @staticmethod
    def _parse_axis(axis_data):
        """
        Estimate, offset, scale from a 1D array
        """
        if axis_data.ndim > 0 and np.issubdtype(axis_data.dtype, np.number):
            offset, scale = axis_data[0], np.diff(axis_data).mean()
        else:
            # This is a string, return default values
            # When non-uniform axis is supported we should be able to parse
            # string
            offset, scale = 0, 1
        return offset, scale

    def write_file(self, file, signal, **kwargs):
        """
        Write signal to file.

        Parameters
        ----------
        file : str of h5py file handle
            If str, filename of the file to write, otherwise a h5py file handle
        signal : instance of hyperspy signal
            The signal to save.
        **kwargs : dict
            Keyword argument are passed to the ``h5py.Group.create_dataset``
            method.

        """
        if isinstance(file, str):
            emd_file = h5py.File(file, "w")
        else:
            emd_file = file
        # Write version:
        ver_maj, ver_min = EMD_VERSION.split(".")
        emd_file.attrs["version_major"] = ver_maj
        emd_file.attrs["version_minor"] = ver_min

        # Write attribute from the original_metadata
        om = DTBox(signal["original_metadata"], box_dots=True)
        for group_name in ["microscope", "sample", "user", "comments"]:
            group = emd_file.require_group(group_name)
            d = om.get(group_name, None)
            if d is not None:
                for key, value in d.items():
                    group.attrs[key] = value

        # Write signals:
        signal_group = emd_file.require_group("signals")
        signal_group.attrs["emd_group_type"] = 1
        self._write_signal_to_group(signal_group, signal, **kwargs)
        emd_file.close()

    def _write_signal_to_group(self, signal_group, signal, chunks=None, **kwargs):
        # Save data:
        title = signal["metadata"]["General"]["title"] or "__unnamed__"
        dataset = signal_group.require_group(title)
        data = signal["data"].T
        maxshape = tuple(None for _ in data.shape)
        if np.issubdtype(data.dtype, np.dtype("U")):
            # Saving numpy unicode type is not supported in h5py
            data = data.astype(np.dtype("S"))
        if chunks is None:
            if isinstance(data, da.Array):
                # For lazy dataset, by default, we use the current dask chunking
                chunks = tuple([c[0] for c in data.chunks])
            else:
                signal_axes = [
                    i for i, axis in enumerate(signal["axes"]) if not axis["navigate"]
                ]
                chunks = get_signal_chunks(data.shape, data.dtype, signal_axes)
        # when chunks=True, we leave it to h5py `guess_chunk`
        elif chunks is not True:
            # Need to reverse since the data is transposed when saving
            chunks = chunks[::-1]

        dataset.create_dataset(
            "data", data=data, maxshape=maxshape, chunks=chunks, **kwargs
        )

        array_indices = np.arange(0, len(data.shape))
        dim_indices = (array_indices + 1)[::-1]
        # Iterate over all dimensions:
        for i, dim_index in zip(array_indices, dim_indices):
            key = f"dim{dim_index}"
            axis = signal["axes"][i]
            offset = axis["offset"]
            scale = axis["scale"]
            dim = dataset.create_dataset(key, data=[offset, offset + scale])
            name = axis["name"]
            if name is None:
                name = ""
            dim.attrs["name"] = name
            units = axis["units"]
            if units is None:
                units = ""
            else:
                units = "[{}]".format("_".join(list(units)))
            dim.attrs["units"] = units
        # Write metadata:
        dataset.attrs["emd_group_type"] = 1
        for key, value in signal["metadata"]["Signal"].items():
            try:  # If something h5py can't handle is saved in the metadata...
                dataset.attrs[key] = value
            except Exception:  # ...let the user know what could not be added!
                _logger.warning(
                    "The following information couldn't be "
                    f"written in the file: {key}: {value}"
                )


def read_emd_version(group):
    """Function to read the emd file version from a group. The EMD version is
    saved in the attributes 'version_major' and 'version_minor'.

    Parameters
    ----------
    group : hdf5 group
        The group to extract the version from.

    Returns
    -------
    file version : str
        Empty string if the file version is not defined in this group

    """
    major = group.attrs.get("version_major", None)
    minor = group.attrs.get("version_minor", None)
    if major is not None and minor is not None:
        return f"{major}.{minor}"
    else:
        return ""
