# -*- coding: utf-8 -*-
"""Pyrgg functions module."""
from typing import Dict, List, Tuple, Union, Any, IO, Callable
import os
from random import randint
from json import loads as json_loads
from json import dump as json_dump
from pickle import dump as pickle_dump
from yaml import safe_dump as yaml_dump
import datetime
import pyrgg.params


def is_weighted(max_weight: float, min_weight: float, signed: bool) -> bool:
    """
    Check the graph is weighted or not.

    :param max_weight: maximum weight
    :param min_weight: minimum weight
    :param signed: weight sign flag
    """
    if max_weight == min_weight and min_weight == 0:
        return False
    if max_weight == min_weight and min_weight == 1 and not signed:
        return False
    return True


def get_min_max_weight(weight_dict: Dict[int, List[float]]) -> Tuple[float, float]:
    """
    Get minimum and maximum weight values.

    :param weight_dict: weight dictionary
    """
    all_weights = [abs(w) for weights in weight_dict.values() for w in weights]
    return min(all_weights), max(all_weights)


def is_signed(weight_dict: Dict[int, List[float]]) -> bool:  # pragma: no cover
    """
    Check if the graph is signed.

    :param weight_dict: weight dictionary
    """
    return any([any([w < 0 for w in weights]) for weights in weight_dict.values()])


def has_self_loop(edge_dict: Dict[int, List[int]]) -> bool:  # pragma: no cover
    """
    Check if the graph has self loops.

    :param edge_dict: edge dictionary
    """
    return any([v in edges for v, edges in edge_dict.items()])


def is_multigraph(edge_dict: Dict[int, List[int]]) -> bool:
    """
    Check if the graph is a multigraph.

    :param edge_dict: edge dictionary
    """
    return any([len(set(edges)) != len(edges) for edges in edge_dict.values()])


def get_precision(input_number: float) -> int:
    """
    Return precision of input number.

    :param input_number: input number
    """
    try:
        number_str = str(input_number)
        _, decimal_part = number_str.split(".")
        return len(decimal_part)
    except Exception:
        return 0


def calculate_threshold(min_edges: int, max_edges: int, vertex_degree: int) -> int:
    """
    Calculate threshold for generate_branches function.

    :param min_edges: minimum number of edges (connected to each vertex)
    :param max_edges: maximum number of edges (connected to each vertex)
    :param vertex_degree: vertex degree
    """
    threshold = min_edges
    lower_limit = 0
    upper_limit = max_edges - vertex_degree
    if vertex_degree < min_edges:
        lower_limit = min_edges - vertex_degree
    if upper_limit > lower_limit:
        threshold = randint(lower_limit, upper_limit)
    return threshold


def is_float(input_number: Union[float, int, str]) -> bool:
    """
    Check input for float conversion.

    :param input_number: input number
    """
    try:
        _, decimal_part = divmod(float(input_number), 1)
    except TypeError:
        return False
    else:
        return True if decimal_part else False


def handle_string(string: str) -> str:
    """
    Handle string and raise ValueError if it is empty.

    :param string: input string
    """
    if string == "":
        raise ValueError
    return string


def handle_pos_int(input_number: Union[float, int, str]) -> int:
    """
    Handle input number and raise ValueError if it is negative.

    :param input_number: input number
    """
    val = int(input_number)
    if val < 0:
        raise ValueError
    return val


def handle_pos_even(input_number: Union[float, int, str]) -> int:
    """
    Check if the input number is a positive even number and raise a ValueError if it is not.

    :param input_number: input number
    """
    val = handle_pos_int(input_number)
    if val % 2 != 0:
        raise ValueError
    return val


def handle_natural_number(input_number: Union[float, int, str]) -> int:
    """
    Check if the input number is a natural number and raise a ValueError if it is not.

    :param input_number: input number
    """
    val = int(input_number)
    if val < 1:
        raise ValueError
    return val


def handle_str_to_number(string: str) -> Union[float, int]:
    """
    Convert string to float or int.

    :param string: input string
    """
    return float(string) if is_float(string) else int(string)


def handle_str_prob(string: str) -> float:
    """
    Convert string to float and raise ValueError if string is invalid.

    :param string: input string
    """
    val = handle_str_to_number(string)
    if val < 0:
        raise ValueError
    if val > 1:
        raise ValueError
    return val


