File: serialization.rst

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (395 lines) | stat: -rw-r--r-- 16,324 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

Serialization semantics
=======================

This note describes how you can save and load PyTorch tensors and module states
in Python, and how to serialize Python modules so they can be loaded in C++.

.. contents:: Table of Contents

.. _saving-loading-tensors:

Saving and loading tensors
--------------------------

:func:`torch.save` and :func:`torch.load` let you easily save and load tensors:

::

    >>> t = torch.tensor([1., 2.])
    >>> torch.save(t, 'tensor.pt')
    >>> torch.load('tensor.pt')
    tensor([1., 2.])

By convention, PyTorch files are typically written with a ‘.pt’ or ‘.pth’ extension.

:func:`torch.save` and :func:`torch.load` use Python’s pickle by default,
so you can also save multiple tensors as part of Python objects like tuples,
lists, and dicts:

::

    >>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
    >>> torch.save(d, 'tensor_dict.pt')
    >>> torch.load('tensor_dict.pt')
    {'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

Custom data structures that include PyTorch tensors can also be saved if the
data structure is pickle-able.

.. _preserve-storage-sharing:

Saving and loading tensors preserves views
---------------------------------------------

Saving tensors preserves their view relationships:

::

    >>> numbers = torch.arange(1, 10)
    >>> evens = numbers[1::2]
    >>> torch.save([numbers, evens], 'tensors.pt')
    >>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
    >>> loaded_evens *= 2
    >>> loaded_numbers
    tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

Behind the scenes, these tensors share the same "storage." See
`Tensor Views <https://pytorch.org/docs/main/tensor_view.html>`_ for more
on views and storage.

When PyTorch saves tensors it saves their storage objects and tensor
metadata separately. This is an implementation detail that may change in the
future, but it typically saves space and lets PyTorch easily
reconstruct the view relationships between the loaded tensors. In the above
snippet, for example, only a single storage is written to 'tensors.pt'.

In some cases, however, saving the current storage objects may be unnecessary
and create prohibitively large files. In the following snippet a storage much
larger than the saved tensor is written to a file:

::

    >>> large = torch.arange(1, 1000)
    >>> small = large[0:5]
    >>> torch.save(small, 'small.pt')
    >>> loaded_small = torch.load('small.pt')
    >>> loaded_small.storage().size()
    999

Instead of saving only the five values in the `small` tensor to 'small.pt,'
the 999 values in the storage it shares with `large` were saved and loaded.

When saving tensors with fewer elements than their storage objects, the size of
the saved file can be reduced by first cloning the tensors. Cloning a tensor
produces a new tensor with a new storage object containing only the values
in the tensor:

::

    >>> large = torch.arange(1, 1000)
    >>> small = large[0:5]
    >>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
    >>> loaded_small = torch.load('small.pt')
    >>> loaded_small.storage().size()
    5

Since the cloned tensors are independent of each other, however, they have
none of the view relationships the original tensors did. If both file size and
view relationships are important when saving tensors smaller than their
storage objects, then care must be taken to construct new tensors that minimize
the size of their storage objects but still have the desired view relationships
before saving.

.. _saving-loading-python-modules:

Saving and loading torch.nn.Modules
-----------------------------------

See also: `Tutorial: Saving and loading modules <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`_

In PyTorch, a module’s state is frequently serialized using a ‘state dict.’
A module’s state dict contains all of its parameters and persistent buffers:

::

    >>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
    >>> list(bn.named_parameters())
    [('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
     ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

    >>> list(bn.named_buffers())
    [('running_mean', tensor([0., 0., 0.])),
     ('running_var', tensor([1., 1., 1.])),
     ('num_batches_tracked', tensor(0))]

    >>> bn.state_dict()
    OrderedDict([('weight', tensor([1., 1., 1.])),
                 ('bias', tensor([0., 0., 0.])),
                 ('running_mean', tensor([0., 0., 0.])),
                 ('running_var', tensor([1., 1., 1.])),
                 ('num_batches_tracked', tensor(0))])

Instead of saving a module directly, for compatibility reasons it is recommended
to instead save only its state dict. Python modules even have a function,
:meth:`~torch.nn.Module.load_state_dict`, to restore their states from a state dict:

::

    >>> torch.save(bn.state_dict(), 'bn.pt')
    >>> bn_state_dict = torch.load('bn.pt')
    >>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
    >>> new_bn.load_state_dict(bn_state_dict)
    <All keys matched successfully>

Note that the state dict is first loaded from its file with :func:`torch.load`
and the state then restored with :meth:`~torch.nn.Module.load_state_dict`.

Even custom modules and modules containing other modules have state dicts and
can use this pattern:

::

    # A module with two linear layers
    >>> class MyModule(torch.nn.Module):
          def __init__(self):
            super().__init__()
            self.l0 = torch.nn.Linear(4, 2)
            self.l1 = torch.nn.Linear(2, 1)

          def forward(self, input):
            out0 = self.l0(input)
            out0_relu = torch.nn.functional.relu(out0)
            return self.l1(out0_relu)

    >>> m = MyModule()
    >>> m.state_dict()
    OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                       [-0.3289, 0.2827, 0.4588, 0.2031]])),
                 ('l0.bias', tensor([ 0.0300, -0.1316])),
                 ('l1.weight', tensor([[0.6533, 0.3413]])),
                 ('l1.bias', tensor([-0.1112]))])

    >>> torch.save(m.state_dict(), 'mymodule.pt')
    >>> m_state_dict = torch.load('mymodule.pt')
    >>> new_m = MyModule()
    >>> new_m.load_state_dict(m_state_dict)
    <All keys matched successfully>


