1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
|
import copy
from typing import Any, Dict, List
from pytest import fixture, lazy_fixture, mark, param
from omegaconf import OmegaConf
from omegaconf._utils import ValueKind, _is_missing_literal, get_value_kind
def build_dict(
d: Dict[str, Any], depth: int, width: int, leaf_value: Any = 1
) -> Dict[str, Any]:
if depth == 0:
for i in range(width):
d[f"key_{i}"] = leaf_value
else:
for i in range(width):
c: Dict[str, Any] = {}
d[f"key_{i}"] = c
build_dict(c, depth - 1, width, leaf_value)
return d
def build_list(length: int, val: Any = 1) -> List[int]:
return [val] * length
@fixture(scope="module")
def large_dict() -> Any:
return build_dict({}, 11, 2)
@fixture(scope="module")
def small_dict() -> Any:
return build_dict({}, 5, 2)
@fixture(scope="module")
def dict_with_list_leaf() -> Any:
return build_dict({}, 5, 2, leaf_value=[1, 2])
@fixture(scope="module")
def small_dict_config(small_dict: Any) -> Any:
return OmegaConf.create(small_dict)
@fixture(scope="module")
def dict_config_with_list_leaf(dict_with_list_leaf: Any) -> Any:
return OmegaConf.create(dict_with_list_leaf)
@fixture(scope="module")
def large_dict_config(large_dict: Any) -> Any:
return OmegaConf.create(large_dict)
@fixture(scope="module")
def merge_data(small_dict: Any) -> Any:
return [OmegaConf.create(small_dict) for _ in range(5)]
@fixture(scope="module")
def small_list() -> Any:
return build_list(3, 1)
@fixture(scope="module")
def small_listconfig(small_list: Any) -> Any:
return OmegaConf.create(small_list)
@mark.parametrize(
"data",
[
lazy_fixture("small_dict"), # type: ignore
lazy_fixture("large_dict"), # type: ignore
lazy_fixture("small_dict_config"), # type: ignore
lazy_fixture("large_dict_config"), # type: ignore
lazy_fixture("dict_config_with_list_leaf"), # type: ignore
],
)
def test_omegaconf_create(data: Any, benchmark: Any) -> None:
benchmark(OmegaConf.create, data)
@mark.parametrize(
"merge_function",
[
param(OmegaConf.merge, id="merge"),
param(OmegaConf.unsafe_merge, id="unsafe_merge"),
],
)
def test_omegaconf_merge(merge_function: Any, merge_data: Any, benchmark: Any) -> None:
benchmark(merge_function, merge_data)
@mark.parametrize(
"lst",
[
lazy_fixture("small_list"), # type: ignore
lazy_fixture("small_listconfig"), # type: ignore
],
)
def test_list_in(lst: List[Any], benchmark: Any) -> None:
benchmark(lambda seq, val: val in seq, lst, 10)
@mark.parametrize(
"lst",
[
lazy_fixture("small_list"), # type: ignore
lazy_fixture("small_listconfig"), # type: ignore
],
)
def test_list_iter(lst: List[Any], benchmark: Any) -> None:
def iterate(seq: Any) -> None:
for _ in seq:
pass
benchmark(iterate, lst)
@mark.parametrize(
"strict_interpolation_validation",
[True, False],
)
@mark.parametrize(
("value", "expected"),
[
("simple", ValueKind.VALUE),
("${a}", ValueKind.INTERPOLATION),
("${a:b,c,d}", ValueKind.INTERPOLATION),
("${${b}}", ValueKind.INTERPOLATION),
("${a:${b}}", ValueKind.INTERPOLATION),
("${long_string1xxx}_${long_string2xxx:${key}}", ValueKind.INTERPOLATION),
(
"${a[1].a[1].a[1].a[1].a[1].a[1].a[1].a[1].a[1].a[1].a[1]}",
ValueKind.INTERPOLATION,
),
],
)
def test_get_value_kind(
strict_interpolation_validation: bool, value: Any, expected: Any, benchmark: Any
) -> None:
assert benchmark(get_value_kind, value, strict_interpolation_validation) == expected
def test_is_missing_literal(benchmark: Any) -> None:
assert benchmark(_is_missing_literal, "???")
@mark.parametrize("force_add", [False, True])
@mark.parametrize("key", ["a", "a.a.a.a.a.a.a.a.a.a.a"])
def test_update_force_add(
large_dict_config: Any, key: str, force_add: bool, benchmark: Any
) -> None:
cfg = copy.deepcopy(large_dict_config) # this test modifies the config
if force_add:
OmegaConf.set_struct(cfg, True)
def recursive_is_struct(node: Any) -> None:
if OmegaConf.is_config(node):
OmegaConf.is_struct(node)
for val in node.values():
recursive_is_struct(val)
recursive_is_struct(cfg)
benchmark(OmegaConf.update, cfg, key, 10, force_add=force_add)
|