# -*- coding: utf-8 -*-

#    Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import contextlib
import threading

from zake import fake_client

from taskflow.conductors import single_threaded as stc
from taskflow import engines
from taskflow.jobs.backends import impl_zookeeper
from taskflow.jobs import jobboard
from taskflow.patterns import linear_flow as lf
from taskflow.persistence.backends import impl_memory
from taskflow import states as st
from taskflow import test
from taskflow.tests import utils as test_utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu


@contextlib.contextmanager
def close_many(*closeables):
    try:
        yield
    finally:
        for c in closeables:
            c.close()


def test_factory(blowup):
    f = lf.Flow("test")
    if not blowup:
        f.add(test_utils.SaveOrderTask('test1'))
    else:
        f.add(test_utils.FailingTask("test1"))
    return f


def make_thread(conductor):
    t = threading.Thread(target=conductor.run)
    t.daemon = True
    return t


class SingleThreadedConductorTest(test_utils.EngineTestBase, test.TestCase):
    def make_components(self, name='testing', wait_timeout=0.1):
        client = fake_client.FakeClient()
        persistence = impl_memory.MemoryBackend()
        board = impl_zookeeper.ZookeeperJobBoard(name, {},
                                                 client=client,
                                                 persistence=persistence)
        engine_conf = {
            'engine': 'default',
        }
        conductor = stc.SingleThreadedConductor(name, board, engine_conf,
                                                persistence, wait_timeout)
        return misc.AttrDict(board=board,
                             client=client,
                             persistence=persistence,
                             conductor=conductor)

    def test_connection(self):
        components = self.make_components()
        components.conductor.connect()
        with close_many(components.conductor, components.client):
            self.assertTrue(components.board.connected)
            self.assertTrue(components.client.connected)
        self.assertFalse(components.board.connected)
        self.assertFalse(components.client.connected)

    def test_run_empty(self):
        components = self.make_components()
        components.conductor.connect()
        with close_many(components.conductor, components.client):
            t = make_thread(components.conductor)
            t.start()
            self.assertTrue(components.conductor.stop(0.5))
            self.assertFalse(components.conductor.dispatching)
            t.join()

    def test_run(self):
        components = self.make_components()
        components.conductor.connect()
        consumed_event = threading.Event()

        def on_consume(state, details):
            consumed_event.set()

        components.board.notifier.register(jobboard.REMOVAL, on_consume)
        with close_many(components.conductor, components.client):
            t = make_thread(components.conductor)
            t.start()
            lb, fd = pu.temporary_flow_detail(components.persistence)
            engines.save_factory_details(fd, test_factory,
                                         [False], {},
                                         backend=components.persistence)
            components.board.post('poke', lb,
                                  details={'flow_uuid': fd.uuid})
            consumed_event.wait(1.0)
            self.assertTrue(consumed_event.is_set())
            self.assertTrue(components.conductor.stop(1.0))
            self.assertFalse(components.conductor.dispatching)

        persistence = components.persistence
        with contextlib.closing(persistence.get_connection()) as conn:
            lb = conn.get_logbook(lb.uuid)
            fd = lb.find(fd.uuid)
        self.assertIsNotNone(fd)
        self.assertEqual(st.SUCCESS, fd.state)

    def test_fail_run(self):
        components = self.make_components()
        components.conductor.connect()

        consumed_event = threading.Event()

        def on_consume(state, details):
            consumed_event.set()

        components.board.notifier.register(jobboard.REMOVAL, on_consume)
        with close_many(components.conductor, components.client):
            t = make_thread(components.conductor)
            t.start()
            lb, fd = pu.temporary_flow_detail(components.persistence)
            engines.save_factory_details(fd, test_factory,
                                         [True], {},
                                         backend=components.persistence)
            components.board.post('poke', lb,
                                  details={'flow_uuid': fd.uuid})
            consumed_event.wait(1.0)
            self.assertTrue(consumed_event.is_set())
            self.assertTrue(components.conductor.stop(1.0))
            self.assertFalse(components.conductor.dispatching)

        persistence = components.persistence
        with contextlib.closing(persistence.get_connection()) as conn:
            lb = conn.get_logbook(lb.uuid)
            fd = lb.find(fd.uuid)
        self.assertIsNotNone(fd)
        self.assertEqual(st.REVERTED, fd.state)
