File: utils.py

package info (click to toggle)
pytest-mypy-plugins 3.2.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 296 kB
  • sloc: python: 1,287; sh: 15; makefile: 3
file content (383 lines) | stat: -rw-r--r-- 12,353 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# Borrowed from Pew.
# See https://github.com/berdario/pew/blob/master/pew/_utils.py#L82
import inspect
import os
import re
import sys
from dataclasses import dataclass
from itertools import zip_longest
from pathlib import Path
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union

import jinja2
import regex
from decorator import contextmanager

_rendering_env = jinja2.Environment()


@contextmanager
def temp_environ() -> Iterator[None]:
    """Allow the ability to set os.environ temporarily"""
    environ = dict(os.environ)
    try:
        yield
    finally:
        os.environ.clear()
        os.environ.update(environ)


@contextmanager
def temp_path() -> Iterator[None]:
    """A context manager which allows the ability to set sys.path temporarily"""
    path = sys.path[:]
    try:
        yield
    finally:
        sys.path = path[:]


@contextmanager
def temp_sys_modules() -> Iterator[None]:
    sys_modules = sys.modules.copy()
    try:
        yield
    finally:
        sys.modules = sys_modules.copy()


def fname_to_module(fpath: Path, root_path: Path) -> Optional[str]:
    try:
        relpath = fpath.relative_to(root_path).with_suffix("")
        return str(relpath).replace(os.sep, ".")
    except ValueError:
        return None


# AssertStringArraysEqual displays special line alignment helper messages if
# the first different line has at least this many characters,
MIN_LINE_LENGTH_FOR_ALIGNMENT = 5


@dataclass
class OutputMatcher:
    fname: str
    lnum: int
    severity: str
    message: str
    regex: bool
    col: Optional[str] = None

    def matches(self, actual: str) -> bool:
        if self.regex:
            pattern = (
                regex.escape(
                    f"{self.fname}:{self.lnum}: {self.severity}: "
                    if self.col is None
                    else f"{self.fname}:{self.lnum}:{self.col}: {self.severity}: "
                )
                + self.message
            )
            return bool(regex.match(pattern, actual))
        else:
            return str(self) == actual

    def __str__(self) -> str:
        if self.col is None:
            return f"{self.fname}:{self.lnum}: {self.severity}: {self.message}"
        else:
            return f"{self.fname}:{self.lnum}:{self.col}: {self.severity}: {self.message}"

    def __format__(self, format_spec: str) -> str:
        return format_spec.format(str(self))

    def __len__(self) -> int:
        return len(str(self))


class TypecheckAssertionError(AssertionError):
    def __init__(self, error_message: Optional[str] = None, lineno: int = 0) -> None:
        self.error_message = error_message or ""
        self.lineno = lineno

    def first_line(self) -> str:
        return self.__class__.__name__ + '(message="Invalid output")'

    def __str__(self) -> str:
        return self.error_message


def remove_common_prefix(lines: List[str]) -> List[str]:
    """Remove common directory prefix from all strings in a.

    This uses a naive string replace; it seems to work well enough. Also
    remove trailing carriage returns.
    """
    cleaned_lines = []
    for line in lines:
        # Ignore spaces at end of line.
        line = re.sub(" +$", "", line)
        cleaned_lines.append(re.sub("\\r$", "", line))
    return cleaned_lines


def _add_aligned_message(s1: str, s2: str, error_message: str) -> str:
    """Align s1 and s2 so that the their first difference is highlighted.

    For example, if s1 is 'foobar' and s2 is 'fobar', display the
    following lines:

      E: foobar
      A: fobar
           ^

    If s1 and s2 are long, only display a fragment of the strings around the
    first difference. If s1 is very short, do nothing.
    """

    # Seeing what went wrong is trivial even without alignment if the expected
    # string is very short. In this case do nothing to simplify output.
    if len(s1) < 4:
        return error_message

    maxw = 72  # Maximum number of characters shown

    error_message += "Alignment of first line difference:\n"

    assert s1 != s2

    trunc = False
    while s1[:30] == s2[:30]:
        s1 = s1[10:]
        s2 = s2[10:]
        trunc = True

    if trunc:
        s1 = "..." + s1
        s2 = "..." + s2

    max_len = max(len(s1), len(s2))
    extra = ""
    if max_len > maxw:
        extra = "..."

    # Write a chunk of both lines, aligned.
    error_message += f"  E: {s1[:maxw]}{extra}\n"
    error_message += f"  A: {s2[:maxw]}{extra}\n"
    # Write an indicator character under the different columns.
    error_message += "     "
    # sys.stderr.write('     ')
    for j in range(min(maxw, max(len(s1), len(s2)))):
        if s1[j : j + 1] != s2[j : j + 1]:
            error_message += "^"
            break
        else:
            error_message += " "
    error_message += "\n"
    return error_message


def remove_empty_lines(lines: List[str]) -> List[str]:
    filtered_lines = []
    for line in lines:
        if line:
            filtered_lines.append(line)
    return filtered_lines


def sorted_by_file_and_line(lines: List[str]) -> List[str]:
    def extract_parts_as_tuple(line: str) -> Tuple[str, int]:
        if len(line.split(":", maxsplit=2)) < 3:
            return "", 0

        fname, line_number, _ = line.split(":", maxsplit=2)
        try:
            return fname, int(line_number)
        except ValueError:
            return "", 0

    return sorted(lines, key=extract_parts_as_tuple)


