#    Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import collections
import functools
import random

from automaton import exceptions as excp
from automaton import machines
from automaton import runners

from testtools import testcase


class FSMTest(testcase.TestCase):

    @staticmethod
    def _create_fsm(start_state, add_start=True, add_states=None):
        m = machines.FiniteMachine()
        if add_start:
            m.add_state(start_state)
            m.default_start_state = start_state
        if add_states:
            for s in add_states:
                if s in m:
                    continue
                m.add_state(s)
        return m

    def setUp(self):
        super(FSMTest, self).setUp()
        # NOTE(harlowja): this state machine will never stop if run() is used.
        self.jumper = self._create_fsm("down", add_states=['up', 'down'])
        self.jumper.add_transition('down', 'up', 'jump')
        self.jumper.add_transition('up', 'down', 'fall')
        self.jumper.add_reaction('up', 'jump', lambda *args: 'fall')
        self.jumper.add_reaction('down', 'fall', lambda *args: 'jump')

    def test_build(self):
        space = []
        for a in 'abc':
            space.append(machines.State(a))
        m = machines.FiniteMachine.build(space)
        for a in 'abc':
            self.assertIn(a, m)

    def test_build_transitions(self):
        space = [
            machines.State('down', is_terminal=False,
                           next_states={'jump': 'up'}),
            machines.State('up', is_terminal=False,
                           next_states={'fall': 'down'}),
        ]
        m = machines.FiniteMachine.build(space)
        m.default_start_state = 'down'
        expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')]
        self.assertEqual(expected, list(m))

    def test_build_transitions_with_callbacks(self):
        entered = collections.defaultdict(list)
        exitted = collections.defaultdict(list)

        def on_enter(state, event):
            entered[state].append(event)

        def on_exit(state, event):
            exitted[state].append(event)

        space = [
            machines.State('down', is_terminal=False,
                           next_states={'jump': 'up'},
                           on_enter=on_enter, on_exit=on_exit),
            machines.State('up', is_terminal=False,
                           next_states={'fall': 'down'},
                           on_enter=on_enter, on_exit=on_exit),
        ]
        m = machines.FiniteMachine.build(space)
        m.default_start_state = 'down'
        expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')]
        self.assertEqual(expected, list(m))

        m.initialize()
        m.process_event('jump')

        self.assertEqual({'down': ['jump']}, dict(exitted))
        self.assertEqual({'up': ['jump']}, dict(entered))

        m.process_event('fall')

        self.assertEqual({'down': ['jump'], 'up': ['fall']}, dict(exitted))
        self.assertEqual({'up': ['jump'], 'down': ['fall']}, dict(entered))

    def test_build_transitions_dct(self):
        space = [
            {
                'name': 'down', 'is_terminal': False,
                'next_states': {'jump': 'up'},
            },
            {
                'name': 'up', 'is_terminal': False,
                'next_states': {'fall': 'down'},
            },
        ]
        m = machines.FiniteMachine.build(space)
        m.default_start_state = 'down'
        expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')]
        self.assertEqual(expected, list(m))

    def test_build_terminal(self):
        space = [
            machines.State('down', is_terminal=False,
                           next_states={'jump': 'fell_over'}),
            machines.State('fell_over', is_terminal=True),
        ]
        m = machines.FiniteMachine.build(space)
        m.default_start_state = 'down'
        m.initialize()
        m.process_event('jump')
        self.assertTrue(m.terminated)

    def test_actionable(self):
        self.jumper.initialize()
        self.assertTrue(self.jumper.is_actionable_event('jump'))
        self.assertFalse(self.jumper.is_actionable_event('fall'))

    def test_bad_start_state(self):
        m = self._create_fsm('unknown', add_start=False)
        r = runners.FiniteRunner(m)
        self.assertRaises(excp.NotFound, r.run, 'unknown')

    def test_contains(self):
        m = self._create_fsm('unknown', add_start=False)
        self.assertNotIn('unknown', m)
        m.add_state('unknown')
        self.assertIn('unknown', m)

    def test_no_add_transition_terminal(self):
        m = self._create_fsm('up')
        m.add_state('down', terminal=True)
        self.assertRaises(excp.InvalidState,
                          m.add_transition, 'down', 'up', 'jump')

    def test_duplicate_state(self):
        m = self._create_fsm('unknown')
        self.assertRaises(excp.Duplicate, m.add_state, 'unknown')

    def test_duplicate_transition(self):
        m = self.jumper
        m.add_state('side_ways')
        self.assertRaises(excp.Duplicate,
                          m.add_transition, 'up', 'side_ways', 'fall')

    def test_duplicate_transition_replace(self):
        m = self.jumper
        m.add_state('side_ways')
        m.add_transition('up', 'side_ways', 'fall', replace=True)

    def test_duplicate_transition_same_transition(self):
        m = self.jumper
        m.add_transition('up', 'down', 'fall')

    def test_duplicate_reaction(self):
        self.assertRaises(
            # Currently duplicate reactions are not allowed...
            excp.Duplicate,
            self.jumper.add_reaction, 'down', 'fall', lambda *args: 'skate')

    def test_bad_transition(self):
        m = self._create_fsm('unknown')
        m.add_state('fire')
        self.assertRaises(excp.NotFound, m.add_transition,
                          'unknown', 'something', 'boom')
        self.assertRaises(excp.NotFound, m.add_transition,
                          'something', 'unknown', 'boom')

    def test_bad_reaction(self):
        m = self._create_fsm('unknown')
        self.assertRaises(excp.NotFound, m.add_reaction, 'something', 'boom',
                          lambda *args: 'cough')

    def test_run(self):
        m = self._create_fsm('down', add_states=['up', 'down'])
        m.add_state('broken', terminal=True)
        m.add_transition('down', 'up', 'jump')
        m.add_transition('up', 'broken', 'hit-wall')
        m.add_reaction('up', 'jump', lambda *args: 'hit-wall')
        self.assertEqual(['broken', 'down', 'up'], sorted(m.states))
        self.assertEqual(2, m.events)
        m.initialize()
        self.assertEqual('down', m.current_state)
        self.assertFalse(m.terminated)
        r = runners.FiniteRunner(m)
        r.run('jump')
        self.assertTrue(m.terminated)
        self.assertEqual('broken', m.current_state)
        self.assertRaises(excp.InvalidState, r.run,
                          'jump', initialize=False)

    def test_on_enter_on_exit(self):
        enter_transitions = []
        exit_transitions = []

        def on_exit(state, event):
            exit_transitions.append((state, event))

        def on_enter(state, event):
            enter_transitions.append((state, event))

        m = self._create_fsm('start', add_start=False)
        m.add_state('start', on_exit=on_exit)
        m.add_state('down', on_enter=on_enter, on_exit=on_exit)
        m.add_state('up', on_enter=on_enter, on_exit=on_exit)
        m.add_transition('start', 'down', 'beat')
        m.add_transition('down', 'up', 'jump')
        m.add_transition('up', 'down', 'fall')

        m.initialize('start')
        m.process_event('beat')
        m.process_event('jump')
        m.process_event('fall')
        self.assertEqual([('down', 'beat'),
                          ('up', 'jump'), ('down', 'fall')], enter_transitions)
        self.assertEqual([('start', 'beat'), ('down', 'jump'), ('up', 'fall')],
                         exit_transitions)

    def test_run_iter(self):
        up_downs = []
        runner = runners.FiniteRunner(self.jumper)
        for (old_state, new_state) in runner.run_iter('jump'):
            up_downs.append((old_state, new_state))
            if len(up_downs) >= 3:
                break
        self.assertEqual([('down', 'up'), ('up', 'down'), ('down', 'up')],
                         up_downs)
        self.assertFalse(self.jumper.terminated)
        self.assertEqual('up', self.jumper.current_state)
        self.jumper.process_event('fall')
        self.assertEqual('down', self.jumper.current_state)

    def test_run_send(self):
        up_downs = []
        runner = runners.FiniteRunner(self.jumper)
        it = runner.run_iter('jump')
        while True:
            up_downs.append(it.send(None))
            if len(up_downs) >= 3:
                it.close()
                break
        self.assertEqual('up', self.jumper.current_state)
        self.assertFalse(self.jumper.terminated)
        self.assertEqual([('down', 'up'), ('up', 'down'), ('down', 'up')],
                         up_downs)
        self.assertRaises(StopIteration, next, it)

    def test_run_send_fail(self):
        up_downs = []
        runner = runners.FiniteRunner(self.jumper)
        it = runner.run_iter('jump')
        up_downs.append(next(it))
        self.assertRaises(excp.NotFound, it.send, 'fail')
        it.close()
        self.assertEqual([('down', 'up')], up_downs)

    def test_not_initialized(self):
        self.assertRaises(excp.NotInitialized,
                          self.jumper.process_event, 'jump')

    def test_copy_states(self):
        c = self._create_fsm('down', add_start=False)
        self.assertEqual(0, len(c.states))
        d = c.copy()
        c.add_state('up')
        c.add_state('down')
        self.assertEqual(2, len(c.states))
        self.assertEqual(0, len(d.states))

    def test_copy_reactions(self):
        c = self._create_fsm('down', add_start=False)
        d = c.copy()

        c.add_state('down')
        c.add_state('up')
        c.add_reaction('down', 'jump', lambda *args: 'up')
        c.add_transition('down', 'up', 'jump')

        self.assertEqual(1, c.events)
        self.assertEqual(0, d.events)
        self.assertNotIn('down', d)
        self.assertNotIn('up', d)
        self.assertEqual([], list(d))
        self.assertEqual([('down', 'jump', 'up')], list(c))

    def test_copy_initialized(self):
        j = self.jumper.copy()
        self.assertIsNone(j.current_state)
        r = runners.FiniteRunner(self.jumper)

        for i, transition in enumerate(r.run_iter('jump')):
            if i == 4:
                break

        self.assertIsNone(j.current_state)
        self.assertIsNotNone(self.jumper.current_state)

    def test_iter(self):
        transitions = list(self.jumper)
        self.assertEqual(2, len(transitions))
        self.assertIn(('up', 'fall', 'down'), transitions)
        self.assertIn(('down', 'jump', 'up'), transitions)

    def test_freeze(self):
        self.jumper.freeze()
        self.assertRaises(excp.FrozenMachine, self.jumper.add_state, 'test')
        self.assertRaises(excp.FrozenMachine,
                          self.jumper.add_transition, 'test', 'test', 'test')
        self.assertRaises(excp.FrozenMachine,
                          self.jumper.add_reaction,
                          'test', 'test', lambda *args: 'test')

    def test_freeze_copy_unfreeze(self):
        self.jumper.freeze()
        self.assertTrue(self.jumper.frozen)
        cp = self.jumper.copy(unfreeze=True)
        self.assertTrue(self.jumper.frozen)
        self.assertFalse(cp.frozen)

    def test_invalid_callbacks(self):
        m = self._create_fsm('working', add_states=['working', 'broken'])
        self.assertRaises(ValueError, m.add_state, 'b', on_enter=2)
        self.assertRaises(ValueError, m.add_state, 'b', on_exit=2)


