File: checkimports.py

package info (click to toggle)
python-ase 3.26.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 15,484 kB
  • sloc: python: 148,112; xml: 2,728; makefile: 110; javascript: 47
file content (127 lines) | stat: -rw-r--r-- 3,867 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Utility for checking Python module imports triggered by any code snippet.

This module was developed to monitor the import footprint of the ase CLI
command: The CLI command can become unnecessarily slow and unresponsive
if too many modules are imported even before the CLI is launched or
it is known what modules will be actually needed.
See https://gitlab.com/ase/ase/-/issues/1124 for more discussion.

The utility here is general, so it can be used for checking and
monitoring other code snippets too.
"""

import json
import os
import re
import sys
from pprint import pprint
from subprocess import run
from typing import List, Optional, Set


def exec_and_check_modules(expression: str) -> Set[str]:
    """Return modules loaded by the execution of a Python expression.

    Parameters
    ----------
    expression
        Python expression

    Returns
    -------
    Set of module names.
    """
    # Take null outside command to avoid
    # `import os` before expression
    null = os.devnull
    command = (
        'import sys;'
        f" stdout = sys.stdout; sys.stdout = open({null!r}, 'w');"
        f' {expression};'
        ' sys.stdout = stdout;'
        ' modules = list(sys.modules);'
        ' import json; print(json.dumps(modules))'
    )
    proc = run(
        [sys.executable, '-c', command],
        capture_output=True,
        universal_newlines=True,
        check=True,
    )
    return set(json.loads(proc.stdout))


def check_imports(
    expression: str,
    *,
    forbidden_modules: List[str] = [],
    max_module_count: Optional[int] = None,
    max_nonstdlib_module_count: Optional[int] = None,
    do_print: bool = False,
) -> None:
    """Check modules imported by the execution of a Python expression.

    Parameters
    ----------
    expression
        Python expression
    forbidden_modules
        Throws an error if any module in this list was loaded.
    max_module_count
        Throws an error if the number of modules exceeds this value.
    max_nonstdlib_module_count
        Throws an error if the number of non-stdlib modules exceeds this value.
    do_print:
        Print loaded modules if set.
    """
    modules = exec_and_check_modules(expression)

    if do_print:
        print('all modules:')
        pprint(sorted(modules))

    for module_pattern in forbidden_modules:
        r = re.compile(module_pattern)
        for module in modules:
            assert not r.fullmatch(module), f'{module} was imported'

    if max_nonstdlib_module_count is not None:
        assert sys.version_info >= (3, 10), 'Python 3.10+ required'

        nonstdlib_modules = []
        for module in modules:
            if (
                module.split('.')[0] in sys.stdlib_module_names  # type: ignore[attr-defined]
            ):
                continue
            nonstdlib_modules.append(module)

        if do_print:
            print('nonstdlib modules:')
            pprint(sorted(nonstdlib_modules))

        module_count = len(nonstdlib_modules)
        assert module_count <= max_nonstdlib_module_count, (
            'too many nonstdlib modules loaded:'
            f' {module_count}/{max_nonstdlib_module_count}'
        )

    if max_module_count is not None:
        module_count = len(modules)
        assert module_count <= max_module_count, (
            f'too many modules loaded: {module_count}/{max_module_count}'
        )


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('expression')
    parser.add_argument('--forbidden_modules', nargs='+', default=[])
    parser.add_argument('--max_module_count', type=int, default=None)
    parser.add_argument('--max_nonstdlib_module_count', type=int, default=None)
    parser.add_argument('--do_print', action='store_true')
    args = parser.parse_args()

    check_imports(**vars(args))