1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
|
# coding=utf-8
# Copyright 2024-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a tool to generate `src/huggingface_hub/inference/_generated/types`."""
import argparse
import re
from pathlib import Path
from typing import Dict, List, Literal, NoReturn, Optional
import libcst as cst
from helpers import check_and_update_file_content, format_source_code
huggingface_hub_folder_path = Path(__file__).parents[1] / "src" / "huggingface_hub"
INFERENCE_TYPES_FOLDER_PATH = huggingface_hub_folder_path / "inference" / "_generated" / "types"
MAIN_INIT_PY_FILE = huggingface_hub_folder_path / "__init__.py"
REFERENCE_PACKAGE_EN_PATH = (
Path(__file__).parents[1] / "docs" / "source" / "en" / "package_reference" / "inference_types.md"
)
REFERENCE_PACKAGE_KO_PATH = (
Path(__file__).parents[1] / "docs" / "source" / "ko" / "package_reference" / "inference_types.md"
)
IGNORE_FILES = [
"__init__.py",
"base.py",
]
BASE_DATACLASS_REGEX = re.compile(
r"""
^@dataclass
\nclass\s(\w+):\n
""",
re.VERBOSE | re.MULTILINE,
)
INHERITED_DATACLASS_REGEX = re.compile(
r"""
^@dataclass_with_extra
\nclass\s(\w+)\(BaseInferenceType\):
""",
re.VERBOSE | re.MULTILINE,
)
TYPE_ALIAS_REGEX = re.compile(
r"""
^(?!\s) # to make sure the line does not start with whitespace (top-level)
(\w+)
\s*=\s*
(.+)
$
""",
re.VERBOSE | re.MULTILINE,
)
OPTIONAL_FIELD_REGEX = re.compile(r": Optional\[(.+)\]$", re.MULTILINE)
INIT_PY_HEADER = """
# This file is auto-generated by `utils/generate_inference_types.py`.
# Do not modify it manually.
#
# ruff: noqa: F401
from .base import BaseInferenceType
"""
# Regex to add all dataclasses to ./src/huggingface_hub/__init__.py
MAIN_INIT_PY_REGEX = re.compile(
r"""
\"inference\._generated\.types\":\s*\[ # module name
(.*?) # all dataclasses listed
\] # closing bracket
""",
re.MULTILINE | re.VERBOSE | re.DOTALL,
)
# List of classes that are shared across multiple modules
# This is used to fix the naming of the classes (to make them unique by task)
SHARED_CLASSES = [
"BoundingBox",
"ClassificationOutputTransform",
"ClassificationOutput",
"GenerationParameters",
"TargetSize",
"EarlyStoppingEnum",
]
REFERENCE_PACKAGE_EN_CONTENT = """
<!--⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<!--⚠️ Note that this file is auto-generated by `utils/generate_inference_types.py`. Do not modify it manually.-->
# Inference types
This page lists the types (e.g. dataclasses) available for each task supported on the Hugging Face Hub.
Each task is specified using a JSON schema, and the types are generated from these schemas - with some customization
due to Python requirements.
Visit [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks)
to find the JSON schemas for each task.
This part of the lib is still under development and will be improved in future releases.
{types}
"""
REFERENCE_PACKAGE_KO_CONTENT = """
<!--⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<!--⚠️ Note that this file is auto-generated by `utils/generate_inference_types.py`. Do not modify it manually.-->
# 추론 타입[[inference-types]]
이 페이지에는 Hugging Face Hub에서 지원하는 타입(예: 데이터 클래스)이 나열되어 있습니다.
각 작업은 JSON 스키마를 사용하여 지정되며, 이러한 스키마에 의해서 타입이 생성됩니다. 이때 Python 요구 사항으로 인해 일부 사용자 정의가 있을 수 있습니다.
각 작업의 JSON 스키마를 확인하려면 [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks)를 확인하세요.
라이브러리에서 이 부분은 아직 개발 중이며, 향후 릴리즈에서 개선될 예정입니다.
{types}
"""
def _replace_class_name(content: str, cls: str, new_cls: str) -> str:
"""
Replace the class name `cls` with the new class name `new_cls` in the content.
"""
pattern = rf"""
(?<![\w'"])
(['"]?)
{cls}
(['"]?)
(?![\w'"])
"""
def replacement(m):
quote_start = m.group(1) or ""
quote_end = m.group(2) or ""
return f"{quote_start}{new_cls}{quote_end}"
content = re.sub(pattern, replacement, content, flags=re.VERBOSE)
return content
def _inherit_from_base(content: str) -> str:
content = content.replace(
"\nfrom dataclasses import",
"\nfrom .base import BaseInferenceType, dataclass_with_extra\nfrom dataclasses import",
)
content = BASE_DATACLASS_REGEX.sub(r"@dataclass_with_extra\nclass \1(BaseInferenceType):\n", content)
return content
def _delete_empty_lines(content: str) -> str:
return "\n".join([line for line in content.split("\n") if line.strip()])
def _fix_naming_for_shared_classes(content: str, module_name: str) -> str:
for cls in SHARED_CLASSES:
# No need to fix the naming of a shared class if it's not used in the module
if cls not in content:
continue
# Update class definition
# Very hacky way to build "AudioClassificationOutputElement" instead of "ClassificationOutput"
new_cls = "".join(part.capitalize() for part in module_name.split("_"))
if "Classification" in new_cls:
# to avoid "ClassificationClassificationOutput"
new_cls += cls.removeprefix("Classification")
else:
new_cls += cls
if new_cls.endswith("ClassificationOutput"):
# to get "AudioClassificationOutputElement"
new_cls += "Element"
content = _replace_class_name(content, cls, new_cls)
return content
def _fix_text2text_shared_parameters(content: str, module_name: str) -> str:
if module_name in ("summarization", "translation"):
content = content.replace(
"Text2TextGenerationParameters",
f"{module_name.capitalize()}GenerationParameters",
)
content = content.replace(
"Text2TextGenerationTruncationStrategy",
f"{module_name.capitalize()}GenerationTruncationStrategy",
)
return content
def _make_optional_fields_default_to_none(content: str):
lines = []
for line in content.split("\n"):
if "Optional[" in line and not line.endswith("None"):
line += " = None"
lines.append(line)
return "\n".join(lines)
def _list_dataclasses(content: str) -> List[str]:
"""List all dataclasses defined in the module."""
return INHERITED_DATACLASS_REGEX.findall(content)
def _list_type_aliases(content: str) -> List[str]:
"""List all type aliases defined in the module."""
return [alias_class for alias_class, _ in TYPE_ALIAS_REGEX.findall(content)]
class DeprecatedRemover(cst.CSTTransformer):
def is_deprecated(self, docstring: Optional[str]) -> bool:
"""Check if a docstring contains @deprecated."""
return docstring is not None and "@deprecated" in docstring.lower()
def get_docstring(self, body: List[cst.BaseStatement]) -> Optional[str]:
"""Extract docstring from a body of statements."""
if not body:
return None
first = body[0]
if isinstance(first, cst.SimpleStatementLine):
expr = first.body[0]
if isinstance(expr, cst.Expr) and isinstance(expr.value, cst.SimpleString):
return expr.value.evaluated_value
return None
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> Optional[cst.ClassDef]:
"""Handle class definitions - remove if deprecated."""
docstring = self.get_docstring(original_node.body.body)
if self.is_deprecated(docstring):
return cst.RemoveFromParent()
new_body = []
statements = list(updated_node.body.body)
i = 0
while i < len(statements):
stmt = statements[i]
# Check if this is a field (AnnAssign)
if isinstance(stmt, cst.SimpleStatementLine) and isinstance(stmt.body[0], cst.AnnAssign):
# Look ahead for docstring
next_docstring = None
if i + 1 < len(statements):
next_docstring = self.get_docstring([statements[i + 1]])
if self.is_deprecated(next_docstring):
i += 2 # Skip both the field and its docstring
continue
new_body.append(stmt)
i += 1
if not new_body:
return cst.RemoveFromParent()
return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))
def _clean_deprecated_fields(content: str) -> str:
"""Remove deprecated classes and fields using libcst."""
source_tree = cst.parse_module(content)
transformer = DeprecatedRemover()
modified_tree = source_tree.visit(transformer)
return modified_tree.code
def fix_inference_classes(content: str, module_name: str) -> str:
content = _inherit_from_base(content)
content = _delete_empty_lines(content)
content = _fix_naming_for_shared_classes(content, module_name)
content = _fix_text2text_shared_parameters(content, module_name)
content = _make_optional_fields_default_to_none(content)
return content
def create_init_py(dataclasses: Dict[str, List[str]]):
"""Create __init__.py file with all dataclasses."""
content = INIT_PY_HEADER
content += "\n"
content += "\n".join(
[f"from .{module} import {', '.join(dataclasses_list)}" for module, dataclasses_list in dataclasses.items()]
)
return content
def add_dataclasses_to_main_init(content: str, dataclasses: Dict[str, List[str]]):
dataclasses_list = sorted({cls for classes in dataclasses.values() for cls in classes})
dataclasses_str = ", ".join(f"'{cls}'" for cls in dataclasses_list)
return MAIN_INIT_PY_REGEX.sub(f'"inference._generated.types": [{dataclasses_str}]', content)
def generate_reference_package(dataclasses: Dict[str, List[str]], language: Literal["en", "ko"]) -> str:
"""Generate the reference package content."""
per_task_docs = []
for task in sorted(dataclasses.keys()):
lines = [f"[[autodoc]] huggingface_hub.{cls}" for cls in sorted(dataclasses[task])]
lines_str = "\n\n".join(lines)
if language == "en":
# e.g. '## audio_classification'
per_task_docs.append(f"\n## {task}\n\n{lines_str}\n\n")
elif language == "ko":
# e.g. '## audio_classification[[huggingface_hub.AudioClassificationInput]]'
per_task_docs.append(f"\n## {task}[[huggingface_hub.{sorted(dataclasses[task])[0]}]]\n\n{lines_str}\n\n")
else:
raise ValueError(f"Language {language} is not supported.")
template = REFERENCE_PACKAGE_EN_CONTENT if language == "en" else REFERENCE_PACKAGE_KO_CONTENT
return template.format(types="\n".join(per_task_docs))
def check_inference_types(update: bool) -> NoReturn:
"""Check and update inference types.
This script is used in the `make style` and `make quality` checks.
"""
dataclasses = {}
aliases = {}
for file in INFERENCE_TYPES_FOLDER_PATH.glob("*.py"):
if file.name in IGNORE_FILES:
continue
content = file.read_text()
content = _clean_deprecated_fields(content)
fixed_content = fix_inference_classes(content, module_name=file.stem)
formatted_content = format_source_code(fixed_content)
dataclasses[file.stem] = _list_dataclasses(formatted_content)
aliases[file.stem] = _list_type_aliases(formatted_content)
check_and_update_file_content(file, formatted_content, update)
all_classes = {module: dataclasses[module] + aliases[module] for module in dataclasses.keys()}
init_py_content = create_init_py(all_classes)
init_py_content = format_source_code(init_py_content)
init_py_file = INFERENCE_TYPES_FOLDER_PATH / "__init__.py"
check_and_update_file_content(init_py_file, init_py_content, update)
main_init_py_content = MAIN_INIT_PY_FILE.read_text()
updated_main_init_py_content = add_dataclasses_to_main_init(main_init_py_content, all_classes)
updated_main_init_py_content = format_source_code(updated_main_init_py_content)
check_and_update_file_content(MAIN_INIT_PY_FILE, updated_main_init_py_content, update)
reference_package_content_en = generate_reference_package(dataclasses, "en")
check_and_update_file_content(REFERENCE_PACKAGE_EN_PATH, reference_package_content_en, update)
reference_package_content_ko = generate_reference_package(dataclasses, "ko")
check_and_update_file_content(REFERENCE_PACKAGE_KO_PATH, reference_package_content_ko, update)
print("✅ All good! (inference types)")
exit(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--update",
action="store_true",
help=(
"Whether to re-generate files in `./src/huggingface_hub/inference/_generated/types/` if a change is detected."
),
)
args = parser.parse_args()
check_inference_types(update=args.update)
|