#!/usr/bin/env python

import os
import ssl
import time
import sys
import logging
import contextlib
import concurrent.futures
import threading
import multiprocessing

from irods.data_object import iRODSDataObject
from irods.exception import DataObjectDoesNotExist
import irods.keywords as kw
from queue import Queue, Full, Empty


logger = logging.getLogger(__name__)
_nullh = logging.NullHandler()
logger.addHandler(_nullh)


MINIMUM_SERVER_VERSION = (4, 2, 9)


class deferred_call:
    """
    A callable object that stores a function to be called later, along
    with its parameters.
    """

    def __init__(self, function, args, keywords):
        """Initialize the object with a function and its call parameters."""
        self.function = function
        self.args = args
        self.keywords = keywords

    def __setitem__(self, key, val):
        """Allow changing a keyword option for the deferred function call."""
        self.keywords[key] = val

    def __call__(self):
        """Call the stored function, using the arguments and keywords also stored
        in the instance."""
        return self.function(*self.args, **self.keywords)


from threading import Barrier

RECOMMENDED_NUM_THREADS_PER_TRANSFER = 3

verboseConnection = False


class BadCallbackTarget(TypeError):
    pass


class AsyncNotify:
    """A type returned when the PUT or GET operation passed includes NONBLOCKING.
    If enabled, the callback function (or callable object) will be triggered
    when all parts of the parallel transfer are complete.  It should accept
    exactly one argument, the irods.parallel.AsyncNotify instance that
    is calling it.
    """

    def set_transfer_done_callback(self, callback):
        if callback is not None:
            if not callable(callback):
                raise BadCallbackTarget(
                    '"callback" must be a callable accepting at least 1 argument'
                )
        self.done_callback = callback

    def __init__(
        self, futuresList, callback=None, progress_Queue=None, total=None, keep_=()
    ):
        """AsyncNotify initialization (used internally to the io.parallel library).
        The casual user will only be concerned with the callback parameter, called when all threads
        of the parallel PUT or GET have been terminated and the data object closed.
        """
        self._futures = set(futuresList)
        self._futures_done = dict()
        self.keep = dict(keep_)
        self._lock = threading.Lock()
        self.set_transfer_done_callback(callback)
        self.__done = False
        if self._futures:
            for future in self._futures:
                future.add_done_callback(self)
        else:
            self.__invoke_done_callback()

        self.progress = [0, 0]
        if (progress_Queue) and (total is not None):
            self.progress[1] = total

            def _progress(Q, this):  # - thread to update progress indicator
                while this.progress[0] < this.progress[1]:
                    i = None
                    try:
                        i = Q.get(timeout=0.1)
                    except Empty:
                        pass
                    if i is not None:
                        if isinstance(i, int) and i >= 0:
                            this.progress[0] += i
                        else:
                            break

            self._progress_fn = _progress
            self._progress_thread = threading.Thread(
                target=self._progress_fn, args=(progress_Queue, self)
            )
            self._progress_thread.start()

    @staticmethod
    def asciiBar(lst, memo=[1]):
        memo[0] += 1
        spinner = "|/-\\"[memo[0] % 4]
        percent = "%5.1f%%" % (lst[0] * 100.0 / lst[1])
        mbytes = "%9.1f MB / %9.1f MB" % (lst[0] / 1e6, lst[1] / 1e6)
        if lst[1] != 0:
            s = "  {spinner} {percent} [ {mbytes} ] "
        else:
            s = "  {spinner} "
        return s.format(**locals())

    def wait_until_transfer_done(self, timeout=float("inf"), progressBar=False):
        carriageReturn = "\r"
        begin = t = time.time()
        end = begin + timeout
        while not self.__done:
            time.sleep(min(0.1, max(0.0, end - t)))
            t = time.time()
            if t >= end:
                break
            if progressBar:
                print(
                    "  " + self.asciiBar(self.progress) + carriageReturn,
                    end="",
                    file=sys.stderr,
                )
                sys.stderr.flush()
        return self.__done

    def __call__(
        self, future
    ):  # Our instance is called by each future (individual file part) when done.
        # When all futures are done, we invoke the configured callback.
        with self._lock:
            self._futures_done[future] = future.result()
            if len(self._futures) == len(self._futures_done):
                self.__invoke_done_callback()

    def __invoke_done_callback(self):
        try:
            if callable(self.done_callback):
                self.done_callback(self)
        finally:
            self.keep.pop("mgr", None)
            self.__done = True
        self.set_transfer_done_callback(None)

    @property
    def futures(self):
        return list(self._futures)

    @property
    def futures_done(self):
        return dict(self._futures_done)


