import logging
import os
import time

from parsl.jobs.states import JobState, JobStatus
from parsl.launchers import SingleNodeLauncher
from parsl.providers.base import ExecutionProvider
from parsl.providers.errors import (
    SchedulerMissingArgs,
    ScriptPathError,
    SubmitException,
)
from parsl.utils import RepresentationMixin, execute_wait

logger = logging.getLogger(__name__)


class LocalProvider(ExecutionProvider, RepresentationMixin):
    """ Local Execution Provider

    This provider is used to provide execution resources from the localhost.

    Parameters
    ----------

    min_blocks : int
        Minimum number of blocks to maintain.
    max_blocks : int
        Maximum number of blocks to maintain.
    parallelism : float
        Ratio of provisioned task slots to active tasks. A parallelism value of 1 represents aggressive
        scaling where as many resources as possible are used; parallelism close to 0 represents
        the opposite situation in which as few resources as possible (i.e., min_blocks) are used.
    worker_init : str
        Command to be run before starting a worker, such as 'module load Anaconda; source activate env'.
    """

    def __init__(self,
                 nodes_per_block=1,
                 launcher=SingleNodeLauncher(),
                 init_blocks=1,
                 min_blocks=0,
                 max_blocks=1,
                 worker_init='',
                 cmd_timeout=30,
                 parallelism=1):
        self._label = 'local'
        self.nodes_per_block = nodes_per_block
        self.launcher = launcher
        self.worker_init = worker_init
        self.init_blocks = init_blocks
        self.min_blocks = min_blocks
        self.max_blocks = max_blocks
        self.parallelism = parallelism
        self.script_dir = None
        self.cmd_timeout = cmd_timeout

        # Dictionary that keeps track of jobs, keyed on job_id
        self.resources = {}

    def status(self, job_ids):
        '''  Get the status of a list of jobs identified by their ids.

        Args:
            - job_ids (List of ids) : List of identifiers for the jobs

        Returns:
            - List of status codes.

        '''

        for job_id in self.resources:
            # This job dict should really be a class on its own
            job_dict = self.resources[job_id]
            if job_dict['status'] and job_dict['status'].terminal:
                # We already checked this and it can't change after that
                continue
            script_path = job_dict['script_path']

            alive = self._is_alive(job_dict)
            str_ec = self._read_job_file(script_path, '.ec').strip()

            status = None
            if str_ec == '-':
                if alive:
                    status = JobStatus(JobState.RUNNING)
                else:
                    # not alive but didn't get to write an exit code
                    if 'cancelled' in job_dict:
                        # because we cancelled it
                        status = JobStatus(JobState.CANCELLED)
                    else:
                        # we didn't cancel it, so it must have been killed by something outside
                        # parsl; we don't have a state for this, but we'll use CANCELLED with
                        # a specific message
                        status = JobStatus(JobState.CANCELLED, message='Killed')
            else:
                try:
                    # TODO: ensure that these files are only read once and clean them
                    ec = int(str_ec)
                    stdout_path = self._job_file_path(script_path, '.out')
                    stderr_path = self._job_file_path(script_path, '.err')
                    if ec == 0:
                        state = JobState.COMPLETED
                    else:
                        state = JobState.FAILED
                    status = JobStatus(state, exit_code=ec,
                                       stdout_path=stdout_path, stderr_path=stderr_path)
                except Exception:
                    status = JobStatus(JobState.FAILED,
                                       'Cannot parse exit code: {}'.format(str_ec))

            job_dict['status'] = status

        return [self.resources[jid]['status'] for jid in job_ids]

    def _is_alive(self, job_dict):
        retcode, stdout, stderr = execute_wait(
            'ps -p {} > /dev/null 2> /dev/null; echo "STATUS:$?" '.format(
                job_dict['remote_pid']), self.cmd_timeout)
        for line in stdout.split('\n'):
            if line.startswith("STATUS:"):
                status = line.split("STATUS:")[1].strip()
                if status == "0":
                    return True
                else:
                    return False

    def _job_file_path(self, script_path: str, suffix: str) -> str:
        path = '{0}{1}'.format(script_path, suffix)
        return path

    def _read_job_file(self, script_path: str, suffix: str) -> str:
        path = self._job_file_path(script_path, suffix)

        with open(path, 'r') as f:
            return f.read()

    def _write_submit_script(self, script_string, script_filename):
        '''
        Load the template string with config values and write the generated submit script to
        a submit script file.

        Args:
              - template_string (string) : The template string to be used for the writing submit script
              - script_filename (string) : Name of the submit script

        Returns:
              - True: on success

        Raises:
              SchedulerMissingArgs : If template is missing args
              ScriptPathError : Unable to write submit script out
        '''

        try:
            with open(script_filename, 'w') as f:
                f.write(script_string)

        except KeyError as e:
            logger.error("Missing keys for submit script: %s", e)
            raise SchedulerMissingArgs(e.args, self.label)

        except IOError as e:
            logger.error("Failed writing to submit script: %s", script_filename)
            raise ScriptPathError(script_filename, e)

        return True

    def submit(self, command, tasks_per_node, job_name="parsl.localprovider"):
        ''' Submits the command onto an Local Resource Manager job.
        Submit returns an ID that corresponds to the task that was just submitted.

        If tasks_per_node <  1:
             1/tasks_per_node is provisioned

        If tasks_per_node == 1:
             A single node is provisioned

        If tasks_per_node >  1 :
             tasks_per_node nodes are provisioned.

        Args:
             - command  :(String) Commandline invocation to be made on the remote side.
             - tasks_per_node (int) : command invocations to be launched per node

        Kwargs:
             - job_name (String): Name for job, must be unique

        Returns:
             - None: At capacity, cannot provision more
             - job_id: (string) Identifier for the job

        '''

        job_name = "{0}.{1}".format(job_name, time.time())

        # Set script path
        script_path = "{0}/{1}.sh".format(self.script_dir, job_name)
        script_path = os.path.abspath(script_path)

        wrap_command = self.worker_init + f'\nexport JOBNAME={job_name}\n' + self.launcher(command, tasks_per_node, self.nodes_per_block)

        self._write_submit_script(wrap_command, script_path)

        job_id = None
        remote_pid = None

        logger.debug("Launching")
        # We need to capture the exit code and the streams, so we put them in files. We also write
        # '-' to the exit code file to isolate potential problems with writing to files in the
        # script directory
        #
        # The basic flow is:
        #   1. write "-" to the exit code file. If this fails, exit
        #   2. Launch the following sequence in the background:
        #      a. the command to run
        #      b. write the exit code of the command from (a) to the exit code file
        #   3. Write the PID of the background sequence on stdout. The PID is needed if we want to
        #      cancel the task later.
        #
        # We need to do the >/dev/null 2>&1 so that bash closes stdout, otherwise
        # execute_wait hangs reading the process stdout until all the
        # background commands complete.
        cmd = '/bin/bash -c \'echo - >{0}.ec && {{ {{ bash {0} 1>{0}.out 2>{0}.err ; ' \
              'echo $? > {0}.ec ; }} >/dev/null 2>&1 & echo "PID:$!" ; }}\''.format(script_path)
        retcode, stdout, stderr = execute_wait(cmd, self.cmd_timeout)
        if retcode != 0:
            raise SubmitException(job_name, "Launch command exited with code {0}".format(retcode),
                                  stdout, stderr)
        for line in stdout.split('\n'):
            if line.startswith("PID:"):
                remote_pid = line.split("PID:")[1].strip()
                job_id = remote_pid
        if job_id is None:
            raise SubmitException(job_name, "Channel failed to start remote command/retrieve PID")

        self.resources[job_id] = {'job_id': job_id, 'status': JobStatus(JobState.RUNNING),
                                  'remote_pid': remote_pid, 'script_path': script_path}

        return job_id

    def cancel(self, job_ids):
        ''' Cancels the jobs specified by a list of job ids

        Args:
        job_ids : [<job_id> ...]

        Returns: [True]  Always returns true for every job_id, regardless of
                         whether an individual cancel failed (unless an
                         exception is raised)
        '''
        for job in job_ids:
            job_dict = self.resources[job]
            job_dict['cancelled'] = True
            logger.debug("Terminating job/process ID: {0}".format(job))
            cmd = "kill -- -$(ps -o pgid= {} | grep -o '[0-9]*')".format(job_dict['remote_pid'])
            retcode, stdout, stderr = execute_wait(cmd, self.cmd_timeout)
            if retcode != 0:
                logger.warning("Failed to kill PID: {} and child processes on {}".format(job_dict['remote_pid'],
                                                                                         self.label))

        rets = [True for i in job_ids]
        return rets

    @property
    def label(self):
        return self._label

    @property
    def status_polling_interval(self):
        return 5
