1# Copyright 2020-2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_cell_bprop """ 16import numpy as np 17import pytest 18 19import mindspore as ms 20import mindspore.common.dtype as mstype 21import mindspore.nn as nn 22from mindspore import Parameter, ParameterTuple 23from mindspore import context, mutable 24from mindspore.common.initializer import initializer 25from mindspore.common.tensor import Tensor 26from mindspore.ops import composite as C 27from mindspore.ops import operations as P 28from mindspore import ops 29from mindspore._extends import cell_attr_register 30 31context.set_context(mode=context.GRAPH_MODE) 32grad_all = C.GradOperation(get_all=True) 33 34 35class MulAdd(nn.Cell): 36 def construct(self, x, y): 37 return 2 * x + y 38 39 def bprop(self, x, y, out, dout): 40 # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result 41 return 2 * dout, 2 * y 42 43 44@pytest.mark.level0 45@pytest.mark.platform_x86_ascend_training 46@pytest.mark.env_onecard 47def test_grad_mul_add(): 48 mul_add = MulAdd() 49 x = Tensor(1, dtype=ms.int32) 50 y = Tensor(2, dtype=ms.int32) 51 assert grad_all(mul_add)(x, y) == (2, 4) 52 53 54class InlineMulADD(nn.Cell): 55 def __init__(self): 56 super(InlineMulADD, self).__init__() 57 self.mul_add = MulAdd() 58 self.param = 2 59 60 def construct(self, x, y): 61 return self.mul_add(x, y) + x + self.param * y 62 63 64@pytest.mark.level0 65@pytest.mark.platform_x86_ascend_training 66@pytest.mark.env_onecard 67def test_grad_inline_mul_add(): 68 inline_mul_add = InlineMulADD() 69 x = Tensor(1, dtype=ms.int32) 70 y = Tensor(2, dtype=ms.int32) 71 assert grad_all(inline_mul_add)(x, y) == (3, 6) 72 73 74class WithParameter(nn.Cell): 75 def __init__(self): 76 super(WithParameter, self).__init__() 77 self.param1 = Parameter(1, 'param1') 78 self.param2 = Parameter(2, 'param2') 79 80 def construct(self, x, y): 81 return self.param1 * self.param2 * x + y 82 83 def bprop(self, x, y, out, dout): 84 # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result 85 return self.param1 * self.param2 * dout, 2 * y 86 87 88@pytest.mark.level0 89@pytest.mark.platform_x86_ascend_training 90@pytest.mark.env_onecard 91def test_with_param(): 92 with_param = WithParameter() 93 with pytest.raises(RuntimeError): 94 grad_all(with_param)(mutable(1), 2) 95 96 97class WithNoBprop(nn.Cell): 98 def construct(self, x, y): 99 return 2 * x + y 100 101 102@pytest.mark.level0 103@pytest.mark.platform_x86_ascend_training 104@pytest.mark.env_onecard 105def test_with_no_bprop(): 106 with_no_bprop = WithNoBprop() 107 x = Tensor(1, dtype=ms.int32) 108 y = Tensor(2, dtype=ms.int32) 109 assert grad_all(with_no_bprop)(x, y) == (2, 1) 110 111 112@pytest.mark.level0 113@pytest.mark.platform_x86_ascend_training 114@pytest.mark.env_onecard 115def test_grad_in_bprop_1(): 116 class GradInBprop_1(nn.Cell): 117 def __init__(self): 118 super(GradInBprop_1, self).__init__() 119 self.relu = P.ReLU() 120 121 def construct(self, x, y): 122 return self.relu(x) 123 124 class GradInBprop_2(nn.Cell): 125 def __init__(self): 126 super(GradInBprop_2, self).__init__() 127 self.f = GradInBprop_1() 128 129 def construct(self, x, y): 130 return self.f(x, y), grad_all(self.f)(x, y) 131 132 def bprop(self, x, y, out, dout): 133 grads = grad_all(self.f)(x, y) 134 return out[1][0], grads[1] 135 136 class GradInBprop_3(nn.Cell): 137 def __init__(self): 138 super(GradInBprop_3, self).__init__() 139 self.f = GradInBprop_2() 140 141 def construct(self, x, y): 142 return self.f(x, y) 143 144 grad_in_bprop = GradInBprop_3() 145 grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), 146 Tensor(np.ones([2, 2]).astype(np.float32))) 147 assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() 148 assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() 149 150 151@pytest.mark.level0 152@pytest.mark.platform_x86_ascend_training 153@pytest.mark.env_onecard 154def test_grad_in_bprop_2(): 155 class GradInBprop_1(nn.Cell): 156 def __init__(self): 157 super(GradInBprop_1, self).__init__() 158 self.relu = P.ReLU() 159 160 def construct(self, x, y): 161 return self.relu(x) 162 163 def bprop(self, x, y, out, dout): 164 return x * y, y + x 165 166 class GradInBprop_2(nn.Cell): 167 def __init__(self): 168 super(GradInBprop_2, self).__init__() 169 self.f = GradInBprop_1() 170 171 def construct(self, x, y): 172 return self.f(x, y), grad_all(self.f)(x, y) 173 174 def bprop(self, x, y, out, dout): 175 grads = grad_all(self.f)(x, y) 176 return out[1][0], grads[1] 177 178 class GradInBprop_3(nn.Cell): 179 def __init__(self): 180 super(GradInBprop_3, self).__init__() 181 self.f = GradInBprop_2() 182 183 def construct(self, x, y): 184 return self.f(x, y) 185 186 grad_in_bprop = GradInBprop_3() 187 grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), 188 Tensor(np.ones([2, 2]).astype(np.float32))) 189 assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() 190 assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() 191 192 193@pytest.mark.level0 194@pytest.mark.platform_x86_ascend_training 195@pytest.mark.env_onecard 196def test_grad_in_bprop_3(): 197 class GradInBprop_1(nn.Cell): 198 def __init__(self): 199 super(GradInBprop_1, self).__init__() 200 self.relu = P.ReLU() 201 202 def construct(self, x, y): 203 return self.relu(x) 204 205 class GradInBprop_2(nn.Cell): 206 def __init__(self): 207 super(GradInBprop_2, self).__init__() 208 self.f = GradInBprop_1() 209 210 def construct(self, x, y): 211 return self.f(x, y), grad_all(self.f)(x, y) 212 213 def bprop(self, x, y, out, dout): 214 grads = grad_all(self.f)(x, y) 215 return out[1][0], grads[1] 216 217 class GradInBprop_3(nn.Cell): 218 def __init__(self): 219 super(GradInBprop_3, self).__init__() 220 self.f = GradInBprop_2() 221 222 def construct(self, x, y): 223 return self.f(x, y) 224 225 def bprop(self, x, y, out, dout): 226 return x + y + y + out[0], x + x + y + y + dout[0] 227 228 grad_in_bprop = GradInBprop_3() 229 grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), 230 Tensor(np.ones([2, 2]).astype(np.float32))) 231 assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all() 232 assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all() 233 234 235class OneInputBprop(nn.Cell): 236 def __init__(self): 237 super().__init__() 238 self.op = P.ReLU() 239 240 def construct(self, x): 241 return self.op(x) 242 243 def bprop(self, x, out, dout): 244 return (5 * x,) 245 246 247@pytest.mark.level0 248@pytest.mark.platform_x86_ascend_training 249@pytest.mark.env_onecard 250def test_grad_one_input_bprop(): 251 net = OneInputBprop() 252 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 253 grad = grad_all(net)(input1) 254 assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all() 255 256 257class TwoInput(nn.Cell): 258 def construct(self, x, y): 259 return x * y 260 261 262class InlineBpropTwoInput(nn.Cell): 263 def __init__(self): 264 super().__init__() 265 self.f = TwoInput() 266 267 def construct(self, x, y): 268 return self.f(x, y), grad_all(self.f)(x, y) 269 270 def bprop(self, x, y, out, dout): 271 grads = grad_all(self.f)(x, y) 272 return grads[0] * 2, grads[1] * 2 273 274 275@pytest.mark.level0 276@pytest.mark.platform_x86_ascend_training 277@pytest.mark.env_onecard 278def test_grad_inline_bprop_two_input(): 279 net = InlineBpropTwoInput() 280 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 281 input2 = Tensor(np.ones([2, 2]).astype(np.float32)) 282 grads = grad_all(net)(input1, input2) 283 assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() 284 assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() 285 assert len(grads) == 2 286 287 288class TwoInputBprop(nn.Cell): 289 def __init__(self): 290 super().__init__() 291 self.op = P.Mul() 292 293 def construct(self, x, y): 294 return self.op(x, y) 295 296 def bprop(self, x, y, out, dout): 297 return 5 * x, 8 * y 298 299 300class TwoInputWithParameter(nn.Cell): 301 def __init__(self): 302 super().__init__() 303 self.op = P.Mul() 304 self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") 305 306 def construct(self, x, y): 307 x = self.inputdata + x 308 return self.op(x, y) 309 310 311class TwoInputWithOnlyInitParameterBprop(nn.Cell): 312 def __init__(self): 313 super().__init__() 314 self.op = P.Mul() 315 self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") 316 317 def construct(self, x, y): 318 return self.op(x, y) 319 320 def bprop(self, x, y, out, dout): 321 return 5 * x, 8 * y 322 323 324class InlineMutilTwoInputParameterCell(nn.Cell): 325 def __init__(self): 326 super().__init__() 327 self.f1 = TwoInputBprop() 328 self.f2 = TwoInput() 329 self.f3 = TwoInputWithParameter() 330 self.f4 = TwoInputWithOnlyInitParameterBprop() 331 332 def construct(self, x, y): 333 output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y) 334 return output 335 336 337@pytest.mark.level0 338@pytest.mark.platform_x86_ascend_training 339@pytest.mark.env_onecard 340def test_grad_inline_bprop_multi_input(): 341 net = InlineMutilTwoInputParameterCell() 342 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 343 input2 = Tensor(np.ones([2, 2]).astype(np.float32)) 344 net.init_parameters_data() 345 grads = grad_all(net)(input1, input2) 346 assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all() 347 assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all() 348 assert len(grads) == 2 349 350 351class MulAddWithParam(nn.Cell): 352 def __init__(self): 353 super(MulAddWithParam, self).__init__() 354 self.mul_add = MulAdd() 355 self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param') 356 357 def construct(self, x): 358 return self.mul_add(self.param, x) 359 360 361@pytest.mark.level0 362@pytest.mark.platform_x86_ascend_training 363@pytest.mark.env_onecard 364def test_refkey_bprop(): 365 grad_by_list = C.GradOperation(get_all=True, get_by_list=True) 366 367 class GradWrap(nn.Cell): 368 def __init__(self, network): 369 super(GradWrap, self).__init__() 370 self.network = network 371 self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) 372 373 def construct(self, x): 374 weights = self.weights 375 grads = grad_by_list(self.network, weights)(x) 376 return grads 377 378 network = GradWrap(MulAddWithParam()) 379 input_data = Tensor(np.array([2, 2], np.float32)) 380 grads = network(input_data) 381 assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all() 382 assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() 383 384 385class MulAddWithWrongOutputNum(nn.Cell): 386 def construct(self, x, y): 387 return 2 * x + y 388 389 def bprop(self, x, y, out, dout): 390 return (2 * dout,) 391 392 393@pytest.mark.level0 394@pytest.mark.platform_x86_ascend_training 395@pytest.mark.env_onecard 396def test_grad_mul_add_with_wrong_output_num(): 397 context.set_context(check_bprop=True) 398 mul_add = MulAddWithWrongOutputNum() 399 with pytest.raises(ValueError): 400 grad_all(mul_add)(mutable(1), 2) 401 402 403class MulAddWithWrongOutputType(nn.Cell): 404 def construct(self, x, y): 405 return 2 * x + y 406 407 def bprop(self, x, y, out, dout): 408 return 2 * dout, 2 409 410 411@pytest.mark.level0 412@pytest.mark.platform_x86_ascend_training 413@pytest.mark.env_onecard 414def test_grad_mul_add_with_wrong_output_type(): 415 context.set_context(check_bprop=True) 416 mul_add = MulAddWithWrongOutputType() 417 with pytest.raises(TypeError): 418 grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) 419 420 421class MulAddWithWrongOutputShape(nn.Cell): 422 def __init__(self): 423 super(MulAddWithWrongOutputShape, self).__init__() 424 self.ones = Tensor(np.ones([2,])) 425 426 def construct(self, x, y): 427 return 2 * x + y 428 429 def bprop(self, x, y, out, dout): 430 return 2, self.ones 431 432 433@pytest.mark.level0 434@pytest.mark.platform_x86_ascend_training 435@pytest.mark.env_onecard 436def test_grad_mul_add_with_wrong_output_shape(): 437 context.set_context(check_bprop=True) 438 mul_add = MulAddWithWrongOutputShape() 439 with pytest.raises(ValueError): 440 grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) 441 442 443@pytest.mark.level1 444@pytest.mark.platform_x86_cpu 445@pytest.mark.env_onecard 446def test_forward_with_parameter(): 447 """ 448 Feature: Custom cell bprop 449 Description: Get the gradients of inputs when the forward net using Parameter. 450 Expectation: Get the correct gradients. 451 """ 452 453 class Net(nn.Cell): 454 def __init__(self): 455 super(Net, self).__init__() 456 self.matmul = P.MatMul() 457 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 458 459 def construct(self, x, y): 460 x = x * self.z 461 out = self.matmul(x, y) 462 return out 463 464 def bprop(self, x, y, out, dout): 465 dx = x + x 466 dy = y + y 467 return dx, dy 468 469 class GradNet(nn.Cell): 470 def __init__(self, net): 471 super(GradNet, self).__init__() 472 self.net = net 473 474 def construct(self, x, y): 475 grad_f = grad_all(self.net) 476 return grad_f(x, y) 477 478 x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) 479 y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) 480 out = GradNet(Net())(x, y) 481 expect_dx = np.array([[1.0, 1.2, 0.8], 482 [2.4, 2.6, 2.2]]).astype(np.float32) 483 expect_dy = np.array([[0.02, 0.6, 2.2], 484 [0.2, 0.4, 2.6], 485 [4.2, 2.4, 6.6]]).astype(np.float32) 486 assert np.allclose(out[0].asnumpy(), expect_dx) 487 assert np.allclose(out[1].asnumpy(), expect_dy) 488 489 490@pytest.mark.level1 491@pytest.mark.platform_x86_cpu 492@pytest.mark.env_onecard 493def test_forward_with_parameter_in_sub_cell(): 494 """ 495 Feature: Custom cell bprop 496 Description: Get the gradients of inputs when the forward net using Parameter in the sub-cell. 497 Expectation: Get the correct gradients. 498 """ 499 500 class Net(nn.Cell): 501 def __init__(self): 502 super(Net, self).__init__() 503 self.net = Net1() 504 505 def construct(self, x, y): 506 return self.net(x, y) 507 508 class Net1(nn.Cell): 509 def __init__(self): 510 super(Net1, self).__init__() 511 self.matmul = P.MatMul() 512 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 513 514 def construct(self, x, y): 515 x = x * self.z 516 out = self.matmul(x, y) 517 return out 518 519 def bprop(self, x, y, out, dout): 520 dx = x + x 521 dy = y + y 522 return dx, dy 523 524 class GradNet(nn.Cell): 525 def __init__(self, net): 526 super(GradNet, self).__init__() 527 self.net = net 528 529 def construct(self, x, y): 530 grad_f = grad_all(self.net) 531 return grad_f(x, y) 532 533 x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) 534 y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) 535 out = GradNet(Net())(x, y) 536 expect_dx = np.array([[1.0, 1.2, 0.8], 537 [2.4, 2.6, 2.2]]).astype(np.float32) 538 expect_dy = np.array([[0.02, 0.6, 2.2], 539 [0.2, 0.4, 2.6], 540 [4.2, 2.4, 6.6]]).astype(np.float32) 541 assert np.allclose(out[0].asnumpy(), expect_dx) 542 assert np.allclose(out[1].asnumpy(), expect_dy) 543 544 545@pytest.mark.level1 546@pytest.mark.platform_x86_cpu 547@pytest.mark.env_onecard 548def test_forward_with_parameter_in_sub_cell_get_by_list(): 549 """ 550 Feature: Custom cell bprop 551 Description: Get the gradients of inputs and Parameters when the forward net using Parameter in the sub-cell. 552 Expectation: Get the correct gradients. 553 """ 554 555 class Net(nn.Cell): 556 def __init__(self): 557 super(Net, self).__init__() 558 self.net = Net1() 559 560 def construct(self, x, y): 561 return self.net(x, y) 562 563 class Net1(nn.Cell): 564 def __init__(self): 565 super(Net1, self).__init__() 566 self.matmul = P.MatMul() 567 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 568 569 def construct(self, x, y): 570 x = x * self.z 571 out = self.matmul(x, y) 572 return out 573 574 def bprop(self, x, y, out, dout): 575 dx = x + x 576 dy = y + y 577 return dx, dy 578 579 class GradNet(nn.Cell): 580 def __init__(self, net): 581 super(GradNet, self).__init__() 582 self.net = net 583 self.params = ParameterTuple(net.trainable_params()) 584 self.grad_op = C.GradOperation(get_by_list=True, get_all=True) 585 586 def construct(self, x, y): 587 grad_f = self.grad_op(self.net, self.params) 588 return grad_f(x, y) 589 590 x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) 591 y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) 592 out = GradNet(Net())(x, y) 593 expect_dx = np.array([[1.0, 1.2, 0.8], 594 [2.4, 2.6, 2.2]]).astype(np.float32) 595 expect_dy = np.array([[0.02, 0.6, 2.2], 596 [0.2, 0.4, 2.6], 597 [4.2, 2.4, 6.6]]).astype(np.float32) 598 expect_dz = np.array([0.0]).astype(np.float32) 599 assert np.allclose(out[0][0].asnumpy(), expect_dx) 600 assert np.allclose(out[0][1].asnumpy(), expect_dy) 601 assert np.allclose(out[1][0].asnumpy(), expect_dz) 602 603 604@pytest.mark.level1 605@pytest.mark.platform_x86_cpu 606@pytest.mark.env_onecard 607def test_pynative_forward_with_parameter(): 608 """ 609 Feature: Custom cell bprop 610 Description: Get the gradients of inputs when the forward net using Parameter. 611 Expectation: Get the correct gradients. 612 """ 613 context.set_context(mode=context.PYNATIVE_MODE) 614 615 class Net(nn.Cell): 616 def __init__(self): 617 super(Net, self).__init__() 618 self.matmul = P.MatMul() 619 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 620 621 def construct(self, x, y): 622 x = x * self.z 623 out = self.matmul(x, y) 624 return out 625 626 def bprop(self, x, y, out, dout): 627 dx = x + x 628 dy = y + y 629 return dx, dy 630 631 class GradNet(nn.Cell): 632 def __init__(self, net): 633 super(GradNet, self).__init__() 634 self.net = net 635 636 def construct(self, x, y): 637 grad_f = grad_all(self.net) 638 return grad_f(x, y) 639 640 x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) 641 y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) 642 out = GradNet(Net())(x, y) 643 expect_dx = np.array([[1.0, 1.2, 0.8], 644 [2.4, 2.6, 2.2]]).astype(np.float32) 645 expect_dy = np.array([[0.02, 0.6, 2.2], 646 [0.2, 0.4, 2.6], 647 [4.2, 2.4, 6.6]]).astype(np.float32) 648 assert np.allclose(out[0].asnumpy(), expect_dx) 649 assert np.allclose(out[1].asnumpy(), expect_dy) 650 context.set_context(mode=context.GRAPH_MODE) 651 652 653@pytest.mark.level1 654@pytest.mark.platform_x86_cpu 655@pytest.mark.env_onecard 656def test_pynative_forward_with_parameter_in_sub_cell(): 657 """ 658 Feature: Custom cell bprop 659 Description: Get the gradients of inputs when the forward net using Parameter in the sub-cell. 660 Expectation: Get the correct gradients. 661 """ 662 context.set_context(mode=context.PYNATIVE_MODE) 663 664 class Net(nn.Cell): 665 def __init__(self): 666 super(Net, self).__init__() 667 self.net = Net1() 668 669 def construct(self, x, y): 670 return self.net(x, y) 671 672 class Net1(nn.Cell): 673 def __init__(self): 674 super(Net1, self).__init__() 675 self.matmul = P.MatMul() 676 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 677 678 def construct(self, x, y): 679 x = x * self.z 680 out = self.matmul(x, y) 681 return out 682 683 def bprop(self, x, y, out, dout): 684 dx = x + x 685 dy = y + y 686 return dx, dy 687 688 class GradNet(nn.Cell): 689 def __init__(self, net): 690 super(GradNet, self).__init__() 691 self.net = net 692 693 def construct(self, x, y): 694 grad_f = grad_all(self.net) 695 return grad_f(x, y) 696 697 x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) 698 y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) 699 out = GradNet(Net())(x, y) 700 expect_dx = np.array([[1.0, 1.2, 0.8], 701 [2.4, 2.6, 2.2]]).astype(np.float32) 702 expect_dy = np.array([[0.02, 0.6, 2.2], 703 [0.2, 0.4, 2.6], 704 [4.2, 2.4, 6.6]]).astype(np.float32) 705 assert np.allclose(out[0].asnumpy(), expect_dx) 706 assert np.allclose(out[1].asnumpy(), expect_dy) 707 context.set_context(mode=context.GRAPH_MODE) 708 709 710@pytest.mark.level1 711@pytest.mark.platform_x86_cpu 712@pytest.mark.env_onecard 713def test_pynative_forward_with_parameter_in_sub_cell_get_by_list(): 714 """ 715 Feature: Custom cell bprop 716 Description: Get the gradients of inputs and Parameters when the forward net using Parameter in the sub-cell. 717 Expectation: Get the correct gradients. 718 """ 719 context.set_context(mode=context.PYNATIVE_MODE) 720 721 class Net(nn.Cell): 722 def __init__(self): 723 super(Net, self).__init__() 724 self.net = Net1() 725 726 def construct(self, x, y): 727 return self.net(x, y) 728 729 class Net1(nn.Cell): 730 def __init__(self): 731 super(Net1, self).__init__() 732 self.matmul = P.MatMul() 733 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 734 735 def construct(self, x, y): 736 x = x * self.z 737 out = self.matmul(x, y) 738 return out 739 740 def bprop(self, x, y, out, dout): 741 dx = x + x 742 dy = y + y 743 return dx, dy 744 745 class GradNet(nn.Cell): 746 def __init__(self, net): 747 super(GradNet, self).__init__() 748 self.net = net 749 self.params = ParameterTuple(net.trainable_params()) 750 self.grad_op = C.GradOperation(get_by_list=True, get_all=True) 751 752 def construct(self, x, y): 753 grad_f = self.grad_op(self.net, self.params) 754 return grad_f(x, y) 755 756 x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) 757 y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) 758 out = GradNet(Net())(x, y) 759 expect_dx = np.array([[1.0, 1.2, 0.8], 760 [2.4, 2.6, 2.2]]).astype(np.float32) 761 expect_dy = np.array([[0.02, 0.6, 2.2], 762 [0.2, 0.4, 2.6], 763 [4.2, 2.4, 6.6]]).astype(np.float32) 764 expect_dz = np.array([0.0]).astype(np.float32) 765 assert np.allclose(out[0][0].asnumpy(), expect_dx) 766 assert np.allclose(out[0][1].asnumpy(), expect_dy) 767 assert np.allclose(out[1][0].asnumpy(), expect_dz) 768 context.set_context(mode=context.GRAPH_MODE) 769 770 771@pytest.mark.level1 772@pytest.mark.platform_x86_gpu_training 773@pytest.mark.env_onecard 774def test_dde_self_define_cell_output_not_use(): 775 """ 776 Feature: Custom cell bprop 777 Description: Fprop output[1] only used by bprop, it should not erased by dde. 778 Expectation: Get the correct gradients. 779 """ 780 781 class SelfDefineCell(ms.nn.Cell): 782 def construct(self, x): 783 return x + 1, x + 2 784 785 def bprop(self, x, out, dout): 786 return (out[1],) 787 788 class ForwardNet(ms.nn.Cell): 789 def __init__(self): 790 super(ForwardNet, self).__init__() 791 self.self_defined_cell = SelfDefineCell() 792 793 def construct(self, x): 794 # keep out1 not used in fprop. 795 out0, _ = self.self_defined_cell(x) 796 return out0 797 798 class TestNet(ms.nn.Cell): 799 def __init__(self): 800 super(TestNet, self).__init__() 801 self.forward_net = ForwardNet() 802 self.grad_op = ops.GradOperation(get_all=True) 803 804 def construct(self, x): 805 grad_out = self.grad_op(self.forward_net)(x) 806 return grad_out 807 808 net = TestNet() 809 x_input = ms.Tensor([1]) 810 out = net(x_input) 811 assert out[0] == ms.Tensor([3]) 812 813 814@pytest.mark.level1 815@pytest.mark.platform_x86_gpu_training 816@pytest.mark.env_onecard 817def test_bprop_defined_in_cell_attr_register(): 818 """ 819 Feature: Custom cell bprop 820 Description: Get the gradients of input for the cell which has been added @cell_attr_register. 821 Expectation: Get the correct gradients. 822 """ 823 824 class Net(nn.Cell): 825 @cell_attr_register 826 def __init__(self): 827 super().__init__() 828 self.z = Parameter(Tensor(2, mstype.float32), name='z') 829 830 def construct(self, x, y): 831 x = x * self.z 832 return x * y 833 834 def bprop(self, x, y, out, dout): 835 return y, x 836 837 net = Net() 838 x = Tensor(3, mstype.float32) 839 y = Tensor(4, mstype.float32) 840 output = ops.grad(net)(x, y) 841 assert output == 4 842