class Oper:
    """A custom enum-type class with utility methods.

    It makes some logic clearer, including succinct calculation of file and data
    object open() modes based on whether the operation is a PUT or GET and whether
    we are doing the "initial" open of the file or object.
    """

    GET = 0
    PUT = 1
    NONBLOCKING = 2

    def __int__(self):
        """Return the stored flags as an integer bitmask."""
        return self._opr

    def __init__(self, rhs):
        """Initialize with a bit mask of flags ie. whether Operation PUT or GET,
        and whether NONBLOCKING."""
        self._opr = int(rhs)

    def isPut(self):
        return 0 != (self._opr & self.PUT)

    def isGet(self):
        return not self.isPut()

    def isNonBlocking(self):
        return 0 != (self._opr & self.NONBLOCKING)

    def data_object_mode(self, initial_open=False):
        if self.isPut():
            return "w" if initial_open else "a"
        else:
            return "r"

    def disk_file_mode(self, initial_open=False, binary=True):
        if self.isPut():
            mode = "r"
        else:
            mode = "w" if initial_open else "r+"
        return (mode + "b") if binary else mode


def _io_send_bytes_progress(queueObject, item):
    try:
        queueObject.put(item)
        return True
    except Full:
        return False


COPY_BUF_SIZE = (1024**2) * 4


def _copy_part(src, dst, length, queueObject, debug_info, mgr, updatables=()):
    """
    The work-horse for performing the copy between file and data object.

    It also helps determine whether there has been a large enough increment of
    bytes to inform the progress bar of a need to update.
    """
    from irods.manager.data_object_manager import do_progress_updates

    bytecount = 0
    accum = 0
    while True and bytecount < length:
        buf = src.read(min(COPY_BUF_SIZE, length - bytecount))
        buf_len = len(buf)
        if 0 == buf_len:
            break
        dst.write(buf)
        bytecount += buf_len
        accum += buf_len
        if queueObject and accum and _io_send_bytes_progress(queueObject, accum):
            accum = 0
        do_progress_updates(updatables, buf_len)
        if verboseConnection:
            print("(" + debug_info + ")", end="", file=sys.stderr)
            sys.stderr.flush()

    # In a put or get, exactly one of (src,dst) is a file. Find which and close that one first.
    (file_, obj_) = (src, dst) if dst in mgr else (dst, src)
    file_.close()
    mgr.remove_io(obj_)  # 1. closes obj if it is not the mgr's initial descriptor
    # 2. blocks at barrier until all transfer threads are done copying
    # 3. closes with finalize if obj is mgr's initial descriptor
    return bytecount


class _Multipart_close_manager:
    """An object used to ensure that the initial transfer thread is also the last one to
    call the close method on its `Io' object.  The caller is responsible for setting up the
    conditions that the initial thread's close() is the one performing the catalog update.

    All non-initial transfer threads just call close() as soon as they are done transferring
    the byte range for which they are responsible, whereas we block the initial thread
    using a threading Barrier until we know all other threads have called close().

    """

    def __init__(self, initial_io_, exit_barrier_):
        self.exit_barrier = exit_barrier_
        self.initial_io = initial_io_
        self.__lock = threading.Lock()
        self.aux = []

    def __contains__(self, Io):
        with self.__lock:
            return Io is self.initial_io or Io in self.aux

    # `add_io' - add an i/o object to be managed
    # note: `remove_io' should only be called for managed i/o objects

    def add_io(self, Io):
        with self.__lock:
            if Io is not self.initial_io:
                self.aux.append(Io)

    # `remove_io' is for closing a channel of parallel i/o and allowing the
    # data object to flush write operations (if any) in a timely fashion.  It also
    # synchronizes all of the parallel threads just before exit, so that we know
    # exactly when to perform a finalizing close on the data object

    def remove_io(self, Io):
        is_initial = True
        with self.__lock:
            if Io is not self.initial_io:
                Io.close()
                self.aux.remove(Io)
                is_initial = False
        self.exit_barrier.wait()
        if is_initial:
            self.finalize()

    def finalize(self):
        self.initial_io.close()


