File: create_local_hub.py

package info (click to toggle)
firefox 142.0.1-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,591,884 kB
  • sloc: cpp: 7,451,570; javascript: 6,392,463; ansic: 3,712,584; python: 1,388,569; xml: 629,223; asm: 426,919; java: 184,857; sh: 63,439; makefile: 19,150; objc: 13,059; perl: 12,983; yacc: 4,583; cs: 3,846; pascal: 3,352; lex: 1,720; ruby: 1,003; exp: 762; php: 436; lisp: 258; awk: 247; sql: 66; sed: 53; csh: 10
file content (253 lines) | stat: -rw-r--r-- 8,780 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
#!/usr/bin/env python3
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.


import argparse
import hashlib
import os
import shutil
import subprocess
import sys
import urllib.request
from pathlib import Path

import yaml

HERE = Path(__file__).resolve().parent
FETCH_FILE = (
    HERE / "../../../../../taskcluster/kinds/fetch/onnxruntime-web-fetch.yml"
).resolve()


def is_git_lfs_installed():
    try:
        output = subprocess.check_output(
            ["git", "lfs", "version"], stderr=subprocess.DEVNULL, text=True
        )
        return "git-lfs" in output.lower()
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False


def compute_sha256(file_path):
    """Compute SHA-256 of a file (binary read)."""
    hasher = hashlib.sha256()
    with file_path.open("rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hasher.update(chunk)
    return hasher.hexdigest()


def download_wasm(fetches, fetches_dir):
    """
    Download and verify ort.jsep.wasm if needed,
    using the 'ort.jsep.wasm' entry in the YAML file.
    """
    wasm_fetch = fetches["ort.jsep.wasm"]["fetch"]
    url = wasm_fetch["url"]
    expected_sha256 = wasm_fetch["sha256"]

    filename = url.split("/")[-1]
    output_file = fetches_dir / filename

    # If the file exists and its checksum matches, skip re-download
    if output_file.exists():
        print(f"Found existing file {output_file}, verifying checksum...")
        if compute_sha256(output_file) == expected_sha256:
            print("Existing file's checksum matches. Skipping download.")
            return
        else:
            print("Checksum mismatch on existing file. Removing and re-downloading...")
            output_file.unlink()

    # Download the file
    print(f"Downloading {url} to {output_file}...")
    with urllib.request.urlopen(url) as response, open(output_file, "wb") as out_file:
        shutil.copyfileobj(response, out_file)

    # Verify SHA-256
    print(f"Verifying SHA-256 of {output_file}...")
    downloaded_sha256 = compute_sha256(output_file)
    if downloaded_sha256 != expected_sha256:
        output_file.unlink(missing_ok=True)
        raise ValueError(
            f"Checksum mismatch for {filename}! "
            f"Expected: {expected_sha256}, got: {downloaded_sha256}"
        )

    print(f"File {filename} downloaded and verified successfully!")


def list_models(fetches):
    """
    List all YAML keys where fetch.type == 'git',
    along with the path-prefix specified in the YAML.
    """
    print("Available git-based models from the YAML:\n")
    for key, data in fetches.items():
        fetch = data.get("fetch")
        if fetch and fetch.get("type") == "git":
            path_prefix = fetch.get("path-prefix", "[no path-prefix specified]")
            print(f"- {key} -> path-prefix: {path_prefix}")
    print("\n(Use `--model <key>` to clone one of these repositories.)")


def clone_model(key, data, fetches_dir):
    """
    Clone (or re-clone) a model if needed.

    The directory is determined by 'path-prefix' from the YAML,
    relative to --fetches-dir. Example:

      path-prefix: "onnx-models/Xenova/all-MiniLM-L6-v2/main/"

    We'll end up cloning to <fetches-dir>/onnx-models/Xenova/all-MiniLM-L6-v2/main
    """
    fetch_data = data["fetch"]
    repo_url = fetch_data["repo"]
    path_prefix = fetch_data["path-prefix"]
    revision = fetch_data.get("revision", "main")

    # Compute the final directory from --fetches-dir + path-prefix
    repo_dir = fetches_dir / path_prefix

    # Ensure parent directories exist
    repo_dir.parent.mkdir(parents=True, exist_ok=True)

    # If the target directory exists, verify that it matches the correct repo & revision
    if repo_dir.exists():
        # 1. Check if .git exists
        if not (repo_dir / ".git").is_dir():
            print(f"Directory '{repo_dir}' exists but is not a git repo. Removing it.")
            shutil.rmtree(repo_dir, ignore_errors=True)
        else:
            # 2. Check if remote origin URL matches
            try:
                existing_url = subprocess.check_output(
                    ["git", "remote", "get-url", "origin"], cwd=repo_dir, text=True
                ).strip()
            except subprocess.CalledProcessError:
                existing_url = None

            if existing_url != repo_url:
                print(
                    f"Repository at '{repo_dir}' has remote '{existing_url}' "
                    f"instead of '{repo_url}'. Removing it."
                )
                shutil.rmtree(repo_dir, ignore_errors=True)
            else:
                # 3. Check if HEAD commit matches 'revision'
                try:
                    current_revision = subprocess.check_output(
                        ["git", "rev-parse", "HEAD"],
                        cwd=repo_dir,
                        text=True,
                    ).strip()
                except subprocess.CalledProcessError:
                    current_revision = None

                # If the revision is a branch name or tag, matching HEAD exactly
                # might not always be correct. We're keeping it simple:
                # if HEAD != revision, remove & reclone.
                if current_revision != revision:
                    print(
                        f"Repo at '{repo_dir}' has HEAD {current_revision}, "
                        f"but we need '{revision}'. Removing it."
                    )
                    shutil.rmtree(repo_dir, ignore_errors=True)

    # If we removed the directory or it never existed, clone it
    if not repo_dir.exists():
        print(f"Cloning {repo_url} into '{repo_dir}'...")
        # Normal clone first
        subprocess.run(["git", "clone", repo_url, str(repo_dir)], check=True)
        # Then checkout the desired revision (branch, commit, or tag)
        subprocess.run(["git", "checkout", revision], cwd=repo_dir, check=True)
        print(f"Checked out revision '{revision}' in '{repo_dir}'.")
    else:
        print(f"{repo_dir} already exists and is up to date. Skipping clone.")


def clone_models(keys, fetches, fetches_dir):
    """
    Clone each model specified by YAML key, if fetch.type == 'git'.
    Uses the path-prefix from the YAML to determine the final directory.
    """
    if not keys:
        return

    # Initialize git lfs once (if we have at least one model)
    subprocess.run(["git", "lfs", "install"], check=True)

    for key in keys:
        if key not in fetches:
            raise ValueError(f"Model '{key}' not found in YAML.")
        data = fetches[key]
        if data.get("fetch", {}).get("type") != "git":
            raise ValueError(f"Model '{key}' is not a git fetch type.")
        clone_model(key, data, fetches_dir)


def main():
    if not is_git_lfs_installed():
        print("git lfs is required for this program to run:")
        print("\t$ sudo apt install git-lfs")
        print("\t$ sudo yum install git-lfs")
        print("\t$ brew install git-lfs")
        print()
        print("\tor see https://github.com/git-lfs/git-lfs/blob/main/README.md")
        sys.exit(1)

    parser = argparse.ArgumentParser(
        description="Download ort.jsep.wasm and optionally clone specified models."
    )

    default_dir = os.getenv("MOZ_ML_LOCAL_DIR", None)

    parser.add_argument(
        "--fetches-dir",
        help="Directory to store the downloaded files (and cloned repos). Uses MOZ_FETCH_DIR if present.",
        default=default_dir,
    )
    parser.add_argument(
        "--list-models",
        action="store_true",
        help="List all available git-based models (keys in the YAML) and exit.",
    )
    parser.add_argument(
        "--model",
        action="append",
        help="YAML key of a model to clone (can be specified multiple times).",
    )
    args = parser.parse_args()

    # Load YAML
    with FETCH_FILE.open("r", encoding="utf-8") as f:
        fetches = yaml.safe_load(f)

    # If listing models, do so and exit
    if args.list_models:
        list_models(fetches)
        return

    if args.fetches_dir is None:
        raise ValueError(
            "Missing --fetches-dir argument or MOZ_ML_LOCAL_DIR env var. Please specify a directory to store the downloaded files"
        )

    fetches_dir = Path(args.fetches_dir).resolve()
    fetches_dir.mkdir(parents=True, exist_ok=True)

    # Always download/verify ort.jsep.wasm
    download_wasm(fetches, fetches_dir)

    # Clone requested models
    if args.model:
        clone_models(args.model, fetches, fetches_dir)


if __name__ == "__main__":
    main()