File: convert_to_external_data.py

package info (click to toggle)
thunderbird 1%3A143.0.1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 4,703,968 kB
  • sloc: cpp: 7,770,492; javascript: 5,943,842; ansic: 3,918,754; python: 1,418,263; xml: 653,354; asm: 474,045; java: 183,079; sh: 111,238; makefile: 20,410; perl: 14,359; objc: 13,059; yacc: 4,583; pascal: 3,405; lex: 1,720; ruby: 999; exp: 762; sql: 715; awk: 580; php: 436; lisp: 430; sed: 69; csh: 10
file content (81 lines) | stat: -rw-r--r-- 2,637 bytes parent folder | download | duplicates (11)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*-
# vim: set filetype=python:
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distibuted with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import glob
import json
import os
import pathlib
import shutil
import tempfile
from typing import Annotated

import onnx
import typer

# To use:
# pip install onnx typer
# python convert_to_external_data --base-path "/path/to/directory"


def convert(model_path, save_path):
    model = onnx.load(model_path)

    external_data_name = f"{pathlib.Path(model_path).stem}.onnx_data"

    # Create the new model in a temporary directory and copy all it's content back to save_path
    # Doing this because if save_path is same as model_path & we directly write to model_path
    # onnx will append to the external data path which would make it grow more than expected.
    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_model_path = os.path.join(tmp_dir, os.path.basename(model_path))

        onnx.save_model(
            model,
            tmp_model_path,
            save_as_external_data=True,
            location=external_data_name,
        )

        file_names = os.listdir(tmp_dir)
        target_dir = str(pathlib.Path(save_path).parent)
        os.makedirs(target_dir, exist_ok=True)
        for file_name in file_names:
            shutil.copy2(os.path.join(tmp_dir, file_name), target_dir)


def main(base_path: Annotated[str, typer.Option()]):
    """
    This will convert recursively all onnx models in that directory to one with external data format.
    """
    # Convert all
    for model_path in glob.glob(
        os.path.join(base_path, "**/*.onnx"),
        recursive=True,
    ):
        print("Converting", model_path)
        convert(model_path, model_path)

    # Find all config.json and enable use_external_data_format
    for config_path in glob.glob(
        os.path.join(base_path, "**/config.json"),
        recursive=True,
    ):
        print("Modifying", config_path)
        # Load the JSON file
        with open(config_path) as infile:
            config_data = json.load(infile)

        config_data["transformers.js_config"] = config_data.get(
            "transformers.js_config", {}
        )
        config_data["transformers.js_config"]["use_external_data_format"] = True

        # Save the JSON file with additional config
        with open(config_path, "w") as outfile:
            json.dump(config_data, outfile, indent=4, ensure_ascii=False)


if __name__ == "__main__":
    typer.run(main)