#!/usr/bin/env python

# intensity.py
# definitions of intensity characters

import collections

import geopandas as gpd
import numpy as np
import pandas as pd
from packaging.version import Version
from tqdm.auto import tqdm  # progress bar

from .utils import deprecated, removed

GPD_GE_10 = Version(gpd.__version__) >= Version("1.0dev")

__all__ = [
    "AreaRatio",
    "Count",
    "Courtyards",
    "BlocksCount",
    "Reached",
    "NodeDensity",
    "Density",
]


@removed("a direct division of areas or momepy.describe_agg()")
class AreaRatio:
    """
    Calculate covered area ratio or floor area ratio of objects. Either ``unique_id``
    or both ``left_unique_id`` and ``right_unique_id`` are required.

    .. math::
        \\textit{covering object area} \\over \\textit{covered object area}

    Adapted from :cite:`schirmer2015`.

    Parameters
    ----------
    left : GeoDataFrame
        A GeoDataFrame containing objects being covered (e.g. land unit).
    right : GeoDataFrame
        A GeoDataFrame with covering objects (e.g. building).
    left_areas : str, list, np.array, pd.Series
        The name of the left dataframe column, ``np.array``, or
        ``pd.Series`` where area values are stored.
    right_areas : str, list, np.array, pd.Series
        The name of the right dataframe column, ``np.array``, or
        ``pd.Series`` where area values are stored.
        representing either projected or floor area.
    unique_id : str (default None)
        The name of the unique ID column shared amongst ``left`` and ``right`` gdfs.
        If there is none, it can be generated by :py:func:`momepy.unique_id()`.
    left_unique_id : str, list, np.array, pd.Series (default None)
        The name of the ``left`` dataframe column, ``np.array``, or
        ``pd.Series`` where the shared unique IDs are stored.
    right_unique_id : str, list, np.array, pd.Series (default None)
        The name of the ``right`` dataframe column, ``np.array``, or
        ``pd.Series`` where the shared unique IDs are stored.

    Attributes
    ----------
    series : Series
        A Series containing resulting values.
    left : GeoDataFrame
        The original left GeoDataFrame.
    right : GeoDataFrame
        The original right GeoDataFrame.
    left_areas :  Series
        A Series containing the used left areas.
    right_areas :  Series
        A Series containing the used right areas.
    left_unique_id : Series
        A Series containing the used left ID.
    right_unique_id : Series
        A Series containing used right ID.

    Examples
    --------
    >>> tessellation_df['CAR'] = mm.AreaRatio(tessellation_df,
    ...                                       buildings_df,
    ...                                       'area',
    ...                                       'area',
    ...                                       'uID').series
    """

    def __init__(
        self,
        left,
        right,
        left_areas,
        right_areas,
        unique_id=None,
        left_unique_id=None,
        right_unique_id=None,
    ):
        self.left = left
        self.right = right

        left = left.copy()
        right = right.copy()

        if unique_id:
            left_unique_id = unique_id
            right_unique_id = unique_id
        else:
            if left_unique_id is None or right_unique_id is None:
                raise ValueError(
                    "Unique ID not correctly set. Use either ``network_id`` or both"
                    "``left_unique_id`` and ``right_unique_id``."
                )
        self.left_unique_id = left_unique_id
        self.right_unique_id = right_unique_id

        if not isinstance(left_areas, str):
            left["mm_a"] = left_areas
            left_areas = "mm_a"
        self.left_areas = left[left_areas]
        if not isinstance(right_areas, str):
            right["mm_a"] = right_areas
            right_areas = "mm_a"
        self.right_areas = right[right_areas]

        look_for = right[
            [right_unique_id, right_areas]
        ].copy()  # keeping only necessary columns
        look_for.rename(index=str, columns={right_areas: "lf_area"}, inplace=True)
        look_for = look_for.groupby(right_unique_id).sum().reset_index()
        objects_merged = left[[left_unique_id, left_areas]].merge(
            look_for, left_on=left_unique_id, right_on=right_unique_id, how="left"
        )
        objects_merged.index = left.index

        self.series = objects_merged["lf_area"] / objects_merged[left_areas]


