1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""The removable handle for cell hook function.""" 16from __future__ import absolute_import 17import weakref 18from mindspore._c_expression import Tensor as Tensor_ 19 20 21class _TensorHookHandle: 22 r""" 23 A handle provides the ability to remote a tensor hook. 24 25 Note: 26 It is only supported in pynative mode and works when registering or removing hook function for tensor 27 28 Supported Platforms: 29 ``Ascend`` ``GPU`` ``CPU`` 30 """ 31 32 def __init__(self): 33 self.id = None 34 35 def remove(self): 36 """ 37 Remove the tensor hook function, which corresponds to this '_TensorHookHandle' object. 38 39 Args: 40 None. 41 42 Returns: 43 None. 44 45 Supported Platforms: 46 ``Ascend`` ``GPU`` ``CPU`` 47 48 Examples: 49 >>> import mindspore as ms 50 >>> from mindspore import Tensor 51 >>> ms.set_context(mode=ms.PYNATIVE_MODE) 52 >>> def hook_fn(grad): 53 ... return grad * 2 54 ... 55 >>> def hook_test(x, y): 56 ... z = x * y 57 ... handle = z.register_hook(hook_fn) 58 ... z = z * y 59 ... handle.remove() 60 ... return z 61 ... 62 >>> ms_grad = ms.grad(hook_test, grad_position=(0,1)) 63 >>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32)) 64 >>> print(output) 65 (Tensor(shape=[], dtype=Float32, value=4), Tensor(shape=[], dtype=Float32, value=4)) 66 """ 67 if self.id is not None: 68 Tensor_.remove_hook(self.id) 69 70 71class HookHandle: 72 r""" 73 It is the return object of forward pre hook function, forward hook function and backward hook function of Cell 74 object. It corresponds to the cell hook function and is used to remove the cell hook function by calling 'remove()'. 75 76 Note: 77 It is only supported in pynative mode and works when registering or removing hook function for Cell object. 78 79 Args: 80 hook_cell (Cell): The Cell object with hook function registered on. Default value: None. 81 hook_key (int): The key of cell hook function in dict. It is generated during cell hook function registration. 82 Default value: -1. 83 hook_type (str): The type of cell hook function: '_forward_pre_hook', '_forward_hook' or '_cell_backward_hook'. 84 Default value: "". 85 86 Supported Platforms: 87 ``Ascend`` ``GPU`` ``CPU`` 88 """ 89 def __init__(self, hook_cell=None, hook_key=-1, hook_type=""): 90 if hook_cell is not None: 91 self._hook_cell = weakref.ref(hook_cell) 92 else: 93 self._hook_cell = hook_cell 94 self._hook_key = hook_key 95 self._hook_type = hook_type 96 97 def __del__(self): 98 self._hook_cell = None 99 self._hook_key = None 100 self._hook_type = None 101 102 def remove(self): 103 """ 104 Remove the cell hook function, which corresponds to this 'HookHandle' object. 105 In order to prevent running failed when switching to graph mode, it is not recommended to call the `remove()` 106 function in the construct function of Cell object. 107 108 Args: 109 None. 110 111 Returns: 112 None. 113 114 Supported Platforms: 115 ``Ascend`` ``GPU`` ``CPU`` 116 117 Examples: 118 >>> import numpy as np 119 >>> import mindspore as ms 120 >>> import mindspore.nn as nn 121 >>> from mindspore import Tensor 122 >>> from mindspore.ops import GradOperation 123 >>> ms.set_context(mode=ms.PYNATIVE_MODE) 124 >>> def forward_pre_hook_fn(cell_id, inputs): 125 ... print("forward inputs: ", inputs) 126 ... 127 >>> class Net(nn.Cell): 128 ... def __init__(self): 129 ... super(Net, self).__init__() 130 ... self.mul = nn.MatMul() 131 ... self.handle = self.mul.register_forward_pre_hook(forward_pre_hook_fn) 132 ... 133 ... def construct(self, x, y): 134 ... x = x + x 135 ... x = self.mul(x, y) 136 ... return x 137 >>> grad = GradOperation(get_all=True) 138 >>> net = Net() 139 >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32))) 140 forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], 141 dtype=Float32, value= [ 1.00000000e+00])) 142 >>> net.handle.remove() 143 >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32))) 144 >>> print(output) 145 (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, 146 value= [ 2.00000000e+00])) 147 """ 148 if self._hook_cell is not None: 149 hook_cell = self._hook_cell() 150 if self._hook_type == "_forward_pre_hook" and self._hook_key in hook_cell._forward_pre_hook: 151 del hook_cell._forward_pre_hook[self._hook_key] 152 elif self._hook_type == "_forward_hook" and self._hook_key in hook_cell._forward_hook: 153 del hook_cell._forward_hook[self._hook_key] 154 elif self._hook_type == "_cell_backward_hook": 155 hook_cell._cell_backward_hook.remove_backward_hook(self._hook_key) 156