.. _serialized-file-format:

Serialized file format for ``torch.save``
-----------------------------------------

Since PyTorch 1.6.0, ``torch.save`` defaults to returning an uncompressed ZIP64
archive unless the user sets ``_use_new_zipfile_serialization=False``.

In this archive, the files are ordered as such

.. code-block:: text

    checkpoint.pth
    ├── data.pkl
    ├── byteorder  # added in PyTorch 2.1.0
    ├── data/
    │   ├── 0
    │   ├── 1
    │   ├── 2
    │   └── …
    └── version

The entries are as follows:
  * ``data.pkl`` is the result of pickling the object passed to ``torch.save``
    excluding ``torch.Storage`` objects that it contains
  * ``byteorder`` contains a string with the ``sys.byteorder`` when saving (“little” or “big”)
  * ``data/`` contains all the storages in the object, where each storage is a separate file
  * ``version`` contains a version number at save time that can be used at load time

When saving, PyTorch will ensure that the local file header of each file is padded
to an offset that is a multiple of 64 bytes, ensuring that the offset of each file
is 64-byte aligned.

.. note::
    Tensors on certain devices such as XLA are serialized as pickled numpy arrays. As
    such, their storages are not serialized. In these cases ``data/`` might not exist
    in the checkpoint.

.. _layout-control:

Layout Control
--------------

The ``mmap`` argument in :func:`torch.load` allows for lazy loading of tensor storages.

In addition, there are some advanced features that allow for more fine-grained
control and manipulation of a ``torch.save`` checkpoint.

The :class:`torch.serialization.skip_data` context manager enables
  * Saving a checkpoint with ``torch.save`` that includes empty space for data bytes
    to be written later.
  * Loading a checkpoint with ``torch.load`` and filling in the data bytes of tensors later.

To inspect tensor metadata in a ``torch.save`` checkpoint without allocating memory for storage
data, use ``torch.load`` within the ``FakeTensorMode`` context manager. On top of skipping loading
storage data similar to ``skip_data`` above, it additionally tags storages with their offset within
the checkpoint, enabling direct checkpoint manipulation.

.. code-block:: python

  import torch.nn as nn
  from torch._subclasses.fake_tensor import FakeTensorMode

  m = nn.Linear(10, 10)
  torch.save(m.state_dict(), "checkpoint.pt")

  with FakeTensorMode() as mode:
      fake_sd = torch.load("checkpoint.pt")

  for k, v in fake_sd.items():
      print(f"key={k}, dtype={v.dtype}, shape={v.shape}, stride={v.stride()}, storage_offset={v.storage_offset()}")
      # offset of the storage in the checkpoint
      print(f"key={k}, checkpoint_offset={v.untyped_storage()._checkpoint_offset}")

For more information, `this tutorial <https://docs.pytorch.org/tutorials/prototype/gpu_direct_storage.html>`_
offers a comprehensive example of using these features to manipulate a checkpoint.


.. _weights-only:

``torch.load`` with ``weights_only=True``
-----------------------------------------

Starting in version 2.6, ``torch.load`` will use ``weights_only=True`` if the ``pickle_module``
argument is not passed.

As discussed in the documentation for :func:`torch.load`, ``weights_only=True`` restricts
the unpickler used in ``torch.load`` to only executing functions/building classes required for
``state_dicts`` of plain ``torch.Tensors`` as well as some other primitive types. Further,
unlike the default ``Unpickler`` provided by the ``pickle`` module, the ``weights_only`` Unpickler
is not allowed to dynamically import anything during unpickling.

As mentioned above, saving a module's ``state_dict`` is a best practice when using ``torch.save``. If loading an old
checkpoint that contains an ``nn.Module``, we recommend ``weights_only=False``. When loading a checkpoint that contains
tensor subclasses, there will likely be functions/classes that need to be allowlisted, see below for further details.