@removed("momepy.describe_agg()")
class Count:
    """
    Calculate the number of elements within an aggregated structure. Aggregated
    structures can typically be blocks, street segments, or street nodes (their
    snapepd objects). The right gdf has to have a unique ID of aggregated structures
    assigned before hand (e.g. using :py:func:`momepy.get_network_id`).
    If ``weighted=True``, the number of elements will be divided by the area of
    length (based on geometry type) of aggregated elements, to return relative value.

    .. math::
        \\sum_{i \\in aggr} (n_i);\\space \\frac{\\sum_{i \\in aggr} (n_i)}{area_{aggr}}

    Adapted from :cite:`hermosilla2012` and :cite:`feliciotti2018`.

    Parameters
    ----------
    left : GeoDataFrame
        A GeoDataFrame containing aggregation to analyse.
    right : GeoDataFrame
        A GeoDataFrame containing objects to analyse.
    left_id : str
        The name of the column where unique ID in the ``left`` gdf is stored.
    right_id : str
        The name of the column where unique ID of
        aggregation in the ``right`` gdf is stored.
    weighted : bool (default False)
        If ``True``, count will be divided by the area or length.

    Attributes
    ----------
    series : Series
        A Series containing resulting values.
    left : GeoDataFrame
        The original ``left`` GeoDataFrame.
    right : GeoDataFrame
        The original ``right`` GeoDataFrame.
    left_id : Series
        A Series containing used ``left`` ID.
    right_id : Series
        A Series containing used ``right`` ID.
    weighted : bool
        ``True`` if the weighted value was used.

    Examples
    --------
    >>> blocks_df['buildings_count'] = mm.Count(blocks_df,
    ...                                         buildings_df,
    ...                                         'bID',
    ...                                         'bID',
    ...                                         weighted=True).series
    """

    def __init__(self, left, right, left_id, right_id, weighted=False):
        self.left = left
        self.right = right
        self.left_id = left[left_id]
        self.right_id = right[right_id]
        self.weighted = weighted

        count = collections.Counter(right[right_id])
        df = pd.DataFrame.from_dict(count, orient="index", columns=["mm_count"])
        joined = left[[left_id, left.geometry.name]].join(df["mm_count"], on=left_id)
        joined.loc[joined["mm_count"].isna(), "mm_count"] = 0

        if weighted:
            if left.geometry[0].geom_type in ["Polygon", "MultiPolygon"]:
                joined["mm_count"] = joined["mm_count"] / left.geometry.area
            elif left.geometry[0].geom_type in ["LineString", "MultiLineString"]:
                joined["mm_count"] = joined["mm_count"] / left.geometry.length
            else:
                raise TypeError("Geometry type does not support weighting.")

        self.series = joined["mm_count"]


@deprecated("courtyards")
class Courtyards:
    """
    Calculate the number of courtyards within the joined structure.

    Adapted from :cite:`schirmer2015`.

    Parameters
    ----------
    gdf : GeoDataFrame
        A GeoDataFrame containing objects to analyse.
    spatial_weights : libpysal.weights, optional
        A spatial weights matrix. If None, Queen contiguity matrix
        will be calculated based on objects. It is to denote adjacent
        buildings and is based on integer index.
    verbose : bool (default True)
        If ``True``, shows progress bars in loops and indication of steps.

    Attributes
    ----------
    series : Series
        A Series containing resulting values.
    gdf : GeoDataFrame
        The original GeoDataFrame.
    sw : libpysal.weights
        The spatial weights matrix.

    Examples
    --------
    >>> buildings_df['courtyards'] = mm.Courtyards(buildings_df).series
    Calculating spatial weights...
    """

    def __init__(self, gdf, spatial_weights=None, verbose=True):
        self.gdf = gdf

        results_list = []
        gdf = gdf.copy()

        # if weights matrix is not passed, generate it from objects
        if spatial_weights is None:
            print("Calculating spatial weights...") if verbose else None
            from libpysal.weights import Queen

            spatial_weights = Queen.from_dataframe(
                gdf, silence_warnings=True, use_index=False
            )

        self.sw = spatial_weights
        # dict to store nr of courtyards for each uID
        courtyards = {}
        components = pd.Series(spatial_weights.component_labels, index=gdf.index)
        for i, index in tqdm(
            enumerate(gdf.index), total=gdf.shape[0], disable=not verbose
        ):
            # if the id is already present in courtyards, continue (avoid repetition)
            if index in courtyards:
                continue
            else:
                comp = spatial_weights.component_labels[i]
                to_join = components[components == comp].index
                joined = gdf.loc[to_join]
                # buffer to avoid multipolygons where buildings touch by corners only
                dissolved = (
                    joined.buffer(0.01).union_all()
                    if GPD_GE_10
                    else joined.buffer(0.01).unary_union
                )
                interiors = len(list(dissolved.interiors))
                for b in to_join:
                    courtyards[b] = interiors  # fill dict with values

        results_list = [courtyards[index] for index in gdf.index]

        self.series = pd.Series(results_list, index=gdf.index)


