File: generate_inference_types.py

package info (click to toggle)
huggingface-hub 0.31.1-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 5,092 kB
  • sloc: python: 40,321; makefile: 54
file content (383 lines) | stat: -rw-r--r-- 14,043 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# coding=utf-8
# Copyright 2024-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a tool to generate `src/huggingface_hub/inference/_generated/types`."""

import argparse
import re
from pathlib import Path
from typing import Dict, List, Literal, NoReturn, Optional

import libcst as cst
from helpers import check_and_update_file_content, format_source_code


huggingface_hub_folder_path = Path(__file__).parents[1] / "src" / "huggingface_hub"
INFERENCE_TYPES_FOLDER_PATH = huggingface_hub_folder_path / "inference" / "_generated" / "types"
MAIN_INIT_PY_FILE = huggingface_hub_folder_path / "__init__.py"
REFERENCE_PACKAGE_EN_PATH = (
    Path(__file__).parents[1] / "docs" / "source" / "en" / "package_reference" / "inference_types.md"
)
REFERENCE_PACKAGE_KO_PATH = (
    Path(__file__).parents[1] / "docs" / "source" / "ko" / "package_reference" / "inference_types.md"
)

IGNORE_FILES = [
    "__init__.py",
    "base.py",
]

BASE_DATACLASS_REGEX = re.compile(
    r"""
    ^@dataclass
    \nclass\s(\w+):\n
""",
    re.VERBOSE | re.MULTILINE,
)

INHERITED_DATACLASS_REGEX = re.compile(
    r"""
    ^@dataclass_with_extra
    \nclass\s(\w+)\(BaseInferenceType\):
""",
    re.VERBOSE | re.MULTILINE,
)

TYPE_ALIAS_REGEX = re.compile(
    r"""
    ^(?!\s) # to make sure the line does not start with whitespace (top-level)
    (\w+)
    \s*=\s*
    (.+)
    $
    """,
    re.VERBOSE | re.MULTILINE,
)
OPTIONAL_FIELD_REGEX = re.compile(r": Optional\[(.+)\]$", re.MULTILINE)


INIT_PY_HEADER = """
# This file is auto-generated by `utils/generate_inference_types.py`.
# Do not modify it manually.
#
# ruff: noqa: F401

from .base import BaseInferenceType
"""

# Regex to add all dataclasses to ./src/huggingface_hub/__init__.py
MAIN_INIT_PY_REGEX = re.compile(
    r"""
\"inference\._generated\.types\":\s*\[ # module name
    (.*?) # all dataclasses listed
\] # closing bracket
""",
    re.MULTILINE | re.VERBOSE | re.DOTALL,
)


# List of classes that are shared across multiple modules
# This is used to fix the naming of the classes (to make them unique by task)
SHARED_CLASSES = [
    "BoundingBox",
    "ClassificationOutputTransform",
    "ClassificationOutput",
    "GenerationParameters",
    "TargetSize",
    "EarlyStoppingEnum",
]

REFERENCE_PACKAGE_EN_CONTENT = """
<!--⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<!--⚠️ Note that this file is auto-generated by `utils/generate_inference_types.py`. Do not modify it manually.-->


# Inference types

This page lists the types (e.g. dataclasses) available for each task supported on the Hugging Face Hub.
Each task is specified using a JSON schema, and the types are generated from these schemas - with some customization
due to Python requirements.
Visit [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks)
to find the JSON schemas for each task.

This part of the lib is still under development and will be improved in future releases.


{types}
"""

REFERENCE_PACKAGE_KO_CONTENT = """
<!--⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<!--⚠️ Note that this file is auto-generated by `utils/generate_inference_types.py`. Do not modify it manually.-->


# 추론 타입[[inference-types]]

이 페이지에는 Hugging Face Hub에서 지원하는 타입(예: 데이터 클래스)이 나열되어 있습니다.
각 작업은 JSON 스키마를 사용하여 지정되며, 이러한 스키마에 의해서 타입이 생성됩니다. 이때 Python 요구 사항으로 인해 일부 사용자 정의가 있을 수 있습니다.

각 작업의 JSON 스키마를 확인하려면 [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks)를 확인하세요.

라이브러리에서 이 부분은 아직 개발 중이며, 향후 릴리즈에서 개선될 예정입니다.


{types}
"""


