1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
|
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder, nodes
from astroid.brain.brain_numpy_utils import (
NUMPY_VERSION_TYPE_HINTS_SUPPORT,
numpy_supports_type_hints,
)
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class NumpyBrainNdarrayTest(unittest.TestCase):
"""Test that calls to numpy functions returning arrays are correctly inferred."""
ndarray_returning_ndarray_methods = (
"__abs__",
"__add__",
"__and__",
"__array__",
"__array_wrap__",
"__copy__",
"__deepcopy__",
"__eq__",
"__floordiv__",
"__ge__",
"__gt__",
"__iadd__",
"__iand__",
"__ifloordiv__",
"__ilshift__",
"__imod__",
"__imul__",
"__invert__",
"__ior__",
"__ipow__",
"__irshift__",
"__isub__",
"__itruediv__",
"__ixor__",
"__le__",
"__lshift__",
"__lt__",
"__matmul__",
"__mod__",
"__mul__",
"__ne__",
"__neg__",
"__or__",
"__pos__",
"__pow__",
"__rshift__",
"__sub__",
"__truediv__",
"__xor__",
"all",
"any",
"argmax",
"argmin",
"argpartition",
"argsort",
"astype",
"byteswap",
"choose",
"clip",
"compress",
"conj",
"conjugate",
"copy",
"cumprod",
"cumsum",
"diagonal",
"dot",
"flatten",
"getfield",
"max",
"mean",
"min",
"newbyteorder",
"prod",
"ptp",
"ravel",
"repeat",
"reshape",
"round",
"searchsorted",
"squeeze",
"std",
"sum",
"swapaxes",
"take",
"trace",
"transpose",
"var",
"view",
)
def _inferred_ndarray_method_call(self, func_name):
node = builder.extract_node(
f"""
import numpy as np
test_array = np.ndarray((2, 2))
test_array.{func_name:s}()
"""
)
return node.infer()
def _inferred_ndarray_attribute(self, attr_name):
node = builder.extract_node(
f"""
import numpy as np
test_array = np.ndarray((2, 2))
test_array.{attr_name:s}
"""
)
return node.infer()
def test_numpy_function_calls_inferred_as_ndarray(self):
"""Test that some calls to numpy functions are inferred as numpy.ndarray."""
licit_array_types = ".ndarray"
for func_ in self.ndarray_returning_ndarray_methods:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_ndarray_method_call(func_))
self.assertTrue(
len(inferred_values) == 1,
msg=f"Too much inferred value for {func_:s}",
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg=f"Illicit type for {func_:s} ({inferred_values[-1].pytype()})",
)
def test_numpy_ndarray_attribute_inferred_as_ndarray(self):
"""Test that some numpy ndarray attributes are inferred as numpy.ndarray."""
licit_array_types = ".ndarray"
for attr_ in ("real", "imag", "shape", "T"):
with self.subTest(typ=attr_):
inferred_values = list(self._inferred_ndarray_attribute(attr_))
self.assertTrue(
len(inferred_values) == 1,
msg=f"Too much inferred value for {attr_:s}",
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg=f"Illicit type for {attr_:s} ({inferred_values[-1].pytype()})",
)
@unittest.skipUnless(
HAS_NUMPY and numpy_supports_type_hints(),
f"This test requires the numpy library with a version above {NUMPY_VERSION_TYPE_HINTS_SUPPORT}",
)
def test_numpy_ndarray_class_support_type_indexing(self):
"""Test that numpy ndarray class can be subscripted (type hints)."""
src = """
import numpy as np
np.ndarray[int]
"""
node = builder.extract_node(src)
cls_node = node.inferred()[0]
self.assertIsInstance(cls_node, nodes.ClassDef)
self.assertEqual(cls_node.name, "ndarray")
|