# -*- coding: utf-8 -*-

#    Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

from __future__ import absolute_import

import collections
import logging

import fixtures
from oslotest import base
from oslotest import mockpatch
import six

# This is weird like this since we want to import a mock that works the best
# and we need to try this import order, since oslotest registers a six.moves
# module (but depending on the import order of importing oslotest we may or
# may not see that change when trying to use it from six).
try:
    from six.moves import mock
except ImportError:
    try:
        # In python 3.3+ mock got included in the standard library...
        from unittest import mock
    except ImportError:
        import mock

from testtools import compat
from testtools import matchers
from testtools import testcase

from taskflow import exceptions
from taskflow.tests import utils
from taskflow.utils import misc


class GreaterThanEqual(object):
    """Matches if the item is geq than the matchers reference object."""

    def __init__(self, source):
        self.source = source

    def match(self, other):
        if other >= self.source:
            return None
        return matchers.Mismatch("%s was not >= %s" % (other, self.source))


class FailureRegexpMatcher(object):
    """Matches if the failure was caused by the given exception and message.

    This will match if a given failure contains and exception of the given
    class type and if its string message matches to the given regular
    expression pattern.
    """

    def __init__(self, exc_class, pattern):
        self.exc_class = exc_class
        self.pattern = pattern

    def match(self, failure):
        for cause in failure:
            if cause.check(self.exc_class) is not None:
                return matchers.MatchesRegex(
                    self.pattern).match(cause.exception_str)
        return matchers.Mismatch("The `%s` wasn't caused by the `%s`" %
                                 (failure, self.exc_class))


class ItemsEqual(object):
    """Matches the items in two sequences.

    This matcher will validate that the provided sequence has the same elements
    as a reference sequence, regardless of the order.
    """

    def __init__(self, seq):
        self._seq = seq
        self._list = list(seq)

    def match(self, other):
        other_list = list(other)
        extra = misc.sequence_minus(other_list, self._list)
        missing = misc.sequence_minus(self._list, other_list)
        if extra or missing:
            msg = ("Sequences %s and %s do not have same items."
                   % (self._seq, other))
            if missing:
                msg += " Extra items in first sequence: %s." % missing
            if extra:
                msg += " Extra items in second sequence: %s." % extra
            return matchers.Mismatch(msg)
        return None


class TestCase(base.BaseTestCase):
    """Test case base class for all taskflow unit tests."""

    def makeTmpDir(self):
        t_dir = self.useFixture(fixtures.TempDir())
        return t_dir.path

    def assertDictEqual(self, expected, check):
        self.assertIsInstance(expected, dict,
                              'First argument is not a dictionary')
        self.assertIsInstance(check, dict,
                              'Second argument is not a dictionary')

        # Testtools seems to want equals objects instead of just keys?
        compare_dict = {}
        for k in list(six.iterkeys(expected)):
            if not isinstance(expected[k], matchers.Equals):
                compare_dict[k] = matchers.Equals(expected[k])
            else:
                compare_dict[k] = expected[k]
        self.assertThat(matchee=check,
                        matcher=matchers.MatchesDict(compare_dict))

    def assertRaisesAttrAccess(self, exc_class, obj, attr_name):

        def access_func():
            getattr(obj, attr_name)

        self.assertRaises(exc_class, access_func)

    def assertRaisesRegexp(self, exc_class, pattern, callable_obj,
                           *args, **kwargs):
        # TODO(harlowja): submit a pull/review request to testtools to add
        # this method to there codebase instead of having it exist in ours
        # since it really doesn't belong here.

        class ReRaiseOtherTypes(object):
            def match(self, matchee):
                if not issubclass(matchee[0], exc_class):
                    compat.reraise(*matchee)

        class CaptureMatchee(object):
            def match(self, matchee):
                self.matchee = matchee[1]

        capture = CaptureMatchee()
        matcher = matchers.Raises(matchers.MatchesAll(ReRaiseOtherTypes(),
                                  matchers.MatchesException(exc_class,
                                                            pattern),
                                  capture))
        our_callable = testcase.Nullary(callable_obj, *args, **kwargs)
        self.assertThat(our_callable, matcher)
        return capture.matchee

    def assertGreater(self, first, second):
        matcher = matchers.GreaterThan(first)
        self.assertThat(second, matcher)

    def assertGreaterEqual(self, first, second):
        matcher = GreaterThanEqual(first)
        self.assertThat(second, matcher)

    def assertRegexpMatches(self, text, pattern):
        matcher = matchers.MatchesRegex(pattern)
        self.assertThat(text, matcher)

    def assertIsSuperAndSubsequence(self, super_seq, sub_seq, msg=None):
        super_seq = list(super_seq)
        sub_seq = list(sub_seq)
        current_tail = super_seq
        for sub_elem in sub_seq:
            try:
                super_index = current_tail.index(sub_elem)
            except ValueError:
                # element not found
                if msg is None:
                    msg = ("%r is not subsequence of %r: "
                           "element %r not found in tail %r"
                           % (sub_seq, super_seq, sub_elem, current_tail))
                self.fail(msg)
            else:
                current_tail = current_tail[super_index + 1:]

    def assertFailuresRegexp(self, exc_class, pattern, callable_obj, *args,
                             **kwargs):
        """Asserts the callable failed with the given exception and message."""
        try:
            with utils.wrap_all_failures():
                callable_obj(*args, **kwargs)
        except exceptions.WrappedFailure as e:
            self.assertThat(e, FailureRegexpMatcher(exc_class, pattern))

    def assertItemsEqual(self, seq1, seq2, msg=None):
        matcher = ItemsEqual(seq1)
        self.assertThat(seq2, matcher)


