1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_framstruct """ 16import numpy as np 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import context 20from mindspore.common import dtype as mstype 21from mindspore.common.parameter import Parameter, ParameterTuple 22from mindspore.common.api import ms_function 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from ..ut_filter import non_graph_engine 26from ....mindspore_test_framework.utils.check_gradient import ( 27 check_jacobian, Tensor, NNGradChecker, 28 OperationGradChecker, check_gradient) 29 30context.set_context(mode=context.PYNATIVE_MODE) 31 32 33def setup_module(module): 34 context.set_context(mode=context.PYNATIVE_MODE) 35 36 37grad_all = C.GradOperation(get_all=True) 38grad_by_list = C.GradOperation(get_by_list=True) 39 40 41@ms_function 42def while_upper_bound(upper): 43 rval = 2 44 while rval < upper: 45 rval = rval * rval 46 return rval 47 48 49def test_while_upper_bound(): 50 res = while_upper_bound(10) 51 assert res == 16 52 53 54@ms_function 55def while_lower_bound(lower): 56 """ t_while """ 57 rval = lower 58 while rval < 100: 59 rval = rval * rval 60 return rval 61 62 63def test_while_lower_bound(): 64 res = while_lower_bound(2) 65 assert res == 256 66 67 68@ms_function 69def dynamic_make_tuple(x, lower, upper): 70 out = () 71 i = lower 72 while i < upper: 73 out = out + (x,) 74 i = i + 1 75 return out 76 77 78def test_dynamic_make_tuple(): 79 assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2) 80 81 82def test_make_tuple(): 83 # Statically recursively creating static type is valid in mindspore. 84 @ms_function 85 def make_tuple(x): 86 out = () 87 for i in range(3): 88 out = out + (x,) 89 return out 90 91 res = make_tuple(5) 92 assert res == (5, 5, 5) 93 94 95@ms_function 96def add(x, y): 97 """ add """ 98 return x + y 99 100 101def mul(x, y): 102 """ mul """ 103 return x * y 104 105 106def add_mul(x, y): 107 """ add_mul """ 108 return (x + y) * y 109 110 111def mainf(x, y): 112 """ mainf """ 113 return grad_all(mul)(x, y) 114 115 116def grad_add_mul(x, y): 117 """ grad_add_mul """ 118 return grad_all(add_mul)(x, y) 119 120 121@ms_function 122def sub(x, y): 123 """ sub """ 124 return x - y 125 126 127# pylint: disable=using-constant-test 128@ms_function 129def if_always_true(x): 130 """ if_always_true """ 131 if True: 132 return x 133 else: 134 return 0 135 136 137def test_add(): 138 """ test_add """ 139 res = add(2.5, 3) 140 assert res == 5.5 141 142 143def test_sub(): 144 """ test_sub """ 145 res = sub(3.5, 3) 146 assert res == 0.5 147 148 149@non_graph_engine 150def test_if_always_true(): 151 """ test_if_always_true """ 152 res = if_always_true(1) 153 assert res == 1 154 155 156@non_graph_engine 157def test_f(): 158 """ test_f """ 159 res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) 160 assert res == (2, 3) 161 162 163@non_graph_engine 164def test_grad_add_mul(): 165 """ test_grad_add_mul """ 166 res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) 167 assert res == (2, 7) 168 169 170def f(x): 171 if x > 0: 172 return f(x - 1) 173 return x 174 175 176@ms_function 177def list_subscript(): 178 """ list_subscript """ 179 x = [1, 2, 3] 180 return x[0] * x[1] 181 182 183def test_list_subscript(): 184 """ test_list_subscript """ 185 res = list_subscript() 186 assert res == 2 187 188 189@ms_function 190def ms_infer_for(xs, y): 191 """ ms_infer_for """ 192 rval = y 193 for x in xs: 194 rval = rval + x 195 return rval 196 197 198def test_infer_for(): 199 """ test_infer_for """ 200 t = (1, 2, 3) 201 y = 4 202 res = ms_infer_for(t, y) 203 assert res == 10 204 205 206@ms_function 207def if_construct(a, b): 208 z = a 209 if a > b: 210 z = a + b 211 else: 212 z = a * b 213 if z > b: 214 return z - a 215 else: 216 return a - b 217 218 219def test_if_construct(): 220 """ test_if_construct """ 221 res = if_construct(3, 6) 222 assert res == 15 223 224 225@ms_function 226def if_scalar(a, b): 227 """ if_abstract """ 228 if a: 229 return a 230 return b 231 232 233def test_if_scalar1(): 234 """ test_if_abstract """ 235 res = if_scalar(3, 6) 236 assert res == 3 237 238 239def test_if_scalar2(): 240 """ test_if_abstract """ 241 res = if_scalar(0, 6) 242 assert res == 6 243 244 245@ms_function 246def if_tensor(a, b): 247 c = a 248 if a < b: 249 c = a + a 250 if c < b: 251 c = a + c 252 else: 253 c = a + b 254 else: 255 c = b + b 256 out = c + c 257 return out 258 259 260def test_if_tensor(): 261 res = if_tensor(Tensor(np.ones([1]).astype(np.int32)), Tensor(np.ones([1]).astype(np.int32))) 262 assert res == Tensor(np.ones([1]).astype(np.int32) * 4) 263 264 265def rec(x): 266 """ rec """ 267 if x > 0: 268 return rec(x - 1) 269 return x 270 271 272def test_me_rec(): 273 """ test_me_rec """ 274 res = rec(10) 275 assert res == 0 276 277 278def t2_while(x, y): 279 out = y - x 280 i = 0 281 while i < 10: 282 out = mul(x, y) 283 i = i + 1 284 return out 285 286 287def test_while2(): 288 res = t2_while(2, 3) 289 assert res == 6 290 291 292def if_test(a, b): 293 """ if_test """ 294 if a > b: 295 return 3 * a 296 return 2 * b 297 298 299def grad_if(x, y): 300 """ grad_if """ 301 return grad_all(if_test)(x, y) 302 303 304def test_grad_if(): 305 """ test_grad_if """ 306 assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0) 307 308 309class ConvNet(nn.Cell): 310 def __init__(self): 311 super(ConvNet, self).__init__() 312 out_channel = 16 313 kernel_size = 3 314 self.conv = P.Conv2D(out_channel, 315 kernel_size, 316 mode=1, 317 pad_mode="pad", 318 pad=0, 319 stride=1, 320 dilation=2, 321 group=1) 322 self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w') 323 324 def construct(self, x): 325 return self.conv(x, self.w) 326 327 328conv = ConvNet() 329c1 = Tensor([2], mstype.float32) 330c2 = Tensor([10], mstype.float32) 331c3 = Tensor([1], mstype.float32) 332 333 334@ms_function 335def t1_while(x, y, z): 336 out = x 337 i = c1 338 while i < c2: 339 out = out + conv(z) 340 i = i + c3 341 out = out + out 342 return out 343 344 345def test_while_net(): 346 y = Tensor(np.ones([1, 3, 3, 4]).astype(np.float32)) 347 x = Tensor(np.ones([1, 16, 12, 12]).astype(np.float32)) 348 z = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32)) 349 res = t1_while(x, y, z) 350 assert np.all(res.asnumpy() == np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0) 351 352 353@ms_function 354def if_while(a, b, x, z): 355 c = a 356 i = c1 357 out = x 358 if a < b: 359 c = a + a 360 while i < c2: 361 out = out + conv(z) 362 i = i + c3 363 else: 364 c = b + b 365 out = c + c 366 return out 367 368 369def test_if_while(): 370 x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32)) 371 z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32)) 372 res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z) 373 assert np.all(res.asnumpy() == np.ones([64, 10]).astype(np.float32) * 4.0) 374 375 376def _while(x): 377 """ _while """ 378 ret = x * x 379 i = 2 380 while i <= 3: 381 ret = ret * i 382 i = i + 1 383 return ret 384 385 386def grad_while(x): 387 """ grad_while """ 388 return grad_all(_while)(x) 389 390 391def test_grad_while(): 392 """ test_grad_while """ 393 assert grad_while(Tensor(5, dtype=ms.int32)) == (60,) 394 395 396@ms_function 397def factorial(n): 398 """ factorial """ 399 if n == 0: 400 return 1 401 return n * factorial(n - 1) 402 403 404def test_factorial(): 405 res = factorial(3) 406 assert res == 6 407 408 409@ms_function 410def factorial2(n): 411 """ factorial """ 412 if n != 0: 413 return n * factorial2(n - 1) 414 elif n == 1: 415 return 1 * factorial2(n - 1) 416 else: 417 return 1 418 419 420def test_factorial2(): 421 res = factorial2(3) 422 assert res == 6 423 424 425@ms_function 426def foo(n): 427 if n <= 1: 428 if n == 1: 429 return foo(n - 1) 430 else: 431 return 1 432 else: 433 return foo(n - 1) 434 435 436def test_foo(): 437 res = foo(5) 438 assert res == 1 439 440 441@ms_function 442def double_nested_loop(x): 443 i = 0 444 s = 0 445 while i < x: 446 j = 0 447 i = i + 1 448 while j < 3: 449 j = j + 1 450 s = s + j 451 return s 452 453 454def test_nested_loop(): 455 res = double_nested_loop(3) 456 assert res == 18 457 458 459@ms_function 460def double_nested_loop2(x): 461 s = 0 462 for i in range(x): 463 for j in range(3): 464 s = s + j 465 return s 466 467 468def test_nested_loop2(): 469 res = double_nested_loop(1) 470 assert res == 6 471 472 473def _for(x): 474 """ _for """ 475 ret = x * x 476 for i in (2, 3): 477 ret = ret * i 478 return ret 479 480 481@ms_function 482def grad_for(x): 483 """ grad_for """ 484 return grad_all(_for)(x) 485 486 487@ms_function 488def try_tail(x): 489 """ try_tail """ 490 return C.tail(x) 491 492 493@non_graph_engine 494def test_tail(): 495 """ test_tail """ 496 try_tail((0, 1, 2, 3)) 497 498 499@ms_function 500def zero_like_tensor(x): 501 """ zero_like_tensor """ 502 return C.zeros_like(x) 503 504 505def test_zeros(): 506 """ test_zeros """ 507 x = Tensor(np.ones([2, 3]).astype(np.int32)) 508 res = zero_like_tensor(x) 509 assert np.all(res.asnumpy() == np.zeros([2, 3]).astype(np.int32)) 510 511 512@ms_function 513def arithmetic_simplify_01(x, y): 514 """ arithmetic_simplify_01 """ 515 return C.zeros_like(x) * y 516 517 518def test_arithmetic_simplify_01(): 519 """ test_arithmetic_simplify_01 """ 520 x = Tensor(np.ones([2, 3]).astype(np.int32)) 521 y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) 522 res = arithmetic_simplify_01(x, y) 523 expect = np.zeros([2, 3]).astype(np.int32) 524 assert np.all(res.asnumpy() == expect) 525 526 527@ms_function 528def arithmetic_simplify_02(x, y): 529 """ arithmetic_simplify_02 """ 530 return C.ones_like(x) * y 531 532 533def test_arithmetic_simplify_02(): 534 """ test_arithmetic_simplify_02 """ 535 x = Tensor(np.ones([2, 3]).astype(np.int32)) 536 y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) 537 res = arithmetic_simplify_02(x, y) 538 expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) 539 assert np.all(res.asnumpy() == expect) 540 541 542@ms_function 543def arithmetic_simplify_03(x, y): 544 """ arithmetic_simplify_03 """ 545 return x * C.ones_like(y) 546 547 548def test_arithmetic_simplify_03(): 549 """ test_arithmetic_simplify_03 """ 550 x = Tensor(np.ones([2, 3]).astype(np.int32)) 551 y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) 552 res = arithmetic_simplify_03(x, y) 553 expect = np.ones([2, 3]).astype(np.int32) 554 assert np.all(res.asnumpy() == expect) 555 556 557@ms_function 558def arithmetic_simplify_04(x): 559 """ arithmetic_simplify_04 """ 560 return x + 0 561 562 563def test_arithmetic_simplify_04(): 564 """ test_arithmetic_simplify_04 """ 565 x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) 566 res = arithmetic_simplify_04(x) 567 expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) 568 assert np.all(res.asnumpy() == expect) 569 570 571@ms_function 572def arithmetic_simplify_05(x): 573 """ arithmetic_simplify_05 """ 574 return x * 1 575 576 577def test_arithmetic_simplify_05(): 578 """ test_arithmetic_simplify_05 """ 579 x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) 580 res = arithmetic_simplify_05(x) 581 expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) 582 assert np.all(res.asnumpy() == expect) 583 584 585@ms_function 586def arithmetic_simplify_06(x): 587 """ arithmetic_simplify_06 """ 588 return x * 2 * 5 589 590 591def test_arithmetic_simplify_06(): 592 """ test_arithmetic_simplify_06 """ 593 x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) 594 res = arithmetic_simplify_06(x) 595 expect = np.array([[10, 20, 30], [40, 50, 60]]).astype(np.int32) 596 assert np.all(res.asnumpy() == expect) 597 598 599@ms_function 600def arithmetic_simplify_07(x): 601 """ arithmetic_simplify_07 """ 602 return (x + 1) * 2 * 5 603 604 605def test_arithmetic_simplify_07(): 606 """ test_arithmetic_simplify_07 """ 607 x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) 608 res = arithmetic_simplify_07(x) 609 expect = np.array([[20, 30, 40], [50, 60, 70]]).astype(np.int32) 610 assert np.all(res.asnumpy() == expect) 611 612 613@ms_function 614def arithmetic_simplify_08(x, y): 615 """ arithmetic_simplify_08 """ 616 return 1 * x * 1 * 1 + 1 * 0 * 1 + 0 + y * 1 617 618 619def test_arithmetic_simplify_08(): 620 """ test_arithmetic_simplify_08 """ 621 x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)) 622 y = Tensor(np.ones([2, 3]).astype(np.int32)) 623 res = arithmetic_simplify_08(x, y) 624 expect = np.array([[2, 3, 4], [5, 6, 7]]).astype(np.int32) 625 assert np.all(res.asnumpy() == expect) 626 627 628def test_GradCheckerPrimitive(): 629 """ test_GradCheckerPrimitive """ 630 matmul = P.MatMul() 631 632 def prim_f(x, y): 633 return matmul(x, y) 634 635 check_gradient(prim_f, Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)), 636 Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), 637 grad_checker_class=OperationGradChecker, sampling_times=2) 638 639 640def test_NNGradChecker(): 641 """ test_NNGradChecker """ 642 643 class Net(nn.Cell): 644 """ Net definition """ 645 646 def __init__(self): 647 super(Net, self).__init__() 648 self.dense = nn.Dense(10, 10) 649 650 def construct(self, x): 651 out = self.dense(x) 652 return out 653 654 check_gradient(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)), 655 delta=1e-3, 656 max_error=1e-3, 657 grad_checker_class=NNGradChecker, sampling_times=3) 658 659 660def test_OperationGradChecker(): 661 """ test_OperationGradChecker """ 662 663 class Net(nn.Cell): 664 """ Net definition """ 665 666 def __init__(self): 667 super(Net, self).__init__() 668 self.matmul = P.MatMul() 669 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 670 671 def construct(self, x, y): 672 x = x * self.z 673 out = self.matmul(x, y) 674 return out 675 676 check_gradient(Net(), Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)), 677 Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), grad_checker_class=OperationGradChecker, 678 input_selector=[1], sampling_times=2) 679 680 681def test_OperationJacobianChecker(): 682 """ test_OperationJacobianChecker """ 683 684 class Net(nn.Cell): 685 """ Net definition """ 686 687 def __init__(self): 688 super(Net, self).__init__() 689 self.matmul = P.MatMul() 690 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 691 692 def construct(self, x, y): 693 x = x * self.z 694 out = self.matmul(x, y) 695 return x, out 696 697 check_jacobian(Net(), Tensor(np.array([[0.65, 0.8, 0.8], [0.1, 0.2, 0.3]], np.float32)), 698 Tensor(np.array([[0.1, 0.3], [0.2, 0.2], [-.1, 0.4]], np.float32)), 699 grad_checker_class=OperationGradChecker, input_selector=[0], 700 output_selector=[0]) 701 702 703def test_NNJacobianChecker(): 704 """ test_NNJacobianChecker """ 705 706 class Net(nn.Cell): 707 """ Net definition """ 708 709 def __init__(self): 710 super(Net, self).__init__() 711 self.dense = nn.Dense(10, 10) 712 713 def construct(self, x): 714 out = self.dense(x) 715 return out, x 716 717 check_jacobian(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)), 718 delta=1e-3, 719 max_error=1e-7, 720 grad_checker_class=NNGradChecker, 721 input_selector=[1], 722 output_selector=[0]) 723 724 725def multi_outputs(x, y): 726 z = x + y 727 return 2 * z, 2 * z 728 729 730@ms_function 731def while_sp(x, y, z): 732 out = x 733 i = c3 734 while i < c2: 735 out = mul(x, out) 736 i = i + c3 737 return out 738 739 740def test_while_sp(): 741 y = Tensor(np.ones([1, 3]).astype(np.float32)) 742 z = Tensor(np.ones([1, 3]).astype(np.float32)) 743 x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0) 744 res = while_sp(x, y, z) 745 assert np.all(res.asnumpy() == np.ones([1, 3]).astype(np.float32) * 1024.0) 746 747 748def grad_refactor_simple_1(x, y): 749 """ add """ 750 return x * x + 2 * y 751 752 753def test_grad_refactor_simple_1(): 754 assert grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2) 755 756 757def grad_refactor_simple_2(x, y, z): 758 """ add """ 759 return x * y + z + x * y * z + x + x * y 760 761 762def test_grad_refactor_simple_2(): 763 x = Tensor(2, dtype=ms.int32) 764 y = Tensor(3, dtype=ms.int32) 765 z = Tensor(0, dtype=ms.int32) 766 assert grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7) 767 768 769def grad_refactor_1(a, b): 770 """ if_test """ 771 772 def inner(x, y): 773 return x * y 774 775 return inner(a, b) 776 777 778def test_grad_refactor_1(): 779 assert grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2) 780 781 782def grad_refactor_2(a, b): 783 """ if_test """ 784 785 def inner(x): 786 return x * b 787 788 return inner(b) * inner(a) 789 790 791def test_grad_refactor_2(): 792 assert grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54) 793 794 795def grad_refactor_3(a): 796 """ if_test """ 797 if a > 3: 798 return 0 799 return 3 * a 800 801 802def grad_refactor_4(a): 803 """ if_test """ 804 if a > 3: 805 return 3 * a 806 return 0 807 808 809def test_grad_refactor_4(): 810 assert grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,) 811 812 813def grad_refactor_5(a): 814 """ if_test """ 815 if a > 3: 816 return 1 817 return a 818 819 820def grad_refactor_6(a, b): 821 """ if_test """ 822 if a > b: 823 return 3 * a + b 824 return 2 * b * a 825 826 827def test_grad_refactor_6(): 828 assert grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1) 829 830 831def grad_refactor_while(x): 832 """ grad_refactor_while """ 833 rval = x 834 while rval < 4: 835 rval = rval * rval 836 return rval 837 838 839def grad_refactor__while_1(x): 840 """ _while """ 841 ret = x * x 842 i = 2 843 while i <= 3: 844 ret = ret * i 845 i = i + 1 846 return ret 847 848 849def test_grad_refactor_10(): 850 """ test_grad_while """ 851 assert grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,) 852 853 854def test_grad_refactor_11(): 855 class Net(nn.Cell): 856 """ Net definition """ 857 858 def __init__(self): 859 super(Net, self).__init__() 860 861 def construct(self, x, y): 862 return x * y * y 863 864 net = Net() 865 grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32))) 866 867 868def test_grad_refactor_12(): 869 class Net(nn.Cell): 870 """ Net definition """ 871 872 def __init__(self): 873 super(Net, self).__init__() 874 self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') 875 876 def construct(self, x, y): 877 return x * self.z * y 878 879 net = Net() 880 grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32))) 881 882 883def test_grad_refactor_13(): 884 class Net(nn.Cell): 885 """ Net definition """ 886 887 def __init__(self): 888 super(Net, self).__init__() 889 self.z = Parameter(Tensor(np.ones([2]).astype(np.float32)), name='z') 890 891 def construct(self, x, y): 892 return x * self.z * y 893 894 net = Net() 895 weights = ParameterTuple(net.trainable_params()) 896 grad_by_list(net, weights)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32))) 897 898 899def grad_refactor_14(a, b): 900 """ if_test """ 901 902 def inner1(x): 903 return x * b 904 905 def inner2(x): 906 return a * b 907 908 def inner3(x): 909 if x > 2: 910 return a 911 return b 912 913 return inner1(b) + inner2(a) + inner3(a) 914 915 916# pylint: disable=using-constant-test 917class IfDeferInline(nn.Cell): 918 def __init__(self, mul_size): 919 super().__init__() 920 self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32)) 921 self.mul = P.Mul() 922 923 def construct(self, inputs): 924 x = self.mul(inputs, self.mul_weight) 925 if True: 926 x = x 927 return x 928 929 930def test_grad_if_defer_inline(): 931 """ test_grad_if_defer_inline """ 932 network = IfDeferInline([128, 96]) 933 network.add_flags(defer_inline=False) 934 inp = Tensor(np.ones([128, 96]).astype(np.float32)) 935 grads = grad_all(network)(inp) 936 assert np.all(grads[0].asnumpy() == np.full([128, 96], 0.6, dtype=np.float32)) 937 938 939def test_dict_const(): 940 class Net(nn.Cell): 941 def __init__(self): 942 super(Net, self).__init__() 943 self.res = {'1': 10} 944 945 def construct(self): 946 return self.res 947 948 Net()() 949