• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""The removable handle for cell hook function."""
16from __future__ import absolute_import
17import weakref
18from mindspore._c_expression import Tensor as Tensor_
19
20
21class _TensorHookHandle:
22    r"""
23    A handle provides the ability to remote a tensor hook.
24
25    Note:
26        It is only supported in pynative mode and works when registering or removing hook function for tensor
27
28    Supported Platforms:
29        ``Ascend`` ``GPU`` ``CPU``
30    """
31
32    def __init__(self):
33        self.id = None
34
35    def remove(self):
36        """
37        Remove the tensor hook function, which corresponds to this '_TensorHookHandle' object.
38
39        Args:
40            None.
41
42        Returns:
43            None.
44
45        Supported Platforms:
46        ``Ascend`` ``GPU`` ``CPU``
47
48        Examples:
49            >>> import mindspore as ms
50            >>> from mindspore import Tensor
51            >>> ms.set_context(mode=ms.PYNATIVE_MODE)
52            >>> def hook_fn(grad):
53            ...     return grad * 2
54            ...
55            >>> def hook_test(x, y):
56            ...     z = x * y
57            ...     handle = z.register_hook(hook_fn)
58            ...     z = z * y
59            ...     handle.remove()
60            ...     return z
61            ...
62            >>> ms_grad = ms.grad(hook_test, grad_position=(0,1))
63            >>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32))
64            >>> print(output)
65            (Tensor(shape=[], dtype=Float32, value=4), Tensor(shape=[], dtype=Float32, value=4))
66        """
67        if self.id is not None:
68            Tensor_.remove_hook(self.id)
69
70
71class HookHandle:
72    r"""
73    It is the return object of forward pre hook function, forward hook function and backward hook function of Cell
74    object. It corresponds to the cell hook function and is used to remove the cell hook function by calling 'remove()'.
75
76    Note:
77        It is only supported in pynative mode and works when registering or removing hook function for Cell object.
78
79    Args:
80        hook_cell (Cell): The Cell object with hook function registered on. Default value: None.
81        hook_key (int): The key of cell hook function in dict. It is generated during cell hook function registration.
82                        Default value: -1.
83        hook_type (str): The type of cell hook function: '_forward_pre_hook', '_forward_hook' or '_cell_backward_hook'.
84                         Default value: "".
85
86    Supported Platforms:
87        ``Ascend`` ``GPU`` ``CPU``
88    """
89    def __init__(self, hook_cell=None, hook_key=-1, hook_type=""):
90        if hook_cell is not None:
91            self._hook_cell = weakref.ref(hook_cell)
92        else:
93            self._hook_cell = hook_cell
94        self._hook_key = hook_key
95        self._hook_type = hook_type
96
97    def __del__(self):
98        self._hook_cell = None
99        self._hook_key = None
100        self._hook_type = None
101
102    def remove(self):
103        """
104        Remove the cell hook function, which corresponds to this 'HookHandle' object.
105        In order to prevent running failed when switching to graph mode, it is not recommended to call the `remove()`
106        function in the construct function of Cell object.
107
108        Args:
109            None.
110
111        Returns:
112            None.
113
114        Supported Platforms:
115        ``Ascend`` ``GPU`` ``CPU``
116
117        Examples:
118            >>> import numpy as np
119            >>> import mindspore as ms
120            >>> import mindspore.nn as nn
121            >>> from mindspore import Tensor
122            >>> from mindspore.ops import GradOperation
123            >>> ms.set_context(mode=ms.PYNATIVE_MODE)
124            >>> def forward_pre_hook_fn(cell_id, inputs):
125            ...     print("forward inputs: ", inputs)
126            ...
127            >>> class Net(nn.Cell):
128            ...     def __init__(self):
129            ...         super(Net, self).__init__()
130            ...         self.mul = nn.MatMul()
131            ...         self.handle = self.mul.register_forward_pre_hook(forward_pre_hook_fn)
132            ...
133            ...     def construct(self, x, y):
134            ...         x = x + x
135            ...         x = self.mul(x, y)
136            ...         return x
137            >>> grad = GradOperation(get_all=True)
138            >>> net = Net()
139            >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
140            forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
141                            dtype=Float32, value= [ 1.00000000e+00]))
142            >>> net.handle.remove()
143            >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
144            >>> print(output)
145            (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
146            value= [ 2.00000000e+00]))
147        """
148        if self._hook_cell is not None:
149            hook_cell = self._hook_cell()
150            if self._hook_type == "_forward_pre_hook" and self._hook_key in hook_cell._forward_pre_hook:
151                del hook_cell._forward_pre_hook[self._hook_key]
152            elif self._hook_type == "_forward_hook" and self._hook_key in hook_cell._forward_hook:
153                del hook_cell._forward_hook[self._hook_key]
154            elif self._hook_type == "_cell_backward_hook":
155                hook_cell._cell_backward_hook.remove_backward_hook(self._hook_key)
156