File: pt_wrapper_module.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (47 lines) | stat: -rw-r--r-- 1,940 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch


class WrapperModule:
    """Wraps the instance of wrapped_type.
    For graph_mode traces the instance of wrapped_type.
    Randomaly initializes num_params tensors with single float element.
    Args:
        wrapped_type:
            - Object type to be wrapped.
                Expects the wrapped_type to:
                   - be constructed with pt_fn specified in module_config.
                   - provide forward method that takes module_config.num_params args.
        module_config:
            - Specified pt_fn to construct wrapped_type with, whether graph_mode
              is enabled, and number of parameters wrapped_type's forward method
              takes.
        debug:
            - Whether debug mode is enabled.
        save:
            - In graph mode, whether graph is to be saved.
    """

    def __init__(self, wrapped_type, module_config, debug, save=False):
        pt_fn = module_config.pt_fn
        self.module = wrapped_type(pt_fn)
        self.tensor_inputs = []
        self.module_name = wrapped_type.__name__
        for _ in range(module_config.num_params):
            self.tensor_inputs.append(torch.randn(1))
        if module_config.graph_mode:
            self.module = torch.jit.trace(self.module, self.tensor_inputs)
            if save:
                file_name = self.module_name + "_" + pt_fn.__name__ + ".pt"
                torch.jit.save(self.module, file_name)
                print(f"Generated graph is saved in {file_name}")
        print(
            f"Benchmarking module {self.module_name} with fn {pt_fn.__name__}: Graph mode:{module_config.graph_mode}"
        )
        if debug and isinstance(self.module, torch.jit.ScriptModule):
            print(self.module.graph)
            print(self.module.code)

    def forward(self, niters):
        with torch.no_grad():
            for _ in range(niters):
                self.module.forward(*self.tensor_inputs)