@removed("`.describe()` method of libpysal.graph.Graph")
class BlocksCount:
    """
    Calculates the weighted number of blocks. The number of blocks within neighbours
    defined in ``spatial_weights`` divided by the area covered by the neighbours.

    .. math::

    Adapted from :cite:`dibble2017`.

    Parameters
    ----------
    gdf : GeoDataFrame
        A GeoDataFrame containing morphological tessellation.
    block_id : str, list, np.array, pd.Series
        The name of the objects dataframe column, ``np.array``,
         or ``pd.Series`` where block IDs are stored.
    spatial_weights : libpysal.weights
        A spatial weights matrix.
    unique_id : str
        The name of the column with a unique ID used as the ``spatial_weights`` index.
    weigted : bool, default True
        Return value weighted by the analysed area (``True``) or pure count (``False``).
    verbose : bool (default True)
        If ``True``, shows progress bars in loops and indication of steps.

    Attributes
    ----------
    series : Series
        A Series containing resulting values.
    gdf : GeoDataFrame
        The original GeoDataFrame.
    block_id : Series
        A  Series containing used block ID.
    sw : libpysal.weights
        The spatial weights matrix
    id : Series
        A Series containing used unique ID.
    weighted : bool
        ``True`` if the weighted value was used.

    Examples
    --------
    >>> sw4 = mm.sw_high(k=4, gdf='tessellation_df', ids='uID')
    >>> tessellation_df['blocks_within_4'] = mm.BlocksCount(tessellation_df,
    ...                                                     'bID',
    ...                                                     sw4,
    ...                                                     'uID').series
    """

    def __init__(
        self, gdf, block_id, spatial_weights, unique_id, weighted=True, verbose=True
    ):
        self.gdf = gdf
        self.sw = spatial_weights
        self.id = gdf[unique_id]
        self.weighted = weighted

        # define empty list for results
        results_list = []
        data = gdf.copy()
        if not isinstance(block_id, str):
            data["mm_bid"] = block_id
            block_id = "mm_bid"
        self.block_id = data[block_id]
        data = data.set_index(unique_id)

        if weighted is True:
            areas = data.geometry.area

        for index in tqdm(data.index, total=data.shape[0], disable=not verbose):
            if index in spatial_weights.neighbors:
                neighbours = [index]
                neighbours += spatial_weights.neighbors[index]

                vicinity = data.loc[neighbours]

                if weighted is True:
                    results_list.append(
                        vicinity[block_id].unique().shape[0]
                        / sum(areas.loc[neighbours])
                    )
                elif weighted is False:
                    results_list.append(vicinity[block_id].unique().shape[0])
                else:
                    raise ValueError("Attribute 'weighted' needs to be True or False.")
            else:
                results_list.append(np.nan)

        self.series = pd.Series(results_list, index=gdf.index)