def assert_expected_matched_actual(expected: List[OutputMatcher], actual: List[str]) -> None:
    """Assert that two string arrays are equal.

    Display any differences in a human-readable form.
    """

    def format_mismatched_line(line: str) -> str:
        return f"  {str(line):<45} (diff)"

    def format_matched_line(line: str, width: int = 100) -> str:
        return f" {line[:width]}..." if len(line) > width else f" {line}"

    def format_error_lines(lines: List[str]) -> str:
        return "\n".join(lines) if lines else "  (empty)"

    expected = sorted(expected, key=lambda om: (om.fname, om.lnum))
    actual = sorted_by_file_and_line(remove_empty_lines(actual))

    actual = remove_common_prefix(actual)

    diff_lines: Dict[int, Tuple[OutputMatcher, str]] = {
        i: (e, a)
        for i, (e, a) in enumerate(zip_longest(expected, actual))
        if e is None or a is None or not e.matches(a)
    }

    if diff_lines:
        first_diff_line = min(diff_lines.keys())
        last_diff_line = max(diff_lines.keys())

        expected_message_lines = []
        actual_message_lines = []

        for i in range(first_diff_line, last_diff_line + 1):
            if i in diff_lines:
                expected_line, actual_line = diff_lines[i]
                if expected_line:
                    expected_message_lines.append(format_mismatched_line(str(expected_line)))
                if actual_line:
                    actual_message_lines.append(format_mismatched_line(actual_line))

            else:
                expected_line, actual_line = expected[i], actual[i]
                actual_message_lines.append(format_matched_line(actual_line))
                expected_message_lines.append(format_matched_line(str(expected_line)))

        first_diff_expected, first_diff_actual = diff_lines[first_diff_line]

        failure_reason = "Output is not expected" if actual and not expected else "Invalid output"

        if actual_message_lines and expected_message_lines:
            if first_diff_line > 0:
                expected_message_lines.insert(0, "  ...")
                actual_message_lines.insert(0, "  ...")

            if last_diff_line < len(actual) - 1 and last_diff_line < len(expected) - 1:
                expected_message_lines.append("  ...")
                actual_message_lines.append("  ...")

        error_message = "Actual:\n{}\nExpected:\n{}\n".format(
            format_error_lines(actual_message_lines), format_error_lines(expected_message_lines)
        )

        if expected_line and expected_line.regex:
            error_message += "The actual output does not match the expected regex."
        elif (
            first_diff_actual is not None
            and first_diff_expected is not None
            and (
                len(first_diff_actual) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
                or len(str(first_diff_expected)) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
            )
        ):
            error_message = _add_aligned_message(str(first_diff_expected), first_diff_actual, error_message)

        raise TypecheckAssertionError(
            error_message=f"{failure_reason}: \n{error_message}",
            lineno=first_diff_expected.lnum if first_diff_expected else 0,
        )


def extract_output_matchers_from_comments(fname: str, input_lines: List[str], regex: bool) -> List[OutputMatcher]:
    """Transform comments such as '# E: message' or
    '# E:3: message' in input.

    The result is a list pf output matchers
    """
    fname = fname.replace(".py", "")
    matchers = []
    for index, line in enumerate(input_lines):
        # The first in the split things isn't a comment
        for possible_err_comment in line.split(" # ")[1:]:
            match = re.search(
                r"^([ENW])(?P<regex>[R])?:((?P<col>\d+):)? (?P<message>.*)$", possible_err_comment.strip()
            )
            if match:
                if match.group(1) == "E":
                    severity = "error"
                elif match.group(1) == "N":
                    severity = "note"
                elif match.group(1) == "W":
                    severity = "warning"
                else:
                    severity = match.group(1)
                col = match.group("col")
                matchers.append(
                    OutputMatcher(
                        fname,
                        index + 1,
                        severity,
                        message=match.group("message"),
                        regex=regex or bool(match.group("regex")),
                        col=col,
                    )
                )
    return matchers


def extract_output_matchers_from_out(out: str, params: Mapping[str, Any], regex: bool) -> List[OutputMatcher]:
    """Transform output lines such as 'function:9: E: message'

    The result is a list of output matchers
    """
    matchers = []
    lines = render_template(out, params).split("\n")
    for line in lines:
        match = re.search(
            r"^(?P<fname>.+):(?P<lnum>\d+): (?P<severity>[A-Za-z]+):((?P<col>\d+):)? (?P<message>.*)$", line.strip()
        )
        if match:
            if match.group("severity") == "E":
                severity = "error"
            elif match.group("severity") == "N":
                severity = "note"
            elif match.group("severity") == "W":
                severity = "warning"
            else:
                severity = match.group("severity")
            col = match.group("col")
            matchers.append(
                OutputMatcher(
                    match.group("fname"),
                    int(match.group("lnum")),
                    severity,
                    message=match.group("message"),
                    regex=regex,
                    col=col,
                )
            )
    return matchers


def render_template(template: str, data: Mapping[str, Any]) -> str:
    if _rendering_env.variable_start_string not in template:
        return template

    t: jinja2.environment.Template = _rendering_env.from_string(template)
    return t.render({k: v if v is not None else "None" for k, v in data.items()})


def get_func_first_lnum(attr: Callable[..., None]) -> Optional[Tuple[int, List[str]]]:
    lines, _ = inspect.getsourcelines(attr)
    for lnum, line in enumerate(lines):
        no_space_line = line.strip()
        if f"def {attr.__name__}" in no_space_line:
            return lnum, lines[lnum + 1 :]
    raise ValueError(f'No line "def {attr.__name__}" found')


@contextmanager
def cd(path: Union[str, Path]) -> Iterator[None]:
    """Context manager to temporarily change working directories"""
    if not path:
        return
    prev_cwd = Path.cwd().as_posix()
    if isinstance(path, Path):
        path = path.as_posix()
    os.chdir(str(path))
    try:
        yield
    finally:
        os.chdir(prev_cwd)