File: serialization.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (244 lines) | stat: -rw-r--r-- 11,432 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# mypy: allow-untyped-defs
from __future__ import annotations

import io
import logging
import os
from typing import TYPE_CHECKING

import torch
from torch.onnx import _type_utils as jit_type_utils


if TYPE_CHECKING:
    import onnx

log = logging.getLogger(__name__)


def _create_tensor_proto_with_external_data(
    tensor: torch.Tensor,
    name: str,
    location: str,
    basepath: str,
    dtype_override: onnx.TypeProto | None = None,  # type: ignore[name-defined]
) -> onnx.TensorProto:  # type: ignore[name-defined]
    """Create a TensorProto with external data from a PyTorch tensor.
    The external data is saved to os.path.join(basepath, location).

    Args:
        tensor: Tensor to be saved.
        name: Name of the tensor (i.e., initializer name in ONNX graph).
        location: Relative location of the external data file
            (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
        basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").


    Reference for ONNX's external data format:
        How to load?
        https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
        How to save?
        https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
        How to set ONNX fields?
        https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
    """
    # FIXME: Avoid importing onnx into torch.onnx.
    import onnx

    scalar_type = (
        jit_type_utils.JitScalarType.from_onnx_type(
            dtype_override.tensor_type.elem_type
        )
        if dtype_override is not None
        else jit_type_utils.JitScalarType.from_dtype(tensor.dtype)
    )

    # Checkpoints can be stored with a different dtype as the model expects because
    # the user script can explicitly cast the original type to something or maybe
    # PyTorch's type promotion might do it
    if dtype_override is not None and scalar_type.dtype() != tensor.dtype:
        tensor = tensor.to(scalar_type.dtype())

    tensor_proto = onnx.TensorProto()  # type: ignore[attr-defined]
    tensor_proto.name = name
    tensor_proto.data_type = scalar_type.onnx_type()  # type: ignore[assignment]

    tensor_proto.dims.extend(tensor.shape)
    tensor_proto.data_location = onnx.TensorProto.EXTERNAL  # type: ignore[attr-defined]

    # Settings for saving one tensor per file.
    # Offset is zero because there is no other tensor in the same file.
    key_value_pairs = {
        "location": location,
        "offset": 0,
        "length": tensor.untyped_storage().nbytes(),
    }
    for k, v in key_value_pairs.items():
        entry = tensor_proto.external_data.add()
        entry.key = k
        entry.value = str(v)

    # Actual path to write content of tensor.
    external_data_file_path = os.path.join(basepath, location)
    if os.path.exists(external_data_file_path):
        os.remove(external_data_file_path)

    # Create external data's folder if not exists.
    external_data_dir_path = os.path.dirname(external_data_file_path)
    if not os.path.exists(external_data_dir_path):
        # if the demo_folder directory is not present
        # then create it.
        os.makedirs(external_data_dir_path)

    # Create a fresh file.
    with open(external_data_file_path, "xb") as data_file:
        # No need to call "seek" because offset is 0.
        # data_file.seek(0)
        # Write tensor content to the file.
        data_file.write(tensor.numpy(force=True).tobytes())

    return tensor_proto


def _convert_safetensors_to_torch_format(safetensors_file):
    # It this function is called, safetensors is guaranteed to exist
    # because the HF model with safetensors was already loaded and exported to ONNX
    from safetensors import safe_open  # type: ignore[import-not-found, import-untyped]

    tensors = {}
    with safe_open(safetensors_file, framework="pt", device="cpu") as f:  # type: ignore[attr-defined]
        for k in f.keys():
            tensors[k] = f.get_tensor(k).cpu()
    return tensors