def handle_str_to_bool(string: str) -> bool:
    """
    Convert 0/1 string to bool and raise ValueError if string is invalid.

    :param string: input string
    """
    val = int(string)
    if val not in [0, 1]:
        raise ValueError
    return bool(val)


def handle_output_format(string: str) -> int:
    """
    Convert string to output format index.

    :param string: input string
    """
    output_format = handle_pos_int(string)
    if output_format not in pyrgg.params.SUFFIX_MENU:
        raise ValueError
    return output_format


def handle_engine(string: str) -> int:
    """
    Convert string to engine index.

    :param string: input string
    """
    engine = handle_pos_int(string)
    if engine not in pyrgg.params.ENGINE_MENU:
        raise ValueError
    return engine


ITEM_HANDLERS = {
    "file_name": handle_string,
    "output_format": handle_output_format,
    "weight": handle_str_to_bool,
    "engine": handle_engine,
    "vertices": handle_pos_int,
    "number_of_files": handle_pos_int,
    "max_weight": handle_str_to_number,
    "min_weight": handle_str_to_number,
    "min_edges": handle_pos_int,
    "max_edges": handle_pos_int,
    "edge_number": handle_pos_int,
    "sign": handle_str_to_bool,
    "direct": handle_str_to_bool,
    "self_loop": handle_str_to_bool,
    "multigraph": handle_str_to_bool,
    "config": handle_str_to_bool,
    "probability": handle_str_prob,
    "blocks": handle_natural_number,
    "inter_probability": handle_str_prob,
    "intra_probability": handle_str_prob,
    "attaching_edge_number": handle_natural_number,
    "mean_degree": handle_pos_even,
    "rewiring_probability": handle_str_prob,
    "space_dimension": handle_natural_number,
    "cutoff_threshold": handle_str_prob,
}


def print_description() -> None:
    """Print justified description for overview in console."""
    print(pyrgg.params.PYRGG_LINKS)
    print_line(40)
    print("\n")
    print(pyrgg.params.PYRGG_DESCRIPTION)
    print("\n")
    print_line(40)


def print_line(num: int = 11, char: str = "#") -> None:
    """
    Print line of char.

    :param num: number of character in this line
    :param char: character
    """
    print(char * num)


def convert_bytes(num: int) -> str:
    """
    Convert num to idiomatic byte unit.

    :param num: the input number.
    """
    for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:
        if num < 1024.0:
            return "%3.1f %s" % (num, x)
        num /= 1024.0


def get_file_size(path: str) -> str:  # pragma: no cover
    """
    Get output file size.

    :param path: file path
    """
    file_info = os.stat(path)
    file_size = file_info.st_size
    return convert_bytes(file_size)


def convert_time(input_time: float) -> str:
    """
    Convert input_time from sec to DD,HH,MM,SS format.

    :param input_time: input time in sec
    """
    postfix_dict = {"s": "second", "d": "day", "h": "hour", "m": "minute"}
    value_dict = {"s": 0, "d": 0, "h": 0, "m": 0}
    value_dict["s"] = float(input_time)
    value_dict["d"], value_dict["s"] = divmod(value_dict["s"], 24 * 3600)
    value_dict["h"], value_dict["s"] = divmod(value_dict["s"], 3600)
    value_dict["m"], value_dict["s"] = divmod(value_dict["s"], 60)
    for i in postfix_dict:
        if value_dict[i] != 1:
            postfix_dict[i] += "s"
    return ", ".join([
        "{0:02.0f} {1}".format(value_dict["d"], postfix_dict["d"]),
        "{0:02.0f} {1}".format(value_dict["h"], postfix_dict["h"]),
        "{0:02.0f} {1}".format(value_dict["m"], postfix_dict["m"]),
        "{0:02.0f} {1}".format(value_dict["s"], postfix_dict["s"]),
    ])


