File: parameterized_utils.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (50 lines) | stat: -rw-r--r-- 1,816 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import json
from itertools import product

from parameterized import param, parameterized

from .data_utils import get_asset_path


def load_params(*paths):
    with open(get_asset_path(*paths), "r") as file:
        return [param(json.loads(line)) for line in file]


def _name_func(func, _, params):
    strs = []
    for arg in params.args:
        if isinstance(arg, tuple):
            strs.append("_".join(str(a) for a in arg))
        else:
            strs.append(str(arg))
    # sanitize the test name
    name = "_".join(strs)
    return parameterized.to_safe_name(f"{func.__name__}_{name}")


def nested_params(*params_set, name_func=_name_func):
    """Generate the cartesian product of the given list of parameters.

    Args:
        params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
            all the parameters have to be specified with the class, only using kwargs.
    """
    flatten = [p for params in params_set for p in params]

    # Parameters to be nested are given as list of plain objects
    if all(not isinstance(p, param) for p in flatten):
        args = list(product(*params_set))
        return parameterized.expand(args, name_func=_name_func)

    # Parameters to be nested are given as list of `parameterized.param`
    if not all(isinstance(p, param) for p in flatten):
        raise TypeError("When using ``parameterized.param``, " "all the parameters have to be of the ``param`` type.")
    if any(p.args for p in flatten):
        raise ValueError(
            "When using ``parameterized.param``, " "all the parameters have to be provided as keyword argument."
        )
    args = [param()]
    for params in params_set:
        args = [param(**x.kwargs, **y.kwargs) for x in args for y in params]
    return parameterized.expand(args)