import sys

import click
import h5py
import numpy as np
import pandas as pd
import simplejson as json
from cytoolz import compose
from multiprocess import Pool

from ..create import (
    HDF5Aggregator,
    PairixAggregator,
    TabixAggregator,
    aggregate_records,
    create_cooler,
    sanitize_records,
)
from . import cli, get_logger
from ._util import parse_bins, parse_field_param, parse_kv_list_param

_pandas_version = pd.__version__.split(".")
if int(_pandas_version[0]) > 0:
    from pandas.io.common import get_handle


# Copied from pairtools._headerops
def get_header(instream, comment_char="#"):
    """Returns a header from the stream and an the reaminder of the stream
    with the actual data.
    Parameters
    ----------
    instream : a file object
        An input stream.
    comment_char : str
        The character prepended to header lines (use '@' when parsing sams,
        '#' when parsing pairsams).
    Returns
    -------
    header : list
        The header lines, stripped of terminal spaces and newline characters.
    remainder_stream : stream/file-like object
        Stream with the remaining lines.

    """
    header = []
    if not comment_char:
        raise ValueError("Please, provide a comment char!")
    comment_byte = comment_char.encode()
    # get peekable buffer for the instream
    read_f, peek_f = None, None
    if hasattr(instream, "buffer"):
        peek_f = instream.buffer.peek
        readline_f = instream.buffer.readline
    elif hasattr(instream, "peek"):
        peek_f = instream.peek
        readline_f = instream.readline
    else:
        raise ValueError("Cannot find the peek() function of the provided stream!")

    current_peek = peek_f(1)
    while current_peek.startswith(comment_byte):
        # consuming a line from buffer guarantees
        # that the remainder of the buffer starts
        # with the beginning of the line.
        line = readline_f()
        if isinstance(line, bytes):
            line = line.decode()
        # append line to header, since it does start with header
        header.append(line.strip())
        # peek into the remainder of the instream
        current_peek = peek_f(1)
    # apparently, next line does not start with the comment
    # return header and the instream, advanced to the beginning of the data
    return header, instream


@cli.group()
def cload():
    """
    Create a cooler from genomic pairs and bins.

    Choose a subcommand based on the format of the input contact list.

    """
    pass


# flake8: noqa
def register_subcommand(func):
    return cload.command()(
        click.argument("bins", type=str, metavar="BINS")(
            click.argument(
                "pairs_path",
                type=click.Path(exists=True, allow_dash=True),
                metavar="PAIRS_PATH",
            )(
                click.argument("cool_path", metavar="COOL_PATH")(
                    click.option(
                        "--metadata", help="Path to JSON file containing user metadata."
                    )(
                        click.option(
                            "--assembly",
                            help="Name of genome assembly (e.g. hg19, mm10)",
                        )(func)
                    )
                )
            )
        )
    )


def add_arg_help(func):
    func.__doc__ = func.__doc__.format(
        """
    BINS : One of the following

        <TEXT:INTEGER> : 1. Path to a chromsizes file, 2. Bin size in bp

        <TEXT> : Path to BED file defining the genomic bin segmentation.

    PAIRS_PATH : Path to contacts (i.e. read pairs) file.

    COOL_PATH : Output COOL file path or URI."""
    )
    return func


@register_subcommand
@add_arg_help
@click.option(
    "--chunksize",
    "-c",
    help="Control the number of pixels handled by each worker process at a time.",
    type=int,
    default=int(100e6),
    show_default=True,
)
def hiclib(bins, pairs_path, cool_path, metadata, assembly, chunksize):
    """
    Bin a hiclib HDF5 contact list (frag) file.

    {}

    hiclib on BitBucket: <https://github.com/mirnylab/hiclib-legacy>.

    """

    chromsizes, bins = parse_bins(bins)

    if metadata is not None:
        with open(metadata) as f:
            metadata = json.load(f)

    with h5py.File(pairs_path, "r") as h5pairs:
        iterator = HDF5Aggregator(h5pairs, chromsizes, bins, chunksize)
        create_cooler(
            cool_path,
            bins,
            iterator,
            metadata=metadata,
            assembly=assembly,
            ordered=True,
        )