@deprecated("describe_reached_agg")
class Reached:
    """
    Calculates the number of objects reached within neighbours on a street network.
    The number of elements within neighbourhood defined in ``spatial_weights``. If
    ``spatial_weights`` are ``None``, it will assume topological distance ``0``
    (element itself). If ``mode='area'``, returns sum of areas of reached elements.
    Requires a ``unique_id`` of network assigned beforehand
    (e.g. using :py:func:`momepy.get_network_id`).

    Parameters
    ----------
    left : GeoDataFrame
        A GeoDataFrame containing streets (either segments or nodes).
    right : GeoDataFrame
        A GeoDataFrame containing elements to be counted.
    left_id : str, list, np.array, pd.Series (default None)
        The name of the ``left`` dataframe column, ``np.array``, or ``pd.Series``
        where the IDs of streets (segments or nodes) are stored.
    right_id : str, list, np.array, pd.Series (default None)
        The name of the ``right`` dataframe column, ``np.array``, or ``pd.Series``
        where the IDs of streets (segments or nodes) are stored.
    spatial_weights : libpysal.weights (default None)
        A spatial weights matrix.
    mode : str (default 'count')
        Tode of calculation. If ``'count'`` function will return the count of reached
        elements. If ``'sum'``, it will return sum of ``'values'``. If ``'mean'`` it
        will return mean value of ``'values'``. If ``'std'`` it will return standard
        deviation of ``'values'``. If ``'values'`` not set it will use of areas of
        reached elements.
    values : str (default None)
        The name of the objects dataframe column with values used for calculations.
    verbose : bool (default True)
        If ``True``, shows progress bars in loops and indication of steps.

    Attributes
    ----------
    series : Series
        A Series containing resulting values.
    left : GeoDataFrame
        The original left GeoDataFrame.
    right : GeoDataFrame
        The original right GeoDataFrame.
    left_id : Series
        A Series containing used left ID.
    right_id : Series
        A Series containing used right ID.
    mode : str
        The mode of calculation.
    sw : libpysal.weights
        The spatial weights matrix (if set).

    Examples
    --------
    >>> streets_df['reached'] = mm.Reached(streets_df, buildings_df, 'uID').series
    """

    # TODO: allow all modes

    def __init__(
        self,
        left,
        right,
        left_id,
        right_id,
        spatial_weights=None,
        mode="count",
        values=None,
        verbose=True,
    ):
        self.left = left
        self.right = right
        self.sw = spatial_weights
        self.mode = mode

        # define empty list for results
        results_list = []

        if not isinstance(right_id, str):
            right = right.copy()
            right["mm_id"] = right_id
            right_id = "mm_id"
        self.right_id = right[right_id]
        if not isinstance(left_id, str):
            left = left.copy()
            left["mm_lid"] = left_id
            left_id = "mm_lid"
        self.left_id = left[left_id]
        if mode == "count":
            count = collections.Counter(right[right_id])

        # iterating over rows one by one
        for index, lid in tqdm(
            left[left_id].items(), total=left.shape[0], disable=not verbose
        ):
            if spatial_weights is None:
                ids = [lid]
            else:
                neighbours = [index]
                neighbours += spatial_weights.neighbors[index]
                ids = left.iloc[neighbours][left_id]
            if mode == "count":
                counts = []
                for nid in ids:
                    counts.append(count[nid])
                results_list.append(sum(counts))
            else:
                if mode == "sum":
                    func = sum
                elif mode == "mean":
                    func = np.nanmean
                elif mode == "std":
                    func = np.nanstd

                mask = right[right_id].isin(ids)
                if mask.any():
                    if values:
                        results_list.append(func(right.loc[mask][values]))
                    else:
                        results_list.append(func(right.loc[mask].geometry.area))
                else:
                    results_list.append(np.nan)

        self.series = pd.Series(results_list, index=left.index)


