# -*- coding: utf-8 -*-

#    Copyright (C) 2012-2013 Yahoo! Inc. All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import logging

from taskflow import exceptions
from taskflow import states
from taskflow.utils import misc

LOG = logging.getLogger(__name__)

SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE)


class TaskAction(object):

    def __init__(self, storage, task_executor, notifier):
        self._storage = storage
        self._task_executor = task_executor
        self._notifier = notifier

    def change_state(self, task, state, result=None, progress=None):
        old_state = self._storage.get_atom_state(task.name)
        if old_state == state:
            return state != states.PENDING
        if state in SAVE_RESULT_STATES:
            self._storage.save(task.name, result, state)
        else:
            self._storage.set_atom_state(task.name, state)
        if progress is not None:
            self._storage.set_task_progress(task.name, progress)
        task_uuid = self._storage.get_atom_uuid(task.name)
        details = dict(task_name=task.name,
                       task_uuid=task_uuid,
                       result=result)
        self._notifier.notify(state, details)
        if progress is not None:
            task.update_progress(progress)
        return True

    def _on_update_progress(self, task, event_data, progress, **kwargs):
        """Should be called when task updates its progress."""
        try:
            self._storage.set_task_progress(task.name, progress, kwargs)
        except Exception:
            # Update progress callbacks should never fail, so capture and log
            # the emitted exception instead of raising it.
            LOG.exception("Failed setting task progress for %s to %0.3f",
                          task, progress)

    def schedule_execution(self, task):
        if not self.change_state(task, states.RUNNING, progress=0.0):
            raise exceptions.InvalidState("Task %s is in invalid state and"
                                          " can't be executed" % task.name)
        kwargs = self._storage.fetch_mapped_args(task.rebind,
                                                 atom_name=task.name)
        task_uuid = self._storage.get_atom_uuid(task.name)
        return self._task_executor.execute_task(task, task_uuid, kwargs,
                                                self._on_update_progress)

    def complete_execution(self, task, result):
        if isinstance(result, misc.Failure):
            self.change_state(task, states.FAILURE, result=result)
        else:
            self.change_state(task, states.SUCCESS,
                              result=result, progress=1.0)

    def schedule_reversion(self, task):
        if not self.change_state(task, states.REVERTING, progress=0.0):
            raise exceptions.InvalidState("Task %s is in invalid state and"
                                          " can't be reverted" % task.name)
        kwargs = self._storage.fetch_mapped_args(task.rebind,
                                                 atom_name=task.name)
        task_uuid = self._storage.get_atom_uuid(task.name)
        task_result = self._storage.get(task.name)
        failures = self._storage.get_failures()
        future = self._task_executor.revert_task(task, task_uuid, kwargs,
                                                 task_result, failures,
                                                 self._on_update_progress)
        return future

    def complete_reversion(self, task, rev_result):
        if isinstance(rev_result, misc.Failure):
            self.change_state(task, states.FAILURE)
        else:
            self.change_state(task, states.REVERTED, progress=1.0)

    def wait_for_any(self, fs, timeout):
        return self._task_executor.wait_for_any(fs, timeout)
