1Frequently Asked Questions 2========================== 3 4My model reports "cuda runtime error(2): out of memory" 5------------------------------------------------------- 6 7As the error message suggests, you have run out of memory on your 8GPU. Since we often deal with large amounts of data in PyTorch, 9small mistakes can rapidly cause your program to use up all of your 10GPU; fortunately, the fixes in these cases are often simple. 11Here are a few common things to check: 12 13**Don't accumulate history across your training loop.** 14By default, computations involving variables that require gradients 15will keep history. This means that you should avoid using such 16variables in computations which will live beyond your training loops, 17e.g., when tracking statistics. Instead, you should detach the variable 18or access its underlying data. 19 20Sometimes, it can be non-obvious when differentiable variables can 21occur. Consider the following training loop (abridged from `source 22<https://discuss.pytorch.org/t/high-memory-usage-while-training/162>`_): 23 24.. code-block:: python 25 26 total_loss = 0 27 for i in range(10000): 28 optimizer.zero_grad() 29 output = model(input) 30 loss = criterion(output) 31 loss.backward() 32 optimizer.step() 33 total_loss += loss 34 35Here, ``total_loss`` is accumulating history across your training loop, since 36``loss`` is a differentiable variable with autograd history. You can fix this by 37writing `total_loss += float(loss)` instead. 38 39Other instances of this problem: 40`1 <https://discuss.pytorch.org/t/resolved-gpu-out-of-memory-error-with-batch-size-1/3719>`_. 41 42**Don't hold onto tensors and variables you don't need.** 43If you assign a Tensor or Variable to a local, Python will not 44deallocate until the local goes out of scope. You can free 45this reference by using ``del x``. Similarly, if you assign 46a Tensor or Variable to a member variable of an object, it will 47not deallocate until the object goes out of scope. You will 48get the best memory usage if you don't hold onto temporaries 49you don't need. 50 51The scopes of locals can be larger than you expect. For example: 52 53.. code-block:: python 54 55 for i in range(5): 56 intermediate = f(input[i]) 57 result += g(intermediate) 58 output = h(result) 59 return output 60 61Here, ``intermediate`` remains live even while ``h`` is executing, 62because its scope extrudes past the end of the loop. To free it 63earlier, you should ``del intermediate`` when you are done with it. 64 65**Avoid running RNNs on sequences that are too large.** 66The amount of memory required to backpropagate through an RNN scales 67linearly with the length of the RNN input; thus, you will run out of memory 68if you try to feed an RNN a sequence that is too long. 69 70The technical term for this phenomenon is `backpropagation through time 71<https://en.wikipedia.org/wiki/Backpropagation_through_time>`_, 72and there are plenty of references for how to implement truncated 73BPTT, including in the `word language model <https://github.com/pytorch/examples/tree/master/word_language_model>`_ example; truncation is handled by the 74``repackage`` function as described in 75`this forum post <https://discuss.pytorch.org/t/help-clarifying-repackage-hidden-in-word-language-model/226>`_. 76 77**Don't use linear layers that are too large.** 78A linear layer ``nn.Linear(m, n)`` uses :math:`O(nm)` memory: that is to say, 79the memory requirements of the weights 80scales quadratically with the number of features. It is very easy 81to `blow through your memory <https://github.com/pytorch/pytorch/issues/958>`_ 82this way (and remember that you will need at least twice the size of the 83weights, since you also need to store the gradients.) 84 85**Consider checkpointing.** 86You can trade-off memory for compute by using `checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_. 87 88My GPU memory isn't freed properly 89---------------------------------- 90PyTorch uses a caching memory allocator to speed up memory allocations. As a 91result, the values shown in ``nvidia-smi`` usually don't reflect the true 92memory usage. See :ref:`cuda-memory-management` for more details about GPU 93memory management. 94 95If your GPU memory isn't freed even after Python quits, it is very likely that 96some Python subprocesses are still alive. You may find them via 97``ps -elf | grep python`` and manually kill them with ``kill -9 [pid]``. 98 99My out of memory exception handler can't allocate memory 100-------------------------------------------------------- 101You may have some code that tries to recover from out of memory errors. 102 103.. code-block:: python 104 105 try: 106 run_model(batch_size) 107 except RuntimeError: # Out of memory 108 for _ in range(batch_size): 109 run_model(1) 110 111But find that when you do run out of memory, your recovery code can't allocate 112either. That's because the python exception object holds a reference to the 113stack frame where the error was raised. Which prevents the original tensor 114objects from being freed. The solution is to move you OOM recovery code outside 115of the ``except`` clause. 116 117.. code-block:: python 118 119 oom = False 120 try: 121 run_model(batch_size) 122 except RuntimeError: # Out of memory 123 oom = True 124 125 if oom: 126 for _ in range(batch_size): 127 run_model(1) 128 129 130.. _dataloader-workers-random-seed: 131 132My data loader workers return identical random numbers 133------------------------------------------------------- 134You are likely using other libraries to generate random numbers in the dataset 135and worker subprocesses are started via ``fork``. See 136:class:`torch.utils.data.DataLoader`'s documentation for how to 137properly set up random seeds in workers with its :attr:`worker_init_fn` option. 138 139.. _pack-rnn-unpack-with-data-parallelism: 140 141My recurrent network doesn't work with data parallelism 142------------------------------------------------------- 143There is a subtlety in using the 144``pack sequence -> recurrent network -> unpack sequence`` pattern in a 145:class:`~torch.nn.Module` with :class:`~torch.nn.DataParallel` or 146:func:`~torch.nn.parallel.data_parallel`. Input to each the :meth:`forward` on 147each device will only be part of the entire input. Because the unpack operation 148:func:`torch.nn.utils.rnn.pad_packed_sequence` by default only pads up to the 149longest input it sees, i.e., the longest on that particular device, size 150mismatches will happen when results are gathered together. Therefore, you can 151instead take advantage of the :attr:`total_length` argument of 152:func:`~torch.nn.utils.rnn.pad_packed_sequence` to make sure that the 153:meth:`forward` calls return sequences of same length. For example, you can 154write:: 155 156 from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 157 158 class MyModule(nn.Module): 159 # ... __init__, other methods, etc. 160 161 # padded_input is of shape [B x T x *] (batch_first mode) and contains 162 # the sequences sorted by lengths 163 # B is the batch size 164 # T is max sequence length 165 def forward(self, padded_input, input_lengths): 166 total_length = padded_input.size(1) # get the max sequence length 167 packed_input = pack_padded_sequence(padded_input, input_lengths, 168 batch_first=True) 169 packed_output, _ = self.my_lstm(packed_input) 170 output, _ = pad_packed_sequence(packed_output, batch_first=True, 171 total_length=total_length) 172 return output 173 174 175 m = MyModule().cuda() 176 dp_m = nn.DataParallel(m) 177 178 179Additionally, extra care needs to be taken when batch dimension is dim ``1`` 180(i.e., ``batch_first=False``) with data parallelism. In this case, the first 181argument of pack_padded_sequence ``padding_input`` will be of shape 182``[T x B x *]`` and should be scattered along dim ``1``, but the second argument 183``input_lengths`` will be of shape ``[B]`` and should be scattered along dim 184``0``. Extra code to manipulate the tensor shapes will be needed. 185