1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import pytest 16import numpy as np 17from mindspore import context, nn, Tensor, Parameter, ParameterTuple 18from mindspore.common import dtype as mstype 19from mindspore.ops import composite as C 20 21 22@pytest.fixture(scope="module", autouse=True) 23def setup_teardown(): 24 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 25 yield 26 context.set_context(mode=context.GRAPH_MODE) 27 28 29class _Grad(nn.Cell): 30 def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): 31 super().__init__() 32 self.network = network 33 self.grad = grad 34 self.sens_param = self.grad.sens_param 35 self.wrt_params = wrt_params 36 self.real_inputs_count = real_inputs_count 37 if self.wrt_params: 38 self.params = ParameterTuple(self.network.trainable_params()) 39 40 def construct(self, *inputs): 41 if self.wrt_params: 42 if self.real_inputs_count is None or self.sens_param is False: 43 return self.grad(self.network, self.params)(*inputs) 44 real_inputs = inputs[:self.real_inputs_count] 45 sense_param_inputs = inputs[self.real_inputs_count:] 46 return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) 47 48 if self.real_inputs_count is None or self.sens_param is False: 49 return self.grad(self.network)(*inputs) 50 real_inputs = inputs[:self.real_inputs_count] 51 sense_param_inputs = inputs[self.real_inputs_count:] 52 return self.grad(self.network)(*real_inputs, sense_param_inputs) 53 54 55class GradOfFirstInput(_Grad): 56 """ 57 get grad of first input 58 """ 59 60 def __init__(self, network, sens_param=True, real_inputs_count=None): 61 super().__init__(grad=C.GradOperation(sens_param=sens_param), 62 network=network, real_inputs_count=real_inputs_count) 63 64 65class GradOfAllInputs(_Grad): 66 """ 67 get grad of first input 68 """ 69 70 def __init__(self, network, sens_param=True, real_inputs_count=None): 71 super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param), 72 network=network, real_inputs_count=real_inputs_count) 73 74 75def test_multi_grad(): 76 class ForwardNetMul(nn.Cell): 77 def __init__(self): 78 super().__init__() 79 80 def construct(self, x, y): 81 a = x * x 82 b = y * y 83 return a * b 84 85 class ForwardNetAdd(nn.Cell): 86 def __init__(self): 87 super().__init__() 88 89 def construct(self, x, y): 90 a = x + x + x 91 b = y + y 92 return a * b 93 mulnet = ForwardNetMul() 94 addnet = ForwardNetAdd() 95 x = Tensor(np.ones([32]), dtype=mstype.float32) 96 y = Tensor(np.ones([32])*2, dtype=mstype.float32) 97 sens = Tensor(np.ones([32]), dtype=mstype.float32) 98 mulnet.set_grad() 99 addnet.set_grad() 100 out1 = mulnet(x, y) 101 out2 = addnet(x, y) 102 grad_mul = GradOfAllInputs(mulnet) 103 grad_add = GradOfAllInputs(addnet) 104 grad_mul(x, y, sens) 105 grad_add(x, y, sens) 106 107 108def test_multi_same_grad(): 109 class ForwardNetMul(nn.Cell): 110 def __init__(self): 111 super().__init__() 112 113 def construct(self, x, y): 114 a = x * x 115 b = y * y 116 return a * b 117 118 class ForwardNetAdd(nn.Cell): 119 def __init__(self): 120 super().__init__() 121 122 def construct(self, x, y): 123 a = x*3 124 b = y*2 125 return a + b 126 mulnet = ForwardNetMul() 127 addnet = ForwardNetAdd() 128 x = Tensor(np.ones([32]), dtype=mstype.float32) 129 y = Tensor(np.ones([32]), dtype=mstype.float32) 130 sens = Tensor(np.ones([32]), dtype=mstype.float32) 131 mulnet.set_grad() 132 addnet.set_grad() 133 out1 = mulnet(x, y) 134 out2 = addnet(x, y) 135 grad_mul = GradOfAllInputs(mulnet) 136 grad_add = GradOfFirstInput(mulnet) 137 grad_mul(x, y, sens) 138 grad_add(x, y, sens) 139 140 141def test_net_inner_grad(): 142 class ForwardNetMul(nn.Cell): 143 def __init__(self): 144 super().__init__() 145 146 def construct(self, x, y): 147 a = x * x 148 b = y * y 149 return a * b 150 151 class ForwardNetAdd(nn.Cell): 152 def __init__(self, net): 153 super().__init__() 154 self.net = net 155 156 def construct(self, x, y): 157 a = x + x 158 b = y + y 159 res = self.net(a, b) 160 return res 161 mulnet = ForwardNetMul() 162 addnet = ForwardNetAdd(mulnet) 163 x = Tensor(np.ones([32]), dtype=mstype.float32) 164 y = Tensor(np.ones([32]), dtype=mstype.float32) 165 sens = Tensor(np.ones([32]), dtype=mstype.float32) 166 mulnet.set_grad() 167 addnet.set_grad() 168 out1 = mulnet(x, y) 169 out2 = addnet(x, y) 170 grad_mul = GradOfAllInputs(addnet) 171 grad_add = GradOfAllInputs(mulnet) 172 grad_mul(x, y, sens) 173 grad_add(x, y, sens) 174 175 176def test_net_inner_first_run_grad(): 177 class ForwardNetMul(nn.Cell): 178 def __init__(self): 179 super().__init__() 180 self.z1 = Parameter(Tensor(np.ones([32])*2, dtype=mstype.float32), name='z1') 181 182 def construct(self, x, y): 183 a = x * self.z1 184 b = y * y 185 return a * b 186 187 class ForwardNetAdd(nn.Cell): 188 def __init__(self, net): 189 super().__init__() 190 self.net = net 191 self.z2 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2') 192 self.z3 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2') 193 194 def construct(self, x, y): 195 a = x + x*self.z3 196 b = y + y*self.z2 197 res = self.net(a, b) 198 return res 199 mulnet = ForwardNetMul() 200 addnet = ForwardNetAdd(mulnet) 201 x = Tensor(np.ones([32]), dtype=mstype.float32) 202 y = Tensor(np.ones([32]), dtype=mstype.float32) 203 sens = Tensor(np.ones([32]), dtype=mstype.float32) 204 mulnet.set_grad() 205 addnet.set_grad() 206 out1 = mulnet(x, y) 207 out2 = addnet(x, y) 208 grad_mul = GradOfAllInputs(addnet) 209 grad_add = GradOfFirstInput(mulnet) 210 grad_mul(x, y, sens) 211 grad_add(x, y, sens) 212