1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_cell_bprop """ 16import numpy as np 17import pytest 18 19import mindspore as ms 20import mindspore.common.dtype as mstype 21import mindspore.nn as nn 22from mindspore import Parameter, ParameterTuple 23from mindspore import context 24from mindspore.common.initializer import initializer 25from mindspore.common.tensor import Tensor 26from mindspore.ops import composite as C 27from mindspore.ops import operations as P 28 29context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 30 31 32grad_all = C.GradOperation(get_all=True) 33 34 35class MulAdd(nn.Cell): 36 def construct(self, x, y): 37 return 2 * x + y 38 39 def bprop(self, x, y, out, dout): 40 # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result 41 return 2 * dout, 2 * y 42 43@pytest.mark.level0 44@pytest.mark.platform_x86_ascend_training 45@pytest.mark.env_onecard 46def test_grad_mul_add(): 47 mul_add = MulAdd() 48 x = Tensor(1, dtype=ms.int32) 49 y = Tensor(2, dtype=ms.int32) 50 assert grad_all(mul_add)(x, y) == (2, 4) 51 52 53class InlineMulADD(nn.Cell): 54 def __init__(self): 55 super(InlineMulADD, self).__init__() 56 self.mul_add = MulAdd() 57 self.param = 2 58 59 def construct(self, x, y): 60 return self.mul_add(x, y) + x + self.param * y 61 62@pytest.mark.level0 63@pytest.mark.platform_x86_ascend_training 64@pytest.mark.env_onecard 65def test_grad_inline_mul_add(): 66 inline_mul_add = InlineMulADD() 67 x = Tensor(1, dtype=ms.int32) 68 y = Tensor(2, dtype=ms.int32) 69 assert grad_all(inline_mul_add)(x, y) == (3, 6) 70 71 72class WithParameter(nn.Cell): 73 def __init__(self): 74 super(WithParameter, self).__init__() 75 self.param1 = Parameter(1, 'param1') 76 self.param2 = Parameter(2, 'param2') 77 78 def construct(self, x, y): 79 return self.param1 * self.param2 * x + y 80 81 def bprop(self, x, y, out, dout): 82 # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result 83 return self.param1 * self.param2 * dout, 2 * y 84 85@pytest.mark.level0 86@pytest.mark.platform_x86_ascend_training 87@pytest.mark.env_onecard 88def test_with_param(): 89 with_param = WithParameter() 90 with pytest.raises(RuntimeError): 91 grad_all(with_param)(1, 2) 92 93 94class WithNoBprop(nn.Cell): 95 def construct(self, x, y): 96 return 2 * x + y 97 98@pytest.mark.level0 99@pytest.mark.platform_x86_ascend_training 100@pytest.mark.env_onecard 101def test_with_no_bprop(): 102 with_no_bprop = WithNoBprop() 103 x = Tensor(1, dtype=ms.int32) 104 y = Tensor(2, dtype=ms.int32) 105 assert grad_all(with_no_bprop)(x, y) == (2, 1) 106 107@pytest.mark.level0 108@pytest.mark.platform_x86_ascend_training 109@pytest.mark.env_onecard 110def test_grad_in_bprop_1(): 111 class GradInBprop_1(nn.Cell): 112 def __init__(self): 113 super(GradInBprop_1, self).__init__() 114 self.relu = P.ReLU() 115 116 def construct(self, x, y): 117 return self.relu(x) 118 119 class GradInBprop_2(nn.Cell): 120 def __init__(self): 121 super(GradInBprop_2, self).__init__() 122 self.f = GradInBprop_1() 123 124 def construct(self, x, y): 125 return self.f(x, y), grad_all(self.f)(x, y) 126 127 def bprop(self, x, y, out, dout): 128 grads = grad_all(self.f)(x, y) 129 return out[1][0], grads[1] 130 131 class GradInBprop_3(nn.Cell): 132 def __init__(self): 133 super(GradInBprop_3, self).__init__() 134 self.f = GradInBprop_2() 135 136 def construct(self, x, y): 137 return self.f(x, y) 138 139 grad_in_bprop = GradInBprop_3() 140 grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), 141 Tensor(np.ones([2, 2]).astype(np.float32))) 142 assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() 143 assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() 144 145@pytest.mark.level0 146@pytest.mark.platform_x86_ascend_training 147@pytest.mark.env_onecard 148def test_grad_in_bprop_2(): 149 class GradInBprop_1(nn.Cell): 150 def __init__(self): 151 super(GradInBprop_1, self).__init__() 152 self.relu = P.ReLU() 153 154 def construct(self, x, y): 155 return self.relu(x) 156 157 def bprop(self, x, y, out, dout): 158 return x * y, y + x 159 160 class GradInBprop_2(nn.Cell): 161 def __init__(self): 162 super(GradInBprop_2, self).__init__() 163 self.f = GradInBprop_1() 164 165 def construct(self, x, y): 166 return self.f(x, y), grad_all(self.f)(x, y) 167 168 def bprop(self, x, y, out, dout): 169 grads = grad_all(self.f)(x, y) 170 return out[1][0], grads[1] 171 172 class GradInBprop_3(nn.Cell): 173 def __init__(self): 174 super(GradInBprop_3, self).__init__() 175 self.f = GradInBprop_2() 176 177 def construct(self, x, y): 178 return self.f(x, y) 179 180 grad_in_bprop = GradInBprop_3() 181 grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), 182 Tensor(np.ones([2, 2]).astype(np.float32))) 183 assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() 184 assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() 185 186@pytest.mark.level0 187@pytest.mark.platform_x86_ascend_training 188@pytest.mark.env_onecard 189def test_grad_in_bprop_3(): 190 class GradInBprop_1(nn.Cell): 191 def __init__(self): 192 super(GradInBprop_1, self).__init__() 193 self.relu = P.ReLU() 194 195 def construct(self, x, y): 196 return self.relu(x) 197 198 class GradInBprop_2(nn.Cell): 199 def __init__(self): 200 super(GradInBprop_2, self).__init__() 201 self.f = GradInBprop_1() 202 203 def construct(self, x, y): 204 return self.f(x, y), grad_all(self.f)(x, y) 205 206 def bprop(self, x, y, out, dout): 207 grads = grad_all(self.f)(x, y) 208 return out[1][0], grads[1] 209 210 class GradInBprop_3(nn.Cell): 211 def __init__(self): 212 super(GradInBprop_3, self).__init__() 213 self.f = GradInBprop_2() 214 215 def construct(self, x, y): 216 return self.f(x, y) 217 218 def bprop(self, x, y, out, dout): 219 return x + y + y + out[0], x + x + y + y + dout[0] 220 221 grad_in_bprop = GradInBprop_3() 222 grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), 223 Tensor(np.ones([2, 2]).astype(np.float32))) 224 assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all() 225 assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all() 226 227 228class OneInputBprop(nn.Cell): 229 def __init__(self): 230 super().__init__() 231 self.op = P.ReLU() 232 233 def construct(self, x): 234 return self.op(x) 235 236 def bprop(self, x, out, dout): 237 return (5 * x,) 238 239@pytest.mark.level0 240@pytest.mark.platform_x86_ascend_training 241@pytest.mark.env_onecard 242def test_grad_one_input_bprop(): 243 net = OneInputBprop() 244 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 245 grad = grad_all(net)(input1) 246 assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all() 247 248 249class TwoInput(nn.Cell): 250 def construct(self, x, y): 251 return x * y 252 253 254class InlineBpropTwoInput(nn.Cell): 255 def __init__(self): 256 super().__init__() 257 self.f = TwoInput() 258 259 def construct(self, x, y): 260 return self.f(x, y), grad_all(self.f)(x, y) 261 262 def bprop(self, x, y, out, dout): 263 grads = grad_all(self.f)(x, y) 264 return grads[0] * 2, grads[1] * 2 265 266@pytest.mark.level0 267@pytest.mark.platform_x86_ascend_training 268@pytest.mark.env_onecard 269def test_grad_inline_bprop_two_input(): 270 net = InlineBpropTwoInput() 271 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 272 input2 = Tensor(np.ones([2, 2]).astype(np.float32)) 273 grads = grad_all(net)(input1, input2) 274 assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() 275 assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() 276 assert len(grads) == 2 277 278 279class TwoInputBprop(nn.Cell): 280 def __init__(self): 281 super().__init__() 282 self.op = P.Mul() 283 284 def construct(self, x, y): 285 return self.op(x, y) 286 287 def bprop(self, x, y, out, dout): 288 return 5 * x, 8 * y 289 290 291class TwoInputWithParameter(nn.Cell): 292 def __init__(self): 293 super().__init__() 294 self.op = P.Mul() 295 self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") 296 297 def construct(self, x, y): 298 x = self.inputdata + x 299 return self.op(x, y) 300 301 302class TwoInputWithOnlyInitParameterBprop(nn.Cell): 303 def __init__(self): 304 super().__init__() 305 self.op = P.Mul() 306 self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") 307 308 def construct(self, x, y): 309 return self.op(x, y) 310 311 def bprop(self, x, y, out, dout): 312 return 5 * x, 8 * y 313 314 315class InlineMutilTwoInputParameterCell(nn.Cell): 316 def __init__(self): 317 super().__init__() 318 self.f1 = TwoInputBprop() 319 self.f2 = TwoInput() 320 self.f3 = TwoInputWithParameter() 321 self.f4 = TwoInputWithOnlyInitParameterBprop() 322 323 def construct(self, x, y): 324 output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y) 325 return output 326 327@pytest.mark.level0 328@pytest.mark.platform_x86_ascend_training 329@pytest.mark.env_onecard 330def test_grad_inline_bprop_multi_input(): 331 net = InlineMutilTwoInputParameterCell() 332 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 333 input2 = Tensor(np.ones([2, 2]).astype(np.float32)) 334 net.init_parameters_data() 335 grads = grad_all(net)(input1, input2) 336 assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all() 337 assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all() 338 assert len(grads) == 2 339 340 341class MulAddWithParam(nn.Cell): 342 def __init__(self): 343 super(MulAddWithParam, self).__init__() 344 self.mul_add = MulAdd() 345 self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param') 346 347 def construct(self, x): 348 return self.mul_add(self.param, x) 349 350@pytest.mark.level0 351@pytest.mark.platform_x86_ascend_training 352@pytest.mark.env_onecard 353def test_refkey_bprop(): 354 grad_by_list = C.GradOperation(get_all=True, get_by_list=True) 355 class GradWrap(nn.Cell): 356 def __init__(self, network): 357 super(GradWrap, self).__init__() 358 self.network = network 359 self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) 360 def construct(self, x): 361 weights = self.weights 362 grads = grad_by_list(self.network, weights)(x) 363 return grads 364 network = GradWrap(MulAddWithParam()) 365 input_data = Tensor(np.array([2, 2], np.float32)) 366 grads = network(input_data) 367 assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all() 368 assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() 369 370 371class MulAddWithWrongOutputNum(nn.Cell): 372 def construct(self, x, y): 373 return 2 * x + y 374 375 def bprop(self, x, y, out, dout): 376 return (2 * dout,) 377 378@pytest.mark.level0 379@pytest.mark.platform_x86_ascend_training 380@pytest.mark.env_onecard 381def test_grad_mul_add_with_wrong_output_num(): 382 context.set_context(check_bprop=True) 383 mul_add = MulAddWithWrongOutputNum() 384 with pytest.raises(TypeError): 385 grad_all(mul_add)(1, 2) 386 387 388class MulAddWithWrongOutputType(nn.Cell): 389 def construct(self, x, y): 390 return 2 * x + y 391 392 def bprop(self, x, y, out, dout): 393 return 2 * dout, 2 394 395@pytest.mark.level0 396@pytest.mark.platform_x86_ascend_training 397@pytest.mark.env_onecard 398def test_grad_mul_add_with_wrong_output_type(): 399 context.set_context(check_bprop=True) 400 mul_add = MulAddWithWrongOutputType() 401 with pytest.raises(TypeError): 402 grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) 403 404 405class MulAddWithWrongOutputShape(nn.Cell): 406 def __init__(self): 407 super(MulAddWithWrongOutputShape, self).__init__() 408 self.ones = Tensor(np.ones([2,])) 409 410 def construct(self, x, y): 411 return 2 * x + y 412 413 def bprop(self, x, y, out, dout): 414 return 2, self.ones 415 416@pytest.mark.level0 417@pytest.mark.platform_x86_ascend_training 418@pytest.mark.env_onecard 419def test_grad_mul_add_with_wrong_output_shape(): 420 context.set_context(check_bprop=True) 421 mul_add = MulAddWithWrongOutputShape() 422 with pytest.raises(TypeError): 423 grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) 424