# Copyright 2009 Brian Quinlan. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.

"""Implements ProcessPoolExecutor.

The follow diagram and text describe the data-flow through the system::

    |==================== In-process ====================|== Out-of-process ==|

    +----------+     +-------+       +--------+     +----------+    +---------+
    |          |  => | Work  |    => |        |  => | Call Q   | => |         |
    |          |     | Ids   |       |        |     |          |    |         |
    |          |     +-------+       |        |     +----------+    |         |
    |          |     | ...   |       |        |     |...       |    |         |
    |          |     | 6     |       |        |     |5, call() |    |         |
    |          |     | 7     |       |        |     |...       |    |         |
    | Process  |     | ...   |       | Local  |     +----------+    | Process |
    |  Pool    |     +-------+       | Worker |                     |  #1..n  |
    | Executor |                     | Thread |                     |         |
    |          |     +---------+     |        |     +----------+    |         |
    |          | <=> |  Work   | <=> |        | <=  | Result Q | <= |         |
    |          |     |  Items  |     |        |     |          |    |         |
    |          |     +---------+     |        |     +----------+    |         |
    |          |     |6: call()|     |        |     |...       |    |         |
    |          |     |   future|     |        |     |4, result |    |         |
    |          |     |...      |     |        |     |3, except |    |         |
    +----------+     +---------+     +--------+     +----------+    +---------+

Executor.submit() called:
    - creates a uniquely numbered _WorkItem and adds it to the 'Work Items'
      dict
    - adds the id of the _WorkItem to the 'Work Ids' queue

Local worker thread:
    - reads work ids from the 'Work Ids' queue and looks up the corresponding
      WorkItem from the 'Work Items' dict: if the work item has been cancelled
      then it is simply removed from the dict, otherwise it is repackaged as a
      _CallItem and put in the 'Call Q'. New _CallItems are put in the 'Call Q'
      until 'Call Q' is full. NOTE: the size of the 'Call Q' is kept small
      because calls placed in the 'Call Q' can no longer be cancelled with
      Future.cancel().
    - reads _ResultItems from 'Result Q', updates the future stored in the
      'Work Items' dict and deletes the dict entry

Process #1..n:
    - reads _CallItems from 'Call Q', executes the calls, and puts the
      resulting _ResultItems in 'Request Q'
"""

import atexit
import multiprocessing
import threading
import weakref
import sys

from . import _base

try:
    import queue
except ImportError:
    import Queue as queue


__author__ = 'Brian Quinlan (brian@sweetapp.com)'

# Workers are created as daemon threads and processes. This is done to allow
# the interpreter to exit when there are still idle processes in a
# ProcessPoolExecutor's process pool (i.e. shutdown() was not called). However,
# allowing workers to die with the interpreter has two undesirable properties:
#   - The workers would still be running during interpretor shutdown,
#     meaning that they would fail in unpredictable ways.
#   - The workers could be killed while evaluating a work item, which could
#     be bad if the callable being evaluated has external side-effects e.g.
#     writing to a file.
#
# To work around this problem, an exit handler is installed which tells the
# workers to exit when their work queues are empty and then waits until the
# threads/processes finish.

_thread_references = set()
_shutdown = False


def _python_exit():
    global _shutdown
    _shutdown = True
    for thread_reference in _thread_references:
        thread = thread_reference()
        if thread is not None:
            thread.join()


def _remove_dead_thread_references():
    """Remove inactive threads from _thread_references.

    Should be called periodically to prevent memory leaks in scenarios such as:
        >>> while True:
        ...    t = ThreadPoolExecutor(max_workers=5)
        ...    t.map(int, ['1', '2', '3', '4', '5'])
    """
    for thread_reference in set(_thread_references):
        if thread_reference() is None:
            _thread_references.discard(thread_reference)


# Controls how many more calls than processes will be queued in the call queue.
# A smaller number will mean that processes spend more time idle waiting for
# work while a larger number will make Future.cancel() succeed less frequently
# (Futures in the call queue cannot be cancelled).
EXTRA_QUEUED_CALLS = 1


