File: parameter.py

package info (click to toggle)
python-cyclopts 3.12.0-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 3,288 kB
  • sloc: python: 11,445; makefile: 24
file content (374 lines) | stat: -rw-r--r-- 12,127 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
import inspect
from collections.abc import Iterable
from copy import deepcopy
from typing import (
    Any,
    Callable,
    List,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)

import attrs
from attrs import define, field

import cyclopts._env_var
import cyclopts.utils
from cyclopts._convert import ITERABLE_TYPES
from cyclopts.annotations import is_annotated, is_union, resolve_optional
from cyclopts.field_info import signature_parameters
from cyclopts.group import Group
from cyclopts.token import Token
from cyclopts.utils import (
    default_name_transform,
    frozen,
    optional_to_tuple_converter,
    record_init,
    to_tuple_converter,
)

ITERATIVE_BOOL_IMPLICIT_VALUE = frozenset(
    {
        Iterable[bool],
        Sequence[bool],
        List[bool],
        list[bool],
        Tuple[bool, ...],
        tuple[bool, ...],
    }
)


T = TypeVar("T")

_NEGATIVE_FLAG_TYPES = frozenset([bool, *ITERABLE_TYPES, *ITERATIVE_BOOL_IMPLICIT_VALUE])


def _not_hyphen_validator(instance, attribute, values):
    for value in values:
        if value is not None and value.startswith("-"):
            raise ValueError(f'{attribute.alias} value must NOT start with "-".')


def _negative_converter(default: tuple[str, ...]):
    def converter(value) -> tuple[str, ...]:
        if value is None:
            return default
        else:
            return to_tuple_converter(value)

    return converter


