File: typed_env.py

package info (click to toggle)
python-plumbum 1.9.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,300 kB
  • sloc: python: 10,016; makefile: 130; sh: 8
file content (172 lines) | stat: -rw-r--r-- 4,949 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from __future__ import annotations

import inspect
import os
from collections.abc import MutableMapping

NO_DEFAULT = object()


# must not inherit from AttributeError, so not to mess with python's attribute-lookup flow
class EnvironmentVariableError(KeyError):
    pass


class TypedEnv(MutableMapping):
    """
    This object can be used in 'exploratory' mode:

        nv = TypedEnv()
        print(nv.HOME)

    It can also be used as a parser and validator of environment variables:

    class MyEnv(TypedEnv):
        username = TypedEnv.Str("USER", default='anonymous')
        path = TypedEnv.CSV("PATH", separator=":")
        tmp = TypedEnv.Str("TMP TEMP".split())  # support 'fallback' var-names

    nv = MyEnv()

    print(nv.username)

    for p in nv.path:
        print(p)

    try:
        print(p.tmp)
    except EnvironmentVariableError:
        print("TMP/TEMP is not defined")
    else:
        assert False
    """

    __slots__ = ["_env", "_defined_keys"]

    class _BaseVar:
        def __init__(self, name, default=NO_DEFAULT):
            self.names = tuple(name) if isinstance(name, (tuple, list)) else (name,)
            self.name = self.names[0]
            self.default = default

        def convert(self, value):  # pylint:disable=no-self-use
            return value

        def __get__(self, instance, owner):
            if not instance:
                return self
            try:
                return self.convert(instance._raw_get(*self.names))
            except EnvironmentVariableError:
                if self.default is NO_DEFAULT:
                    raise
                return self.default

        def __set__(self, instance, value):
            instance[self.name] = value

    class Str(_BaseVar):
        pass

    class Bool(_BaseVar):
        """
        Converts 'yes|true|1|no|false|0' to the appropriate boolean value.
        Case-insensitive. Throws a ``ValueError`` for any other value.
        """

        def convert(self, value):
            value = value.lower()
            if value not in {"yes", "no", "true", "false", "1", "0"}:
                raise ValueError(f"Unrecognized boolean value: {value!r}")
            return value in {"yes", "true", "1"}

        def __set__(self, instance, value):
            instance[self.name] = "yes" if value else "no"

    class Int(_BaseVar):
        convert = staticmethod(int)

    class Float(_BaseVar):
        convert = staticmethod(float)

    class CSV(_BaseVar):
        """
        Comma-separated-strings get split using the ``separator`` (',' by default) into
        a list of objects of type ``type`` (``str`` by default).
        """

        def __init__(self, name, default=NO_DEFAULT, type=str, separator=","):  # pylint:disable=redefined-builtin
            super().__init__(name, default=default)
            self.type = type
            self.separator = separator

        def __set__(self, instance, value):
            instance[self.name] = self.separator.join(map(str, value))

        def convert(self, value):
            return [self.type(v.strip()) for v in value.split(self.separator)]

    # =========

    def __init__(self, env=None):
        if env is None:
            env = os.environ
        self._env = env
        self._defined_keys = {
            k
            for (k, v) in inspect.getmembers(self.__class__)
            if isinstance(v, self._BaseVar)
        }

    def __iter__(self):
        return iter(dir(self))

    def __len__(self):
        return len(self._env)

    def __delitem__(self, name):
        del self._env[name]

    def __setitem__(self, name, value):
        self._env[name] = str(value)

    def _raw_get(self, *key_names):
        for key in key_names:
            value = self._env.get(key, NO_DEFAULT)
            if value is not NO_DEFAULT:
                return value
        raise EnvironmentVariableError(key_names[0])

    def __contains__(self, key):
        try:
            self._raw_get(key)
        except EnvironmentVariableError:
            return False
        return True

    def __getattr__(self, name):
        # if we're here then there was no descriptor defined
        try:
            return self._raw_get(name)
        except EnvironmentVariableError:
            raise AttributeError(
                f"{self.__class__} has no attribute {name!r}"
            ) from None

    def __getitem__(self, key):
        return getattr(self, key)  # delegate through the descriptors

    def get(self, key, default=None):
        try:
            return self[key]
        except EnvironmentVariableError:
            return default

    def __dir__(self):
        if self._defined_keys:
            # return only defined
            return sorted(self._defined_keys)
        # return whatever is in the environment (for convenience)
        members = set(self._env.keys())
        members.update(dir(self.__class__))
        return sorted(members)