• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16import numpy as np
17import mindspore.nn as nn
18import mindspore.ops.operations as P
19from mindspore.ops import composite as C
20from mindspore import context, Tensor
21from mindspore.common.api import ms_function
22
23grad_all = C.GradOperation(get_all=True)
24
25
26def var_hook_function(grad_out):
27    print("grad:", grad_out)
28
29
30class GraphVarHook(nn.Cell):
31    def __init__(self):
32        super(GraphVarHook, self).__init__()
33        self.relu = nn.ReLU()
34        self.hook = P.HookBackward(var_hook_function)
35
36    def construct(self, x):
37        x = x + x
38        x = x * x
39        x = self.hook(x)
40        x = self.relu(x)
41        return x
42
43
44class MsFuncVarHook(nn.Cell):
45    def __init__(self):
46        super(MsFuncVarHook, self).__init__()
47        self.relu = nn.ReLU()
48        self.hook = P.HookBackward(var_hook_function)
49
50    @ms_function
51    def construct(self, x):
52        x = x + x
53        x = x * x
54        x = self.hook(x)
55        x = self.relu(x)
56        return x
57
58
59@pytest.mark.level0
60@pytest.mark.platform_x86_cpu
61@pytest.mark.platform_arm_ascend_training
62@pytest.mark.platform_x86_ascend_training
63@pytest.mark.platform_x86_gpu_training
64@pytest.mark.env_onecard
65def test_var_hook_forward():
66    input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
67    context.set_context(mode=context.PYNATIVE_MODE)
68    net1 = MsFuncVarHook()
69    out1 = net1(input_x)
70    context.set_context(mode=context.GRAPH_MODE)
71    net2 = GraphVarHook()
72    out2 = net2(input_x)
73    assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
74
75
76@pytest.mark.level0
77@pytest.mark.platform_x86_cpu
78@pytest.mark.platform_arm_ascend_training
79@pytest.mark.platform_x86_ascend_training
80@pytest.mark.platform_x86_gpu_training
81@pytest.mark.env_onecard
82def test_var_hook_grad():
83    input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
84    context.set_context(mode=context.PYNATIVE_MODE)
85    net1 = MsFuncVarHook()
86    grad_out1 = grad_all(net1)(input_x)
87    context.set_context(mode=context.GRAPH_MODE)
88    net2 = GraphVarHook()
89    grad_out2 = grad_all(net2)(input_x)
90    assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)
91
92
93def cell_hook_function(cell_id, grad_input, grad_output):
94    print("cell id:", cell_id)
95    print("grad input:", grad_input)
96    print("grad output:", grad_output)
97
98
99class GraphCellHook(nn.Cell):
100    def __init__(self):
101        super(GraphCellHook, self).__init__()
102        self.relu = nn.ReLU()
103        self.relu.register_backward_hook(cell_hook_function)
104
105    def construct(self, x):
106        x = x + x
107        x = x * x
108        x = self.relu(x)
109        return x
110
111
112class MsFuncCellHook(nn.Cell):
113    def __init__(self):
114        super(MsFuncCellHook, self).__init__()
115        self.relu = nn.ReLU()
116        self.relu.register_backward_hook(cell_hook_function)
117
118    @ms_function
119    def construct(self, x):
120        x = x + x
121        x = x * x
122        x = self.relu(x)
123        return x
124
125
126@pytest.mark.level0
127@pytest.mark.platform_x86_cpu
128@pytest.mark.platform_arm_ascend_training
129@pytest.mark.platform_x86_ascend_training
130@pytest.mark.platform_x86_gpu_training
131@pytest.mark.env_onecard
132def test_cell_hook_forward():
133    input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
134    context.set_context(mode=context.PYNATIVE_MODE)
135    net1 = MsFuncCellHook()
136    out1 = net1(input_x)
137    context.set_context(mode=context.GRAPH_MODE)
138    net2 = GraphCellHook()
139    out2 = net2(input_x)
140    assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
141
142
143@pytest.mark.level0
144@pytest.mark.platform_x86_cpu
145@pytest.mark.platform_arm_ascend_training
146@pytest.mark.platform_x86_ascend_training
147@pytest.mark.platform_x86_gpu_training
148@pytest.mark.env_onecard
149def test_cell_hook_grad():
150    input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
151    context.set_context(mode=context.PYNATIVE_MODE)
152    net1 = MsFuncCellHook()
153    grad_out1 = grad_all(net1)(input_x)
154    context.set_context(mode=context.GRAPH_MODE)
155    net2 = GraphCellHook()
156    grad_out2 = grad_all(net2)(input_x)
157    assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)
158