Lines Matching full:during
249 # to be filled out during the backward.
286 # present at this time during forward. Restore the surrounding state
357 gradient computation during backward, forward computation in checkpointed
358 regions omits saving tensors for backward and recomputes them during the
369 If the :attr:`function` invocation during the backward pass differs
391 entirety during the backward pass.
393 * The reentrant variant does not record the autograd graph during the
430 the RNG state during each checkpoint. Note that under torch.compile,
456 a trace of the operators ran during the original forward computation
531 the RNG state during each checkpoint.
619 # Why is this implied? To unpack a saved tensor X during backward we need to
628 # During unpack calling ctx.saved_tensor triggers the parent checkpoint to
632 # active. Checkpoints encountered during recomputation are still
670 # During recomputation, raise an exception if the number of recomputed tensors
691 # During recomputation the "inner pack hook" has two responsibilities:
702 # - During the original forward IF early-stop is disabled
703 # - During the original backward
709 # The example below shows what happens if during recomputation we find that some
730 # 5. Calling backward triggers another recompute of fn. During recompute, we see
747 not need to be active during backward.
849 # during original forward matches tensors saved during recompute
852 # 1. During recompute, more tensors were saved.
860 # 2. During recompute, fewer tensors were saved
864 # during recompute we increment recompute_counter.
867 "during the original forward and recomputation.\n"
868 f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
869 f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}"
872 # 3. During recompute, the same tensors were saved, but they
881 # alive when we saw it during recompute, therefore, the
905 "have different metadata than during the forward pass.\n"
921 | 2. Stack traces of the operators that ran during recomputation |
935 Operations executed during the original forward:
939 Operations executed during recomputation:
947 tensors to be saved during the original forward and differ between those saved
948 during recomputation. This can happen if different operators were ran in the
953 1) Compare the operators ran during original forward and recomputation to
964 2. Stack traces of the operators that ran during recomputation
977 # during unpack.
1066 # we check if the number of tensors saved during forward and
1071 "torch.utils.checkpoint: trying to save more tensors during "
1072 "recomputation than during the original forward pass."
1093 # the graph created during recomputation could be backwarded.
1176 "Tensor cached during selective activation checkpoint has been mutated"
1202 Context passed to policy function during selective checkpointing.
1204 This class is used to pass relevant metadata to the policy function during
1206 of the policy function is during recomputation or not.
1228 Enum for specifying the policy for checkpointing during backpropagation.
1232 - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward
1233 pass and will not be recomputed during the backward pass
1234 - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the
1235 forward pass and will be recomputed during the backward pass
1261 # AC inserts different number of detach during forward and recompute.
1263 # AC's determinism check invokes additional metadata ops during forward.
1322 raise RuntimeError(f"{func} encountered during backward, but not found in storage")
1336 Helper to avoid recomputing certain ops during activation checkpointing.
1339 operations are recomputed during the backward pass.
1440 the RNG state during each checkpoint.
1453 a trace of the operators ran during the original forward computation
1502 # This will be called later during recomputation. This wrapping enables