# TODO: Breaking change; all fields after ``name`` should be ``kw_only=True``.
@record_init("_provided_args")
@frozen
class Parameter:
    """Cyclopts configuration for individual function parameters with :obj:`~typing.Annotated`.

    Example usage:

    .. code-block:: python

        from cyclopts import app, Parameter
        from typing import Annotated

        app = App()


        @app.default
        def main(foo: Annotated[int, Parameter(name="bar")]):
            print(foo)


        app()

    .. code-block:: console

        $ my-script 100
        100

        $ my-script --bar 100
        100
    """

    # All attribute docstrings has been moved to ``docs/api.rst`` for greater control with attrs.

    # This can ONLY ever be a Tuple[str, ...]
    # Usually starts with "--" or "-"
    name: Union[None, str, Iterable[str]] = field(
        default=None,
        converter=lambda x: cast(tuple[str, ...], to_tuple_converter(x)),
    )

    converter: Optional[Callable[[Any, Sequence[Token]], Any]] = field(default=None)

    # This can ONLY ever be a Tuple[Callable, ...]
    validator: Union[None, Callable[[Any, Any], Any], Iterable[Callable[[Any, Any], Any]]] = field(
        default=(),
        converter=lambda x: cast(tuple[Callable[[Any, Any], Any], ...], to_tuple_converter(x)),
    )

    # This can ONLY ever be ``None`` or ``Tuple[str, ...]``
    negative: Union[None, str, Iterable[str]] = field(default=None, converter=optional_to_tuple_converter)

    # This can ONLY ever be a Tuple[Union[Group, str], ...]
    group: Union[None, Group, str, Iterable[Union[Group, str]]] = field(
        default=None, converter=to_tuple_converter, hash=False
    )

    parse: bool = field(default=None, converter=attrs.converters.default_if_none(True))

    _show: Optional[bool] = field(default=None, alias="show")

    show_default: Optional[bool] = field(default=None)

    show_choices: bool = field(default=None, converter=attrs.converters.default_if_none(True))

    help: Optional[str] = field(default=None)

    show_env_var: bool = field(default=None, converter=attrs.converters.default_if_none(True))

    # This can ONLY ever be a Tuple[str, ...]
    env_var: Union[None, str, Iterable[str]] = field(
        default=None,
        converter=lambda x: cast(tuple[str, ...], to_tuple_converter(x)),
    )

    env_var_split: Callable = cyclopts._env_var.env_var_split

    # This can ONLY ever be a Tuple[str, ...]
    negative_bool: Union[None, str, Iterable[str]] = field(
        default=None,
        converter=_negative_converter(("no-",)),
        validator=_not_hyphen_validator,
    )

    # This can ONLY ever be a Tuple[str, ...]
    negative_iterable: Union[None, str, Iterable[str]] = field(
        default=None,
        converter=_negative_converter(("empty-",)),
        validator=_not_hyphen_validator,
    )

    required: Optional[bool] = field(default=None)

    allow_leading_hyphen: bool = field(default=False)

    _name_transform: Optional[Callable[[str], str]] = field(
        alias="name_transform",
        default=None,
        kw_only=True,
    )

    # Should not get inherited
    accepts_keys: Optional[bool] = field(default=None)

    # Should not get inherited
    consume_multiple: bool = field(default=None, converter=attrs.converters.default_if_none(False))

    json_dict: Optional[bool] = field(default=None, kw_only=True)

    json_list: Optional[bool] = field(default=None, kw_only=True)

    # Populated by the record_attrs_init_args decorator.
    _provided_args: tuple[str] = field(factory=tuple, init=False, eq=False)

    @property
    def show(self) -> bool:
        return self._show if self._show is not None else self.parse

    @property
    def name_transform(self):
        return self._name_transform if self._name_transform else default_name_transform

    def get_negatives(self, type_) -> tuple[str, ...]:
        if is_union(type_):
            type_ = next(x for x in get_args(type_) if x is not None)

        origin = get_origin(type_)

        if type_ not in _NEGATIVE_FLAG_TYPES:
            if origin:
                if origin not in _NEGATIVE_FLAG_TYPES:
                    return ()
            else:
                return ()

        out, user_negatives = [], []
        if self.negative:
            for negative in self.negative:
                (out if negative.startswith("-") else user_negatives).append(negative)

            if not user_negatives:
                return tuple(out)

        assert isinstance(self.name, tuple)
        for name in self.name:
            if not name.startswith("--"):  # Only provide negation for option-like long flags.
                continue
            name = name[2:]
            name_components = name.split(".")

            if type_ is bool or type_ in ITERATIVE_BOOL_IMPLICIT_VALUE:
                negative_prefixes = self.negative_bool
            else:
                negative_prefixes = self.negative_iterable
            name_prefix = ".".join(name_components[:-1])
            if name_prefix:
                name_prefix += "."
            assert isinstance(negative_prefixes, tuple)
            if self.negative is None:
                for negative_prefix in negative_prefixes:
                    out.append(f"--{name_prefix}{negative_prefix}{name_components[-1]}")
            else:
                for negative in user_negatives:
                    out.append(f"--{name_prefix}{negative}")
        return tuple(out)

    def __repr__(self):
        """Only shows non-default values."""
        content = ", ".join(
            [
                f"{a.alias}={getattr(self, a.name)!r}"
                for a in self.__attrs_attrs__  # pyright: ignore[reportAttributeAccessIssue]
                if a.alias in self._provided_args
            ]
        )
        return f"{type(self).__name__}({content})"

    @classmethod
    def combine(cls, *parameters: Optional["Parameter"]) -> "Parameter":
        """Returns a new Parameter with combined values of all provided ``parameters``.

        Parameters
        ----------
        `*parameters`: Optional[Parameter]
             Parameters who's attributes override ``self`` attributes.
             Ordered from least-to-highest attribute priority.
        """
        kwargs = {}
        filtered = [x for x in parameters if x is not None]
        # In the common case of 0/1 parameters to combine, we can avoid
        # instantiating a new Parameter object.
        if len(filtered) == 1:
            return filtered[0]
        elif not filtered:
            return EMPTY_PARAMETER

        for parameter in filtered:
            for alias in parameter._provided_args:
                kwargs[alias] = getattr(parameter, _parameter_alias_to_name[alias])

        return cls(**kwargs)

    @classmethod
    def default(cls) -> "Parameter":
        """Create a Parameter with all Cyclopts-default values.

        This is different than just :class:`Parameter` because the default
        values will be recorded and override all upstream parameter values.
        """
        return cls(
            **{a.alias: a.default for a in cls.__attrs_attrs__ if a.init}  # pyright: ignore[reportAttributeAccessIssue]
        )

    @classmethod
    def from_annotation(cls, type_: Any, *default_parameters: Optional["Parameter"]) -> tuple[Any, "Parameter"]:
        """Resolve the immediate Parameter from a type hint."""
        if type_ is inspect.Parameter.empty:
            if default_parameters:
                return type_, cls.combine(*default_parameters)
            else:
                return type_, EMPTY_PARAMETER
        else:
            type_, parameters = get_parameters(type_)
            return type_, cls.combine(*default_parameters, *parameters)

    def __call__(self, obj: T) -> T:
        """Decorator interface for annotating a function/class with a :class:`Parameter`.

        Most commonly used for directly configuring a class:

        .. code-block:: python

            @Parameter(...)
            class Foo: ...
        """
        if not hasattr(obj, "__cyclopts__"):
            obj.__cyclopts__ = CycloptsConfig(obj=obj)  # pyright: ignore[reportAttributeAccessIssue]
        elif obj.__cyclopts__.obj != obj:  # pyright: ignore[reportAttributeAccessIssue]
            # Create a copy so that children class Parameter decorators don't impact the parent.
            obj.__cyclopts__ = deepcopy(obj.__cyclopts__)  # pyright: ignore[reportAttributeAccessIssue]
        obj.__cyclopts__.parameters.append(self)  # pyright: ignore[reportAttributeAccessIssue]
        return obj


