1# mypy: allow-untyped-defs 2 3import copyreg 4import os.path as _osp 5import weakref 6 7import torch 8from torch.utils import ( 9 backcompat as backcompat, 10 collect_env as collect_env, 11 data as data, 12 deterministic as deterministic, 13 hooks as hooks, 14) 15from torch.utils.backend_registration import ( 16 generate_methods_for_privateuse1_backend, 17 rename_privateuse1_backend, 18) 19from torch.utils.cpp_backtrace import get_cpp_backtrace 20from torch.utils.throughput_benchmark import ThroughputBenchmark 21 22 23def set_module(obj, mod): 24 """ 25 Set the module attribute on a python object for a given object for nicer printing 26 """ 27 if not isinstance(mod, str): 28 raise TypeError("The mod argument should be a string") 29 obj.__module__ = mod 30 31 32if torch._running_with_deploy(): 33 # not valid inside torch_deploy interpreter, no paths exists for frozen modules 34 cmake_prefix_path = None 35else: 36 cmake_prefix_path = _osp.join( 37 _osp.dirname(_osp.dirname(__file__)), "share", "cmake" 38 ) 39 40 41def swap_tensors(t1, t2): 42 """ 43 This function swaps the content of the two Tensor objects. 44 At a high level, this will make t1 have the content of t2 while preserving 45 its identity. 46 47 This will not work if t1 and t2 have different slots. 48 """ 49 # Ensure there are no weakrefs 50 if weakref.getweakrefs(t1): 51 raise RuntimeError("Cannot swap t1 because it has weakref associated with it") 52 if weakref.getweakrefs(t2): 53 raise RuntimeError("Cannot swap t2 because it has weakref associated with it") 54 t1_slots = set(copyreg._slotnames(t1.__class__)) # type: ignore[attr-defined] 55 t2_slots = set(copyreg._slotnames(t2.__class__)) # type: ignore[attr-defined] 56 if t1_slots != t2_slots: 57 raise RuntimeError("Cannot swap t1 and t2 if they have different slots") 58 59 def swap_attr(name): 60 tmp = getattr(t1, name) 61 setattr(t1, name, (getattr(t2, name))) 62 setattr(t2, name, tmp) 63 64 def error_pre_hook(grad_outputs): 65 raise RuntimeError( 66 "Trying to execute AccumulateGrad node that was poisoned by swap_tensors " 67 "this can happen when you try to run backward on a tensor that was swapped. " 68 "For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` " 69 "you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) " 70 "between running forward and backward. To resolve this, please only change the " 71 "device/dtype before running forward (or after both forward and backward)." 72 ) 73 74 def check_use_count(t, name="t1"): 75 use_count = t._use_count() 76 error_str = ( 77 f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} " 78 f"make sure you are not holding references to the tensor in other places." 79 ) 80 if use_count > 1: 81 if use_count == 2 and t.is_leaf: 82 accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node 83 # Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge 84 if t._use_count() == 2: 85 accum_grad_node.register_prehook(error_pre_hook) 86 else: 87 raise RuntimeError(error_str) 88 else: 89 raise RuntimeError(error_str) 90 91 check_use_count(t1, "t1") 92 check_use_count(t2, "t2") 93 94 # Swap the types 95 # Note that this will fail if there are mismatched slots 96 swap_attr("__class__") 97 98 # Swap the dynamic attributes 99 swap_attr("__dict__") 100 101 # Swap the slots 102 for slot in t1_slots: 103 if hasattr(t1, slot) and hasattr(t2, slot): 104 swap_attr(slot) 105 elif hasattr(t1, slot): 106 setattr(t2, slot, (getattr(t1, slot))) 107 delattr(t1, slot) 108 elif hasattr(t2, slot): 109 setattr(t1, slot, (getattr(t2, slot))) 110 delattr(t2, slot) 111 112 # Swap the at::Tensor they point to 113 torch._C._swap_tensor_impl(t1, t2) 114