@deprecated("node_density")
class NodeDensity:
    """
    Calculate the density of nodes neighbours on street network defined in
    ``spatial_weights``. Calculated as the number of neighbouring
    nodes / cummulative length of street network within neighbours.
    ``node_start`` and ``node_end`` is standard output of
    :py:func:`momepy.nx_to_gdf` and is compulsory.

    Adapted from :cite:`dibble2017`.

    Parameters
    ----------
    left : GeoDataFrame
        A GeoDataFrame containing nodes of street network.
    right : GeoDataFrame
        A GeoDataFrame containing edges of street network.
    spatial_weights : libpysal.weights
        A spatial weights matrix capturing relationship between nodes.
    weighted : bool (default False)
        If ``True``, density will take into account node degree as ``k-1``.
    node_degree : str (optional)
        The name of the column of ``left`` containing node degree.
        Used if ``weighted=True``.
    node_start : str (default 'node_start')
        The name of the column of ``right`` containing the ID of the starting nodes.
    node_end : str (default 'node_end')
        The name of the column of ``right`` containing the ID of the ending node.
    verbose : bool (default True)
        If ``True``, shows progress bars in loops and indication of steps.

    Attributes
    ----------
    series : Series
        A Series containing resulting values.
    left : GeoDataFrame
        The original left GeoDataFrame.
    right : GeoDataFrame
        The original right GeoDataFrame.
    node_start : Series
        A Series containing used ids of starting node.
    node_end : Series
        A Series containing used ids of ending node.
    sw : libpysal.weights
        The spatial weights matrix.
    weighted : bool
        The used weighted value.
    node_degree : Series
        A Series containing used node degree values.

    Examples
    --------
    >>> nodes['density'] = mm.NodeDensity(nodes, edges, sw).series
    """

    def __init__(
        self,
        left,
        right,
        spatial_weights,
        weighted=False,
        node_degree=None,
        node_start="node_start",
        node_end="node_end",
        verbose=True,
    ):
        self.left = left
        self.right = right
        self.sw = spatial_weights
        self.weighted = weighted
        if weighted:
            self.node_degree = left[node_degree]
        self.node_start = right[node_start]
        self.node_end = right[node_end]
        # define empty list for results
        results_list = []

        lengths = right.geometry.length

        # iterating over rows one by one
        for index in tqdm(left.index, total=left.shape[0], disable=not verbose):
            neighbours = [index]
            neighbours += spatial_weights.neighbors[index]
            if weighted:
                neighbour_nodes = left.iloc[neighbours]
                number_nodes = sum(neighbour_nodes[node_degree] - 1)
            else:
                number_nodes = len(neighbours)

            length = lengths.loc[
                right["node_start"].isin(neighbours)
                & right["node_end"].isin(neighbours)
            ].sum()

            if length > 0:
                results_list.append(number_nodes / length)
            else:
                results_list.append(0)

        self.series = pd.Series(results_list, index=left.index)


@removed("`.describe()` method of libpysal.graph.Graph")
class Density:
    """
    Calculate the gross density.

    .. math::
        \\frac{\\sum \\text {values}}{\\sum \\text {areas}}

    Adapted from :cite:`dibble2017`.

    Parameters
    ----------
    gdf : GeoDataFrame
        A GeoDataFrame containing objects to analyse.
    values : str, list, np.array, pd.Series
        The name of the dataframe column, ``np.array``, or ``pd.Series``
        where character values are stored.
    spatial_weights : libpysal.weight
        A spatial weights matrix.
    unique_id : str
        The name of the column with unique ID used as ``spatial_weights`` index
    areas :  str, list, np.array, pd.Series (optional)
        The name of the dataframe column, ``np.array``, or ``pd.Series``
        where area values are stored. If ``None``, gdf.geometry.area will be used.
    verbose : bool (default True)
        If ``True``, shows progress bars in loops and indication of steps.

    Attributes
    ----------
    series : Series
        A Series containing resulting values.
    gdf : GeoDataFrame
        The original GeoDataFrame.
    values : Series
        A Series containing used values.
    sw : libpysal.weights
        The spatial weights matrix.
    id : Series
        A Series containing used unique ID.
    areas : Series
        A Series containing used area values.

    Examples
    --------
    >>> tessellation_df['floor_area_dens'] = mm.Density(tessellation_df,
    ...                                                 'floor_area',
    ...                                                 sw,
    ...                                                 'uID').series
    """

    def __init__(
        self, gdf, values, spatial_weights, unique_id, areas=None, verbose=True
    ):
        self.gdf = gdf
        self.sw = spatial_weights
        self.id = gdf[unique_id]

        # define empty list for results
        results_list = []
        data = gdf.copy()

        if values is not None and not isinstance(values, str):
            data["mm_v"] = values
            values = "mm_v"
        self.values = data[values]
        if areas is not None:
            if not isinstance(areas, str):
                data["mm_a"] = areas
                areas = "mm_a"
        else:
            data["mm_a"] = data.geometry.area
            areas = "mm_a"
        self.areas = data[areas]

        data = data.set_index(unique_id)
        # iterating over rows one by one
        for index in tqdm(data.index, total=data.shape[0], disable=not verbose):
            if index in spatial_weights.neighbors:
                neighbours = [index]
                neighbours += spatial_weights.neighbors[index]
                subset = data.loc[neighbours]
                values_list = subset[values]
                areas_list = subset[areas]

                results_list.append(np.sum(values_list) / np.sum(areas_list))
            else:
                results_list.append(np.nan)

        self.series = pd.Series(results_list, index=gdf.index)
