From c543f3640d33c2df2a46bcf85dc274e62a548d68 Mon Sep 17 00:00:00 2001
From: David Wolever <david@wolever.net>
Date: Thu, 9 Jun 2022 00:56:11 +0200
Subject: [PATCH] Enable pytest4

---
 parameterized/parameterized.py | 143 +++++++++++++++++++++++++++++----
 parameterized/test.py          |  74 +++++++++++------
 2 files changed, 177 insertions(+), 40 deletions(-)

--- a/parameterized/parameterized.py
+++ b/parameterized/parameterized.py
@@ -19,9 +19,14 @@
     class SkipTest(Exception):
         pass
 
+try:
+    import pytest
+except ImportError:
+    pytest = None
+
 PY3 = sys.version_info[0] == 3
 PY2 = sys.version_info[0] == 2
-
+PYTEST4 = pytest and int(pytest.__version__.split('.')[0]) >= 4
 
 if PY3:
     # Python 3 doesn't have an InstanceType, so just use a dummy type.
@@ -352,6 +357,120 @@
     def __call__(self, test_func):
         self.assert_not_in_testcase_subclass()
 
+        input = self.get_input()
+        wrapper = self._wrap_test_func(test_func, input)
+        wrapper.parameterized_input = input
+        wrapper.parameterized_func = test_func
+        test_func.__name__ = "_parameterized_original_%s" %(test_func.__name__, )
+
+        return wrapper
+
+    def _wrap_test_func(self, test_func, input):
+        """ Wraps a test function so that it will appropriately handle
+            parameterization.
+
+            In the general case, the wrapper will enumerate the input, yielding
+            test cases.
+
+            In the case of pytest4, the wrapper will use
+            ``@pytest.mark.parametrize`` to parameterize the test function. """
+
+        if not input:
+            if not self.skip_on_empty:
+                raise ValueError(
+                    "Parameters iterable is empty (hint: use "
+                    "`parameterized([], skip_on_empty=True)` to skip "
+                    "this test when the input is empty)"
+                )
+            return wraps(test_func)(skip_on_empty_helper)
+
+        if PYTEST4 and detect_runner() == 'pytest':
+            # pytest >= 4 compatibility is... a bit of work. Basically, the
+            # only way (I can find) of implementing parameterized testing with
+            # pytest >= 4 is through the ``@pytest.mark.parameterized``
+            # decorator. This decorator has some strange requirements around
+            # the name and number of arguments to the test function, so this
+            # wrapper works around that by:
+            # 1. Introspecting the original test function to determine the
+            #    names and default values of all arguments.
+            # 2. Creating a new function with the same arguments, but none
+            #    of them are optional::
+            #
+            #        def foo(a, b=42): ...
+            #
+            #    Becomes:
+            #
+            #        def parameterized_pytest_wrapper_foo(a, b): ...
+            #
+            # 3. Merging the ``@parameterized`` parameters with the argument
+            #    default values.
+            # 4. Generating a list of ``pytest.param(...)`` values, and passing
+            #    that into ``@pytest.mark.parameterized``.
+            # Some work also needs to be done to support the documented usage
+            # of ``mock.patch``, which also adds complexity.
+            Undefined = object()
+            test_func_wrapped = test_func
+            test_func_real, mock_patchings = unwrap_mock_patch_func(test_func_wrapped)
+            func_argspec = getargspec(test_func_real)
+
+            func_args = func_argspec.args
+            if mock_patchings:
+                func_args = func_args[:-len(mock_patchings)]
+
+            func_args_no_self = func_args
+            if func_args_no_self[:1] == ["self"]:
+                func_args_no_self = func_args_no_self[1:]
+
+            args_with_default = dict(
+                (arg, Undefined)
+                for arg in func_args_no_self
+            )
+            for (arg, default) in zip(reversed(func_args_no_self), reversed(func_argspec.defaults or [])):
+                args_with_default[arg] = default
+
+            pytest_params = []
+            for i in input:
+                p = dict(args_with_default)
+                for (arg, val) in zip(func_args_no_self, i.args):
+                    p[arg] = val
+                p.update(i.kwargs)
+
+                # Sanity check: all arguments should now be defined
+                if any(v is Undefined for v in p.values()):
+                    raise ValueError(
+                        "When parameterizing function %r: no value for "
+                        "arguments: %s with parameters %r "
+                        "(see: 'no value for arguments' in "
+                        "https://github.com/wolever/parameterized#faq)" %(
+                            test_func,
+                            ", ".join(
+                                repr(arg)
+                                for (arg, val) in p.items()
+                                if val is Undefined
+                            ),
+                            i,
+                        )
+                    )
+
+                pytest_params.append(pytest.param(*[
+                    p.get(arg) for arg in func_args_no_self
+                ]))
+
+            namespace = {
+                "__parameterized_original_test_func": test_func_wrapped,
+            }
+            wrapper_name = "parameterized_pytest_wrapper_%s" %(test_func.__name__, )
+            exec(
+                "def %s(%s, *__args): return __parameterized_original_test_func(%s, *__args)" %(
+                    wrapper_name,
+                    ",".join(func_args),
+                    ",".join(func_args),
+                ),
+                namespace,
+                namespace,
+            )
+            return pytest.mark.parametrize(",".join(func_args_no_self), pytest_params)(namespace[wrapper_name])
+
         @wraps(test_func)
         def wrapper(test_self=None):
             test_cls = test_self and type(test_self)
