• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1
2Serialization semantics
3=======================
4
5This note describes how you can save and load PyTorch tensors and module states
6in Python, and how to serialize Python modules so they can be loaded in C++.
7
8.. contents:: Table of Contents
9
10.. _saving-loading-tensors:
11
12Saving and loading tensors
13--------------------------
14
15:func:`torch.save` and :func:`torch.load` let you easily save and load tensors:
16
17::
18
19    >>> t = torch.tensor([1., 2.])
20    >>> torch.save(t, 'tensor.pt')
21    >>> torch.load('tensor.pt')
22    tensor([1., 2.])
23
24By convention, PyTorch files are typically written with a ‘.pt’ or ‘.pth’ extension.
25
26:func:`torch.save` and :func:`torch.load` use Python’s pickle by default,
27so you can also save multiple tensors as part of Python objects like tuples,
28lists, and dicts:
29
30::
31
32    >>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
33    >>> torch.save(d, 'tensor_dict.pt')
34    >>> torch.load('tensor_dict.pt')
35    {'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
36
37Custom data structures that include PyTorch tensors can also be saved if the
38data structure is pickle-able.
39
40.. _preserve-storage-sharing:
41
42Saving and loading tensors preserves views
43---------------------------------------------
44
45Saving tensors preserves their view relationships:
46
47::
48
49    >>> numbers = torch.arange(1, 10)
50    >>> evens = numbers[1::2]
51    >>> torch.save([numbers, evens], 'tensors.pt')
52    >>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
53    >>> loaded_evens *= 2
54    >>> loaded_numbers
55    tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])
56
57Behind the scenes, these tensors share the same "storage." See
58`Tensor Views <https://pytorch.org/docs/main/tensor_view.html>`_ for more
59on views and storage.
60
61When PyTorch saves tensors it saves their storage objects and tensor
62metadata separately. This is an implementation detail that may change in the
63future, but it typically saves space and lets PyTorch easily
64reconstruct the view relationships between the loaded tensors. In the above
65snippet, for example, only a single storage is written to 'tensors.pt'.
66
67In some cases, however, saving the current storage objects may be unnecessary
68and create prohibitively large files. In the following snippet a storage much
69larger than the saved tensor is written to a file:
70
71::
72
73    >>> large = torch.arange(1, 1000)
74    >>> small = large[0:5]
75    >>> torch.save(small, 'small.pt')
76    >>> loaded_small = torch.load('small.pt')
77    >>> loaded_small.storage().size()
78    999
79
80Instead of saving only the five values in the `small` tensor to 'small.pt,'
81the 999 values in the storage it shares with `large` were saved and loaded.
82
83When saving tensors with fewer elements than their storage objects, the size of
84the saved file can be reduced by first cloning the tensors. Cloning a tensor
85produces a new tensor with a new storage object containing only the values
86in the tensor:
87
88::
89
90    >>> large = torch.arange(1, 1000)
91    >>> small = large[0:5]
92    >>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
93    >>> loaded_small = torch.load('small.pt')
94    >>> loaded_small.storage().size()
95    5
96
97Since the cloned tensors are independent of each other, however, they have
98none of the view relationships the original tensors did. If both file size and
99view relationships are important when saving tensors smaller than their
100storage objects, then care must be taken to construct new tensors that minimize
101the size of their storage objects but still have the desired view relationships
102before saving.
103
104.. _saving-loading-python-modules:
105
106Saving and loading torch.nn.Modules
107-----------------------------------
108
109See also: `Tutorial: Saving and loading modules <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`_
110
111In PyTorch, a module’s state is frequently serialized using a ‘state dict.’
112A module’s state dict contains all of its parameters and persistent buffers:
113
114::
115
116    >>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
117    >>> list(bn.named_parameters())
118    [('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
119     ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
120
121    >>> list(bn.named_buffers())
122    [('running_mean', tensor([0., 0., 0.])),
123     ('running_var', tensor([1., 1., 1.])),
124     ('num_batches_tracked', tensor(0))]
125
126    >>> bn.state_dict()
127    OrderedDict([('weight', tensor([1., 1., 1.])),
128                 ('bias', tensor([0., 0., 0.])),
129                 ('running_mean', tensor([0., 0., 0.])),
130                 ('running_var', tensor([1., 1., 1.])),
131                 ('num_batches_tracked', tensor(0))])
132
133Instead of saving a module directly, for compatibility reasons it is recommended
134to instead save only its state dict. Python modules even have a function,
135:meth:`~torch.nn.Module.load_state_dict`, to restore their states from a state dict:
136
137::
138
139    >>> torch.save(bn.state_dict(), 'bn.pt')
140    >>> bn_state_dict = torch.load('bn.pt')
141    >>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
142    >>> new_bn.load_state_dict(bn_state_dict)
143    <All keys matched successfully>
144
145Note that the state dict is first loaded from its file with :func:`torch.load`
146and the state then restored with :meth:`~torch.nn.Module.load_state_dict`.
147
148Even custom modules and modules containing other modules have state dicts and
149can use this pattern:
150
151::
152
153    # A module with two linear layers
154    >>> class MyModule(torch.nn.Module):
155          def __init__(self):
156            super().__init__()
157            self.l0 = torch.nn.Linear(4, 2)
158            self.l1 = torch.nn.Linear(2, 1)
159
160          def forward(self, input):
161            out0 = self.l0(input)
162            out0_relu = torch.nn.functional.relu(out0)
163            return self.l1(out0_relu)
164
165    >>> m = MyModule()
166    >>> m.state_dict()
167    OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
168                                       [-0.3289, 0.2827, 0.4588, 0.2031]])),
169                 ('l0.bias', tensor([ 0.0300, -0.1316])),
170                 ('l1.weight', tensor([[0.6533, 0.3413]])),
171                 ('l1.bias', tensor([-0.1112]))])
172
173    >>> torch.save(m.state_dict(), 'mymodule.pt')
174    >>> m_state_dict = torch.load('mymodule.pt')
175    >>> new_m = MyModule()
176    >>> new_m.load_state_dict(m_state_dict)
177    <All keys matched successfully>
178
179.. _serialized-file-format:
180
181Serialized file format for ``torch.save``
182-----------------------------------------
183
184Since PyTorch 1.6.0, ``torch.save`` defaults to returning an uncompressed ZIP64
185archive unless the user sets ``_use_new_zipfile_serialization=False``.
186
187In this archive, the files are ordered as such
188
189.. code-block:: text
190
191    checkpoint.pth
192    ├── data.pkl
193    ├── byteorder  # added in PyTorch 2.1.0
194    ├── data/
195    │   ├── 0
196    │   ├── 1
197    │   ├── 2
198    │   └── …
199    └── version
200
201The entries are as follows:
202  * ``data.pkl`` is the result of pickling the object passed to ``torch.save``
203    excluding ``torch.Storage`` objects that it contains
204  * ``byteorder`` contains a string with the ``sys.byteorder`` when saving (“little” or “big”)
205  * ``data/`` contains all the storages in the object, where each storage is a separate file
206  * ``version`` contains a version number at save time that can be used at load time
207
208When saving, PyTorch will ensure that the local file header of each file is padded
209to an offset that is a multiple of 64 bytes, ensuring that the offset of each file
210is 64-byte aligned.
211
212.. note::
213    Tensors on certain devices such as XLA are serialized as pickled numpy arrays. As
214    such, their storages are not serialized. In these cases ``data/`` might not exist
215    in the checkpoint.
216
217.. _serializing-python-modules:
218
219Serializing torch.nn.Modules and loading them in C++
220----------------------------------------------------
221
222See also: `Tutorial: Loading a TorchScript Model in C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`_
223
224ScriptModules can be serialized as a TorchScript program and loaded
225using :func:`torch.jit.load`.
226This serialization encodes all the modules’ methods, submodules, parameters,
227and attributes, and it allows the serialized program to be loaded in C++
228(i.e. without Python).
229
230The distinction between :func:`torch.jit.save` and :func:`torch.save` may not
231be immediately clear. :func:`torch.save` saves Python objects with pickle.
232This is especially useful for prototyping, researching, and training.
233:func:`torch.jit.save`, on the other hand, serializes ScriptModules to a format
234that can be loaded in Python or C++. This is useful when saving and loading C++
235modules or for running modules trained in Python with C++, a common practice
236when deploying PyTorch models.
237
238To script, serialize and load a module in Python:
239
240::
241
242    >>> scripted_module = torch.jit.script(MyModule())
243    >>> torch.jit.save(scripted_module, 'mymodule.pt')
244    >>> torch.jit.load('mymodule.pt')
245    RecursiveScriptModule( original_name=MyModule
246                          (l0): RecursiveScriptModule(original_name=Linear)
247                          (l1): RecursiveScriptModule(original_name=Linear) )
248
249
250Traced modules can also be saved with :func:`torch.jit.save`, with the caveat
251that only the traced code path is serialized. The following example demonstrates
252this:
253
254::
255
256    # A module with control flow
257    >>> class ControlFlowModule(torch.nn.Module):
258          def __init__(self):
259            super().__init__()
260            self.l0 = torch.nn.Linear(4, 2)
261            self.l1 = torch.nn.Linear(2, 1)
262
263          def forward(self, input):
264            if input.dim() > 1:
265                return torch.tensor(0)
266
267            out0 = self.l0(input)
268            out0_relu = torch.nn.functional.relu(out0)
269            return self.l1(out0_relu)
270
271    >>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
272    >>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
273    >>> loaded = torch.jit.load('controlflowmodule_traced.pt')
274    >>> loaded(torch.randn(2, 4)))
275    tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)
276
277    >>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
278    >>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
279    >>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
280    >> loaded(torch.randn(2, 4))
281    tensor(0)
282
283The above module has an if statement that is not triggered by the traced inputs,
284and so is not part of the traced module and not serialized with it.
285The scripted module, however, contains the if statement and is serialized with it.
286See the `TorchScript documentation <https://pytorch.org/docs/stable/jit.html>`_
287for more on scripting and tracing.
288
289Finally, to load the module in C++:
290
291::
292
293    >>> torch::jit::script::Module module;
294    >>> module = torch::jit::load('controlflowmodule_scripted.pt');
295
296See the `PyTorch C++ API documentation <https://pytorch.org/cppdocs/>`_
297for details about how to use PyTorch modules in C++.
298
299.. _saving-loading-across-versions:
300
301Saving and loading ScriptModules across PyTorch versions
302-----------------------------------------------------------
303
304The PyTorch Team recommends saving and loading modules with the same version of
305PyTorch. Older versions of PyTorch may not support newer modules, and newer
306versions may have removed or modified older behavior. These changes are
307explicitly described in
308PyTorch’s `release notes <https://github.com/pytorch/pytorch/releases>`_,
309and modules relying on functionality that has changed may need to be updated
310to continue working properly. In limited cases, detailed below, PyTorch will
311preserve the historic behavior of serialized ScriptModules so they do not require
312an update.
313
314torch.div performing integer division
315^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
316
317In PyTorch 1.5 and earlier :func:`torch.div` would perform floor division when
318given two integer inputs:
319
320::
321
322    # PyTorch 1.5 (and earlier)
323    >>> a = torch.tensor(5)
324    >>> b = torch.tensor(3)
325    >>> a / b
326    tensor(1)
327
328In PyTorch 1.7, however, :func:`torch.div` will always perform a true division
329of its inputs, just like division in Python 3:
330
331::
332
333    # PyTorch 1.7
334    >>> a = torch.tensor(5)
335    >>> b = torch.tensor(3)
336    >>> a / b
337    tensor(1.6667)
338
339The behavior of :func:`torch.div` is preserved in serialized ScriptModules.
340That is, ScriptModules serialized with versions of PyTorch before 1.6 will continue
341to see :func:`torch.div` perform floor division when given two integer inputs
342even when loaded with newer versions of PyTorch. ScriptModules using :func:`torch.div`
343and serialized on PyTorch 1.6 and later cannot be loaded in earlier versions of
344PyTorch, however, since those earlier versions do not understand the new behavior.
345
346torch.full always inferring a float dtype
347^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
348
349In PyTorch 1.5 and earlier :func:`torch.full` always returned a float tensor,
350regardless of the fill value it’s given:
351
352::
353
354    # PyTorch 1.5 and earlier
355    >>> torch.full((3,), 1)  # Note the integer fill value...
356    tensor([1., 1., 1.])     # ...but float tensor!
357
358In PyTorch 1.7, however, :func:`torch.full` will infer the returned tensor’s
359dtype from the fill value:
360
361::
362
363    # PyTorch 1.7
364    >>> torch.full((3,), 1)
365    tensor([1, 1, 1])
366
367    >>> torch.full((3,), True)
368    tensor([True, True, True])
369
370    >>> torch.full((3,), 1.)
371    tensor([1., 1., 1.])
372
373    >>> torch.full((3,), 1 + 1j)
374    tensor([1.+1.j, 1.+1.j, 1.+1.j])
375
376The behavior of :func:`torch.full` is preserved in serialized ScriptModules. That is,
377ScriptModules serialized with versions of PyTorch before 1.6 will continue to see
378torch.full return float tensors by default, even when given bool or
379integer fill values. ScriptModules using :func:`torch.full` and serialized on PyTorch 1.6
380and later cannot be loaded in earlier versions of PyTorch, however, since those
381earlier versions do not understand the new behavior.
382
383.. _utility functions:
384
385Utility functions
386-----------------
387
388The following utility functions are related to serialization:
389
390.. currentmodule:: torch.serialization
391
392.. autofunction:: register_package
393.. autofunction:: get_default_load_endianness
394.. autofunction:: set_default_load_endianness
395.. autofunction:: get_default_mmap_options
396.. autofunction:: set_default_mmap_options
397.. autofunction:: add_safe_globals
398.. autofunction:: clear_safe_globals
399.. autofunction:: get_safe_globals
400.. autoclass:: safe_globals
401.. autoclass:: skip_data
402