If the ``weights_only`` Unpickler encounters a function or class that is not allowlisted
by default within the pickle file, you should see an actionable error like such

.. code::

    _pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
    to do so you have two options, do those steps only if you trust the source of the checkpoint.
        1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
            but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
        2. Alternatively, to load with `weights_only=True` please check the recommended
           steps in the following error message.
           WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
           default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
           `torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
           if you trust this class/function.

Please follow the steps in the error message and allowlist the functions or classes only if you trust them.

To get all GLOBALs (functions/classes) in the checkpoint that are not yet allowlisted you can use
:func:`torch.serialization.get_unsafe_globals_in_checkpoint` which will return a list of strings of the form
``{__module__}.{__name__}``. If you trust these functions/classes, you can import them and allowlist them per
the error message either via :func:`torch.serialization.add_safe_globals` or the context manager
:class:`torch.serialization.safe_globals`.

To access the list of user-allowlisted functions/classes you can use :func:`torch.serialization.get_safe_globals` and
to clear the current list see :func:`torch.serialization.clear_safe_globals`.

Troubleshooting ``weights_only``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Getting unsafe globals
""""""""""""""""""""""

A caveat is that :func:`torch.serialization.get_unsafe_globals_in_checkpoint` analyzes the checkpoint statically,
some types might be built dynamically during the unpickling process and hence will not be reported by
:func:`torch.serialization.get_unsafe_globals_in_checkpoint`. One such example is ``dtypes`` in numpy. In
``numpy < 1.25`` after allowlisting all the functions/classes reported by
:func:`torch.serialization.get_unsafe_globals_in_checkpoint` you might see an error like

.. code::

    WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
    but got <class 'numpy.dtype[float32]'>

This can be allowlisted via ``{add_}safe_globals([type(np.dtype(np.float32))])``.

In ``numpy >=1.25`` you would see

.. code::

    WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
    but got <class 'numpy.dtypes.Float32DType'>

This can be allowlisted via ``{add_}safe_globals([np.dtypes.Float32DType])``.

Environment Variables
"""""""""""""""""""""

There are two environment variables that will influence the behavior of ``torch.load``. These can be helpful
if one does not have access to the ``torch.load`` callsites.

* ``TORCH_FORCE_WEIGHTS_ONLY_LOAD=1`` will override all ``torch.load`` callsites to use ``weights_only=True``.
* ``TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1`` will make ``torch.load`` callsites use ``weights_only=False`` **only**
  if ``weights_only`` was not passed as an argument.


.. _utility functions:

Utility functions
-----------------

The following utility functions are related to serialization:

.. currentmodule:: torch.serialization

.. autofunction:: register_package
.. autofunction:: get_crc32_options
.. autofunction:: set_crc32_options
.. autofunction:: get_default_load_endianness
.. autofunction:: set_default_load_endianness
.. autofunction:: get_default_mmap_options
.. autofunction:: set_default_mmap_options
.. autofunction:: add_safe_globals
.. autofunction:: clear_safe_globals
.. autofunction:: get_safe_globals
.. autofunction:: get_unsafe_globals_in_checkpoint
.. autoclass:: safe_globals
.. autoclass:: skip_data

.. _serialization config:

Config
------
.. py:module:: torch.utils.serialization
.. py:module:: torch.utils.serialization.config

``torch.utils.serialization.config`` provides a global config that can control the behavior of
``torch.save`` and ``torch.load``.


``torch.utils.serialization.config.save`` contains options that control the behavior of ``torch.save``.

  * ``compute_crc32``: whether to compute and write the zip file checksum (Default : ``True``).
    See :func:`~torch.serialization.set_crc32_options`.
  * ``use_pinned_memory_for_d2h``: for storages that are on an accelerator when passed to ``torch.save``, whether to
    move storage to pinned memory or pageable memory on CPU within ``torch.save``. (Default: ``False`` (i.e. pageable))
  * ``storage_alignment``: alignment of storages in the checkpoint during ``torch.save`` in bytes. (Default ``64``)

``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``.

  * ``mmap``: See the documentation for ``mmap`` argument in :func:`torch.load`.
    This config will set the behavior of ``mmap`` for ``torch.load`` if it is not
    already explicitly passed to the ``torch.load`` call (Default : ``False``).
  * ``endianness``: See :func:`~torch.serialization.set_default_load_endianness`.
    (Default : ``torch.serialization.LoadEndianness.NATIVE``)
  * ``mmap_flags``: See :class:`~torch.serialization.set_default_mmap_options`.
    (Default : ``MAP_PRIVATE``)
  * ``calculate_storage_offsets``: If this config is set to ``True``, offsets for storages will be
    calculated rather than read via random reads when using ``torch.load(mmap=True)``. This minimizes
    random reads, which can be helpful when the file is being loaded over a network. (Default : ``False``)