class _WorkItem(object):
    def __init__(self, future, fn, args, kwargs):
        self.future = future
        self.fn = fn
        self.args = args
        self.kwargs = kwargs


class _ResultItem(object):
    def __init__(self, work_id, exception=None, result=None):
        self.work_id = work_id
        self.exception = exception
        self.result = result


class _CallItem(object):
    def __init__(self, work_id, fn, args, kwargs):
        self.work_id = work_id
        self.fn = fn
        self.args = args
        self.kwargs = kwargs


def _process_worker(call_queue, result_queue, shutdown):
    """Evaluates calls from call_queue and places the results in result_queue.

    This worker is run in a separate process.

    Parameters
    ----------
    call_queue
        A `multiprocessing.Queue` of _CallItems that will be read and
        evaluated by the worker.

    result_queue
        A `multiprocessing.Queue` of _ResultItems that will written
        to by the worker.

    shutdown
        A `multiprocessing.Event` that will be set as a signal to the
        worker that it should exit when call_queue is empty.
    """
    while True:
        try:
            call_item = call_queue.get(block=True, timeout=0.1)
        except queue.Empty:
            if shutdown.is_set():
                return
        else:
            try:
                r = call_item.fn(*call_item.args, **call_item.kwargs)
            except BaseException:
                e = sys.exc_info()[1]
                result_queue.put(_ResultItem(call_item.work_id,
                                             exception=e))
            else:
                result_queue.put(_ResultItem(call_item.work_id,
                                             result=r))


def _add_call_item_to_queue(pending_work_items,
                            work_ids,
                            call_queue):
    """Fills call_queue with _WorkItems from pending_work_items.

    This function never blocks.

    Parameters
    ----------
    pending_work_items
        A dict mapping work ids to _WorkItems e.g.
        {5: <_WorkItem...>, 6: <_WorkItem...>, ...}

    work_ids
        A queue.Queue of work ids e.g. Queue([5, 6, ...]). Work ids
        are consumed and the corresponding _WorkItems from
        pending_work_items are transformed into _CallItems and put in
        call_queue.

    call_queue
        A multiprocessing.Queue that will be filled with _CallItems
        derived from _WorkItems.
    """
    while True:
        if call_queue.full():
            return
        try:
            work_id = work_ids.get(block=False)
        except queue.Empty:
            return
        else:
            work_item = pending_work_items[work_id]

            if work_item.future.set_running_or_notify_cancel():
                call_queue.put(_CallItem(work_id,
                                         work_item.fn,
                                         work_item.args,
                                         work_item.kwargs),
                               block=True)
            else:
                del pending_work_items[work_id]
                continue


def _queue_manangement_worker(executor_reference,
                              processes,
                              pending_work_items,
                              work_ids_queue,
                              call_queue,
                              result_queue,
                              shutdown_process_event):
    """Manages the communication between this process and the worker processes.

    This function is run in a local thread.

    Parameters
    ----------
    executor_reference
        A weakref.ref to the ProcessPoolExecutor that owns
        this thread. Used to determine if the ProcessPoolExecutor has been
        garbage collected and that this function can exit.

    process
        A list of the multiprocessing.Process instances used as
        workers.

    pending_work_items
        A dict mapping work ids to _WorkItems e.g.
        {5: <_WorkItem...>, 6: <_WorkItem...>, ...}

    work_ids_queue
        A queue.Queue of work ids e.g. Queue([5, 6, ...]).

    call_queue
        A multiprocessing.Queue that will be filled with _CallItems
        derived from _WorkItems for processing by the process workers.

    result_queue
        A multiprocessing.Queue of _ResultItems generated by the
        process workers.

    shutdown_process_event
        A multiprocessing.Event used to signal the
        process workers that they should exit when their work queue is
        empty.
    """
    while True:
        _add_call_item_to_queue(pending_work_items,
                                work_ids_queue,
                                call_queue)

        try:
            result_item = result_queue.get(block=True, timeout=0.1)
        except queue.Empty:
            executor = executor_reference()
            # No more work items can be added if:
            #   - The interpreter is shutting down OR
            #   - The executor that owns this worker has been collected OR
            #   - The executor that owns this worker has been shutdown.
            if _shutdown or executor is None or executor._shutdown_thread:
                # Since no new work items can be added, it is safe to shutdown
                # this thread if there are no pending work items.
                if not pending_work_items:
                    shutdown_process_event.set()

                    # If .join() is not called on the created processes then
                    # some multiprocessing.Queue methods may deadlock on Mac OS
                    # X.
                    for p in processes:
                        p.join()
                    return
            del executor
        else:
            work_item = pending_work_items[result_item.work_id]
            del pending_work_items[result_item.work_id]

            if result_item.exception:
                work_item.future.set_exception(result_item.exception)
            else:
                work_item.future.set_result(result_item.result)


