1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
|
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import importlib
import inspect
import json
import os
import sys
from nsys_recipe import nsys_constants
from nsys_recipe.lib import recipe
from nsys_recipe.log import logger
def get_metadata_dict(recipe_dir, recipe_name):
json_path = os.path.join(recipe_dir, recipe_name, "metadata.json")
if not os.path.exists(json_path):
return None
with open(json_path) as f:
return json.load(f)
def is_recipe_subclass(obj):
return (
inspect.isclass(obj) and issubclass(obj, recipe.Recipe) and obj != recipe.Recipe
)
def get_recipe_class_from_module(module, class_name=None):
if class_name is not None:
recipe_class = getattr(module, class_name)
if is_recipe_subclass(recipe_class):
return recipe_class
logger.error(f"{class_name} is not a Recipe class.")
return None
members = inspect.getmembers(module, is_recipe_subclass)
if not members:
logger.error("No Recipe class found.")
return None
name, recipe_class = members[0]
if len(members) > 1:
logger.warning(
f"Multiple Recipe classes detected. Using the first class '{name}' as default."
" To choose a different class, please set the 'class_name' field in the metadata file."
)
return recipe_class
def get_recipe_module(search_path, recipe_name, module_name):
sys.path.append(search_path)
recipe_module_path = f"{recipe_name}.{module_name}"
try:
return importlib.import_module(recipe_module_path)
except Exception as e:
logger.error(f"Could not import {recipe_module_path}: {e}")
def get_recipe_class_from_name(recipe_name):
# Search for the recipe in the following order:
# 1. 'nsys_recipe' package.
# 2. Current directory.
# 3. Directory set by the environment variable NSYS_RECIPE_PATH.
recipe_search_paths = [nsys_constants.NSYS_RECIPE_RECIPES_PATH, ""]
recipe_path_env_var = os.getenv("NSYS_RECIPE_PATH")
if recipe_path_env_var is not None:
recipe_search_paths.append(recipe_path_env_var)
for search_path in recipe_search_paths:
metadata = get_metadata_dict(search_path, recipe_name)
if metadata is None:
continue
module_name = metadata.get("module_name")
if module_name is None:
logger.error("'module_name' not found.")
return None
module = get_recipe_module(search_path, recipe_name, module_name)
if module is None:
return None
return get_recipe_class_from_module(module, metadata.get("class_name"))
logger.error("Unknown recipe.")
return None
|