@@ -366,7 +485,7 @@
                     ) %(test_self, ))
 
             original_doc = wrapper.__doc__
-            for num, args in enumerate(wrapper.parameterized_input):
+            for num, args in enumerate(input):
                 p = param.from_decorator(args)
                 unbound_func, nose_tuple = self.param_as_nose_tuple(test_self, test_func, num, p)
                 try:
@@ -383,21 +502,6 @@
                     if test_self is not None:
                         delattr(test_cls, test_func.__name__)
                     wrapper.__doc__ = original_doc
-
-        input = self.get_input()
-        if not input:
-            if not self.skip_on_empty:
-                raise ValueError(
-                    "Parameters iterable is empty (hint: use "
-                    "`parameterized([], skip_on_empty=True)` to skip "
-                    "this test when the input is empty)"
-                )
-            wrapper = wraps(test_func)(skip_on_empty_helper)
-
-        wrapper.parameterized_input = input
-        wrapper.parameterized_func = test_func
-        test_func.__name__ = "_parameterized_original_%s" %(test_func.__name__, )
-
         return wrapper
 
     def param_as_nose_tuple(self, test_self, func, num, p):
@@ -618,6 +722,11 @@
 
     return decorator
 
+def unwrap_mock_patch_func(f):
+    if not hasattr(f, "patchings"):
+        return (f, [])
+    real_func, patchings = unwrap_mock_patch_func(f.__wrapped__)
+    return (real_func, patchings + f.patchings)
 
 def get_class_name_suffix(params_dict):
     if "name" in params_dict:
--- a/parameterized/test.py
+++ b/parameterized/test.py
@@ -5,8 +5,9 @@
 from unittest import TestCase
 
 from .parameterized import (
-    PY3, PY2, parameterized, param, parameterized_argument_value_pairs,
-    short_repr, detect_runner, parameterized_class, SkipTest,
+    PY3, PY2, PYTEST4, parameterized, param,
+    parameterized_argument_value_pairs, short_repr, detect_runner,
+    parameterized_class, SkipTest,
 )
 
 def assert_equal(*args, **kwds):
@@ -45,6 +46,7 @@
 
 test_params = [
     (42, ),
+    (42, "bar_val"),
     "foo0",
     param("foo1"),
     param("foo2", bar=42),
@@ -55,6 +57,7 @@
     "test_naked_function('foo1', bar=None)",
     "test_naked_function('foo2', bar=42)",
     "test_naked_function(42, bar=None)",
+    "test_naked_function(42, bar='bar_val')",
 ])
 
 @parameterized(test_params)
@@ -68,6 +71,7 @@
         "test_instance_method('foo1', bar=None)",
         "test_instance_method('foo2', bar=42)",
         "test_instance_method(42, bar=None)",
+        "test_instance_method(42, bar='bar_val')",
     ])
 
     @parameterized(test_params)
@@ -100,10 +104,16 @@
             missing_tests.remove("test_setup(%s)" %(self.actual_order, ))
 
 
-def custom_naming_func(custom_tag):
+def custom_naming_func(custom_tag, kw_name):
     def custom_naming_func(testcase_func, param_num, param):
-        return testcase_func.__name__ + ('_%s_name_' % custom_tag) + str(param.args[0])
-
+        return (
+            testcase_func.__name__ +
+            '_%s_name_' %(custom_tag, ) +
+            str(param.args[0]) + 
+            # This ... is a bit messy, to properly handle the values in
+            # `test_params`, but ... it should work.
+            '_%s' %(param.args[1] if len(param.args) > 1 else param.kwargs.get(kw_name), )
+        )
     return custom_naming_func
 
 
