File: recipe_loader.py

package info (click to toggle)
nvidia-cuda-toolkit 12.4.1-3
  • links: PTS, VCS
  • area: non-free
  • in suites: forky, sid
  • size: 18,505,836 kB
  • sloc: ansic: 203,477; cpp: 64,769; python: 34,699; javascript: 22,006; xml: 13,410; makefile: 3,085; sh: 2,343; perl: 352
file content (97 lines) | stat: -rw-r--r-- 3,212 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import importlib
import inspect
import json
import os
import sys

from nsys_recipe import nsys_constants
from nsys_recipe.lib import recipe
from nsys_recipe.log import logger


def get_metadata_dict(recipe_dir, recipe_name):
    json_path = os.path.join(recipe_dir, recipe_name, "metadata.json")
    if not os.path.exists(json_path):
        return None

    with open(json_path) as f:
        return json.load(f)


def is_recipe_subclass(obj):
    return (
        inspect.isclass(obj) and issubclass(obj, recipe.Recipe) and obj != recipe.Recipe
    )


def get_recipe_class_from_module(module, class_name=None):
    if class_name is not None:
        recipe_class = getattr(module, class_name)
        if is_recipe_subclass(recipe_class):
            return recipe_class

        logger.error(f"{class_name} is not a Recipe class.")
        return None

    members = inspect.getmembers(module, is_recipe_subclass)
    if not members:
        logger.error("No Recipe class found.")
        return None

    name, recipe_class = members[0]
    if len(members) > 1:
        logger.warning(
            f"Multiple Recipe classes detected. Using the first class '{name}' as default."
            " To choose a different class, please set the 'class_name' field in the metadata file."
        )

    return recipe_class


def get_recipe_module(search_path, recipe_name, module_name):
    sys.path.append(search_path)
    recipe_module_path = f"{recipe_name}.{module_name}"
    try:
        return importlib.import_module(recipe_module_path)
    except Exception as e:
        logger.error(f"Could not import {recipe_module_path}: {e}")


def get_recipe_class_from_name(recipe_name):
    # Search for the recipe in the following order:
    # 1. 'nsys_recipe' package.
    # 2. Current directory.
    # 3. Directory set by the environment variable NSYS_RECIPE_PATH.
    recipe_search_paths = [nsys_constants.NSYS_RECIPE_RECIPES_PATH, ""]

    recipe_path_env_var = os.getenv("NSYS_RECIPE_PATH")
    if recipe_path_env_var is not None:
        recipe_search_paths.append(recipe_path_env_var)

    for search_path in recipe_search_paths:
        metadata = get_metadata_dict(search_path, recipe_name)
        if metadata is None:
            continue

        module_name = metadata.get("module_name")
        if module_name is None:
            logger.error("'module_name' not found.")
            return None

        module = get_recipe_module(search_path, recipe_name, module_name)
        if module is None:
            return None
        return get_recipe_class_from_module(module, metadata.get("class_name"))

    logger.error("Unknown recipe.")
    return None