def _io_part(
    objHandle,
    range_,
    file_,
    opr_,
    mgr_,
    thread_debug_id="",
    queueObject=None,
    updatables=None,
):
    """
    Runs in a separate thread to manage the transfer of a range of bytes within the data object.

    The particular range is defined by the end of the range_ parameter, which should be of type
    (Py2) xrange or (Py3) range.
    """
    if 0 == len(range_):
        return 0
    Operation = Oper(opr_)
    (offset, length) = (range_[0], len(range_))
    objHandle.seek(offset)
    file_.seek(offset)
    if thread_debug_id == "":  # for more succinct thread identifiers while debugging.
        thread_debug_id = str(threading.currentThread().ident)
    return (
        _copy_part(
            file_, objHandle, length, queueObject, thread_debug_id, mgr_, updatables
        )
        if Operation.isPut()
        else _copy_part(
            objHandle, file_, length, queueObject, thread_debug_id, mgr_, updatables
        )
    )


def _io_multipart_threaded(
    operation_,
    dataObj_and_IO,
    replica_token,
    hier_str,
    session,
    fname,
    total_size,
    num_threads,
    **extra_options
):
    """Called by _io_main.

    Carve up (0,total_size) range into `num_threads` parts and initiate a transfer thread for each one.
    """
    (Data_object, Io) = dataObj_and_IO
    Operation = Oper(operation_)

    def bytes_range_for_thread(i, num_threads, total_bytes, chunk):
        begin_offs = i * chunk
        if i + 1 < num_threads:
            end_offs = (i + 1) * chunk
        else:
            end_offs = total_bytes
        return range(begin_offs, end_offs)

    bytes_per_thread = total_size // num_threads

    ranges = [
        bytes_range_for_thread(i, num_threads, total_size, bytes_per_thread)
        for i in range(num_threads)
    ]

    logger.info(
        "num_threads = %s ; bytes_per_thread = %s", num_threads, bytes_per_thread
    )

    queueLength = extra_options.get("queueLength", 0)
    if queueLength > 0:
        queueObject = Queue(queueLength)
    else:
        queueObject = None

    futures = []
    executor = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
    num_threads = min(num_threads, len(ranges))
    mgr = _Multipart_close_manager(Io, Barrier(num_threads))
    counter = 1
    gen_file_handle = lambda: open(
        fname, Operation.disk_file_mode(initial_open=(counter == 1))
    )
    File = gen_file_handle()

    thread_opts = {
        "updatables": extra_options.get("updatables", ()),
        "queueObject": queueObject,
    }

    for byte_range in ranges:
        if Io is None:
            Io = session.data_objects.open(
                Data_object.path,
                Operation.data_object_mode(initial_open=False),
                create=False,
                finalize_on_close=False,
                allow_redirect=False,
                **{
                    kw.NUM_THREADS_KW: str(num_threads),
                    kw.DATA_SIZE_KW: str(total_size),
                    kw.RESC_HIER_STR_KW: hier_str,
                    kw.REPLICA_TOKEN_KW: replica_token,
                }
            )
        mgr.add_io(Io)
        logger.debug("target_host = %s", Io.raw.session.pool.account.host)
        if File is None:
            File = gen_file_handle()
        futures.append(
            executor.submit(
                _io_part,
                Io,
                byte_range,
                File,
                Operation,
                mgr,
                thread_debug_id=str(counter),
                **thread_opts
            )
        )
        counter += 1
        Io = File = None

    if Operation.isNonBlocking():
        if queueLength:
            return futures, queueObject, mgr
        else:
            return futures
    else:
        bytecounts = [f.result() for f in futures]
        return sum(bytecounts), total_size