@@ -142,19 +152,20 @@
                               mock_fdopen._mock_name, mock_getpid._mock_name))
 
 
-@mock.patch("os.getpid")
-class TestParameterizedExpandWithNoExpand(object):
-    expect("generator", [
-        "test_patch_class_no_expand(42, 51, 'umask', 'getpid')",
-    ])
+if not (PYTEST4 and detect_runner() == 'pytest'):
+    @mock.patch("os.getpid")
+    class TestParameterizedExpandWithNoExpand(object):
+        expect("generator", [
+            "test_patch_class_no_expand(42, 51, 'umask', 'getpid')",
+        ])
 
-    @parameterized([(42, 51)])
-    @mock.patch("os.umask")
-    def test_patch_class_no_expand(self, foo, bar, mock_umask, mock_getpid):
-        missing_tests.remove("test_patch_class_no_expand"
-                             "(%r, %r, %r, %r)" %
-                             (foo, bar, mock_umask._mock_name,
-                              mock_getpid._mock_name))
+        @parameterized([(42, 51)])
+        @mock.patch("os.umask")
+        def test_patch_class_no_expand(self, foo, bar, mock_umask, mock_getpid):
+            missing_tests.remove("test_patch_class_no_expand"
+                                 "(%r, %r, %r, %r)" %
+                                 (foo, bar, mock_umask._mock_name,
+                                  mock_getpid._mock_name))
 
 
 class TestParameterizedExpandWithNoMockPatchForClass(TestCase):
@@ -219,6 +230,7 @@
         "test_on_TestCase('foo1', bar=None)",
         "test_on_TestCase('foo2', bar=42)",
         "test_on_TestCase(42, bar=None)",
+        "test_on_TestCase(42, bar='bar_val')",
     ])
 
     @parameterized.expand(test_params)
@@ -226,20 +238,21 @@
         missing_tests.remove("test_on_TestCase(%r, bar=%r)" %(foo, bar))
 
     expect([
-        "test_on_TestCase2_custom_name_42(42, bar=None)",
-        "test_on_TestCase2_custom_name_foo0('foo0', bar=None)",
-        "test_on_TestCase2_custom_name_foo1('foo1', bar=None)",
-        "test_on_TestCase2_custom_name_foo2('foo2', bar=42)",
+        "test_on_TestCase2_custom_name_42_None(42, bar=None)",
+        "test_on_TestCase2_custom_name_42_bar_val(42, bar='bar_val')",
+        "test_on_TestCase2_custom_name_foo0_None('foo0', bar=None)",
+        "test_on_TestCase2_custom_name_foo1_None('foo1', bar=None)",
+        "test_on_TestCase2_custom_name_foo2_42('foo2', bar=42)",
     ])
 
     @parameterized.expand(test_params,
-                          name_func=custom_naming_func("custom"))
+                          name_func=custom_naming_func("custom", "bar"))
     def test_on_TestCase2(self, foo, bar=None):
         stack = inspect.stack()
         frame = stack[1]
         frame_locals = frame[0].f_locals
         nose_test_method_name = frame_locals['a'][0]._testMethodName
-        expected_name = "test_on_TestCase2_custom_name_" + str(foo)
+        expected_name = "test_on_TestCase2_custom_name_" + str(foo) + "_" + str(bar)
         assert_equal(nose_test_method_name, expected_name,
                      "Test Method name '%s' did not get customized to expected: '%s'" %
                      (nose_test_method_name, expected_name))
@@ -379,6 +392,8 @@
 def test_old_style_classes():
     if PY3:
         raise SkipTest("Py3 doesn't have old-style classes")
+    if PYTEST4 and detect_runner() == 'pytest':
+        raise SkipTest("We're not going to worry about old style classes with pytest 4")
     class OldStyleClass:
         @parameterized(["foo"])
         def parameterized_method(self, param):
@@ -444,6 +459,19 @@
     pass
 
 
+if PYTEST4 and detect_runner() == 'pytest':
+    def test_missing_argument_error():
+        try:
+            @parameterized([
+                (1, ),
+            ])
+            def foo(a, b):
+                pass
+        except ValueError as e:
+            assert_contains(repr(e), "no value for arguments: 'b'")
+        else:
+            raise AssertionError("Expected exception not raised")
+
 cases_over_10 = [(i, i+1) for i in range(11)]
 
 @parameterized(cases_over_10)