@register_subcommand
@add_arg_help
@click.option(
    "--nproc",
    "-p",
    help="Number of processes to split the work between.",
    type=int,
    default=8,
    show_default=True,
)
@click.option(
    "--chrom2", "-c2", help="chrom2 field number (one-based)", type=int, default=4
)
@click.option(
    "--pos2", "-p2", help="pos2 field number (one-based)", type=int, default=5
)
@click.option(
    "--zero-based",
    "-0",
    help="Positions are zero-based",
    is_flag=True,
    default=False,
    show_default=True,
)
@click.option(
    "--max-split",
    "-s",
    help="Divide the pairs from each chromosome into at most this many chunks. "
    "Smaller chromosomes will be split less frequently or not at all. "
    "Increase ths value if large chromosomes dominate the workload on "
    "multiple processors.",
    type=int,
    default=2,
    show_default=True,
)
def tabix(
    bins,
    pairs_path,
    cool_path,
    metadata,
    assembly,
    nproc,
    zero_based,
    max_split,
    **kwargs,
):
    """
    Bin a tabix-indexed contact list file.

    {}

    See also: 'cooler csort' to sort and index a contact list file

    Tabix manpage: <http://www.htslib.org/doc/tabix.html>.

    """
    logger = get_logger(__name__)
    chromsizes, bins = parse_bins(bins)

    if metadata is not None:
        with open(metadata) as f:
            metadata = json.load(f)

    try:
        map_func = map
        if nproc > 1:
            pool = Pool(nproc)
            logger.info(f"Using {nproc} cores")
            map_func = pool.imap

        opts = {}
        if "chrom2" in kwargs:
            opts["C2"] = kwargs["chrom2"] - 1
        if "pos2" in kwargs:
            opts["P2"] = kwargs["pos2"] - 1

        iterator = TabixAggregator(
            pairs_path,
            chromsizes,
            bins,
            map=map_func,
            is_one_based=(not zero_based),
            n_chunks=max_split,
            **opts,
        )

        create_cooler(
            cool_path,
            bins,
            iterator,
            metadata=metadata,
            assembly=assembly,
            ordered=True,
        )
    finally:
        if nproc > 1:
            pool.close()


@register_subcommand
@add_arg_help
@click.option(
    "--nproc",
    "-p",
    help="Number of processes to split the work between.",
    type=int,
    default=8,
    show_default=True,
)
@click.option(
    "--zero-based",
    "-0",
    help="Positions are zero-based",
    is_flag=True,
    default=False,
    show_default=True,
)
@click.option(
    "--max-split",
    "-s",
    help="Divide the pairs from each chromosome into at most this many chunks. "
    "Smaller chromosomes will be split less frequently or not at all. "
    "Increase ths value if large chromosomes dominate the workload on "
    "multiple processors.",
    type=int,
    default=2,
    show_default=True,
)
@click.option(
    "--block-char",
    help="Character separating contig names in the block names of the pairix "
    "index.",
    type=str,
    default="|",
    show_default=True,
)
def pairix(
    bins,
    pairs_path,
    cool_path,
    metadata,
    assembly,
    nproc,
    zero_based,
    max_split,
    block_char,
):
    """
    Bin a pairix-indexed contact list file.

    {}

    See also: 'cooler csort' to sort and index a contact list file

    Pairix on GitHub: <https://github.com/4dn-dcic/pairix>.

    """
    logger = get_logger(__name__)
    chromsizes, bins = parse_bins(bins)

    if metadata is not None:
        with open(metadata) as f:
            metadata = json.load(f)

    try:
        map_func = map
        if nproc > 1:
            pool = Pool(nproc)
            logger.info(f"Using {nproc} cores")
            map_func = pool.imap

        iterator = PairixAggregator(
            pairs_path,
            chromsizes,
            bins,
            map=map_func,
            is_one_based=(not zero_based),
            n_chunks=max_split,
            block_char=block_char,
        )

        create_cooler(
            cool_path,
            bins,
            iterator,
            metadata=metadata,
            assembly=assembly,
            ordered=True,
        )
    finally:
        if nproc > 1:
            pool.close()