class ProcessPoolExecutor(_base.Executor):
    def __init__(self, max_workers=None):
        """Initializes a new ProcessPoolExecutor instance.

        Parameters
        ----------
        max_workers
            The maximum number of processes that can be used to
            execute the given calls. If None or not given then as many
            worker processes will be created as the machine has processors.
        """
        _remove_dead_thread_references()

        if max_workers is None:
            self._max_workers = multiprocessing.cpu_count()
        else:
            self._max_workers = max_workers

        # Make the call queue slightly larger than the number of processes to
        # prevent the worker processes from idling. But don't make it too big
        # because futures in the call queue cannot be cancelled.
        self._call_queue = multiprocessing.Queue(self._max_workers +
                                                 EXTRA_QUEUED_CALLS)
        self._result_queue = multiprocessing.Queue()
        self._work_ids = queue.Queue()
        self._queue_management_thread = None
        self._processes = set()

        # Shutdown is a two-step process.
        self._shutdown_thread = False
        self._shutdown_process_event = multiprocessing.Event()
        self._shutdown_lock = threading.Lock()
        self._queue_count = 0
        self._pending_work_items = {}

    def _start_queue_management_thread(self):
        if self._queue_management_thread is None:
            self._queue_management_thread = threading.Thread(
                target=_queue_manangement_worker,
                args=(weakref.ref(self),
                      self._processes,
                      self._pending_work_items,
                      self._work_ids,
                      self._call_queue,
                      self._result_queue,
                      self._shutdown_process_event))
            self._queue_management_thread.daemon = True
            self._queue_management_thread.start()
            _thread_references.add(weakref.ref(self._queue_management_thread))

    def _adjust_process_count(self):
        for _ in range(len(self._processes), self._max_workers):
            p = multiprocessing.Process(target=_process_worker,
                                        args=(self._call_queue,
                                              self._result_queue,
                                              self._shutdown_process_event))
            p.start()
            self._processes.add(p)

    def submit(self, fn, *args, **kwargs):
        with self._shutdown_lock:
            if self._shutdown_thread:
                raise RuntimeError(
                    'cannot schedule new futures after shutdown')

            f = _base.Future()
            w = _WorkItem(f, fn, args, kwargs)

            self._pending_work_items[self._queue_count] = w
            self._work_ids.put(self._queue_count)
            self._queue_count += 1

            self._start_queue_management_thread()
            self._adjust_process_count()
            return f
    submit.__doc__ = _base.Executor.submit.__doc__

    def shutdown(self, wait=True):
        with self._shutdown_lock:
            self._shutdown_thread = True
        if wait:
            if self._queue_management_thread:
                self._queue_management_thread.join()
        # To reduce the risk of openning too many files, remove references to
        # objects that use file descriptors.
        self._queue_management_thread = None
        self._call_queue = None
        self._result_queue = None
        self._shutdown_process_event = None
        self._processes = None
    shutdown.__doc__ = _base.Executor.shutdown.__doc__


atexit.register(_python_exit)
