File: update_production_ops.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (45 lines) | stat: -rw-r--r-- 1,563 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
This is a script to aggregate production ops from xplat/pytorch_models/build/all_mobile_model_configs.yaml.
Specify the file path in the first argument. The results will be dump to model_ops.yaml
"""

import sys

import yaml


root_operators = {}
traced_operators = {}
kernel_metadata = {}

with open(sys.argv[1]) as input_yaml_file:
    model_infos = yaml.safe_load(input_yaml_file)
    for info in model_infos:
        for op in info["root_operators"]:
            # aggregate occurance per op
            root_operators[op] = 1 + (root_operators[op] if op in root_operators else 0)
        for op in info["traced_operators"]:
            # aggregate occurance per op
            traced_operators[op] = 1 + (
                traced_operators[op] if op in traced_operators else 0
            )
        # merge dtypes for each kernel
        for kernal, dtypes in info["kernel_metadata"].items():
            new_dtypes = dtypes + (
                kernel_metadata[kernal] if kernal in kernel_metadata else []
            )
            kernel_metadata[kernal] = list(set(new_dtypes))


# Only test these built-in ops. No custom ops or non-CPU ops.
namespaces = ["aten", "prepacked", "prim", "quantized"]
root_operators = {
    x: root_operators[x] for x in root_operators if x.split("::")[0] in namespaces
}
traced_operators = {
    x: traced_operators[x] for x in traced_operators if x.split("::")[0] in namespaces
}

out_path = "test/mobile/model_test/model_ops.yaml"
with open(out_path, "w") as f:
    yaml.safe_dump({"root_operators": root_operators}, f)