1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import pytest 16import numpy as np 17import mindspore.nn as nn 18import mindspore.ops.operations as P 19from mindspore.ops import composite as C 20from mindspore import context, Tensor 21from mindspore.common.api import ms_function 22 23grad_all = C.GradOperation(get_all=True) 24 25 26def var_hook_function(grad_out): 27 print("grad:", grad_out) 28 29 30class GraphVarHook(nn.Cell): 31 def __init__(self): 32 super(GraphVarHook, self).__init__() 33 self.relu = nn.ReLU() 34 self.hook = P.HookBackward(var_hook_function) 35 36 def construct(self, x): 37 x = x + x 38 x = x * x 39 x = self.hook(x) 40 x = self.relu(x) 41 return x 42 43 44class MsFuncVarHook(nn.Cell): 45 def __init__(self): 46 super(MsFuncVarHook, self).__init__() 47 self.relu = nn.ReLU() 48 self.hook = P.HookBackward(var_hook_function) 49 50 @ms_function 51 def construct(self, x): 52 x = x + x 53 x = x * x 54 x = self.hook(x) 55 x = self.relu(x) 56 return x 57 58 59@pytest.mark.level0 60@pytest.mark.platform_x86_cpu 61@pytest.mark.platform_arm_ascend_training 62@pytest.mark.platform_x86_ascend_training 63@pytest.mark.platform_x86_gpu_training 64@pytest.mark.env_onecard 65def test_var_hook_forward(): 66 input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) 67 context.set_context(mode=context.PYNATIVE_MODE) 68 net1 = MsFuncVarHook() 69 out1 = net1(input_x) 70 context.set_context(mode=context.GRAPH_MODE) 71 net2 = GraphVarHook() 72 out2 = net2(input_x) 73 assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001) 74 75 76@pytest.mark.level0 77@pytest.mark.platform_x86_cpu 78@pytest.mark.platform_arm_ascend_training 79@pytest.mark.platform_x86_ascend_training 80@pytest.mark.platform_x86_gpu_training 81@pytest.mark.env_onecard 82def test_var_hook_grad(): 83 input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) 84 context.set_context(mode=context.PYNATIVE_MODE) 85 net1 = MsFuncVarHook() 86 grad_out1 = grad_all(net1)(input_x) 87 context.set_context(mode=context.GRAPH_MODE) 88 net2 = GraphVarHook() 89 grad_out2 = grad_all(net2)(input_x) 90 assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001) 91 92 93def cell_hook_function(cell_id, grad_input, grad_output): 94 print("cell id:", cell_id) 95 print("grad input:", grad_input) 96 print("grad output:", grad_output) 97 98 99class GraphCellHook(nn.Cell): 100 def __init__(self): 101 super(GraphCellHook, self).__init__() 102 self.relu = nn.ReLU() 103 self.relu.register_backward_hook(cell_hook_function) 104 105 def construct(self, x): 106 x = x + x 107 x = x * x 108 x = self.relu(x) 109 return x 110 111 112class MsFuncCellHook(nn.Cell): 113 def __init__(self): 114 super(MsFuncCellHook, self).__init__() 115 self.relu = nn.ReLU() 116 self.relu.register_backward_hook(cell_hook_function) 117 118 @ms_function 119 def construct(self, x): 120 x = x + x 121 x = x * x 122 x = self.relu(x) 123 return x 124 125 126@pytest.mark.level0 127@pytest.mark.platform_x86_cpu 128@pytest.mark.platform_arm_ascend_training 129@pytest.mark.platform_x86_ascend_training 130@pytest.mark.platform_x86_gpu_training 131@pytest.mark.env_onecard 132def test_cell_hook_forward(): 133 input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) 134 context.set_context(mode=context.PYNATIVE_MODE) 135 net1 = MsFuncCellHook() 136 out1 = net1(input_x) 137 context.set_context(mode=context.GRAPH_MODE) 138 net2 = GraphCellHook() 139 out2 = net2(input_x) 140 assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001) 141 142 143@pytest.mark.level0 144@pytest.mark.platform_x86_cpu 145@pytest.mark.platform_arm_ascend_training 146@pytest.mark.platform_x86_ascend_training 147@pytest.mark.platform_x86_gpu_training 148@pytest.mark.env_onecard 149def test_cell_hook_grad(): 150 input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) 151 context.set_context(mode=context.PYNATIVE_MODE) 152 net1 = MsFuncCellHook() 153 grad_out1 = grad_all(net1)(input_x) 154 context.set_context(mode=context.GRAPH_MODE) 155 net2 = GraphCellHook() 156 grad_out2 = grad_all(net2)(input_x) 157 assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001) 158