import numpy as np
import numba
from sklearn.base import BaseEstimator
from sklearn.utils import check_array

from umap.sparse import arr_intersect as intersect1d
from umap.sparse import arr_union as union1d
from umap.umap_ import UMAP, make_epochs_per_sample
from umap.spectral import spectral_layout
from umap.layouts import optimize_layout_aligned_euclidean

INT32_MIN = np.iinfo(np.int32).min + 1
INT32_MAX = np.iinfo(np.int32).max - 1


@numba.njit(parallel=True)
def in1d(arr, test_set):
    test_set = set(test_set)
    result = np.empty(arr.shape[0], dtype=np.bool_)
    for i in numba.prange(arr.shape[0]):
        if arr[i] in test_set:
            result[i] = True
        else:
            result[i] = False

    return result


def invert_dict(d):
    return {value: key for key, value in d.items()}


@numba.njit()
def procrustes_align(embedding_base, embedding_to_align, anchors):
    subset1 = embedding_base[anchors[0]]
    subset2 = embedding_to_align[anchors[1]]
    M = subset2.T @ subset1
    U, S, V = np.linalg.svd(M)
    R = U @ V
    return embedding_to_align @ R


def expand_relations(relation_dicts, window_size=3):
    max_n_samples = (
        max(
            [max(d.keys()) for d in relation_dicts]
            + [max(d.values()) for d in relation_dicts]
        )
        + 1
    )
    result = np.full(
        (len(relation_dicts) + 1, 2 * window_size + 1, max_n_samples),
        -1,
        dtype=np.int32,
    )
    reverse_relation_dicts = [invert_dict(d) for d in relation_dicts]
    for i in range(result.shape[0]):
        for j in range(window_size):
            result_index = (window_size) + (j + 1)
            if i + j + 1 >= len(relation_dicts):
                result[i, result_index] = np.full(max_n_samples, -1, dtype=np.int32)
            else:
                mapping = np.arange(max_n_samples)
                for k in range(j + 1):
                    mapping = np.array(
                        [relation_dicts[i + k].get(n, -1) for n in mapping]
                    )
                result[i, result_index] = mapping

        for j in range(0, -window_size, -1):
            result_index = (window_size) + (j - 1)
            if i + j - 1 < 0:
                result[i, result_index] = np.full(max_n_samples, -1, dtype=np.int32)
            else:
                mapping = np.arange(max_n_samples)
                for k in range(0, j - 1, -1):
                    mapping = np.array(
                        [reverse_relation_dicts[i + k - 1].get(n, -1) for n in mapping]
                    )
                result[i, result_index] = mapping

    return result


@numba.njit()
def build_neighborhood_similarities(graphs_indptr, graphs_indices, relations):
    result = np.zeros(relations.shape, dtype=np.float32)
    center_index = (relations.shape[1] - 1) // 2
    for i in range(relations.shape[0]):
        base_graph_indptr = graphs_indptr[i]
        base_graph_indices = graphs_indices[i]
        for j in range(relations.shape[1]):
            if i + j - center_index < 0 or i + j - center_index >= len(graphs_indptr):
                continue

            comparison_graph_indptr = graphs_indptr[i + j - center_index]
            comparison_graph_indices = graphs_indices[i + j - center_index]
            for k in range(relations.shape[2]):
                comparison_index = relations[i, j, k]
                if comparison_index < 0:
                    continue

                raw_base_graph_indices = base_graph_indices[
                    base_graph_indptr[k] : base_graph_indptr[k + 1]
                ].copy()
                base_indices = relations[i, j][raw_base_graph_indices[
                    raw_base_graph_indices < relations.shape[2]]]
                base_indices = base_indices[base_indices >= 0]
                comparison_indices = comparison_graph_indices[
                    comparison_graph_indptr[comparison_index] : comparison_graph_indptr[
                        comparison_index + 1
                    ]
                ]
                comparison_indices = comparison_indices[
                    in1d(comparison_indices, relations[i, j])
                ]

                intersection_size = intersect1d(base_indices, comparison_indices).shape[
                    0
                ]
                union_size = union1d(base_indices, comparison_indices).shape[0]

                if union_size > 0:
                    result[i, j, k] = intersection_size / union_size
                else:
                    result[i, j, k] = 1.0

    return result


def get_nth_item_or_val(iterable_or_val, n):
    if iterable_or_val is None:
        return None
    if type(iterable_or_val) in (list, tuple, np.ndarray):
        return iterable_or_val[n]
    elif type(iterable_or_val) in (int, float, bool, None):
        return iterable_or_val
    else:
        raise ValueError("Unrecognized parameter type")