@register_subcommand
@add_arg_help
@click.option(
    "--chrom1", "-c1", help="chrom1 field number (one-based)", type=int, required=True
)
@click.option(
    "--pos1", "-p1", help="pos1 field number (one-based)", type=int, required=True
)
@click.option(
    "--chrom2", "-c2", help="chrom2 field number (one-based)", type=int, required=True
)
@click.option(
    "--pos2", "-p2", help="pos2 field number (one-based)", type=int, required=True
)
@click.option(
    "--zero-based",
    "-0",
    help="Positions are zero-based",
    is_flag=True,
    default=False,
    show_default=True,
)
@click.option(
    "--comment-char",
    type=str,
    default="#",
    show_default=True,
    help="Comment character that indicates lines to ignore.",
)
@click.option(
    "--no-symmetric-upper",
    "-N",
    help="Create a complete square matrix without implicit symmetry. "
    "This allows for distinct upper- and lower-triangle values",
    is_flag=True,
    default=False,
)
@click.option(
    "--input-copy-status",
    type=click.Choice(["unique", "duplex"]),
    default="unique",
    help="Copy status of input data when using symmetric-upper storage. | "
    "`unique`: Incoming data comes from a unique half of a symmetric "
    "map, regardless of how the coordinates of a pair are ordered. "
    "`duplex`: Incoming data contains upper- and lower-triangle duplicates. "
    "All input records that map to the lower triangle will be discarded! | "
    "If you wish to treat lower- and upper-triangle input data as "
    "distinct, use the ``--no-symmetric-upper`` option. ",
    show_default=True,
)
@click.option(
    "--field",
    help="Specify quantitative input fields to aggregate into value columns "
    "using the syntax ``--field <field-name>=<field-number>``. "
    "Optionally, append ``:`` followed by ``dtype=<dtype>`` to specify "
    "the data type (e.g. float), and/or ``agg=<agg>`` to "
    "specify an aggregation function different from sum (e.g. mean). "
    "Field numbers are 1-based. Passing 'count' as the target name will "
    "override the default behavior of storing pair counts. "
    "Repeat the ``--field`` option for each additional field.",
    type=str,
    multiple=True,
)
# @click.option(
#     "--no-count",
#     help="Do not store the pair counts. Use this only if you use `--field` to "
#          "specify at least one input field for aggregation as an alternative.",
#     is_flag=True,
#     default=False)
@click.option(
    "--chunksize",
    "-c",
    help="Size in number of lines/records of data chunks to read and process "
    "from the input stream at a time. These chunks will be saved as temporary "
    "partial coolers and then merged.",
    type=int,
    default=15_000_000,
)
@click.option(
    "--mergebuf",
    help="Total number of pixel records to buffer per epoch of merging data. "
    "Defaults to the same value as `chunksize`.",
    type=int,
)
@click.option(
    "--max-merge",
    help="Maximum number of chunks to merge in a single pass.",
    type=int,
    default=200,
    show_default=True,
)
@click.option(
    "--temp-dir",
    help="Create temporary files in a specified directory. Pass ``-`` to use "
    "the platform default temp dir.",
    type=click.Path(exists=True, file_okay=False, dir_okay=True, allow_dash=True),
)
@click.option(
    "--no-delete-temp",
    help="Do not delete temporary files when finished.",
    is_flag=True,
    default=False,
)
@click.option(
    "--storage-options",
    help="Options to modify the data filter pipeline. Provide as a "
    "comma-separated list of key-value pairs of the form 'k1=v1,k2=v2,...'. "
    "See http://docs.h5py.org/en/stable/high/dataset.html#filter-pipeline "
    "for more details.",
)
@click.option(
    "--append",
    "-a",
    is_flag=True,
    default=False,
    help="Pass this flag to append the output cooler to an existing file "
    "instead of overwriting the file.",
)
# @click.option(
#     "--format", "-f",
#     help="Preset data format.",
#     type=click.Choice(['4DN', 'BEDPE']))
# --sep
def pairs(
    bins,
    pairs_path,
    cool_path,
    metadata,
    assembly,
    zero_based,
    comment_char,
    input_copy_status,
    no_symmetric_upper,
    field,
    chunksize,
    mergebuf,
    temp_dir,
    no_delete_temp,
    max_merge,
    storage_options,
    append,
    **kwargs,
):
    """
    Bin any text file or stream of pairs.

    Pairs data need not be sorted. Accepts compressed files.
    To pipe input from stdin, set PAIRS_PATH to '-'.

    {}

    """
    chromsizes, bins = parse_bins(bins)
    if mergebuf is None:
        mergebuf = chunksize

    symmetric_upper = not no_symmetric_upper
    tril_action = None
    if symmetric_upper:
        if input_copy_status == "unique":
            tril_action = "reflect"
        elif input_copy_status == "duplex":
            tril_action = "drop"

    if metadata is not None:
        with open(metadata) as f:
            metadata = json.load(f)

    input_field_names = [
        "chrom1",
        "pos1",
        "chrom2",
        "pos2",
    ]
    input_field_dtypes = {
        "chrom1": str,
        "pos1": np.int64,
        "chrom2": str,
        "pos2": np.int64,
    }
    input_field_numbers = {}
    for name in ["chrom1", "pos1", "chrom2", "pos2"]:
        if kwargs[name] == 0:
            raise click.BadParameter("Field numbers start at 1", param_hint=name)
        input_field_numbers[name] = kwargs[name] - 1

    # Include input value columns
    output_field_names = []
    output_field_dtypes = {}
    aggregations = {}
    if len(field):
        for arg in field:
            name, colnum, dtype, agg = parse_field_param(arg)

            # Special cases: these do not have input fields.
            # Omit field number and agg to change standard dtypes.
            if colnum is None:
                if (
                    agg is None
                    and dtype is not None
                    and name in {"bin1_id", "bin2_id", "count"}
                ):
                    output_field_dtypes[name] = dtype
                    continue
                else:
                    raise click.BadParameter(
                        "A field number is required.", param_hint=arg
                    )

            if name not in input_field_names:
                input_field_names.append(name)

            if name not in output_field_names:
                output_field_names.append(name)

            input_field_numbers[name] = colnum

            if dtype is not None:
                input_field_dtypes[name] = dtype
                output_field_dtypes[name] = dtype

            if agg is not None:
                aggregations[name] = agg
            else:
                aggregations[name] = "sum"

    # # Pairs counts are always produced, unless supressed explicitly
    # do_count = not no_count
    # if do_count:
    #     if 'count' not in output_field_names:
    #         output_field_names.append('count')  # default dtype and agg
    # else:
    #     if not len(output_field_names):
    #         click.BadParameter(
    #             "To pass `--no-count`, specify at least one input "
    #             "value-column using `--field`.")
    if "count" not in output_field_names:
        output_field_names.append("count")

    # Customize the HDF5 filters
    if storage_options is not None:
        h5opts = parse_kv_list_param(storage_options)
        for key in h5opts:
            if isinstance(h5opts[key], list):
                h5opts[key] = tuple(h5opts[key])
    else:
        h5opts = None

    # Initialize the input stream
    # TODO: we could save the header into metadata
    kwargs = {}
    if pairs_path == "-":
        f_in = sys.stdin
    elif int(_pandas_version[0]) == 1 and int(_pandas_version[1]) < 2:
        # get_handle returns a pair of objects in pandas 1.0 and 1.1
        f_in = get_handle(pairs_path, mode="r", compression="infer")[0]
    else:
        # get_handle returns a single wrapper object in pandas 1.2+ and 2.*
        f_in = get_handle(pairs_path, mode="r", compression="infer").handle

    _, f_in = get_header(f_in)

    reader = pd.read_csv(
        f_in,
        sep="\t",
        usecols=[input_field_numbers[name] for name in input_field_names],
        names=input_field_names,
        dtype=input_field_dtypes,
        iterator=True,
        chunksize=chunksize,
        **kwargs,
    )

    sanitize = sanitize_records(
        bins,
        schema="pairs",
        decode_chroms=True,
        is_one_based=not zero_based,
        tril_action=tril_action,
        sort=True,
        validate=True,
    )
    aggregate = aggregate_records(agg=aggregations, count=True, sort=False)
    pipeline = compose(aggregate, sanitize)

    create_cooler(
        cool_path,
        bins,
        map(pipeline, reader),
        columns=output_field_names,
        dtypes=output_field_dtypes,
        metadata=metadata,
        assembly=assembly,
        mergebuf=mergebuf,
        max_merge=max_merge,
        temp_dir=temp_dir,
        delete_temp=not no_delete_temp,
        boundscheck=False,
        triucheck=False,
        dupcheck=False,
        ensure_sorted=False,
        symmetric_upper=symmetric_upper,
        h5opts=h5opts,
        ordered=False,
        mode="a" if append else "w",
    )
