File: test_unions.py

package info (click to toggle)
python-omegaconf 2.3.0-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,244 kB
  • sloc: python: 26,413; makefile: 38; sh: 11
file content (100 lines) | stat: -rw-r--r-- 3,240 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from pathlib import Path
from typing import Any, Union

from pytest import mark, param, raises

from omegaconf import OmegaConf, UnionNode, ValidationError
from omegaconf._utils import _get_value
from tests import Color


@mark.parametrize(
    "union_args",
    [
        param((int, float), id="int_float"),
        param((float, bool), id="float_bool"),
        param((bool, str), id="bool_str"),
        param((str, bytes), id="str_bytes"),
        param((bytes, Color), id="bytes_color"),
        param((Color, int), id="color_int"),
    ],
)
@mark.parametrize(
    "input_",
    [
        param(123, id="123"),
        param(10.1, id="10.1"),
        param(b"binary", id="binary"),
        param(True, id="true"),
        param("abc", id="abc"),
        param("RED", id="red_str"),
        param("123", id="123_str"),
        param("10.1", id="10.1_str"),
        param(Color.RED, id="Color.RED"),
        param(Path("hello.txt"), id="path"),
        param(object(), id="object"),
    ],
)
class TestUnionNode:
    def test_creation(self, input_: Any, union_args: Any) -> None:
        ref_type = Union[union_args]  # type: ignore
        legal = type(input_) in union_args
        if legal:
            node = UnionNode(input_, ref_type)
            assert _get_value(node) == input_
        else:
            with raises(ValidationError):
                UnionNode(input_, ref_type)

    def test_set_value(self, input_: Any, union_args: Any) -> None:
        ref_type = Union[union_args]  # type: ignore
        legal = type(input_) in union_args
        node = UnionNode(None, ref_type)
        if legal:
            node._set_value(input_)
            assert _get_value(node) == input_
        else:
            with raises(ValidationError):
                node._set_value(input_)


@mark.parametrize(
    "optional", [param(True, id="optional"), param(False, id="not_optional")]
)
@mark.parametrize(
    "input_",
    [
        param("???", id="missing"),
        param("${interp}", id="interp"),
        param(None, id="none"),
    ],
)
class TestUnionNodeSpecial:
    def test_creation_special(self, input_: Any, optional: bool) -> None:
        if input_ is None and not optional:
            with raises(ValidationError):
                UnionNode(input_, Union[int, str], is_optional=optional)
        else:
            node = UnionNode(input_, Union[int, str], is_optional=optional)
            assert node._value() == input_

    def test_set_value_special(self, input_: Any, optional: bool) -> None:
        node = UnionNode(123, Union[int, str], is_optional=optional)
        if input_ is None and not optional:
            with raises(ValidationError):
                node._set_value(input_)
        else:
            node._set_value(input_)
            assert node._value() == input_


def test_get_parent_container() -> None:
    cfg = OmegaConf.create({"foo": UnionNode(123, Union[int, str]), "bar": "baz"})

    unode = cfg._get_node("foo")
    nested_node = unode._value()  # type: ignore
    any_node = cfg._get_node("bar")

    assert unode._get_parent_container() is cfg  # type: ignore
    assert nested_node._get_parent_container() is cfg
    assert any_node._get_parent_container() is cfg  # type: ignore