File: model.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (91 lines) | stat: -rw-r--r-- 3,179 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import os
import sys

import torch


# grab modules from test_jit_hooks.cpp
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from jit.test_hooks_modules import (
    create_forward_tuple_input,
    create_module_forward_multiple_inputs,
    create_module_forward_single_input,
    create_module_hook_return_nothing,
    create_module_multiple_hooks_multiple_inputs,
    create_module_multiple_hooks_single_input,
    create_module_no_forward_input,
    create_module_same_hook_repeated,
    create_submodule_forward_multiple_inputs,
    create_submodule_forward_single_input,
    create_submodule_hook_return_nothing,
    create_submodule_multiple_hooks_multiple_inputs,
    create_submodule_multiple_hooks_single_input,
    create_submodule_same_hook_repeated,
    create_submodule_to_call_directly_with_hooks,
)


# Create saved modules for JIT forward hooks and pre-hooks
def main():
    parser = argparse.ArgumentParser(
        description="Serialize a script modules with hooks attached"
    )
    parser.add_argument("--export-script-module-to", required=True)
    options = parser.parse_args()
    global save_name
    save_name = options.export_script_module_to + "_"

    tests = [
        (
            "test_submodule_forward_single_input",
            create_submodule_forward_single_input(),
        ),
        (
            "test_submodule_forward_multiple_inputs",
            create_submodule_forward_multiple_inputs(),
        ),
        (
            "test_submodule_multiple_hooks_single_input",
            create_submodule_multiple_hooks_single_input(),
        ),
        (
            "test_submodule_multiple_hooks_multiple_inputs",
            create_submodule_multiple_hooks_multiple_inputs(),
        ),
        ("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
        ("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
        ("test_module_forward_single_input", create_module_forward_single_input()),
        (
            "test_module_forward_multiple_inputs",
            create_module_forward_multiple_inputs(),
        ),
        (
            "test_module_multiple_hooks_single_input",
            create_module_multiple_hooks_single_input(),
        ),
        (
            "test_module_multiple_hooks_multiple_inputs",
            create_module_multiple_hooks_multiple_inputs(),
        ),
        ("test_module_hook_return_nothing", create_module_hook_return_nothing()),
        ("test_module_same_hook_repeated", create_module_same_hook_repeated()),
        ("test_module_no_forward_input", create_module_no_forward_input()),
        ("test_forward_tuple_input", create_forward_tuple_input()),
        (
            "test_submodule_to_call_directly_with_hooks",
            create_submodule_to_call_directly_with_hooks(),
        ),
    ]

    for name, model in tests:
        m_scripted = torch.jit.script(model)
        filename = save_name + name + ".pt"
        torch.jit.save(m_scripted, filename)

    print("OK: completed saving modules with hooks!")


if __name__ == "__main__":
    main()