def _replace_class_name(content: str, cls: str, new_cls: str) -> str:
    """
    Replace the class name `cls` with the new class name `new_cls` in the content.
    """
    pattern = rf"""
        (?<![\w'"])
        (['"]?)
        {cls}
        (['"]?)
        (?![\w'"])
    """

    def replacement(m):
        quote_start = m.group(1) or ""
        quote_end = m.group(2) or ""
        return f"{quote_start}{new_cls}{quote_end}"

    content = re.sub(pattern, replacement, content, flags=re.VERBOSE)
    return content


def _inherit_from_base(content: str) -> str:
    content = content.replace(
        "\nfrom dataclasses import",
        "\nfrom .base import BaseInferenceType, dataclass_with_extra\nfrom dataclasses import",
    )
    content = BASE_DATACLASS_REGEX.sub(r"@dataclass_with_extra\nclass \1(BaseInferenceType):\n", content)
    return content


def _delete_empty_lines(content: str) -> str:
    return "\n".join([line for line in content.split("\n") if line.strip()])


def _fix_naming_for_shared_classes(content: str, module_name: str) -> str:
    for cls in SHARED_CLASSES:
        # No need to fix the naming of a shared class if it's not used in the module
        if cls not in content:
            continue
        # Update class definition
        # Very hacky way to build "AudioClassificationOutputElement" instead of "ClassificationOutput"
        new_cls = "".join(part.capitalize() for part in module_name.split("_"))
        if "Classification" in new_cls:
            # to avoid "ClassificationClassificationOutput"
            new_cls += cls.removeprefix("Classification")
        else:
            new_cls += cls
        if new_cls.endswith("ClassificationOutput"):
            # to get "AudioClassificationOutputElement"
            new_cls += "Element"
        content = _replace_class_name(content, cls, new_cls)

    return content


def _fix_text2text_shared_parameters(content: str, module_name: str) -> str:
    if module_name in ("summarization", "translation"):
        content = content.replace(
            "Text2TextGenerationParameters",
            f"{module_name.capitalize()}GenerationParameters",
        )
        content = content.replace(
            "Text2TextGenerationTruncationStrategy",
            f"{module_name.capitalize()}GenerationTruncationStrategy",
        )
    return content


def _make_optional_fields_default_to_none(content: str):
    lines = []
    for line in content.split("\n"):
        if "Optional[" in line and not line.endswith("None"):
            line += " = None"

        lines.append(line)

    return "\n".join(lines)


def _list_dataclasses(content: str) -> List[str]:
    """List all dataclasses defined in the module."""
    return INHERITED_DATACLASS_REGEX.findall(content)


def _list_type_aliases(content: str) -> List[str]:
    """List all type aliases defined in the module."""
    return [alias_class for alias_class, _ in TYPE_ALIAS_REGEX.findall(content)]


class DeprecatedRemover(cst.CSTTransformer):
    def is_deprecated(self, docstring: Optional[str]) -> bool:
        """Check if a docstring contains @deprecated."""
        return docstring is not None and "@deprecated" in docstring.lower()

    def get_docstring(self, body: List[cst.BaseStatement]) -> Optional[str]:
        """Extract docstring from a body of statements."""
        if not body:
            return None
        first = body[0]
        if isinstance(first, cst.SimpleStatementLine):
            expr = first.body[0]
            if isinstance(expr, cst.Expr) and isinstance(expr.value, cst.SimpleString):
                return expr.value.evaluated_value
        return None

    def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> Optional[cst.ClassDef]:
        """Handle class definitions - remove if deprecated."""
        docstring = self.get_docstring(original_node.body.body)
        if self.is_deprecated(docstring):
            return cst.RemoveFromParent()

        new_body = []
        statements = list(updated_node.body.body)
        i = 0
        while i < len(statements):
            stmt = statements[i]

            # Check if this is a field (AnnAssign)
            if isinstance(stmt, cst.SimpleStatementLine) and isinstance(stmt.body[0], cst.AnnAssign):
                # Look ahead for docstring
                next_docstring = None
                if i + 1 < len(statements):
                    next_docstring = self.get_docstring([statements[i + 1]])

                if self.is_deprecated(next_docstring):
                    i += 2  # Skip both the field and its docstring
                    continue

            new_body.append(stmt)
            i += 1

        if not new_body:
            return cst.RemoveFromParent()

        return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))