class HFSMTest(FSMTest):

    @staticmethod
    def _create_fsm(start_state,
                    add_start=True, hierarchical=False, add_states=None):
        if hierarchical:
            m = machines.HierarchicalFiniteMachine()
        else:
            m = machines.FiniteMachine()
        if add_start:
            m.add_state(start_state)
            m.default_start_state = start_state
        if add_states:
            for s in add_states:
                if s not in m:
                    m.add_state(s)
        return m

    def _make_phone_call(self, talk_time=1.0):

        def phone_reaction(old_state, new_state, event, chat_iter):
            try:
                next(chat_iter)
            except StopIteration:
                return 'finish'
            else:
                # Talk until the iterator expires...
                return 'chat'

        talker = self._create_fsm("talk")
        talker.add_transition("talk", "talk", "pickup")
        talker.add_transition("talk", "talk", "chat")
        talker.add_reaction("talk", "pickup", lambda *args: 'chat')
        chat_iter = iter(list(range(0, 10)))
        talker.add_reaction("talk", "chat", phone_reaction, chat_iter)

        handler = self._create_fsm('begin', hierarchical=True)
        handler.add_state("phone", machine=talker)
        handler.add_state('hangup', terminal=True)
        handler.add_transition("begin", "phone", "call")
        handler.add_reaction("phone", 'call', lambda *args: 'pickup')
        handler.add_transition("phone", "hangup", "finish")

        return handler

    def _make_phone_dialer(self):
        dialer = self._create_fsm("idle", hierarchical=True)
        digits = self._create_fsm("idle")

        dialer.add_state("pickup", machine=digits)
        dialer.add_transition("idle", "pickup", "dial")
        dialer.add_reaction("pickup", "dial", lambda *args: 'press')
        dialer.add_state("hangup", terminal=True)

        def react_to_press(last_state, new_state, event, number_calling):
            if len(number_calling) >= 10:
                return 'call'
            else:
                return 'press'

        digit_maker = functools.partial(random.randint, 0, 9)
        number_calling = []
        digits.add_state(
            "accumulate",
            on_enter=lambda *args: number_calling.append(digit_maker()))
        digits.add_transition("idle", "accumulate", "press")
        digits.add_transition("accumulate", "accumulate", "press")
        digits.add_reaction("accumulate", "press",
                            react_to_press, number_calling)
        digits.add_state("dial", terminal=True)
        digits.add_transition("accumulate", "dial", "call")
        digits.add_reaction("dial", "call", lambda *args: 'ringing')
        dialer.add_state("talk")
        dialer.add_transition("pickup", "talk", "ringing")
        dialer.add_reaction("talk", "ringing", lambda *args: 'hangup')
        dialer.add_transition("talk", "hangup", 'hangup')
        return dialer, number_calling

    def test_nested_machines(self):
        dialer, _number_calling = self._make_phone_dialer()
        self.assertEqual(1, len(dialer.nested_machines))

    def test_nested_machine_initializers(self):
        dialer, _number_calling = self._make_phone_dialer()
        queried_for = []

        def init_with(nested_machine):
            queried_for.append(nested_machine)
            return None

        dialer.initialize(nested_start_state_fetcher=init_with)
        self.assertEqual(1, len(queried_for))

    def test_phone_dialer_iter(self):
        dialer, number_calling = self._make_phone_dialer()
        self.assertEqual(0, len(number_calling))
        r = runners.HierarchicalRunner(dialer)
        transitions = list(r.run_iter('dial'))
        self.assertEqual(('talk', 'hangup'), transitions[-1])
        self.assertEqual(len(number_calling),
                         sum(1 if new_state == 'accumulate' else 0
                         for (old_state, new_state) in transitions))
        self.assertEqual(10, len(number_calling))

    def test_phone_call(self):
        handler = self._make_phone_call()
        r = runners.HierarchicalRunner(handler)
        r.run('call')
        self.assertTrue(handler.terminated)

    def test_phone_call_iter(self):
        handler = self._make_phone_call()
        r = runners.HierarchicalRunner(handler)
        transitions = list(r.run_iter('call'))
        self.assertEqual(('talk', 'hangup'), transitions[-1])
        self.assertEqual(("begin", 'phone'), transitions[0])
        talk_talk = 0
        for transition in transitions:
            if transition == ("talk", "talk"):
                talk_talk += 1
        self.assertGreater(talk_talk, 0)