def filter_input(input_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Filter input data.

    :param input_dict: input dictionary
    """
    filtered_dict = input_dict.copy()
    edges_upper_threshold = filtered_dict["vertices"]
    for key in ["min_edges", "max_edges", "vertices"]:
        if filtered_dict[key] < 0:
            filtered_dict[key] *= -1

    if filtered_dict["min_weight"] > filtered_dict["max_weight"]:
        filtered_dict["min_weight"], filtered_dict["max_weight"] = (
            filtered_dict["max_weight"], filtered_dict["min_weight"]
        )

    if filtered_dict["min_edges"] > filtered_dict["max_edges"]:
        filtered_dict["min_edges"], filtered_dict["max_edges"] = (
            filtered_dict["max_edges"], filtered_dict["min_edges"]
        )

    if not filtered_dict["self_loop"]:
        edges_upper_threshold -= 1

    if not filtered_dict["multigraph"]:
        for key in ["min_edges", "max_edges"]:
            filtered_dict[key] = min(filtered_dict[key], edges_upper_threshold)

    return filtered_dict


def get_input(input_func: Callable[[str], str] = input) -> Dict[str, Any]:
    """
    Get input from user and return as dictionary.

    :param input_func: input function
    """
    result_dict = {
        "file_name": "",
        "number_of_files": 1,
        "vertices": 0,
        "max_weight": 1,
        "min_weight": 1,
        "min_edges": 0,
        "max_edges": 0,
        "edge_number": 0,
        "sign": True,
        "output_format": 1,
        "weight": True,
        "engine": 1,
        "direct": True,
        "self_loop": True,
        "multigraph": False,
        "config": False,
        "probability": 0.5,
        "blocks": 1,
        "inter_probability": 0.75,
        "intra_probability": 0.25,
        "attaching_edge_number": 1,
        "mean_degree": 2,
        "rewiring_probability": 0.5,
        "space_dimension": 2,
        "cutoff_threshold": 0.5,
    }

    result_dict = _update_using_menu(result_dict, input_func)
    result_dict = _update_with_engine_params(
        result_dict, input_func, pyrgg.params.ENGINE_PARAM_MAP[result_dict['engine']])
    result_dict = _post_input_update(result_dict)
    return filter_input(result_dict)


def _update_using_menu(result_dict: Dict[str, Any], input_func: Callable[[str], str]) -> Dict[str, Any]:
    """
    Update result_dict using user input from the menu.

    :param result_dict: result data
    :param input_func: input function
    """
    for index in sorted(pyrgg.params.MENU_ITEMS):
        item1, item2 = pyrgg.params.MENU_ITEMS[index]
        while True:
            try:
                result_dict[item1] = ITEM_HANDLERS[item1](
                    input_func(item2)
                )
            except Exception:
                print(pyrgg.params.PYRGG_INPUT_ERROR_MESSAGE)
            else:
                break
    return result_dict


def _update_with_engine_params(result_dict: Dict[str, Any], input_func: Callable[[
                               str], str], engine_params: Dict[int, Tuple[str, str]]) -> Dict[str, Any]:
    """
    Update result_dict using user input based on given engine requirements.

    :param result_dict: result data
    :param input_func: input function
    :param engine_params: engine parameters
    """
    if result_dict['engine'] == 4:
        print(pyrgg.params.PYRGG_SBM_WARNING_MESSAGE)
    for index in sorted(engine_params):
        item1, item2 = engine_params[index]
        if not result_dict["weight"] and item1 in ["max_weight", "min_weight"]:
            continue
        while True:
            try:
                result_dict[item1] = ITEM_HANDLERS[item1](
                    input_func(item2)
                )
            except Exception:
                print(pyrgg.params.PYRGG_INPUT_ERROR_MESSAGE)
            else:
                break
    return result_dict


def _post_input_update(result_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Update result_dict after getting user input.

    :param result_dict: result data
    """
    result_dict["block_sizes"] = [result_dict["vertices"] // result_dict["blocks"]] * (result_dict["blocks"])
    if result_dict["vertices"] % result_dict["blocks"] != 0:
        result_dict["block_sizes"][-1] += result_dict["vertices"] % result_dict["blocks"]
        print(pyrgg.params.PYRGG_UNDIVISIBLE_WARNING_MESSAGE)
    result_dict["probability_matrix"] = [
        [result_dict["intra_probability"] if i == j else result_dict["inter_probability"]
            for j in range(result_dict["blocks"])]
        for i in range(result_dict["blocks"])
    ]
    return result_dict


def json_to_yaml(filename: str) -> None:
    """
    Convert json file to yaml file.

    :param filename: filename
    """
    try:
        with open(filename + ".json", "r") as json_file:
            json_data = json_loads(json_file.read())
            with open(filename + ".yaml", "w") as yaml_file:
                yaml_dump(json_data, yaml_file, default_flow_style=False)
    except FileNotFoundError:
        print(pyrgg.params.PYRGG_YAML_ERROR_MESSAGE)


def json_to_pickle(filename: str) -> None:
    """
    Convert json file to pickle file.

    :param filename: filename
    """
    try:
        with open(filename + ".json", "r") as json_file:
            json_data = json_loads(json_file.read())
            with open(filename + ".p", "wb") as pickle_file:
                pickle_dump(json_data, pickle_file)
    except FileNotFoundError:
        print(pyrgg.params.PYRGG_PICKLE_ERROR_MESSAGE)


def save_config(input_dict: Dict[str, Any]) -> str:
    """
    Save input_dict as the generation config.

    :param input_dict: input data
    """
    try:
        input_dict_temp = input_dict.copy()
        input_dict_temp['engine'] = pyrgg.params.ENGINE_MENU[input_dict_temp['engine']]
        input_dict_temp['pyrgg_version'] = pyrgg.params.PYRGG_VERSION
        input_dict_temp['output_format'] = pyrgg.params.OUTPUT_FORMAT[input_dict_temp['output_format']]
        fname = pyrgg.params.CONFIG_FILE_FORMAT.format(
            file_name=input_dict_temp['file_name'])
        with open(fname, "w") as json_file:
            json_dump(input_dict_temp, json_file, indent=2)
        return os.path.abspath(fname)
    except BaseException:
        print(pyrgg.params.PYRGG_CONFIG_SAVE_ERROR_MESSAGE)


def load_config(path: str) -> Dict[str, Any]:
    """
    Load config based on given path.

    :param path: path to config file
    """
    try:
        with open(path, "r") as json_file:
            config = json_loads(json_file.read())
            config['output_format'] = pyrgg.params.OUTPUT_FORMAT_INV[config['output_format']]
            config['engine'] = pyrgg.params.ENGINE_MENU_INV[config['engine']]
            return filter_input(config)
    except BaseException:
        print(pyrgg.params.PYRGG_CONFIG_LOAD_ERROR_MESSAGE)


def _print_select_config(configs: List[str], input_func: Callable[[str], str]) -> Union[Dict[str, Any], None]:
    """
    Print configs in current directory and get input from user.

    :param configs: configs path
    :param input_func: input function
    """
    if len(configs) == 0:
        return None
    print(pyrgg.params.PYRGG_CONFIG_LIST_MESSAGE)
    for i, config in enumerate(configs):
        print("[{index}] - {config}".format(index=i + 1, config=config))
    key = input_func(pyrgg.params.PYRGG_CONFIG_LOAD_MESSAGE)
    try:
        return load_config(configs[int(key) - 1])
    except BaseException:
        return None


def check_for_config(input_func: Callable[[str], str] = input) -> Union[Dict[str, Any], None]:
    """
    Check for config files in source directory.

    :param input_func: input function
    """
    configs = []
    for filename in os.listdir(pyrgg.params.SOURCE_DIR):
        file = os.path.join(pyrgg.params.SOURCE_DIR, filename)
        if os.path.isfile(file) and filename.endswith(
                pyrgg.params.CONFIG_FILE_FORMAT.format(file_name="")):
            configs.append(file)
    return _print_select_config(configs, input_func)


def save_log(file: IO, file_name: str, elapsed_time: str, text: str) -> None:
    """
    Save generated graph logs.

    :param file: file to write log into
    :param file_name: file name
    :param elapsed_time: elapsed time
    :param text: rest part of the text to write
    """
    text2file = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + "\n"
    text2file += "Filename : {file_name}\n".format(file_name=file_name)
    text2file += text
    text2file += "Elapsed Time : {elapsed_time}\n".format(elapsed_time=elapsed_time)
    text2file += "-------------------------------\n"
    file.write(text2file)
