1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
"""Utility for checking Python module imports triggered by any code snippet.
This module was developed to monitor the import footprint of the ase CLI
command: The CLI command can become unnecessarily slow and unresponsive
if too many modules are imported even before the CLI is launched or
it is known what modules will be actually needed.
See https://gitlab.com/ase/ase/-/issues/1124 for more discussion.
The utility here is general, so it can be used for checking and
monitoring other code snippets too.
"""
import json
import os
import re
import sys
from pprint import pprint
from subprocess import run
from typing import List, Optional, Set
def exec_and_check_modules(expression: str) -> Set[str]:
"""Return modules loaded by the execution of a Python expression.
Parameters
----------
expression
Python expression
Returns
-------
Set of module names.
"""
# Take null outside command to avoid
# `import os` before expression
null = os.devnull
command = (
'import sys;'
f" stdout = sys.stdout; sys.stdout = open({null!r}, 'w');"
f' {expression};'
' sys.stdout = stdout;'
' modules = list(sys.modules);'
' import json; print(json.dumps(modules))'
)
proc = run(
[sys.executable, '-c', command],
capture_output=True,
universal_newlines=True,
check=True,
)
return set(json.loads(proc.stdout))
def check_imports(
expression: str,
*,
forbidden_modules: List[str] = [],
max_module_count: Optional[int] = None,
max_nonstdlib_module_count: Optional[int] = None,
do_print: bool = False,
) -> None:
"""Check modules imported by the execution of a Python expression.
Parameters
----------
expression
Python expression
forbidden_modules
Throws an error if any module in this list was loaded.
max_module_count
Throws an error if the number of modules exceeds this value.
max_nonstdlib_module_count
Throws an error if the number of non-stdlib modules exceeds this value.
do_print:
Print loaded modules if set.
"""
modules = exec_and_check_modules(expression)
if do_print:
print('all modules:')
pprint(sorted(modules))
for module_pattern in forbidden_modules:
r = re.compile(module_pattern)
for module in modules:
assert not r.fullmatch(module), f'{module} was imported'
if max_nonstdlib_module_count is not None:
assert sys.version_info >= (3, 10), 'Python 3.10+ required'
nonstdlib_modules = []
for module in modules:
if (
module.split('.')[0] in sys.stdlib_module_names # type: ignore[attr-defined]
):
continue
nonstdlib_modules.append(module)
if do_print:
print('nonstdlib modules:')
pprint(sorted(nonstdlib_modules))
module_count = len(nonstdlib_modules)
assert module_count <= max_nonstdlib_module_count, (
'too many nonstdlib modules loaded:'
f' {module_count}/{max_nonstdlib_module_count}'
)
if max_module_count is not None:
module_count = len(modules)
assert module_count <= max_module_count, (
f'too many modules loaded: {module_count}/{max_module_count}'
)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('expression')
parser.add_argument('--forbidden_modules', nargs='+', default=[])
parser.add_argument('--max_module_count', type=int, default=None)
parser.add_argument('--max_nonstdlib_module_count', type=int, default=None)
parser.add_argument('--do_print', action='store_true')
args = parser.parse_args()
check_imports(**vars(args))
|