File: tensorflow.py

package info (click to toggle)
typeshed 0.0~git20241223.ea91db2-4
  • links: PTS, VCS
  • area: main
  • in suites: forky
  • size: 28,824 kB
  • sloc: python: 7,745; makefile: 21; sh: 18
file content (142 lines) | stat: -rw-r--r-- 5,259 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Generates the protobuf stubs for the given tensorflow version using mypy-protobuf.
Generally, new minor versions are a good time to update the stubs.
"""

from __future__ import annotations

import os
import re
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path

from _utils import MYPY_PROTOBUF_VERSION, download_file, extract_archive, run_protoc
from ts_utils.metadata import read_metadata, update_metadata
from ts_utils.paths import distribution_path

PACKAGE_VERSION = read_metadata("tensorflow").version_spec.version

STUBS_FOLDER = distribution_path("tensorflow").absolute()
ARCHIVE_FILENAME = f"v{PACKAGE_VERSION}.zip"
ARCHIVE_URL = f"https://github.com/tensorflow/tensorflow/archive/refs/tags/{ARCHIVE_FILENAME}"
EXTRACTED_PACKAGE_DIR = f"tensorflow-{PACKAGE_VERSION}"

PROTOS_TO_REMOVE = (
    "compiler/xla/autotune_results_pb2.pyi",
    "compiler/xla/autotuning_pb2.pyi",
    "compiler/xla/service/buffer_assignment_pb2.pyi",
    "compiler/xla/service/hlo_execution_profile_data_pb2.pyi",
    "core/protobuf/autotuning_pb2.pyi",
    "core/protobuf/conv_autotuning_pb2.pyi",
    "core/protobuf/critical_section_pb2.pyi",
    "core/protobuf/eager_service_pb2.pyi",
    "core/protobuf/master_pb2.pyi",
    "core/protobuf/master_service_pb2.pyi",
    "core/protobuf/replay_log_pb2.pyi",
    "core/protobuf/tpu/compile_metadata_pb2.pyi",
    "core/protobuf/worker_pb2.pyi",
    "core/protobuf/worker_service_pb2.pyi",
    "core/util/example_proto_fast_parsing_test_pb2.pyi",
)
"""
These protos exist in a folder with protos used in python,
but are not included in the python wheel.
They are likely only used for other language builds.
stubtest was used to identify them by looking for ModuleNotFoundError.
(comment out ".*_pb2.*" from the allowlist)
"""

TSL_IMPORT_PATTERN = re.compile(r"(\[|\s)tsl\.")
XLA_IMPORT_PATTERN = re.compile(r"(\[|\s)xla\.")


def move_tree(source: Path, destination: Path) -> None:
    """Move directory and merge if destination already exists.

    Can't use shutil.move because it can't merge existing directories."""
    print(f"Moving '{source}' to '{destination}'")
    shutil.copytree(source, destination, dirs_exist_ok=True)
    shutil.rmtree(source)


def post_creation() -> None:
    """Move third-party and fix imports"""
    print()
    move_tree(STUBS_FOLDER / "tsl", STUBS_FOLDER / "tensorflow" / "tsl")
    move_tree(STUBS_FOLDER / "xla", STUBS_FOLDER / "tensorflow" / "compiler" / "xla")

    for path in STUBS_FOLDER.rglob("*_pb2.pyi"):
        print(f"Fixing imports in '{path}'")
        with open(path) as file:
            filedata = file.read()

        # Replace the target string
        filedata = re.sub(TSL_IMPORT_PATTERN, "\\1tensorflow.tsl.", filedata)
        filedata = re.sub(XLA_IMPORT_PATTERN, "\\1tensorflow.compiler.xla.", filedata)

        # Write the file out again
        with open(path, "w") as file:
            file.write(filedata)

    print()
    for to_remove in PROTOS_TO_REMOVE:
        file_path = STUBS_FOLDER / "tensorflow" / to_remove
        os.remove(file_path)
        print(f"Removed '{file_path}'")


def main() -> None:
    temp_dir = Path(tempfile.mkdtemp())
    # Fetch tensorflow (which contains all the .proto files)
    archive_path = temp_dir / ARCHIVE_FILENAME
    download_file(ARCHIVE_URL, archive_path)
    extract_archive(archive_path, temp_dir)

    # Remove existing pyi
    for old_stub in STUBS_FOLDER.rglob("*_pb2.pyi"):
        old_stub.unlink()

    PROTOC_VERSION = run_protoc(
        proto_paths=(
            f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/third_party/tsl",
            f"{EXTRACTED_PACKAGE_DIR}/third_party/xla",
            f"{EXTRACTED_PACKAGE_DIR}",
        ),
        mypy_out=STUBS_FOLDER,
        proto_globs=(
            f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/*.proto",
            f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/service/*.proto",
            f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/tsl/protobuf/*.proto",
            f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/example/*.proto",
            f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/framework/*.proto",
            f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/protobuf/*.proto",
            f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/protobuf/tpu/*.proto",
            f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/util/*.proto",
            f"{EXTRACTED_PACKAGE_DIR}/tensorflow/python/keras/protobuf/*.proto",
            f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/third_party/tsl/tsl/protobuf/*.proto",
        ),
        cwd=temp_dir,
    )

    # Cleanup after ourselves, this is a temp dir, but it can still grow fast if run multiple times
    shutil.rmtree(temp_dir)

    post_creation()

    update_metadata(
        "tensorflow",
        extra_description=f"""Partially generated using \
[mypy-protobuf=={MYPY_PROTOBUF_VERSION}](https://github.com/nipunn1313/mypy-protobuf/tree/v{MYPY_PROTOBUF_VERSION}) \
and {PROTOC_VERSION} on `tensorflow=={PACKAGE_VERSION}`.""",
    )
    print("Updated tensorflow/METADATA.toml")

    # Run pre-commit to cleanup the stubs
    subprocess.run((sys.executable, "-m", "pre_commit", "run", "--files", *STUBS_FOLDER.rglob("*_pb2.pyi")))


if __name__ == "__main__":
    main()