class MockTestCase(TestCase):

    def setUp(self):
        super(MockTestCase, self).setUp()
        self.master_mock = mock.Mock(name='master_mock')

    def patch(self, target, autospec=True, **kwargs):
        """Patch target and attach it to the master mock."""
        f = self.useFixture(mockpatch.Patch(target,
                                            autospec=autospec, **kwargs))
        mocked = f.mock
        attach_as = kwargs.pop('attach_as', None)
        if attach_as is not None:
            self.master_mock.attach_mock(mocked, attach_as)
        return mocked

    def patchClass(self, module, name, autospec=True, attach_as=None):
        """Patches a modules class.

        This will create a class instance mock (using the provided name to
        find the class in the module) and attach a mock class the master mock
        to be cleaned up on test exit.
        """
        if autospec:
            instance_mock = mock.Mock(spec_set=getattr(module, name))
        else:
            instance_mock = mock.Mock()

        f = self.useFixture(mockpatch.PatchObject(module, name,
                                                  autospec=autospec))
        class_mock = f.mock
        class_mock.return_value = instance_mock

        if attach_as is None:
            attach_class_as = name
            attach_instance_as = name.lower()
        else:
            attach_class_as = attach_as + '_class'
            attach_instance_as = attach_as

        self.master_mock.attach_mock(class_mock, attach_class_as)
        self.master_mock.attach_mock(instance_mock, attach_instance_as)
        return class_mock, instance_mock

    def resetMasterMock(self):
        self.master_mock.reset_mock()


class CapturingLoggingHandler(logging.Handler):
    """A handler that saves record contents for post-test analysis."""

    def __init__(self, level=logging.DEBUG):
        # It seems needed to use the old style of base class calling, we
        # can remove this old style when we only support py3.x
        logging.Handler.__init__(self, level=level)
        self._records = []

    @property
    def counts(self):
        """Returns a dictionary with the number of records at each level."""
        self.acquire()
        try:
            captured = collections.defaultdict(int)
            for r in self._records:
                captured[r.levelno] += 1
            return captured
        finally:
            self.release()

    @property
    def messages(self):
        """Returns a dictionary with list of record messages at each level."""
        self.acquire()
        try:
            captured = collections.defaultdict(list)
            for r in self._records:
                captured[r.levelno].append(r.getMessage())
            return captured
        finally:
            self.release()

    @property
    def exc_infos(self):
        """Returns a list of all the record exc_info tuples captured."""
        self.acquire()
        try:
            captured = []
            for r in self._records:
                if r.exc_info:
                    captured.append(r.exc_info)
            return captured
        finally:
            self.release()

    def emit(self, record):
        self.acquire()
        try:
            self._records.append(record)
        finally:
            self.release()

    def reset(self):
        """Resets *all* internally captured state."""
        self.acquire()
        try:
            self._records = []
        finally:
            self.release()

    def close(self):
        logging.Handler.close(self)
        self.reset()