# TODO: generalize to allow more checkpoints formats (torch or gguf)
def save_model_with_external_data(
    basepath: str,
    model_location: str,
    initializer_location: str,
    torch_state_dicts: tuple[dict | str | io.BytesIO, ...],
    onnx_model: onnx.ModelProto,  # type: ignore[name-defined]
    rename_initializer: bool = False,
) -> None:
    """Load PyTorch tensors from files and add to "onnx_model" as external initializers.

    Output files:
        ONNX model file path:
        ONNX initializer folder: os.path.join(basepath, initializer_location)

    After running this function, you can do
        ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
    to execute the model.

    Arguments:
        basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/").
        model_location: Relative location of the ONNX model file.
            E.g., "model.onnx" so that the model file is saved to
            "<basepath>/model.onnx".
        initializer_location: Relative location of the ONNX initializer folder.
            E.g., "initializers" so that the initializers are saved to
            "<basepath>/initializers/".
            Note: When initializers are >2GB, must be the same as `model_location`.
        torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved
            as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects.
        onnx_model: ONNX model to be saved with external initializers.
            If an input name matches a tensor loaded from "torch_state_dicts",
            the tensor will be saved as that input's external initializer.
        rename_initializer: Replaces "." by "_" for all ONNX initializer names.
            Not needed by the official torch.onnx.dynamo_export. This is a hack
            for supporting `FXSymbolicTracer` tracer with fake tensor mode.
            In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight)
            as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used.
    """
    # FIXME: Avoid importing onnx into torch.onnx.
    import onnx

    initializers_to_be_deleted = {}  # Using dict because it is **ordered**
    existing_initializers = {
        k.name: idx for idx, k in enumerate(onnx_model.graph.initializer)
    }
    onnx_input_names = {input.name for input in onnx_model.graph.input}
    for el in torch_state_dicts:
        if isinstance(el, dict):
            # Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user
            # Using torch.save wouldn't leverage mmap, leading to higher memory usage
            state_dict = el
        else:
            if isinstance(el, str) and el.endswith(".safetensors"):
                state_dict = _convert_safetensors_to_torch_format(el)
            else:
                try:
                    # Loads checkpoint using memory-map on CPU to support really large models
                    # The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
                    state_dict = torch.load(el, map_location="cpu", mmap=True)
                except (RuntimeError, ValueError) as e:
                    if "mmap can only be used with files saved with" in str(
                        e
                    ) or isinstance(el, io.BytesIO):
                        log.warning(
                            "Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
                            "Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
                        )
                        if isinstance(el, io.BytesIO):
                            el.seek(0)  # torch.load from `try:` has read the file.
                        state_dict = torch.load(el, map_location="cpu")
                    else:
                        raise e

        for name, tensor in state_dict.items():
            if rename_initializer:
                # Basically, "transformer.attention.self.query.weight" is mapped
                # to "transformer_attention_self_query_weight" for mimicking the
                # name-modifying code in FX-to-ONNX exporter.
                # See function _replace_get_attr_with_placeholder for details.
                name = name.replace(".", "_")

            # This block tries to match the onnx initializer name with torch parameter/buffer
            #  e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer
            # For each PyTorch tensor name loaded by torch.load,
            #  1.  Search its best match in ONNX model. E.g., the match of
            #       "transformer_attention_weight" could be "attention_weight".
            #  2.  Set "tensor" as the initializer of the matched ONNX input.
            #      E.g., "tensor" is stored as the initializer of "attention_weight".
            # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
            # loaded by torch.load.
            if name in onnx_input_names:
                # Same input name shouldn't be matched again
                onnx_input_names.remove(name)
            else:
                for onnx_input_name in onnx_input_names:
                    if onnx_input_name.endswith(name) or name.endswith(onnx_input_name):
                        # Find a match. Change name to the matched ONNX input name, so that we
                        # create initializer with the right ONNX name.
                        name = onnx_input_name
                        onnx_input_names.remove(onnx_input_name)
                        break

            relative_tensor_file_path = os.path.join(initializer_location, name)
            # Create one file per tensor.
            # tensor_proto.raw_data is stored to external file at
            # os.path.join(basepath, relative_tensor_file_path).
            model_input_types = {k.name: k.type for k in onnx_model.graph.input}

            # Mark for deletion - a replacement will be appended next
            if name in existing_initializers:
                initializers_to_be_deleted[existing_initializers[name]] = name
            tensor_proto = _create_tensor_proto_with_external_data(
                tensor,
                name,
                relative_tensor_file_path,
                basepath,
                model_input_types.pop(name, None),
            )
            # Add the tensor_proto to the ONNX model as an initializer with external data.
            onnx_model.graph.initializer.append(tensor_proto)
    # Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices
    initializers_to_be_deleted = dict(
        sorted(initializers_to_be_deleted.items(), reverse=True)
    )
    for idx in initializers_to_be_deleted.keys():
        del onnx_model.graph.initializer[idx]

    # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
    onnx.save(onnx_model, os.path.join(basepath, model_location))  # type: ignore[attr-defined]