PARAM_NAMES = (
    "n_neighbors",
    "n_components",
    "metric",
    "metric_kwds",
    "n_epochs",
    "learning_rate",
    "init",
    "min_dist",
    "spread",
    "set_op_mix_ratio",
    "local_connectivity",
    "repulsion_strength",
    "negative_sample_rate",
    "transform_queue_size",
    "angular_rp_forest",
    "target_n_neighbors",
    "target_metric",
    "target_metric_kwds",
    "target_weight",
    "unique",
)


def set_aligned_params(new_params, existing_params, n_models, param_names=PARAM_NAMES):
    for param in param_names:
        if param in new_params:
            if isinstance(existing_params[param], list):
                existing_params[param].append(new_params[param])
            elif isinstance(existing_params[param], tuple):
                existing_params[param] = existing_params[param] + \
                    (new_params[param],)
            elif isinstance(existing_params[param], np.ndarray):
                existing_params[param] = np.append(existing_params[param],
                                                   new_params[param])
            else:
                if new_params[param] != existing_params[param]:
                    existing_params[param] = (existing_params[param],) * n_models + (
                        new_params[param],
                    )

    return existing_params


@numba.njit()
def init_from_existing_internal(
    previous_embedding, weights_indptr, weights_indices, weights_data, relation_dict
):
    n_samples = weights_indptr.shape[0] - 1
    n_features = previous_embedding.shape[1]
    result = np.zeros((n_samples, n_features), dtype=np.float32)

    for i in range(n_samples):
        if i in relation_dict:
            result[i] = previous_embedding[relation_dict[i]]
        else:
            normalisation = 0.0
            for idx in range(weights_indptr[i], weights_indptr[i + 1]):
                j = weights_indices[idx]
                if j in relation_dict:
                    normalisation += weights_data[idx]
                    result[i] += (
                        weights_data[idx] * previous_embedding[relation_dict[j]]
                    )
            if normalisation == 0:
                result[i] = np.random.uniform(-10.0, 10.0, n_features)
            else:
                result[i] /= normalisation

    return result


def init_from_existing(previous_embedding, graph, relations):
    typed_relations = numba.typed.Dict.empty(numba.types.int32, numba.types.int32)
    for key, val in relations.items():
        typed_relations[np.int32(key)] = np.int32(val)
    return init_from_existing_internal(
        previous_embedding,
        graph.indptr,
        graph.indices,
        graph.data,
        typed_relations,
    )


