1 2Serialization semantics 3======================= 4 5This note describes how you can save and load PyTorch tensors and module states 6in Python, and how to serialize Python modules so they can be loaded in C++. 7 8.. contents:: Table of Contents 9 10.. _saving-loading-tensors: 11 12Saving and loading tensors 13-------------------------- 14 15:func:`torch.save` and :func:`torch.load` let you easily save and load tensors: 16 17:: 18 19 >>> t = torch.tensor([1., 2.]) 20 >>> torch.save(t, 'tensor.pt') 21 >>> torch.load('tensor.pt') 22 tensor([1., 2.]) 23 24By convention, PyTorch files are typically written with a ‘.pt’ or ‘.pth’ extension. 25 26:func:`torch.save` and :func:`torch.load` use Python’s pickle by default, 27so you can also save multiple tensors as part of Python objects like tuples, 28lists, and dicts: 29 30:: 31 32 >>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])} 33 >>> torch.save(d, 'tensor_dict.pt') 34 >>> torch.load('tensor_dict.pt') 35 {'a': tensor([1., 2.]), 'b': tensor([3., 4.])} 36 37Custom data structures that include PyTorch tensors can also be saved if the 38data structure is pickle-able. 39 40.. _preserve-storage-sharing: 41 42Saving and loading tensors preserves views 43--------------------------------------------- 44 45Saving tensors preserves their view relationships: 46 47:: 48 49 >>> numbers = torch.arange(1, 10) 50 >>> evens = numbers[1::2] 51 >>> torch.save([numbers, evens], 'tensors.pt') 52 >>> loaded_numbers, loaded_evens = torch.load('tensors.pt') 53 >>> loaded_evens *= 2 54 >>> loaded_numbers 55 tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9]) 56 57Behind the scenes, these tensors share the same "storage." See 58`Tensor Views <https://pytorch.org/docs/main/tensor_view.html>`_ for more 59on views and storage. 60 61When PyTorch saves tensors it saves their storage objects and tensor 62metadata separately. This is an implementation detail that may change in the 63future, but it typically saves space and lets PyTorch easily 64reconstruct the view relationships between the loaded tensors. In the above 65snippet, for example, only a single storage is written to 'tensors.pt'. 66 67In some cases, however, saving the current storage objects may be unnecessary 68and create prohibitively large files. In the following snippet a storage much 69larger than the saved tensor is written to a file: 70 71:: 72 73 >>> large = torch.arange(1, 1000) 74 >>> small = large[0:5] 75 >>> torch.save(small, 'small.pt') 76 >>> loaded_small = torch.load('small.pt') 77 >>> loaded_small.storage().size() 78 999 79 80Instead of saving only the five values in the `small` tensor to 'small.pt,' 81the 999 values in the storage it shares with `large` were saved and loaded. 82 83When saving tensors with fewer elements than their storage objects, the size of 84the saved file can be reduced by first cloning the tensors. Cloning a tensor 85produces a new tensor with a new storage object containing only the values 86in the tensor: 87 88:: 89 90 >>> large = torch.arange(1, 1000) 91 >>> small = large[0:5] 92 >>> torch.save(small.clone(), 'small.pt') # saves a clone of small 93 >>> loaded_small = torch.load('small.pt') 94 >>> loaded_small.storage().size() 95 5 96 97Since the cloned tensors are independent of each other, however, they have 98none of the view relationships the original tensors did. If both file size and 99view relationships are important when saving tensors smaller than their 100storage objects, then care must be taken to construct new tensors that minimize 101the size of their storage objects but still have the desired view relationships 102before saving. 103 104.. _saving-loading-python-modules: 105 106Saving and loading torch.nn.Modules 107----------------------------------- 108 109See also: `Tutorial: Saving and loading modules <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`_ 110 111In PyTorch, a module’s state is frequently serialized using a ‘state dict.’ 112A module’s state dict contains all of its parameters and persistent buffers: 113 114:: 115 116 >>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True) 117 >>> list(bn.named_parameters()) 118 [('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)), 119 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))] 120 121 >>> list(bn.named_buffers()) 122 [('running_mean', tensor([0., 0., 0.])), 123 ('running_var', tensor([1., 1., 1.])), 124 ('num_batches_tracked', tensor(0))] 125 126 >>> bn.state_dict() 127 OrderedDict([('weight', tensor([1., 1., 1.])), 128 ('bias', tensor([0., 0., 0.])), 129 ('running_mean', tensor([0., 0., 0.])), 130 ('running_var', tensor([1., 1., 1.])), 131 ('num_batches_tracked', tensor(0))]) 132 133Instead of saving a module directly, for compatibility reasons it is recommended 134to instead save only its state dict. Python modules even have a function, 135:meth:`~torch.nn.Module.load_state_dict`, to restore their states from a state dict: 136 137:: 138 139 >>> torch.save(bn.state_dict(), 'bn.pt') 140 >>> bn_state_dict = torch.load('bn.pt') 141 >>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True) 142 >>> new_bn.load_state_dict(bn_state_dict) 143 <All keys matched successfully> 144 145Note that the state dict is first loaded from its file with :func:`torch.load` 146and the state then restored with :meth:`~torch.nn.Module.load_state_dict`. 147 148Even custom modules and modules containing other modules have state dicts and 149can use this pattern: 150 151:: 152 153 # A module with two linear layers 154 >>> class MyModule(torch.nn.Module): 155 def __init__(self): 156 super().__init__() 157 self.l0 = torch.nn.Linear(4, 2) 158 self.l1 = torch.nn.Linear(2, 1) 159 160 def forward(self, input): 161 out0 = self.l0(input) 162 out0_relu = torch.nn.functional.relu(out0) 163 return self.l1(out0_relu) 164 165 >>> m = MyModule() 166 >>> m.state_dict() 167 OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406], 168 [-0.3289, 0.2827, 0.4588, 0.2031]])), 169 ('l0.bias', tensor([ 0.0300, -0.1316])), 170 ('l1.weight', tensor([[0.6533, 0.3413]])), 171 ('l1.bias', tensor([-0.1112]))]) 172 173 >>> torch.save(m.state_dict(), 'mymodule.pt') 174 >>> m_state_dict = torch.load('mymodule.pt') 175 >>> new_m = MyModule() 176 >>> new_m.load_state_dict(m_state_dict) 177 <All keys matched successfully> 178 179.. _serialized-file-format: 180 181Serialized file format for ``torch.save`` 182----------------------------------------- 183 184Since PyTorch 1.6.0, ``torch.save`` defaults to returning an uncompressed ZIP64 185archive unless the user sets ``_use_new_zipfile_serialization=False``. 186 187In this archive, the files are ordered as such 188 189.. code-block:: text 190 191 checkpoint.pth 192 ├── data.pkl 193 ├── byteorder # added in PyTorch 2.1.0 194 ├── data/ 195 │ ├── 0 196 │ ├── 1 197 │ ├── 2 198 │ └── … 199 └── version 200 201The entries are as follows: 202 * ``data.pkl`` is the result of pickling the object passed to ``torch.save`` 203 excluding ``torch.Storage`` objects that it contains 204 * ``byteorder`` contains a string with the ``sys.byteorder`` when saving (“little” or “big”) 205 * ``data/`` contains all the storages in the object, where each storage is a separate file 206 * ``version`` contains a version number at save time that can be used at load time 207 208When saving, PyTorch will ensure that the local file header of each file is padded 209to an offset that is a multiple of 64 bytes, ensuring that the offset of each file 210is 64-byte aligned. 211 212.. note:: 213 Tensors on certain devices such as XLA are serialized as pickled numpy arrays. As 214 such, their storages are not serialized. In these cases ``data/`` might not exist 215 in the checkpoint. 216 217.. _serializing-python-modules: 218 219Serializing torch.nn.Modules and loading them in C++ 220---------------------------------------------------- 221 222See also: `Tutorial: Loading a TorchScript Model in C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`_ 223 224ScriptModules can be serialized as a TorchScript program and loaded 225using :func:`torch.jit.load`. 226This serialization encodes all the modules’ methods, submodules, parameters, 227and attributes, and it allows the serialized program to be loaded in C++ 228(i.e. without Python). 229 230The distinction between :func:`torch.jit.save` and :func:`torch.save` may not 231be immediately clear. :func:`torch.save` saves Python objects with pickle. 232This is especially useful for prototyping, researching, and training. 233:func:`torch.jit.save`, on the other hand, serializes ScriptModules to a format 234that can be loaded in Python or C++. This is useful when saving and loading C++ 235modules or for running modules trained in Python with C++, a common practice 236when deploying PyTorch models. 237 238To script, serialize and load a module in Python: 239 240:: 241 242 >>> scripted_module = torch.jit.script(MyModule()) 243 >>> torch.jit.save(scripted_module, 'mymodule.pt') 244 >>> torch.jit.load('mymodule.pt') 245 RecursiveScriptModule( original_name=MyModule 246 (l0): RecursiveScriptModule(original_name=Linear) 247 (l1): RecursiveScriptModule(original_name=Linear) ) 248 249 250Traced modules can also be saved with :func:`torch.jit.save`, with the caveat 251that only the traced code path is serialized. The following example demonstrates 252this: 253 254:: 255 256 # A module with control flow 257 >>> class ControlFlowModule(torch.nn.Module): 258 def __init__(self): 259 super().__init__() 260 self.l0 = torch.nn.Linear(4, 2) 261 self.l1 = torch.nn.Linear(2, 1) 262 263 def forward(self, input): 264 if input.dim() > 1: 265 return torch.tensor(0) 266 267 out0 = self.l0(input) 268 out0_relu = torch.nn.functional.relu(out0) 269 return self.l1(out0_relu) 270 271 >>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4)) 272 >>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt') 273 >>> loaded = torch.jit.load('controlflowmodule_traced.pt') 274 >>> loaded(torch.randn(2, 4))) 275 tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>) 276 277 >>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4)) 278 >>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt') 279 >>> loaded = torch.jit.load('controlflowmodule_scripted.pt') 280 >> loaded(torch.randn(2, 4)) 281 tensor(0) 282 283The above module has an if statement that is not triggered by the traced inputs, 284and so is not part of the traced module and not serialized with it. 285The scripted module, however, contains the if statement and is serialized with it. 286See the `TorchScript documentation <https://pytorch.org/docs/stable/jit.html>`_ 287for more on scripting and tracing. 288 289Finally, to load the module in C++: 290 291:: 292 293 >>> torch::jit::script::Module module; 294 >>> module = torch::jit::load('controlflowmodule_scripted.pt'); 295 296See the `PyTorch C++ API documentation <https://pytorch.org/cppdocs/>`_ 297for details about how to use PyTorch modules in C++. 298 299.. _saving-loading-across-versions: 300 301Saving and loading ScriptModules across PyTorch versions 302----------------------------------------------------------- 303 304The PyTorch Team recommends saving and loading modules with the same version of 305PyTorch. Older versions of PyTorch may not support newer modules, and newer 306versions may have removed or modified older behavior. These changes are 307explicitly described in 308PyTorch’s `release notes <https://github.com/pytorch/pytorch/releases>`_, 309and modules relying on functionality that has changed may need to be updated 310to continue working properly. In limited cases, detailed below, PyTorch will 311preserve the historic behavior of serialized ScriptModules so they do not require 312an update. 313 314torch.div performing integer division 315^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 316 317In PyTorch 1.5 and earlier :func:`torch.div` would perform floor division when 318given two integer inputs: 319 320:: 321 322 # PyTorch 1.5 (and earlier) 323 >>> a = torch.tensor(5) 324 >>> b = torch.tensor(3) 325 >>> a / b 326 tensor(1) 327 328In PyTorch 1.7, however, :func:`torch.div` will always perform a true division 329of its inputs, just like division in Python 3: 330 331:: 332 333 # PyTorch 1.7 334 >>> a = torch.tensor(5) 335 >>> b = torch.tensor(3) 336 >>> a / b 337 tensor(1.6667) 338 339The behavior of :func:`torch.div` is preserved in serialized ScriptModules. 340That is, ScriptModules serialized with versions of PyTorch before 1.6 will continue 341to see :func:`torch.div` perform floor division when given two integer inputs 342even when loaded with newer versions of PyTorch. ScriptModules using :func:`torch.div` 343and serialized on PyTorch 1.6 and later cannot be loaded in earlier versions of 344PyTorch, however, since those earlier versions do not understand the new behavior. 345 346torch.full always inferring a float dtype 347^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 348 349In PyTorch 1.5 and earlier :func:`torch.full` always returned a float tensor, 350regardless of the fill value it’s given: 351 352:: 353 354 # PyTorch 1.5 and earlier 355 >>> torch.full((3,), 1) # Note the integer fill value... 356 tensor([1., 1., 1.]) # ...but float tensor! 357 358In PyTorch 1.7, however, :func:`torch.full` will infer the returned tensor’s 359dtype from the fill value: 360 361:: 362 363 # PyTorch 1.7 364 >>> torch.full((3,), 1) 365 tensor([1, 1, 1]) 366 367 >>> torch.full((3,), True) 368 tensor([True, True, True]) 369 370 >>> torch.full((3,), 1.) 371 tensor([1., 1., 1.]) 372 373 >>> torch.full((3,), 1 + 1j) 374 tensor([1.+1.j, 1.+1.j, 1.+1.j]) 375 376The behavior of :func:`torch.full` is preserved in serialized ScriptModules. That is, 377ScriptModules serialized with versions of PyTorch before 1.6 will continue to see 378torch.full return float tensors by default, even when given bool or 379integer fill values. ScriptModules using :func:`torch.full` and serialized on PyTorch 1.6 380and later cannot be loaded in earlier versions of PyTorch, however, since those 381earlier versions do not understand the new behavior. 382 383.. _utility functions: 384 385Utility functions 386----------------- 387 388The following utility functions are related to serialization: 389 390.. currentmodule:: torch.serialization 391 392.. autofunction:: register_package 393.. autofunction:: get_default_load_endianness 394.. autofunction:: set_default_load_endianness 395.. autofunction:: get_default_mmap_options 396.. autofunction:: set_default_mmap_options 397.. autofunction:: add_safe_globals 398.. autofunction:: clear_safe_globals 399.. autofunction:: get_safe_globals 400.. autoclass:: safe_globals 401.. autoclass:: skip_data 402