def io_main(session, Data, opr_, fname, R="", **kwopt):
    """
    The entry point for parallel transfers (multithreaded PUT and GET operations).

    Here, we do the following:
    * instantiate the data object, if this has not already been done.
    * determine replica information and the appropriate number of threads.
    * call the multithread manager to initiate multiple data transfer threads

    """
    total_bytes = kwopt.pop("total_bytes", -1)
    Operation = Oper(opr_)
    d_path = None
    Io = None
    if isinstance(Data, tuple):
        (Data, Io) = Data[:2]

    if isinstance(Data, str):
        d_path = Data
        try:
            Data = session.data_objects.get(Data)
            d_path = Data.path
        except DataObjectDoesNotExist:
            if Operation.isGet():
                raise

    R_via_libcall = kwopt.pop("target_resource_name", "")
    if R_via_libcall:
        R = R_via_libcall

    num_threads = kwopt.get("num_threads", None)
    if num_threads is None:
        num_threads = int(kwopt.get("N", "0"))
    if num_threads < 1:
        num_threads = RECOMMENDED_NUM_THREADS_PER_TRANSFER
    num_threads = max(1, min(multiprocessing.cpu_count(), num_threads))

    open_options = {}
    if Operation.isPut():
        if R:
            open_options[kw.RESC_NAME_KW] = R
            open_options[kw.DEST_RESC_NAME_KW] = R
        open_options[kw.NUM_THREADS_KW] = str(num_threads)
        open_options[kw.DATA_SIZE_KW] = str(total_bytes)

    output_values = {}
    if not Io:
        (Io, rawfile) = session.data_objects.open_with_FileRaw(
            (d_path or Data.path),
            Operation.data_object_mode(initial_open=True),
            finalize_on_close=True,
            returned_values=output_values,
            **open_options
        )
    else:
        if type(Io) is deferred_call:
            Io[kw.NUM_THREADS_KW] = str(num_threads)
            Io[kw.DATA_SIZE_KW] = str(total_bytes)
            Io["returned_values"] = output_values
            Io = Io()
        rawfile = Io.raw

    if not output_values:
        output_values = kwopt.get("data_open_returned_values", {})

    if "session" in output_values:
        session = output_values["session"]

    # At this point, the data object's existence in the catalog is guaranteed,
    # whether the Operation is a GET or PUT.

    if not isinstance(Data, iRODSDataObject) or "session" in output_values:
        Data = session.data_objects.get(d_path)

    # Determine total number of bytes for transfer.

    if Operation.isGet():
        total_bytes = Io.seek(0, os.SEEK_END)
        Io.seek(0, os.SEEK_SET)
    else:  # isPut
        if total_bytes < 0:
            with open(fname, "rb") as f:
                f.seek(0, os.SEEK_END)
                total_bytes = f.tell()

    # Get necessary info and initiate threaded transfers.

    (replica_token, resc_hier) = rawfile.replica_access_info()

    queueLength = kwopt.get("queueLength", 0)

    pass_thru_options = ("updatables", "queueLength")
    retval = _io_multipart_threaded(
        Operation,
        (Data, Io),
        replica_token,
        resc_hier,
        session,
        fname,
        total_bytes,
        num_threads=num_threads,
        **{k: v for k, v in kwopt.items() if k in pass_thru_options}
    )

    # SessionObject.data_objects.parallel_{put,get} will return:
    #   - immediately with an AsyncNotify instance, if Oper.NONBLOCKING flag is used.
    #   - upon completion with a boolean completion status, otherwise.

    if Operation.isNonBlocking():

        if queueLength > 0:
            (futures, chunk_notify_queue, mgr) = retval
        else:
            futures = retval
            chunk_notify_queue = total_bytes = None

        return AsyncNotify(
            futures,  # individual futures, one per transfer thread
            progress_Queue=chunk_notify_queue,  # for notifying the progress indicator thread
            total=total_bytes,  # total number of bytes for parallel transfer
            keep_={"mgr": mgr},
        )  # an open raw i/o object needing to be persisted, if any
    else:
        (_bytes_transferred, _bytes_total) = retval
        return _bytes_transferred == _bytes_total