_parameter_alias_to_name = {
    p.alias: p.name
    for p in Parameter.__attrs_attrs__  # pyright: ignore[reportAttributeAccessIssue]
    if p.init
}

EMPTY_PARAMETER = Parameter()


def validate_command(f: Callable):
    """Validate if a function abides by Cyclopts's rules.

    Raises
    ------
    ValueError
        Function has naming or parameter/signature inconsistencies.
    """
    if (f.__module__ or "").startswith("cyclopts"):  # Speed optimization.
        return
    for field_info in signature_parameters(f).values():
        # Speed optimization: if an object is not annotated, then there's nothing
        # to validate. Checking if there's an annotation is significantly faster
        # than instantiating a cyclopts.Parameter object.
        if not is_annotated(field_info.annotation):
            continue
        _, cparam = Parameter.from_annotation(field_info.annotation)
        if not cparam.parse and field_info.kind is not field_info.KEYWORD_ONLY:
            raise ValueError("Parameter.parse=False must be used with a KEYWORD_ONLY function parameter.")


def get_parameters(hint: T) -> tuple[T, list[Parameter]]:
    """At root level, checks for cyclopts.Parameter annotations.

    Includes checking the ``__cyclopts__`` attribute.

    Returns
    -------
    hint
        Annotation hint with :obj:`Annotated` and :obj:`Optional` resolved.
    list[Parameter]
        List of parameters discovered.
    """
    parameters = []
    hint = resolve_optional(hint)
    if cyclopts_config := getattr(hint, "__cyclopts__", None):
        parameters.extend(cyclopts_config.parameters)
    if is_annotated(hint):
        inner = get_args(hint)
        hint = inner[0]
        parameters.extend(x for x in inner[1:] if isinstance(x, Parameter))

    return hint, parameters


@define
class CycloptsConfig:
    """
    Intended for storing additional data to a ``__cyclopts__`` attribute via decoration.
    """

    obj: Any = None
    parameters: list[Parameter] = field(factory=list, init=False)