1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Mapping, MutableSequence
from typing import AbstractSet, Any, Generic, Set, TypeVar, cast
from typing_extensions import ParamSpec
from polyfactory.exceptions import ParameterException
class CoverageContainerBase(ABC):
"""Base class for coverage container implementations.
A coverage container is a wrapper providing values for a particular field. Coverage containers return field values and
track a "done" state to indicate that all coverage examples have been generated.
"""
@abstractmethod
def next_value(self) -> Any:
"""Provide the next value"""
...
@abstractmethod
def is_done(self) -> bool:
"""Indicate if this container has provided every coverage example it has"""
...
T = TypeVar("T")
class CoverageContainer(CoverageContainerBase, Generic[T]):
"""A coverage container that wraps a collection of values.
When calling ``next_value()`` a greater number of times than the length of the given collection will cause duplicate
examples to be returned (wraps around).
If there are any coverage containers within the given collection, the values from those containers are essentially merged
into the parent container.
"""
def __init__(self, instances: Iterable[T]) -> None:
self._pos = 0
self._instances = list(instances)
if not self._instances:
msg = "CoverageContainer must have at least one instance"
raise ValueError(msg)
def next_value(self) -> T:
value = self._instances[self._pos % len(self._instances)]
if isinstance(value, CoverageContainerBase):
result = value.next_value()
if value.is_done():
# Only move onto the next instance if the sub-container is done
self._pos += 1
return cast("T", result)
self._pos += 1
return value
def is_done(self) -> bool:
return self._pos >= len(self._instances)
def __repr__(self) -> str:
return f"CoverageContainer(instances={self._instances}, is_done={self.is_done()})"
P = ParamSpec("P")
class CoverageContainerCallable(CoverageContainerBase, Generic[T]):
"""A coverage container that wraps a callable.
When calling ``next_value()`` the wrapped callable is called to provide a value.
"""
def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None:
self._func = func
self._args = args
self._kwargs = kwargs
def next_value(self) -> T:
try:
return self._func(*self._args, **self._kwargs)
except Exception as e:
msg = f"Unsupported type: {self._func!r}\n\nEither extend the providers map or add a factory function for this type."
raise ParameterException(msg) from e
def is_done(self) -> bool:
return True
def _resolve_next(unresolved: Any) -> tuple[Any, bool]: # noqa: C901
if isinstance(unresolved, CoverageContainerBase):
result, done = _resolve_next(unresolved.next_value())
return result, unresolved.is_done() and done
if isinstance(unresolved, Mapping):
result = {}
done_status = True
for key, value in unresolved.items():
val_resolved, val_done = _resolve_next(value)
key_resolved, key_done = _resolve_next(key)
result[key_resolved] = val_resolved
done_status = done_status and val_done and key_done
return result, done_status
if isinstance(unresolved, (tuple, MutableSequence)):
result = []
done_status = True
for value in unresolved:
resolved, done = _resolve_next(value)
result.append(resolved)
done_status = done_status and done
if isinstance(unresolved, tuple):
result = tuple(result)
return result, done_status
if isinstance(unresolved, Set):
result = type(unresolved)()
done_status = True
for value in unresolved:
resolved, done = _resolve_next(value)
result.add(resolved)
done_status = done_status and done
return result, done_status
if issubclass(type(unresolved), AbstractSet):
result = type(unresolved)()
done_status = True
resolved_values = []
for value in unresolved:
resolved, done = _resolve_next(value)
resolved_values.append(resolved)
done_status = done_status and done
return result.union(resolved_values), done_status
return unresolved, True
def resolve_kwargs_coverage(kwargs: dict[str, Any]) -> Iterator[dict[str, Any]]:
done = False
while not done:
resolved, done = _resolve_next(kwargs)
yield resolved
|