File: patcher.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (143 lines) | stat: -rw-r--r-- 6,055 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# mypy: allow-untyped-defs
import copy
import functools
from typing import List, TYPE_CHECKING, Union

import torch


if TYPE_CHECKING:
    import io


# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
@functools.lru_cache(None)
def has_safetensors_and_transformers():
    try:
        # safetensors is not an exporter requirement, but needed for some huggingface models
        import safetensors  # type: ignore[import]  # noqa: F401
        import transformers  # type: ignore[import]  # noqa: F401
        from safetensors import torch as safetensors_torch  # noqa: F401

        return True
    except ImportError:
        return False


class ONNXTorchPatcher:
    """Context manager to temporarily patch PyTorch during FX-to-ONNX export.

    This class is a collection of "patches" required by FX-to-ONNX exporter.

    This context overrides several torch functions to support symbolic
    export of large scale models.

    torch.load:
        This function is patched to record the files PyTorch stores model
        parameters and buffers. Downstream FX-to-ONNX exporter can create
        initializers from these files.
    torch.fx._symbolic_trace._wrapped_methods_to_patch:
        This list is extended with (torch.Tensor, "__getitem__") so that
        weight[x, :, y] becomes exportable with torch.fx.symbolic_trace.
    safetensors.torch.load_file:
        This function is patched to allow safetensors to be loaded within
        FakeTensorMode. Remove after https://github.com/huggingface/safetensors/pull/318

    Search for ONNXTorchPatcher in test_fx_to_onnx_with_onnxruntime.py for
    example usage.

    TODO: Should this really be a global patcher? Can we make it a local patcher?
        A reason for splitting this into several patchers is to patch one part of the code
        as a collateral damage of patching another part of the code. For example, we
        for tracing model with torch._dynamo.export, we don't need to patch
        `torch.fx._symbolic_trace._wrapped_methods_to_patch`
    """

    def __init__(self) -> None:
        # List of file paths processed by torch.load.
        self.paths: List[Union[str, io.BufferedIOBase]] = []

        def torch_load_wrapper(f, *args, **kwargs):
            # Record path for later serialization into ONNX proto
            self.paths.append(f)
            # Then, call the original torch.load.
            return self.torch_load(f, *args, **kwargs)

        # Original version of torch.load.
        self.torch_load = torch.load

        # Wrapper or modified version of torch functions.
        self.torch_load_wrapper = torch_load_wrapper

        if has_safetensors_and_transformers():
            import safetensors
            import transformers

            def safetensors_load_file_wrapper(filename, device="cpu"):
                # Record path for later serialization into ONNX proto
                self.paths.append(filename)
                result = {}
                with safetensors.torch.safe_open(  # type: ignore[attr-defined]
                    filename, framework="pt", device=device
                ) as f:
                    for k in f.keys():
                        fake_mode = torch._guards.detect_fake_mode()
                        if not fake_mode:
                            result[k] = f.get_tensor(k)
                        else:
                            empty_tensor = f.get_slice(k)
                            result[k] = torch.empty(
                                tuple(empty_tensor.get_shape()),
                                dtype=safetensors.torch._getdtype(
                                    empty_tensor.get_dtype()
                                ),
                            )
                return result

            self.safetensors_torch_load_file = safetensors.torch.load_file
            self.safetensors_torch_load_file_wrapper = safetensors_load_file_wrapper
            self.transformers_modeling_utils_safe_load_file = (
                transformers.modeling_utils.safe_load_file
            )

    def __enter__(self):
        torch.load = self.torch_load_wrapper

        self.torch_fx__symbolic_trace__wrapped_methods_to_patch = (
            torch.fx._symbolic_trace._wrapped_methods_to_patch
        )
        desired_wrapped_methods = copy.deepcopy(
            torch.fx._symbolic_trace._wrapped_methods_to_patch
        )
        if (torch.Tensor, "__getitem__") not in desired_wrapped_methods:
            # Adding `__getitem__` to the patching list will make tensor indexing traceable via
            # torch.fx.symbolic_trace. Otherwise, `tensor[x, :, y]` cannot be traced.
            # This happens because `__getitem__` is neither under torch domain nor an aten operator,
            # so the patching (or similar Proxy-generating mechanism) doesn't happen automatically.
            # Note that torch.fx.symbolic_trace defines FX_PATCH_GETITEM environment variable for
            # enabling the line below for patching.
            desired_wrapped_methods.append((torch.Tensor, "__getitem__"))
        torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods

        if has_safetensors_and_transformers():
            import safetensors
            import transformers

            safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
            transformers.modeling_utils.safe_load_file = (
                self.safetensors_torch_load_file_wrapper
            )

    def __exit__(self, exc_type, exc_value, traceback):
        torch.load = self.torch_load
        torch.fx._symbolic_trace._wrapped_methods_to_patch = (
            self.torch_fx__symbolic_trace__wrapped_methods_to_patch
        )
        if has_safetensors_and_transformers():
            import safetensors
            import transformers

            safetensors.torch.load_file = self.safetensors_torch_load_file
            transformers.modeling_utils.safe_load_file = (
                self.transformers_modeling_utils_safe_load_file
            )