def _clean_deprecated_fields(content: str) -> str:
    """Remove deprecated classes and fields using libcst."""
    source_tree = cst.parse_module(content)
    transformer = DeprecatedRemover()
    modified_tree = source_tree.visit(transformer)
    return modified_tree.code


def fix_inference_classes(content: str, module_name: str) -> str:
    content = _inherit_from_base(content)
    content = _delete_empty_lines(content)
    content = _fix_naming_for_shared_classes(content, module_name)
    content = _fix_text2text_shared_parameters(content, module_name)
    content = _make_optional_fields_default_to_none(content)
    return content


def create_init_py(dataclasses: Dict[str, List[str]]):
    """Create __init__.py file with all dataclasses."""
    content = INIT_PY_HEADER
    content += "\n"
    content += "\n".join(
        [f"from .{module} import {', '.join(dataclasses_list)}" for module, dataclasses_list in dataclasses.items()]
    )
    return content


def add_dataclasses_to_main_init(content: str, dataclasses: Dict[str, List[str]]):
    dataclasses_list = sorted({cls for classes in dataclasses.values() for cls in classes})
    dataclasses_str = ", ".join(f"'{cls}'" for cls in dataclasses_list)

    return MAIN_INIT_PY_REGEX.sub(f'"inference._generated.types": [{dataclasses_str}]', content)


def generate_reference_package(dataclasses: Dict[str, List[str]], language: Literal["en", "ko"]) -> str:
    """Generate the reference package content."""

    per_task_docs = []
    for task in sorted(dataclasses.keys()):
        lines = [f"[[autodoc]] huggingface_hub.{cls}" for cls in sorted(dataclasses[task])]
        lines_str = "\n\n".join(lines)
        if language == "en":
            # e.g. '## audio_classification'
            per_task_docs.append(f"\n## {task}\n\n{lines_str}\n\n")
        elif language == "ko":
            # e.g. '## audio_classification[[huggingface_hub.AudioClassificationInput]]'
            per_task_docs.append(f"\n## {task}[[huggingface_hub.{sorted(dataclasses[task])[0]}]]\n\n{lines_str}\n\n")
        else:
            raise ValueError(f"Language {language} is not supported.")

    template = REFERENCE_PACKAGE_EN_CONTENT if language == "en" else REFERENCE_PACKAGE_KO_CONTENT
    return template.format(types="\n".join(per_task_docs))


def check_inference_types(update: bool) -> NoReturn:
    """Check and update inference types.

    This script is used in the `make style` and `make quality` checks.
    """
    dataclasses = {}
    aliases = {}
    for file in INFERENCE_TYPES_FOLDER_PATH.glob("*.py"):
        if file.name in IGNORE_FILES:
            continue
        content = file.read_text()
        content = _clean_deprecated_fields(content)
        fixed_content = fix_inference_classes(content, module_name=file.stem)
        formatted_content = format_source_code(fixed_content)
        dataclasses[file.stem] = _list_dataclasses(formatted_content)
        aliases[file.stem] = _list_type_aliases(formatted_content)
        check_and_update_file_content(file, formatted_content, update)

    all_classes = {module: dataclasses[module] + aliases[module] for module in dataclasses.keys()}
    init_py_content = create_init_py(all_classes)
    init_py_content = format_source_code(init_py_content)
    init_py_file = INFERENCE_TYPES_FOLDER_PATH / "__init__.py"
    check_and_update_file_content(init_py_file, init_py_content, update)

    main_init_py_content = MAIN_INIT_PY_FILE.read_text()
    updated_main_init_py_content = add_dataclasses_to_main_init(main_init_py_content, all_classes)
    updated_main_init_py_content = format_source_code(updated_main_init_py_content)
    check_and_update_file_content(MAIN_INIT_PY_FILE, updated_main_init_py_content, update)
    reference_package_content_en = generate_reference_package(dataclasses, "en")
    check_and_update_file_content(REFERENCE_PACKAGE_EN_PATH, reference_package_content_en, update)

    reference_package_content_ko = generate_reference_package(dataclasses, "ko")
    check_and_update_file_content(REFERENCE_PACKAGE_KO_PATH, reference_package_content_ko, update)

    print("✅ All good! (inference types)")
    exit(0)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--update",
        action="store_true",
        help=(
            "Whether to re-generate files in `./src/huggingface_hub/inference/_generated/types/` if a change is detected."
        ),
    )
    args = parser.parse_args()

    check_inference_types(update=args.update)