1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
|
import inspect
import json
from typing import List, Optional, Union, get_args, get_origin
import pytest
import huggingface_hub.inference._generated.types as types
from huggingface_hub.inference._generated.types import AutomaticSpeechRecognitionParameters
from huggingface_hub.inference._generated.types.base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class DummyType(BaseInferenceType):
foo: int
bar: str
@dataclass_with_extra
class DummyNestedType(BaseInferenceType):
item: DummyType
items: List[DummyType] # works both with List and list
maybe_items: Optional[list[DummyType]] = None
DUMMY_AS_DICT = {"foo": 42, "bar": "baz"}
DUMMY_AS_STR = json.dumps(DUMMY_AS_DICT)
DUMMY_AS_BYTES = DUMMY_AS_STR.encode()
DUMMY_AS_LIST = [DUMMY_AS_DICT]
def test_parse_from_bytes():
instance = DummyType.parse_obj(DUMMY_AS_BYTES)
assert instance.foo == 42
assert instance.bar == "baz"
def test_parse_from_str():
instance = DummyType.parse_obj(DUMMY_AS_STR)
assert instance.foo == 42
assert instance.bar == "baz"
def test_parse_from_dict():
instance = DummyType.parse_obj(DUMMY_AS_DICT)
assert instance.foo == 42
assert instance.bar == "baz"
def test_parse_from_list():
instances = DummyType.parse_obj(DUMMY_AS_LIST)
assert len(instances) == 1
assert instances[0].foo == 42
assert instances[0].bar == "baz"
def test_parse_from_unexpected_type():
with pytest.raises(ValueError):
DummyType.parse_obj(42)
def test_parse_as_instance_success():
instance = DummyType.parse_obj_as_instance(DUMMY_AS_DICT)
assert isinstance(instance, DummyType)
def test_parse_as_instance_failure():
with pytest.raises(ValueError):
DummyType.parse_obj_as_instance(DUMMY_AS_LIST)
def test_parse_as_list_success():
instances = DummyType.parse_obj_as_list(DUMMY_AS_LIST)
assert len(instances) == 1
def test_parse_as_list_failure():
with pytest.raises(ValueError):
DummyType.parse_obj_as_list(DUMMY_AS_DICT)
def test_parse_nested_class():
instance = DummyNestedType.parse_obj(
{
"item": DUMMY_AS_DICT,
"items": DUMMY_AS_LIST,
"maybe_items": None,
}
)
assert instance.item.foo == 42
assert instance.item.bar == "baz"
assert len(instance.items) == 1
assert instance.items[0].foo == 42
assert instance.items[0].bar == "baz"
assert instance.maybe_items is None
def test_all_fields_are_optional():
# all fields are optional => silently accept None if server returns less data than expected
instance = DummyNestedType.parse_obj({"maybe_items": [{}, DUMMY_AS_BYTES]})
assert isinstance(instance, DummyNestedType)
assert instance.item is None
assert instance.items is None
assert len(instance.maybe_items) == 2
assert instance.maybe_items[0].foo is None
assert instance.maybe_items[0].bar is None
assert instance.maybe_items[1].foo == 42
assert instance.maybe_items[1].bar == "baz"
def test_normalize_keys():
# all fields are normalized in the dataclasses (by convention)
# if server response uses different keys, they will be normalized
instance = DummyNestedType.parse_obj({"ItEm": DUMMY_AS_DICT, "Maybe-Items": [DUMMY_AS_DICT]})
assert isinstance(instance.item, DummyType)
assert isinstance(instance.maybe_items, list)
assert len(instance.maybe_items) == 1
assert isinstance(instance.maybe_items[0], DummyType)
def test_optional_are_set_to_none():
for _type in types.BaseInferenceType.__subclasses__():
parameters = inspect.signature(_type).parameters
for parameter in parameters.values():
if _is_optional(parameter.annotation):
assert parameter.default is None, f"Parameter {parameter} of {_type} should be set to None"
def test_none_inferred():
"""Regression test for https://github.com/huggingface/huggingface_hub/pull/2095"""
# Doing this should not fail with
# TypeError: __init__() missing 2 required positional arguments: 'generate' and 'return_timestamps'
AutomaticSpeechRecognitionParameters()
def test_other_fields_are_set():
instance = DummyNestedType.parse_obj(
{
"item": DUMMY_AS_DICT,
"extra": "value",
"items": [{"foo": 42, "another_extra": "value", "bar": "baz"}],
"maybe_items": None,
}
)
assert instance.extra == "value"
assert instance.items[0].another_extra == "value"
assert str(instance.items[0]) == "DummyType(foo=42, bar='baz', another_extra='value')" # extra field always last
assert (
repr(instance) # works both with __str__ and __repr__
== (
"DummyNestedType("
"item=DummyType(foo=42, bar='baz'), "
"items=[DummyType(foo=42, bar='baz', another_extra='value')], "
"maybe_items=None, extra='value'"
")"
)
)
def test_other_fields_not_proper_dataclass_fields():
instance_1 = DummyType.parse_obj({"foo": 42, "bar": "baz", "extra": "value1"})
instance_2 = DummyType.parse_obj({"foo": 42, "bar": "baz", "extra": "value2", "another_extra": "value2.1"})
assert instance_1.extra == "value1"
assert instance_2.extra == "value2"
assert instance_2.another_extra == "value2.1"
# extra fields are not part of the dataclass fields
# all dataclass methods except __repr__ should work as if the extra fields were not there
assert instance_1 == instance_2
def _is_optional(field) -> bool:
# Taken from https://stackoverflow.com/a/58841311
return get_origin(field) is Union and type(None) in get_args(field)
|