1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
|
import pytest
import sys
import types
import functools as ft
import threading
import os
import os.path
import importlib
import collections
import ctypes
import re
# list of function names which shall not be
# traced when compression keyword hardening
# test run is executed
non_core_loader_functions = {
'type_legacy_mro',
'load_pickled_data',
'recover_custom_dataset',
#'register_compact_expand',
'_moc_numpy_array_object_lambda',
'fix_lambda_obj_type',
'LoaderManager.load_loader',
'RecoverGroupContainer.convert',
'NoContainer.convert',
'_DictItemContainer.convert',
'ExpandReferenceContainer.convert',
'RecoverGroupContainer.filter',
'ExpandReferenceContainer.filter',
'ReferenceManager.resolve_type',
'RecoverGroupContainer._append'
}
def pytest_addoption(parser):
"""
adds enable_compression keywort to pytest commandline options
for enabling h5py compression keyword hardening testing of
dump functions of hikcle.loaders and hickle core loaders
"""
parser.addoption(
"--enable-compression",
action='store',
nargs='?',
const=6,
type=int,
choices=range(0,10),
help="run all tests with bzip compression enabled. Optionally specify compression level 0-9 (default 6)",
dest="enable_compression"
)
def _get_trace_function(trace_function):
"""
try to get hold of FunctionType object of passed in Method, Function or callable
"""
while not isinstance(trace_function,(types.FunctionType,types.LambdaType,types.BuiltinFunctionType)):
if isinstance(trace_function,(types.MethodType,types.BuiltinMethodType)):
trace_function = getattr(trace_function,'__func__')
continue
if isinstance(trace_function,ft.partial):
trace_function = trace_function.func
continue
return (
getattr(trace_function,'__call__',trace_function)
if callable(trace_function) and not isinstance(trace_function,type) else
None
)
return trace_function
# keyword arguments to yield from compression_kwargs fixture below
# may in future become a list of dictionaries to be yieled for
# running same test with different sets of compression keywords
# (implizit parametrization of tests)
_compression_args = dict(
compression='gzip',
compression_opts=6
)
_test_compression = None
def pytest_configure(config):
"""
make no_compression mark available from pytest.mark.
if not yet activated enable profiling of dump methods and functions
and set compression_level selected on commandline if explicitly
specified.
"""
global _test_compression
config.addinivalue_line(
"markers","no_compression: do not enforce h5py compression hardening testing"
)
if _test_compression is not None:
return
compression_level = config.getoption("enable_compression",default=-1)
if compression_level is None or compression_level < 0:
return
_compression_args['compression_opts'] = compression_level
_test_compression = True
# local handle of no_compression mark
no_compression = pytest.mark.no_compression
@pytest.fixture#(scope='session')
def compression_kwargs(request):
"""
fixture providing the compressoin related keyword arguments
to be passed to any test not marked with no_compression mark
and expecting compression_kwargs as one of its parameters
"""
global _test_compression
yield ( _compression_args if _test_compression else {} )
# list of distinct copyies of LoaderManager.register_class function
# keys are either "<filename>::LoaderManager.register_class" or
# copy of code object executed when LoaderManager.register_class method
# is called
_trace_register_class = {}
# list of dump_functions to be traced with respect to being
# passed the compression related keywords provided through compression_kwargs
# fixture above. In case a call to any of these does not include at least these
# keywords an AssertionError Exception is raised.
_trace_functions = collections.OrderedDict()
# profiling function to be called after execution of _trace_loader_funcs
# below
_trace_profile_call = None
# index of dump_function argument in argument list of LoaderManager.register_class
# method.
_trace_function_argument_default = -1
def _chain_profile_call(frame,event,arg):
global _trace_profile_call
if _trace_profile_call:
next_call = _trace_profile_call(frame,event,arg)
if next_call:
_trace_profile_call = next_call
# argument names which correspond to argument being passed dump_function
# object
_trace_function_arg_names = {'dump_function'}
# the pytest session tracing of proper handling of compression related
# keywords is activated for
_traced_session = None
_loader_file_pattern = re.compile(r'^load_\w+\.py$')
def pytest_sessionstart(session):
"""
pytest hook called at start of session.
- collects all functions exported by hickle.lookup module (for now) and
records inserts "<filename>::<function.__qualname__>" strings into
_trace_functions list for any not listed in above non_core_loader_functions
- collects all dump_functions listed in class_register tables of all
hickle.loaders.load_*.py modules.
"""
global _test_compression,_traced_session,_trace_register_class,_trace_functions,_trace_profile_call
if _test_compression is None:
pytest_configure(session.config)
if not _test_compression:
return None
# extract all loader function from hickle.lookup
lookup_module = sys.modules.get('hickle.lookup',None)
if not isinstance(lookup_module,types.ModuleType):
lookup_module_spec = importlib.util.find_spec("hickle.lookup")
lookup_module = importlib.util.module_from_spec(lookup_module_spec)
lookup_module_spec.loader.exec_module(lookup_module)
register_class = lookup_module.LoaderManager.register_class
register_class_code = register_class.__func__.__code__
trace_function_argument = register_class_code.co_argcount + register_class_code.co_kwonlyargcount
for argid,trace_function in ( (count,varname) for count,varname in enumerate(register_class_code.co_varnames[:(register_class_code.co_argcount + register_class_code.co_kwonlyargcount)]) if varname in _trace_function_arg_names ):
trace_function_argument = argid
break
if trace_function_argument < 0:
return None
_trace_function_argument_default = trace_function_argument
qualname = getattr(register_class,'__qualname__',register_class.__name__)
code_name = qualname if qualname.rsplit('.',1) == register_class_code.co_name else register_class_code.co_name
_trace_register_class.update({"{}::{}".format(register_class_code.co_filename,code_name):trace_function_argument})
for loader_func_name,loader_func in (
(func_name,func)
for name, item in lookup_module.__dict__.items()
if isinstance(item,(types.FunctionType,type))
for func_name,func in (
((name,item),)
if isinstance(item,types.FunctionType) else
(
( meth_name,meth)
for meth_name,meth in item.__dict__.items()
if isinstance(meth,types.FunctionType)
)
)
if func_name[:2] != '__' and func_name[-2:] != '__'
):
loader_func = _get_trace_function(loader_func)
if loader_func is not None and loader_func.__module__ == lookup_module.__name__:
code = loader_func.__code__
qualname = getattr(loader_func,'__qualname__',loader_func.__name__)
if qualname not in non_core_loader_functions:
code_name = qualname if qualname.rsplit('.',1) == code.co_name else code.co_name
_trace_functions["{}::{}".format(code.co_filename,code_name)] = (loader_func.__module__,qualname)
# extract all dump functions from any known loader module
hickle_loaders_path = os.path.join(os.path.dirname(lookup_module.__file__),'loaders')
for loader in os.scandir(hickle_loaders_path):
if not loader.is_file() or _loader_file_pattern.match(loader.name) is None:
continue
loader_module_name = "hickle.loaders.{}".format(loader.name.rsplit('.',1)[0])
loader_module = sys.modules.get(loader_module_name,None)
if loader_module is None:
loader_module_spec = importlib.util.find_spec("hickle.loaders.{}".format(loader.name.rsplit('.',1)[0]))
if loader_module_spec is None:
continue
loader_module = importlib.util.module_from_spec(loader_module_spec)
try:
loader_module_spec.loader.exec_module(loader_module)
except ModuleNotFoundError:
continue
except ImportError:
if sys.version_info[0] > 3 or sys.version_info[1] > 5:
raise
continue
class_register_table = getattr(loader_module,'class_register',())
# trace function has cls/self
for dump_function in ( entry[trace_function_argument-1] for entry in class_register_table ):
dump_function = _get_trace_function(dump_function)
if dump_function is not None:
code = dump_function.__code__
qualname = getattr(dump_function,'__qualname__',dump_function.__name__)
code_name = qualname if qualname.rsplit('.',1) == code.co_name else code.co_name
_trace_functions["{}::{}".format(code.co_filename,code_name)] = (dump_function.__module__,qualname)
# activate compression related profiling
_trace_profile_call = sys.getprofile()
_traced_session = session
sys.setprofile(_trace_loader_funcs)
return None
# List of test functions which are marked by no_compression mark
_never_trace_compression = set()
def traceback_from_frame(frame,stopafter):
"""
helper function used in Python >= 3.7 to beautify traceback
of AssertionError exceptoin thrown by _trace_loader_funcs
"""
tb = types.TracebackType(None,frame,frame.f_lasti,frame.f_lineno)
while frame.f_back is not stopafter.f_back:
frame = frame.f_back
tb = types.TracebackType(tb,frame,frame.f_lasti,frame.f_lineno)
return tb
def pytest_collection_finish(session):
"""
collect all test functions for which compression related keyword monitoring
shall be disabled.
"""
if not sys.getprofile() == _trace_loader_funcs:
return
listed = set()
listemodules = set()
for item in session.items:
func = item.getparent(pytest.Function)
if func not in listed:
listed.add(func)
for marker in func.iter_markers(no_compression.name):
never_trace_code = func.function.__code__
qualname = getattr(func.function,'__qualname__',func.function.__name__)
code_name = qualname if qualname.rsplit('.',1) == never_trace_code.co_name else never_trace_code.co_name
_never_trace_compression.add("{}::{}".format(never_trace_code.co_filename,code_name))
break
def _trace_loader_funcs(frame,event,arg,nochain=False):
"""
does the actuatual profiling with respect to proper passing compression keywords
to dump_functions
"""
global _chain_profile_call, _trace_functions,_never_trace_compression,_trace_register_class,_trace_function_argument_default
try:
if event not in {'call','c_call'}:
return _trace_loader_funcs
# check if LoaderManager.register_class has been called
# if get position of dump_function argument and extract
# code object for dump_function to be registered if not None
code_block = frame.f_code
trace_function_argument = _trace_register_class.get(code_block,None)
if trace_function_argument is not None:
trace_function = frame.f_locals.get(code_block.co_varnames[trace_function_argument],None)
load_function = frame.f_locals.get(code_block.co_varnames[trace_function_argument+1],None)
if load_function is not None:
load_function = _get_trace_function(load_function)
_trace_functions.pop("{}::{}".format(load_function.__code__.co_filename,load_function.__code__.co_name),None)
if trace_function is None:
return _trace_loader_funcs
trace_function = _get_trace_function(trace_function)
if trace_function is None:
return _trace_loader_funcs
trace_function_code = getattr(trace_function,'__code__',None)
if trace_function_code is not None:
# store code object corresponding to dump_function in _trace_functions list
# if not yet present there.
qualname = getattr(trace_function,'__qualname__',trace_function.__name__)
code_name = qualname if qualname.rsplit('.',1) == trace_function_code.co_name else trace_function_code.co_name
trace_function_code_name = "{}::{}".format(trace_function_code.co_filename,code_name)
if (
trace_function_code_name not in _trace_register_class and
(
trace_function_code_name not in _trace_functions or
trace_function_code not in _trace_functions
)
):
trace_function_spec = (trace_function.__module__,qualname)
_trace_functions[trace_function_code] = trace_function_spec
_trace_functions[trace_function_code_name] = trace_function_spec
return _trace_loader_funcs
# estimate qualname from local variable stored in frame.f_local corresponding
# to frame.f_code.co_varnames[0] if any.
object_self_name = frame.f_code.co_varnames[:1]
if object_self_name:
self = frame.f_locals.get(object_self_name[0],None)
module = getattr(self,'__module__','')
if isinstance(module,str) and module.split('.',1)[0] == 'hickle' and isinstance(getattr(self,'__name__',None),str):
method = getattr(self,frame.f_code.co_name,None)
if method is not None and getattr(method,'__code__',None) == frame.f_code:
code_name = "{}::{}.{}".format(
frame.f_code.co_filename,
getattr(self,'__qualname__',self.__name__),
frame.f_code.co_name
)
else:
code_name = "{}::{}".format(frame.f_code.co_filename,frame.f_code.co_name)
else:
code_name = "{}::{}".format(frame.f_code.co_filename,frame.f_code.co_name)
else:
code_name = "{}::{}".format(frame.f_code.co_filename,frame.f_code.co_name)
# check if frame could encode a clall to a new incarnation of LoaderManager.register_class
# method. Add its code object to the list of known incarnations and rerun above code
if code_block.co_name == 'register_class':
trace_function_argument = _trace_register_class.get(code_name,None)
if trace_function_argument is not None:
_trace_register_class[code_block] = trace_function_argument
return _trace_loader_funcs(frame,event,arg,True)
if (
code_block.co_filename.rsplit('/',2) == ['hickle','lookup.py'] and
code_block.co_varnames > trace_function_argument and
code_block.co_varnames[_trace_function_argument_default] in _trace_function_arg_names
):
_trace_register_class[code_name] = _trace_function_argument_default
_trace_register_class[code_block] = _trace_function_argument_default
return _trace_loader_funcs(frame,event,arg,True)
# frame encodes a call to any other function or method.
# If the function or method is listed in _trace_functions list check
# if it received the appropriate set of compresson related keywords
function_object_spec = _trace_functions.get(frame.f_code,None)
if function_object_spec is None:
function_object_spec = _trace_functions.get(code_name,None)
if function_object_spec is None:
return _trace_loader_funcs
_trace_functions[frame.f_code] = function_object_spec
baseargs = (
(arg,frame.f_locals[arg])
for arg in frame.f_code.co_varnames[:(frame.f_code.co_argcount + frame.f_code.co_kwonlyargcount)]
)
kwargs = frame.f_locals.get('kwargs',None)
if kwargs is not None:
fullargs = ( (name,arg) for arglist in (kwargs.items(),baseargs) for name,arg in arglist )
else:
fullargs = baseargs
seen_compression_args = set()
for arg,value in fullargs:
if arg in seen_compression_args:
continue
if _compression_args.get(arg,None) is not None:
seen_compression_args.add(arg)
if len(seen_compression_args) == len(_compression_args):
return _trace_loader_funcs
# keywords not passed or filtered prematurely.
# walk the stack until reaching executed test function.
# if test function is not marked with no_compression raise
# AssertionError stating that dump_function did not
# receive expected compression keywords defined above
# For Python <= 3.6 collect all functions called between current
# frame and frame of executed test function. For Python > 3.6 use
# above traceback_from_frame function to build traceack showing appropriate
# callstack and context excluding this function to ensure AssertionError
# exception appears thrown on behlaf of function triggering call encoded by
# passed frame
function_object_spec = _trace_functions[frame.f_code]
if _traced_session is not None:
test_list = {
"{}::{}".format(
item.function.__code__.co_filename,
getattr(item.function,'__qualname__',
item.function.__name__)
):item
for item in _traced_session.items
}
collect_call_tree = []
next_frame = frame
while next_frame is not None:
object_self_name = frame.f_code.co_varnames[:1]
if object_self_name:
self = frame.f_locals.get(object_self_name[0])
module = getattr(self,'__module__','')
if (
isinstance(module,str) and
module.split('.',1)[0] == 'hickle' and
isinstance(getattr(self,'__name__',None),str)
):
method = getattr(self,frame.f_code.co_name,None)
if method is not None and getattr(method,'__code__',None) == frame.f_code:
frame_name = "{}::{}".format(
next_frame.f_code.co_filename,
getattr(method,'__qualname__',method.__name__)
)
else:
frame_name = "{}::{}".format(next_frame.f_code.co_filename,next_frame.f_code.co_name)
else:
frame_name = "{}::{}".format(next_frame.f_code.co_filename,next_frame.f_code.co_name)
else:
frame_name = "{}::{}".format(next_frame.f_code.co_filename,next_frame.f_code.co_name)
if frame_name in _never_trace_compression:
return _trace_loader_funcs
in_test = test_list.get(frame_name,None)
collect_call_tree.append((next_frame.f_code.co_filename,frame_name,next_frame.f_lineno))
if in_test is not None:
try:
tb = traceback_from_frame(frame,next_frame)
except TypeError:
pass
else:
raise AssertionError(
"'{}': compression_kwargs lost in call".format("::".join(function_object_spec))
).with_traceback(tb)
raise AssertionError(
"'{}': compression_kwargs lost in call:\n\t{}\n".format(
"::".join(function_object_spec),
"\n\t".join("{}::{} ({})".format(*call) for call in collect_call_tree[:0:-1])
)
)
next_frame = next_frame.f_back
except AssertionError as ae:
# check that first entry in traceback does not refer to this function
if ae.__traceback__.tb_frame.f_code == _trace_loader_funcs.__code__:
ae.__traceback__ = ae.__traceback__.tb_next
raise
#except Exception as e:
# import traceback;traceback.print_exc()
# import pdb;pdb.set_trace()
finally:
if not nochain:
_chain_profile_call(frame,event,arg)
def pytest_sessionfinish(session):
sys.setprofile(_trace_profile_call)
|