if __name__ == "__main__":

    import getopt
    import atexit
    from irods.session import iRODSSession

    def setupLoggingWithDateTimeHeader(name, level=logging.DEBUG):
        if _nullh in logger.handlers:
            logger.removeHandler(_nullh)
            if name:
                handler = logging.FileHandler(name)
            else:
                handler = logging.StreamHandler()
            handler.setFormatter(logging.Formatter("%(asctime)-15s - %(message)s"))
        logger.addHandler(handler)
        logger.setLevel(level)

    try:
        env_file = os.environ["IRODS_ENVIRONMENT_FILE"]
    except KeyError:
        env_file = os.path.expanduser("~/.irods/irods_environment.json")
    ssl_context = ssl.create_default_context(
        purpose=ssl.Purpose.SERVER_AUTH, cafile=None, capath=None, cadata=None
    )
    ssl_settings = {"ssl_context": ssl_context}
    sess = iRODSSession(irods_env_file=env_file, **ssl_settings)
    atexit.register(lambda: sess.cleanup())

    opt, arg = getopt.getopt(sys.argv[1:], "vL:l:aR:N:")

    opts = dict(opt)

    logFilename = opts.pop(
        "-L", None
    )  # '' for console, non-empty for filesystem destination
    logLevel = logging.INFO if logFilename is None else logging.DEBUG
    logFilename = logFilename or opts.pop("-l", None)

    if logFilename is not None:
        setupLoggingWithDateTimeHeader(logFilename, logLevel)

    verboseConnection = opts.pop("-v", None) is not None

    async_xfer = opts.pop("-a", None)

    kwarg = {k.lstrip("-"): v for k, v in opts.items()}

    arg[1] = Oper.PUT if arg[1].lower() in ("w", "put", "a") else Oper.GET
    if async_xfer is not None:
        arg[1] |= Oper.NONBLOCKING

    ret = io_main(sess, *arg, **kwarg)  # arg[0] = data object or path
    # arg[1] = operation: or'd flags : [PUT|GET] NONBLOCKING
    # arg[2] = file path on local filesystem
    # kwarg['queueLength'] sets progress-queue length (0 if no progress indication needed)
    # kwarg options 'N' (num threads) and 'R' (target resource name) are via command-line
    # kwarg['num_threads'] (overrides 'N' when called as a library)
    # kwarg['target_resource_name'] (overrides 'R' when called as a library)
    if isinstance(ret, AsyncNotify):
        print("waiting on completion...", file=sys.stderr)
        ret.set_transfer_done_callback(
            lambda r: print("Async transfer done for:", r, file=sys.stderr)
        )
        done = ret.wait_until_transfer_done(
            timeout=10.0
        )  # - or do other useful work here
        if done:
            bytes_transferred = sum(ret.futures_done.values())
            print(
                "Asynch transfer complete. Total bytes transferred:",
                bytes_transferred,
                file=sys.stderr,
            )
        else:
            print(
                "Asynch transfer was not completed before timeout expired.",
                file=sys.stderr,
            )
    else:
        print(
            "Synchronous transfer {}".format("succeeded" if ret else "failed"),
            file=sys.stderr,
        )

# Note : This module requires concurrent.futures, included in Python3.x.
#        On Python2.7, this dependency must be installed using 'pip install futures'.
# Demonstration :
#
# $ dd if=/dev/urandom bs=1k count=150000 of=$HOME/puttest
# $ time python -m irods.parallel -R demoResc -N 3 `ipwd`/test.dat put $HOME/puttest  # add -v,-a for verbose, asynch
# $ time python -m irods.parallel -R demoResc -N 3 `ipwd`/test.dat get $HOME/gettest  # add -v,-a for verbose, asynch
# $ diff puttest gettest