class AlignedUMAP(BaseEstimator):
    def __init__(
        self,
        n_neighbors=15,
        n_components=2,
        metric="euclidean",
        metric_kwds=None,
        n_epochs=None,
        learning_rate=1.0,
        init="spectral",
        alignment_regularisation=1.0e-2,
        alignment_window_size=3,
        min_dist=0.1,
        spread=1.0,
        low_memory=False,
        set_op_mix_ratio=1.0,
        local_connectivity=1.0,
        repulsion_strength=1.0,
        negative_sample_rate=5,
        transform_queue_size=4.0,
        a=None,
        b=None,
        random_state=None,
        angular_rp_forest=False,
        target_n_neighbors=-1,
        target_metric="categorical",
        target_metric_kwds=None,
        target_weight=0.5,
        transform_seed=42,
        force_approximation_algorithm=False,
        verbose=False,
        unique=False,
    ):

        self.n_neighbors = n_neighbors
        self.metric = metric
        self.metric_kwds = metric_kwds

        self.n_epochs = n_epochs
        self.init = init
        self.n_components = n_components
        self.repulsion_strength = repulsion_strength
        self.learning_rate = learning_rate
        self.alignment_regularisation = alignment_regularisation
        self.alignment_window_size = alignment_window_size

        self.spread = spread
        self.min_dist = min_dist
        self.low_memory = low_memory
        self.set_op_mix_ratio = set_op_mix_ratio
        self.local_connectivity = local_connectivity
        self.negative_sample_rate = negative_sample_rate
        self.random_state = random_state
        self.angular_rp_forest = angular_rp_forest
        self.transform_queue_size = transform_queue_size
        self.target_n_neighbors = target_n_neighbors
        self.target_metric = target_metric
        self.target_metric_kwds = target_metric_kwds
        self.target_weight = target_weight
        self.transform_seed = transform_seed
        self.force_approximation_algorithm = force_approximation_algorithm
        self.verbose = verbose
        self.unique = unique

        self.a = a
        self.b = b

    def fit(self, X, y=None, **fit_params):
        if "relations" not in fit_params:
            raise ValueError(
                "Aligned UMAP requires relations between data to be " "specified"
            )

        self.dict_relations_ = fit_params["relations"]
        assert type(self.dict_relations_) in (list, tuple)
        assert type(X) in (list, tuple, np.ndarray)
        assert (len(X) - 1) == (len(self.dict_relations_))

        if y is not None:
            assert type(y) in (list, tuple, np.ndarray)
            assert (len(y) - 1) == (len(self.dict_relations_))
        else:
            y = [None] * len(X)

        # We need n_components to be constant or this won't work
        if type(self.n_components) in (list, tuple, np.ndarray):
            raise ValueError("n_components must be a single integer, and cannot vary")

        self.n_models_ = len(X)

        if self.n_epochs is None:
            self.n_epochs = 200

        n_epochs = self.n_epochs

        self.mappers_ = [
            UMAP(
                n_neighbors=get_nth_item_or_val(self.n_neighbors, n),
                min_dist=get_nth_item_or_val(self.min_dist, n),
                n_epochs=get_nth_item_or_val(self.n_epochs, n),
                repulsion_strength=get_nth_item_or_val(self.repulsion_strength, n),
                learning_rate=get_nth_item_or_val(self.learning_rate, n),
                init=self.init,
                spread=get_nth_item_or_val(self.spread, n),
                negative_sample_rate=get_nth_item_or_val(self.negative_sample_rate, n),
                local_connectivity=get_nth_item_or_val(self.local_connectivity, n),
                set_op_mix_ratio=get_nth_item_or_val(self.set_op_mix_ratio, n),
                unique=get_nth_item_or_val(self.unique, n),
                n_components=self.n_components,
                metric=self.metric,
                metric_kwds=self.metric_kwds,
                low_memory=self.low_memory,
                random_state=self.random_state,
                angular_rp_forest=self.angular_rp_forest,
                transform_queue_size=self.transform_queue_size,
                target_n_neighbors=self.target_n_neighbors,
                target_metric=self.target_metric,
                target_metric_kwds=self.target_metric_kwds,
                target_weight=self.target_weight,
                transform_seed=self.transform_seed,
                force_approximation_algorithm=self.force_approximation_algorithm,
                verbose=self.verbose,
                a=self.a,
                b=self.b,
            ).fit(X[n], y[n])
            for n in range(self.n_models_)
        ]

        window_size = fit_params.get("window_size", self.alignment_window_size)
        relations = expand_relations(self.dict_relations_, window_size)

        indptr_list = numba.typed.List.empty_list(numba.types.int32[::1])
        indices_list = numba.typed.List.empty_list(numba.types.int32[::1])
        heads = numba.typed.List.empty_list(numba.types.int32[::1])
        tails = numba.typed.List.empty_list(numba.types.int32[::1])
        epochs_per_samples = numba.typed.List.empty_list(numba.types.float64[::1])

        for mapper in self.mappers_:
            indptr_list.append(mapper.graph_.indptr)
            indices_list.append(mapper.graph_.indices)
            heads.append(mapper.graph_.tocoo().row)
            tails.append(mapper.graph_.tocoo().col)
            epochs_per_samples.append(
                make_epochs_per_sample(mapper.graph_.tocoo().data, n_epochs)
            )

        rng_state_transform = np.random.RandomState(self.transform_seed)
        regularisation_weights = build_neighborhood_similarities(
            indptr_list,
            indices_list,
            relations,
        )
        first_init = spectral_layout(
            self.mappers_[0]._raw_data,
            self.mappers_[0].graph_,
            self.n_components,
            rng_state_transform,
        )
        expansion = 10.0 / np.abs(first_init).max()
        first_embedding = (first_init * expansion).astype(
            np.float32,
            order="C",
        )

        embeddings = numba.typed.List.empty_list(numba.types.float32[:, ::1])
        embeddings.append(first_embedding)
        for i in range(1, self.n_models_):
            next_init = spectral_layout(
                self.mappers_[i]._raw_data,
                self.mappers_[i].graph_,
                self.n_components,
                rng_state_transform,
            )
            expansion = 10.0 / np.abs(next_init).max()
            next_embedding = (next_init * expansion).astype(
                np.float32,
                order="C",
            )
            anchor_data = relations[i][window_size - 1]
            left_anchors = anchor_data[anchor_data >= 0]
            right_anchors = np.where(anchor_data >= 0)[0]
            embeddings.append(
                procrustes_align(
                    embeddings[-1],
                    next_embedding,
                    np.vstack([left_anchors, right_anchors]),
                )
            )

        seed_triplet = rng_state_transform.randint(INT32_MIN, INT32_MAX, 3).astype(
            np.int64
        )
        self.embeddings_ = optimize_layout_aligned_euclidean(
            embeddings,
            embeddings,
            heads,
            tails,
            n_epochs,
            epochs_per_samples,
            regularisation_weights,
            relations,
            seed_triplet,
            lambda_=self.alignment_regularisation,
            move_other=True,
        )

        for i, embedding in enumerate(self.embeddings_):
            disconnected_vertices = (
                np.array(self.mappers_[i].graph_.sum(axis=1)).flatten() == 0
            )
            embedding[disconnected_vertices] = np.full(self.n_components, np.nan)

        return self

    def fit_transform(self, X, y=None, **fit_params):
        self.fit(X, y, **fit_params)
        return self.embeddings_

    def update(self, X, y=None, **fit_params):
        if "relations" not in fit_params:
            raise ValueError(
                "Aligned UMAP requires relations between data to be " "specified"
            )

        new_dict_relations = fit_params["relations"]
        assert isinstance(new_dict_relations, dict)

        X = check_array(X)

        self.__dict__ = set_aligned_params(fit_params, self.__dict__, self.n_models_)

        # We need n_components to be constant or this won't work
        if type(self.n_components) in (list, tuple, np.ndarray):
            raise ValueError("n_components must be a single integer, and cannot vary")

        if self.n_epochs is None:
            self.n_epochs = 200

        n_epochs = self.n_epochs

        new_mapper = UMAP(
            n_neighbors=get_nth_item_or_val(self.n_neighbors, self.n_models_),
            min_dist=get_nth_item_or_val(self.min_dist, self.n_models_),
            n_epochs=get_nth_item_or_val(self.n_epochs, self.n_models_),
            repulsion_strength=get_nth_item_or_val(
                self.repulsion_strength, self.n_models_
            ),
            learning_rate=get_nth_item_or_val(self.learning_rate, self.n_models_),
	    init=self.init,
            spread=get_nth_item_or_val(self.spread, self.n_models_),
            negative_sample_rate=get_nth_item_or_val(
                self.negative_sample_rate, self.n_models_
            ),
            local_connectivity=get_nth_item_or_val(
                self.local_connectivity, self.n_models_
            ),
            set_op_mix_ratio=get_nth_item_or_val(self.set_op_mix_ratio, self.n_models_),
            unique=get_nth_item_or_val(self.unique, self.n_models_),
            n_components=self.n_components,
            metric=self.metric,
            metric_kwds=self.metric_kwds,
            low_memory=self.low_memory,
            random_state=self.random_state,
            angular_rp_forest=self.angular_rp_forest,
            transform_queue_size=self.transform_queue_size,
            target_n_neighbors=self.target_n_neighbors,
            target_metric=self.target_metric,
            target_metric_kwds=self.target_metric_kwds,
            target_weight=self.target_weight,
            transform_seed=self.transform_seed,
            force_approximation_algorithm=self.force_approximation_algorithm,
            verbose=self.verbose,
            a=self.a,
            b=self.b,
        ).fit(X, y)

        self.n_models_ += 1
        self.mappers_ += [new_mapper]

        self.dict_relations_ += [new_dict_relations]

        window_size = fit_params.get("window_size", self.alignment_window_size)
        new_relations = expand_relations(self.dict_relations_, window_size)

        indptr_list = numba.typed.List.empty_list(numba.types.int32[::1])
        indices_list = numba.typed.List.empty_list(numba.types.int32[::1])
        heads = numba.typed.List.empty_list(numba.types.int32[::1])
        tails = numba.typed.List.empty_list(numba.types.int32[::1])
        epochs_per_samples = numba.typed.List.empty_list(numba.types.float64[::1])

        for i, mapper in enumerate(self.mappers_):
            indptr_list.append(mapper.graph_.indptr)
            indices_list.append(mapper.graph_.indices)
            heads.append(mapper.graph_.tocoo().row)
            tails.append(mapper.graph_.tocoo().col)
            if i == len(self.mappers_) - 1:
                epochs_per_samples.append(
                    make_epochs_per_sample(mapper.graph_.tocoo().data, n_epochs)
                )
            else:
                epochs_per_samples.append(
                    np.full(mapper.embedding_.shape[0], n_epochs + 1, dtype=np.float64)
                )

        new_regularisation_weights = build_neighborhood_similarities(
            indptr_list,
            indices_list,
            new_relations,
        )

        # TODO: We can likely make this more efficient and not recompute each time
        inv_dict_relations = invert_dict(new_dict_relations)

        new_embedding = init_from_existing(
            self.embeddings_[-1], new_mapper.graph_, inv_dict_relations
        )

        self.embeddings_.append(new_embedding)

        rng_state_transform = np.random.RandomState(self.transform_seed)
        seed_triplet = rng_state_transform.randint(INT32_MIN, INT32_MAX, 3).astype(
            np.int64
        )
        self.embeddings_ = optimize_layout_aligned_euclidean(
            self.embeddings_,
            self.embeddings_,
            heads,
            tails,
            n_epochs,
            epochs_per_samples,
            new_regularisation_weights,
            new_relations,
            seed_triplet,
            lambda_=self.alignment_regularisation,
        )
