# UserString is a wrapper around the native builtin string type.
# UserString instances should behave similar to builtin string objects.

import unittest
from test import string_tests
from test import support

from collections import UserString


class UserStringSubclass(UserString):
    pass

class UserStringSubclass2(UserString):
    pass


class UserStringTest(
    string_tests.StringLikeTest,
    unittest.TestCase
    ):

    type2test = UserString

    # Overwrite the three testing methods, because UserString
    # can't cope with arguments propagated to UserString
    # (and we don't test with subclasses)
    def checkequal(self, result, object, methodname, *args, **kwargs):
        result = self.fixtype(result)
        object = self.fixtype(object)
        # we don't fix the arguments, because UserString can't cope with it
        realresult = getattr(object, methodname)(*args, **kwargs)
        self.assertEqual(
            result,
            realresult
        )

    def checkraises(self, exc, obj, methodname, *args, expected_msg=None):
        obj = self.fixtype(obj)
        # we don't fix the arguments, because UserString can't cope with it
        with self.assertRaises(exc) as cm:
            getattr(obj, methodname)(*args)
        self.assertNotEqual(str(cm.exception), '')
        if expected_msg is not None:
            self.assertEqual(str(cm.exception), expected_msg)

    def checkcall(self, object, methodname, *args):
        object = self.fixtype(object)
        # we don't fix the arguments, because UserString can't cope with it
        getattr(object, methodname)(*args)

    def test_data(self):
        u = UserString("spam")
        self.assertEqual(u.data, "spam")
        self.assertIs(type(u.data), str)
        u = UserString(u)
        self.assertEqual(u.data, "spam")
        self.assertIs(type(u.data), str)
        u = UserString(42)
        self.assertEqual(u.data, "42")
        self.assertIs(type(u.data), str)

    def test_mixed_add(self):
        u = UserString("spam") + "eggs"
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserString)

        u = "spam" + UserString("eggs")
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserString)

        u = UserString("spam") + UserStringSubclass("eggs")
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserString)

        u = UserStringSubclass("spam") + UserString("eggs")
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserStringSubclass)

        u = UserStringSubclass("spam") + UserStringSubclass2("eggs")
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserStringSubclass)

        u2 = UserString("eggs").__radd__(UserString("spam"))
        self.assertEqual(u2, "spameggs")
        self.assertIs(type(u), UserStringSubclass)

        u2 = UserStringSubclass("eggs").__radd__(UserStringSubclass2("spam"))
        self.assertEqual(u2, "spameggs")
        self.assertIs(type(u), UserStringSubclass)

    def test_mixed_iadd(self):
        u = UserString("spam")
        u += "eggs"
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserString)

        u = "spam"
        u += UserString("eggs")
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserString)

        u = UserString("spam")
        u += UserStringSubclass("eggs")
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserString)

        u = UserStringSubclass("spam")
        u += UserString("eggs")
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserStringSubclass)

        u = UserStringSubclass("spam")
        u += UserStringSubclass2("eggs")
        self.assertEqual(u, "spameggs")
        self.assertIs(type(u), UserStringSubclass)

    def test_mixed_cmp(self):
        a = self.fixtype('ab')
        self._assert_cmp(a, 'ab', 0)
        self._assert_cmp(a, 'a', 1)
        self._assert_cmp(a, 'ac', -1)

    def test_rmod(self):
        class ustr2(UserString):
            pass

        class ustr3(ustr2):
            def __rmod__(self, other):
                return super().__rmod__(other)

        fmt2 = ustr2('value is %s')
        str3 = ustr3('TEST')
        self.assertEqual(fmt2 % str3, 'value is TEST')

    def test_encode_default_args(self):
        self.checkequal(b'hello', 'hello', 'encode')
        # Check that encoding defaults to utf-8
        self.checkequal(b'\xf0\xa3\x91\x96', '\U00023456', 'encode')
        # Check that errors defaults to 'strict'
        self.checkraises(UnicodeError, '\ud800', 'encode')

    def test_encode_explicit_none_args(self):
        self.checkequal(b'hello', 'hello', 'encode', None, None)
        # Check that encoding defaults to utf-8
        self.checkequal(b'\xf0\xa3\x91\x96', '\U00023456', 'encode', None, None)
        # Check that errors defaults to 'strict'
        self.checkraises(UnicodeError, '\ud800', 'encode', None, None)

    def test_implementation(self):
        s = UserString('ababahalamaha')
        with support.swap_attr(UserString, '__iter__', None):
            c = s.count('a')
            c2 = s.count(UserString('a'))
        self.assertEqual(c, 7)
        self.assertEqual(c2, 7)
        with (support.swap_attr(UserString, '__iter__', None),
              support.swap_attr(UserString, '__getitem__', None)):
            i = s.index('h')
            i2 = s.index(UserString('h'))
        self.assertEqual(i, 5)
        self.assertEqual(i2, 5)


if __name__ == "__main__":
    unittest.main()
