1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from mindspore import Tensor, jit, ops, mutable, nn, lazy_inline, export, load, context 16from mindspore.common import dtype as mstype 17from mindspore.common.parameter import Parameter 18from mindspore.nn import Cell, GraphCell 19import mindspore.ops.operations as P 20import numpy as np 21import pytest 22 23@pytest.mark.level1 24@pytest.mark.platform_arm_ascend_training 25@pytest.mark.platform_x86_ascend_training 26@pytest.mark.env_onecard 27def test_single_if(): 28 """ 29 Feature: Contrtol flow inline. 30 Description: Inline switch node into kernel graph. 31 Expectation: Not throw exception. 32 """ 33 param_a = Parameter(Tensor(5, mstype.int32), name='a') 34 param_b = Parameter(Tensor(4, mstype.int32), name='b') 35 36 @jit 37 def foo(x, y, param_a, param_b): 38 if param_a > param_b: 39 param_b += 1 40 return x + param_b, y + param_b 41 42 x = Tensor(2, mstype.int32) 43 ret1 = foo(x, x, param_a, param_b) 44 ret2 = foo(x, x, param_a, param_b) 45 assert ret1 == (Tensor(7, mstype.int32), Tensor(7, mstype.int32)) 46 assert ret2 47 48 49@pytest.mark.level1 50@pytest.mark.platform_arm_ascend_training 51@pytest.mark.platform_x86_ascend_training 52@pytest.mark.env_onecard 53def test_return_parameter(): 54 """ 55 Feature: Contrtol flow inline. 56 Description: Control flow if. 57 Expectation: AttributeError. 58 """ 59 param_a = Parameter(Tensor(5)) 60 param_b = Parameter(Tensor(5)) 61 62 @jit 63 def foo(x, param_a, param_b): 64 if x < 3: 65 return param_a 66 return param_b 67 68 ret1 = foo(Tensor(1), param_a, param_b) 69 assert ret1 70 71 72@pytest.mark.level1 73@pytest.mark.platform_arm_ascend_training 74@pytest.mark.platform_x86_ascend_training 75@pytest.mark.env_onecard 76def test_return_param_untail_call(): 77 """ 78 Feature: Contrtol flow inline. 79 Description: Control flow if. 80 Expectation: AttributeError. 81 """ 82 param_a = Parameter(Tensor(5)) 83 param_b = Parameter(Tensor(6)) 84 85 @jit 86 def foo(x, param_a, param_b): 87 if x < 3: 88 z = param_a 89 else: 90 z = param_b 91 z = z + 1 92 z = z - 2 93 z = z * 3 94 z = z / 4 95 return z 96 97 ret1 = foo(Tensor(1), param_a, param_b) 98 assert ret1 99 100 101@pytest.mark.level1 102@pytest.mark.platform_arm_ascend_training 103@pytest.mark.platform_x86_ascend_training 104@pytest.mark.env_onecard 105def test_return_valuenode(): 106 """ 107 Feature: Contrtol flow inline. 108 Description: Control flow if. 109 Expectation: AttributeError. 110 """ 111 112 @jit 113 def foo(x): 114 if x < 3: 115 return 1 116 return 2 117 118 ret1 = foo(Tensor(1)) 119 assert ret1 120 121 122@pytest.mark.level1 123@pytest.mark.platform_arm_ascend_training 124@pytest.mark.platform_x86_ascend_training 125@pytest.mark.env_onecard 126def test_return_input(): 127 """ 128 Feature: Contrtol flow inline. 129 Description: Control flow if. 130 Expectation: AttributeError. 131 """ 132 133 @jit 134 def foo(x, y, z): 135 if x < 3: 136 return y 137 return z 138 139 ret1 = foo(Tensor(1), Tensor(2), Tensor(3)) 140 assert ret1 141 142 143@pytest.mark.level0 144@pytest.mark.platform_arm_ascend_training 145@pytest.mark.platform_x86_ascend_training 146@pytest.mark.env_onecard 147def test_value_node_output_in_single_branch(): 148 """ 149 Feature: Contrtol flow inline. 150 Description: Inline switch node into kernel graph. 151 Expectation: Not throw exception. 152 """ 153 154 @jit 155 def BranchReturnTensor(x, y): 156 x = x + Tensor(2, mstype.int32) 157 y = x + y 158 if x < 5: 159 return y, Tensor(2, mstype.int32) 160 return x, y 161 162 x = Tensor(2, mstype.int32) 163 ret1 = BranchReturnTensor(x, x) 164 ret2 = BranchReturnTensor(x, x) 165 ret3 = BranchReturnTensor(x, x) 166 assert ret1 167 assert ret2 168 assert ret3 169 170 171@pytest.mark.level0 172@pytest.mark.platform_arm_ascend_training 173@pytest.mark.platform_x86_ascend_training 174@pytest.mark.env_onecard 175def test_diff_ref_count_in_branch(): 176 """ 177 Feature: Contrtol flow inline. 178 Description: Inline switch node into kernel graph. 179 Expectation: Not throw exception. 180 """ 181 182 @jit 183 def BranchDiffRefCount(x, y): 184 x = x + Tensor(2, mstype.int32) 185 y = x + y 186 if x < 5: 187 x = x + 3 188 y = x + y 189 else: 190 x = x + 3 191 x = x + 4 192 x = x + 5 193 y = x + y 194 y = x + y 195 y = x + y 196 return x, y 197 198 x = Tensor(2, mstype.int32) 199 ret1 = BranchDiffRefCount(x, x) 200 x = Tensor(4, mstype.int32) 201 ret2 = BranchDiffRefCount(x, x) 202 assert ret1 203 assert ret2 204 205 206@pytest.mark.level1 207@pytest.mark.platform_arm_ascend_training 208@pytest.mark.platform_x86_ascend_training 209@pytest.mark.env_onecard 210def test_branch_kernel_backoff(): 211 """ 212 Feature: Contrtol flow inline. 213 Description: Inline switch node into kernel graph. 214 Expectation: Not throw exception. 215 """ 216 217 @jit 218 def foo(x, y, shape): 219 x = x + Tensor(2, mstype.int32) 220 if y < 5: 221 z = ops.reshape(x, shape) 222 else: 223 z = x 224 return z + 1 225 226 x = Tensor([2, 2, 2, 2, 2, 2], mstype.int32) 227 y = Tensor(2, mstype.int32) 228 ret1 = foo(x, y, mutable((2, 3))) 229 ret2 = foo(x, y, mutable((2, 3))) 230 ret3 = foo(x, y, mutable((2, 3))) 231 assert ret1[0][0] 232 assert ret2[0][0] 233 assert ret3[0][0] 234 235 236@pytest.mark.level0 237@pytest.mark.platform_arm_ascend_training 238@pytest.mark.platform_x86_ascend_training 239@pytest.mark.env_onecard 240def test_update_parameter(): 241 """ 242 Feature: Contrtol flow inline. 243 Description: Control flow if. 244 Expectation: AttributeError. 245 """ 246 247 param_a = Parameter(Tensor(5)) 248 249 @jit 250 def foo(x, param_a): 251 x = x + param_a 252 if x < 3: 253 param_a = param_a + 2 254 else: 255 param_a = param_a + x 256 return param_a 257 258 ret1 = foo(Tensor(1), param_a) 259 ret2 = foo(Tensor(1), param_a) 260 ret3 = foo(Tensor(1), param_a) 261 assert ret1 262 assert ret2 263 assert ret3 264 265 266@pytest.mark.level1 267@pytest.mark.platform_arm_ascend_training 268@pytest.mark.platform_x86_ascend_training 269@pytest.mark.env_onecard 270def test_update_and_return_parameter(): 271 """ 272 Feature: Contrtol flow inline. 273 Description: Control flow if. 274 Expectation: AttributeError. 275 """ 276 277 param_a = Parameter(Tensor(5)) 278 param_b = Parameter(Tensor(5)) 279 280 @jit 281 def foo(x, param_a, param_b): 282 x = x + param_a 283 if x < 3: 284 param_a = param_a + 2 285 param_b = param_b - param_a 286 return Tensor(2), param_b 287 param_a = param_a + x 288 param_b = param_b + param_a 289 return param_a, param_b 290 291 ret1 = foo(Tensor(1), param_a, param_b) 292 ret2 = foo(Tensor(1), param_a, param_b) 293 ret3 = foo(Tensor(1), param_a, param_b) 294 assert ret1 295 assert ret2 296 assert ret3 297 298 299@pytest.mark.level1 300@pytest.mark.platform_arm_ascend_training 301@pytest.mark.platform_x86_ascend_training 302@pytest.mark.env_onecard 303def test_return_switch_input_in_branch(): 304 """ 305 Feature: Contrtol flow inline. 306 Description: Control flow if. 307 Expectation: AttributeError. 308 """ 309 310 param_a = Parameter(Tensor(5)) 311 param_b = Parameter(Tensor(5)) 312 313 @jit 314 def foo(x, param_a, param_b): 315 x = x + param_a 316 if x < 3: 317 param_a = param_a + 2 318 param_b = param_b - param_a 319 return x, param_b 320 param_a = param_a + x 321 param_b = param_b + param_a 322 return param_a, param_b 323 324 ret1 = foo(Tensor(1), param_a, param_b) 325 ret2 = foo(Tensor(1), param_a, param_b) 326 ret3 = foo(Tensor(1), param_a, param_b) 327 assert ret1 328 assert ret2 329 assert ret3 330 331 332@pytest.mark.level1 333@pytest.mark.platform_arm_ascend_training 334@pytest.mark.platform_x86_ascend_training 335@pytest.mark.env_onecard 336def test_return_switch_input(): 337 """ 338 Feature: Contrtol flow inline. 339 Description: Control flow if. 340 Expectation: AttributeError. 341 """ 342 343 param_a = Parameter(Tensor(5)) 344 param_b = Parameter(Tensor(5)) 345 346 @jit 347 def foo(x, param_a, param_b): 348 x = x + param_a 349 if x < 3: 350 param_a = param_a + 2 351 param_b = param_b - param_a 352 else: 353 param_a = param_a + x 354 param_b = param_b + param_a 355 return x, param_b, 3 356 357 ret1 = foo(Tensor(1), param_a, param_b) 358 ret2 = foo(Tensor(1), param_a, param_b) 359 ret3 = foo(Tensor(1), param_a, param_b) 360 assert ret1 361 assert ret2 362 assert ret3 363 364 365@pytest.mark.level0 366@pytest.mark.platform_arm_ascend_training 367@pytest.mark.platform_x86_ascend_training 368@pytest.mark.env_onecard 369def test_tuple_args_to_dynamic_tuple_para(): 370 """ 371 Feature: Contrtol flow inline. 372 Description: Control flow if. 373 Expectation: AttributeError. 374 """ 375 376 @jit 377 def foo(x, y): 378 y_shape = ops.shape(y) 379 if x < 3: 380 y_shape = y_shape * 2 381 else: 382 y_shape = y_shape * 3 383 return y_shape[0] 384 385 ret1 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]])) 386 ret2 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]])) 387 ret3 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]])) 388 assert ret1 389 assert ret2 390 assert ret3 391 392 393@pytest.mark.level0 394@pytest.mark.platform_arm_ascend_training 395@pytest.mark.platform_x86_ascend_training 396@pytest.mark.env_onecard 397def test_tuple_input_to_switch(): 398 """ 399 Feature: Contrtol flow inline. 400 Description: Control flow if. 401 Expectation: AttributeError. 402 """ 403 404 @jit 405 def foo(x, y, dst_shape): 406 y, _ = ops.unique(y) 407 y = ops.reshape(y, dst_shape) 408 y_shape = ops.shape(y) 409 if x < 3: 410 y_shape = y_shape * 2 411 else: 412 y_shape = y_shape * 3 413 return y_shape 414 415 ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3))) 416 ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3))) 417 ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3))) 418 assert ret1[0] 419 assert ret2[0] 420 assert ret3[0] 421 422 423@pytest.mark.level0 424@pytest.mark.platform_arm_ascend_training 425@pytest.mark.platform_x86_ascend_training 426@pytest.mark.env_onecard 427def test_dynamic_tuple_input_to_switch(): 428 """ 429 Feature: Contrtol flow inline. 430 Description: Control flow if. 431 Expectation: AttributeError. 432 """ 433 434 @jit 435 def foo(x, dyn_tuple): 436 if x < 3: 437 dyn_tuple = dyn_tuple * 2 438 else: 439 dyn_tuple = dyn_tuple * 3 440 return dyn_tuple 441 442 ret1 = foo(Tensor(1), mutable((2, 3), dynamic_len=True)) 443 ret2 = foo(Tensor(1), mutable((2, 3), dynamic_len=True)) 444 ret3 = foo(Tensor(1), mutable((2, 3), dynamic_len=True)) 445 assert ret1 446 assert ret2 447 assert ret3 448 449 450@pytest.mark.level1 451@pytest.mark.platform_arm_ascend_training 452@pytest.mark.platform_x86_ascend_training 453@pytest.mark.env_onecard 454def test_return_condition(): 455 """ 456 Feature: Contrtol flow inline. 457 Description: Control flow if. 458 Expectation: AttributeError. 459 """ 460 461 @jit 462 def foo(x, cond): 463 if cond: 464 x = x * 2 465 return x, cond 466 x = x * 3 467 return x, cond 468 469 ret1 = foo(Tensor(1), Tensor(True)) 470 ret2 = foo(Tensor(1), Tensor(True)) 471 ret3 = foo(Tensor(1), Tensor(True)) 472 assert ret1 473 assert ret2 474 assert ret3 475 476 477@pytest.mark.level0 478@pytest.mark.platform_arm_ascend_training 479@pytest.mark.platform_x86_ascend_training 480@pytest.mark.env_onecard 481def test_return_include_other_output(): 482 """ 483 Feature: Contrtol flow inline. 484 Description: Control flow if. 485 Expectation: AttributeError. 486 """ 487 488 @jit 489 def foo(x, y): 490 y = y + 2 491 y = y * 3 492 y = y / 4 493 y = y - 5 494 y = y * y 495 if x < 5: 496 x = x * 2 497 else: 498 x = x + 2 499 return x, y 500 501 ret1 = foo(Tensor(1), Tensor(2)) 502 ret2 = foo(Tensor(1), Tensor(2)) 503 ret3 = foo(Tensor(1), Tensor(2)) 504 assert ret1 505 assert ret2 506 assert ret3 507 508 509@pytest.mark.level1 510@pytest.mark.platform_arm_ascend_training 511@pytest.mark.platform_x86_ascend_training 512@pytest.mark.env_onecard 513def test_branch_output_include_refnode_with_dynamic_shape(): 514 """ 515 Feature: Contrtol flow inline. 516 Description: Control flow if. 517 Expectation: AttributeError. 518 """ 519 520 @jit 521 def foo(x, y, dst_shape): 522 y, _ = ops.unique(y) 523 y = ops.reshape(y, dst_shape) 524 if x < 3: 525 y = ops.expand_dims(y, 1) 526 y = ops.flatten(y) 527 return y 528 529 ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]), mutable((2, 3))) 530 ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3))) 531 ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3))) 532 assert ret1[0][0] 533 assert ret2[0][0] 534 assert ret3[0][0] 535 536 537@pytest.mark.level1 538@pytest.mark.platform_arm_ascend_training 539@pytest.mark.platform_x86_ascend_training 540@pytest.mark.env_onecard 541def test_branch_output_include_refnode_true(): 542 """ 543 Feature: Contrtol flow inline. 544 Description: Control flow if. 545 Expectation: AttributeError. 546 """ 547 548 @jit 549 def foo(x, y): 550 if x < 3: 551 y = ops.expand_dims(y, 1) 552 y = ops.flatten(y) 553 y = y + Tensor([[6, 12], [18, 24], [30, 36]]) 554 return y 555 556 ret1 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 557 ret2 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 558 ret3 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 559 assert ret1.shape 560 assert ret2.shape 561 assert ret3.shape 562 563 564@pytest.mark.level0 565@pytest.mark.platform_arm_ascend_training 566@pytest.mark.platform_x86_ascend_training 567@pytest.mark.env_onecard 568def test_branch_output_include_refnode_false(): 569 """ 570 Feature: Contrtol flow inline. 571 Description: Control flow if. 572 Expectation: AttributeError. 573 """ 574 575 @jit 576 def foo(x, y): 577 if x > 3: 578 y = ops.expand_dims(y, 1) 579 y = ops.flatten(y) 580 y = y + Tensor([[6, 12], [18, 24], [30, 36]]) 581 else: 582 z = y + Tensor([[36, 30], [24, 18], [12, 6]]) 583 y = y + Tensor([[36, 30], [24, 18], [12, 36]]) 584 y = z + y 585 return y * 2 586 587 ret1 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 588 ret2 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 589 ret3 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 590 assert ret1.shape 591 assert ret2.shape 592 assert ret3.shape 593 594 595@pytest.mark.level0 596@pytest.mark.platform_arm_ascend_training 597@pytest.mark.platform_x86_ascend_training 598@pytest.mark.env_onecard 599def test_branch_output_include_refnode_output_ref(): 600 """ 601 Feature: Contrtol flow inline. 602 Description: Control flow if. 603 Expectation: AttributeError. 604 """ 605 606 @jit 607 def foo(x, y): 608 if x > 3: 609 y = ops.expand_dims(y, 1) 610 y = ops.flatten(y) 611 else: 612 z = y + Tensor([[36, 30], [24, 18], [12, 6]]) 613 y = y + Tensor([[36, 30], [24, 18], [12, 36]]) 614 y = z + y 615 return y * 2 616 617 ret1 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 618 ret2 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 619 ret3 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 620 assert ret1.shape 621 assert ret2.shape 622 assert ret3.shape 623 624 625@pytest.mark.level0 626@pytest.mark.platform_arm_ascend_training 627@pytest.mark.platform_x86_ascend_training 628@pytest.mark.env_onecard 629def test_branch_output_include_refnode_twice(): 630 """ 631 Feature: Contrtol flow inline. 632 Description: Control flow if. 633 Expectation: AttributeError. 634 """ 635 636 @jit 637 def foo(x, y): 638 if x > 3: 639 y = ops.expand_dims(y, 1) 640 z1 = ops.flatten(y) 641 z2 = ops.reshape(y, (3, 2)) 642 z3 = z2 * 2 643 z4 = z2 * 3 644 y = z1 + z2 + z3 + z4 645 else: 646 z = y + Tensor([[36, 30], [24, 18], [12, 6]]) 647 y = y + Tensor([[36, 30], [24, 18], [12, 36]]) 648 y = z + y 649 return y * 2 650 651 ret1 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 652 ret2 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 653 ret3 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]])) 654 assert ret1.shape 655 assert ret2.shape 656 assert ret3.shape 657 658 659@pytest.mark.level1 660@pytest.mark.platform_arm_ascend_training 661@pytest.mark.platform_x86_ascend_training 662@pytest.mark.env_onecard 663def test_include_dynamic_shape(): 664 """ 665 Feature: Contrtol flow inline. 666 Description: Control flow if. 667 Expectation: AttributeError. 668 """ 669 670 @jit 671 def foo(x, y): 672 y, _ = ops.unique(y) 673 if x < 3: 674 y = y * 2 675 else: 676 z1 = y / 6 677 z2 = y * 2 678 z3 = y - Tensor([[6, 12, 18], [24, 30, 36]]) 679 z4 = y + Tensor([[1, 2, 3], [4, 5, 6]]) 680 y = z1 + z2 + z3 + z4 681 return y 682 683 ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]])) 684 ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [12, 18, 30], [18, 24, 36]])) 685 ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]])) 686 assert ret1[0] 687 assert ret2[0] 688 assert ret3[0] 689 690 691@pytest.mark.level0 692@pytest.mark.platform_arm_ascend_training 693@pytest.mark.platform_x86_ascend_training 694@pytest.mark.env_onecard 695def test_control_arrow_from_switch_to_gather(): 696 """ 697 Feature: Contrtol flow inline. 698 Description: Control flow if. 699 Expectation: AttributeError. 700 """ 701 param_a = Parameter(Tensor(5)) 702 param_b = Parameter(Tensor(5)) 703 704 @jit 705 def foo(x, param_a, param_b): 706 x = x + param_a 707 if x < 3: 708 param_a = param_a + 2 709 param_b = param_b - param_a 710 return Tensor(2), param_b 711 x = x + param_a 712 return param_a, param_b 713 714 ret1 = foo(Tensor(1), param_a, param_b) 715 ret2 = foo(Tensor(1), param_a, param_b) 716 ret3 = foo(Tensor(1), param_a, param_b) 717 assert ret1 718 assert ret2 719 assert ret3 720 721 722@pytest.mark.level0 723@pytest.mark.platform_arm_ascend_training 724@pytest.mark.platform_x86_ascend_training 725@pytest.mark.env_onecard 726def test_branch_only_u_input(): 727 """ 728 Feature: Contrtol flow inline. 729 Description: Control flow if. 730 Expectation: AttributeError. 731 """ 732 733 @jit 734 def foo(x, y): 735 x = x + 1 736 if x < 3: 737 ops.print("this is true") 738 else: 739 y = ops.reshape(y, (4, 1)) 740 ops.print("this is false") 741 return ops.shape(y) 742 743 ret1 = foo(Tensor(1), Tensor([[1, 2], [3, 4]])) 744 assert ret1 745 746 747@pytest.mark.level0 748@pytest.mark.platform_arm_ascend_training 749@pytest.mark.platform_x86_ascend_training 750@pytest.mark.env_onecard 751def test_branch_u_input_and_input(): 752 """ 753 Feature: Contrtol flow inline. 754 Description: Control flow if. 755 Expectation: AttributeError. 756 """ 757 758 @jit 759 def foo(x, y): 760 x = x + 1 761 if x < 3: 762 ops.print("this is true") 763 else: 764 y = ops.reshape(y, (4, 1)) 765 ops.print("this is false") 766 return ops.shape(y) 767 768 ret1 = foo(Tensor(1), Tensor([[1, 2], [3, 4]])) 769 assert ret1 770 771 772@pytest.mark.level0 773@pytest.mark.platform_arm_ascend_training 774@pytest.mark.platform_x86_ascend_training 775@pytest.mark.env_onecard 776def test_branch_output_real_tuple(): 777 """ 778 Feature: Contrtol flow inline. 779 Description: Control flow if. 780 781 Expectation: AttributeError. 782 """ 783 784 @jit 785 def foo(x, y): 786 if x < 3: 787 y, _ = ops.unique(y) 788 y = ops.expand_dims(y, 1) 789 y = ops.flatten(y) 790 z = ops.shape(y) 791 else: 792 z = ops.shape(y) 793 return z 794 795 ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]])) 796 ret2 = foo(Tensor(5), Tensor([[6, 12, 18], [24, 30, 36]])) 797 assert ret1 798 assert ret2 799 800 801@pytest.mark.level0 802@pytest.mark.platform_arm_ascend_training 803@pytest.mark.platform_x86_ascend_training 804@pytest.mark.env_onecard 805def test_branch_output_dynamic_tuple(): 806 """ 807 Feature: Contrtol flow inline. 808 Description: Control flow if. 809 Expectation: AttributeError. 810 """ 811 812 @jit 813 def foo(x, y, shape): 814 if y < 5: 815 z = ops.reshape(x, shape) 816 out = ops.shape(z) 817 else: 818 out = ops.shape(x) 819 return out 820 821 x = Tensor([2, 2, 2, 2, 2, 2], mstype.int32) 822 y = Tensor(2, mstype.int32) 823 ret1 = foo(x, y, mutable((2, 3), dynamic_len=True)) 824 assert ret1[0] 825 826 827@pytest.mark.level0 828@pytest.mark.platform_arm_ascend_training 829@pytest.mark.platform_x86_ascend_training 830@pytest.mark.env_onecard 831def test_if_after_if(): 832 """ 833 Feature: Contrtol flow inline. 834 Description: Inline switch node into kernel graph. 835 Expectation: Not throw exception. 836 """ 837 param_a = Parameter(Tensor(5, mstype.int32), name='a') 838 param_b = Parameter(Tensor(4, mstype.int32), name='b') 839 840 @jit 841 def foo(x, y, param_a, param_b): 842 if param_a > param_b: 843 param_b += 1 844 if param_a + param_b > 10: 845 param_a += 3 846 return x + param_b, y + param_b 847 848 x = Tensor(2, mstype.int32) 849 ret1 = foo(x, x, param_a, param_b) 850 ret2 = foo(x, x, param_a, param_b) 851 assert ret1 == (Tensor(7, mstype.int32), Tensor(7, mstype.int32)) 852 assert ret2 853 854 855@pytest.mark.level0 856@pytest.mark.platform_arm_ascend_training 857@pytest.mark.platform_x86_ascend_training 858@pytest.mark.env_onecard 859def test_if_in_if(): 860 """ 861 Feature: Contrtol flow inline. 862 Description: Inline switch node into kernel graph. 863 Expectation: Not throw exception. 864 """ 865 param_a = Parameter(Tensor(5, mstype.int32), name='a') 866 param_b = Parameter(Tensor(4, mstype.int32), name='b') 867 868 @jit 869 def foo(x, y, param_a, param_b): 870 if param_a > param_b: 871 param_b += 1 872 if param_a + param_b > 10: 873 param_a += 3 874 return x + param_b, y + param_b 875 876 x = Tensor(2, mstype.int32) 877 ret1 = foo(x, x, param_a, param_b) 878 ret2 = foo(x, x, param_a, param_b) 879 assert ret1 == (Tensor(7, mstype.int32), Tensor(7, mstype.int32)) 880 assert ret2 881 882 883@pytest.mark.level0 884@pytest.mark.platform_arm_ascend_training 885@pytest.mark.platform_x86_ascend_training 886@pytest.mark.env_onecard 887def test_output_ref_of_parameter(): 888 """ 889 Feature: Contrtol flow inline. 890 Description: Inline switch node into kernel graph. 891 Expectation: Not throw exception. 892 """ 893 param_a = Parameter(Tensor(5, mstype.int32), name='a') 894 895 @jit 896 def foo(x, y, param_a): 897 if x > y: 898 out = ops.addn([x, x, param_a]) 899 else: 900 out = ops.assign(param_a, x) 901 return out 902 903 x = Tensor(2, mstype.int32) 904 y = Tensor(1, mstype.int32) 905 ret1 = foo(x, x, param_a) 906 ret2 = foo(x, y, param_a) 907 assert ret1 908 assert ret2 909 910 911@pytest.mark.level0 912@pytest.mark.platform_arm_ascend_training 913@pytest.mark.platform_x86_ascend_training 914@pytest.mark.env_onecard 915def test_gather_switch_gather_output(): 916 """ 917 Feature: Contrtol flow inline. 918 Description: Inline switch node into kernel graph. 919 Expectation: Not throw exception. 920 """ 921 param_a = Parameter(Tensor(5, mstype.int32), name='a') 922 923 @jit 924 def foo(x, y, param_a): 925 if x > y: 926 out = param_a 927 else: 928 out = ops.addn([x, x, x]) 929 if x > y: 930 out = ops.assign(param_a, x) 931 return out 932 933 x = Tensor(1, mstype.int32) 934 y = Tensor(1, mstype.int32) 935 ret1 = foo(x, y, param_a) 936 assert ret1 937 938 939@pytest.mark.level0 940@pytest.mark.platform_arm_ascend_training 941@pytest.mark.platform_x86_ascend_training 942@pytest.mark.env_onecard 943def test_if_in_if_directly(): 944 """ 945 Feature: Contrtol flow inline. 946 Description: Inline switch node into kernel graph. 947 Expectation: Not throw exception. 948 """ 949 param_a = Parameter(Tensor(5, mstype.int32), name='a') 950 param_b = Parameter(Tensor(4, mstype.int32), name='b') 951 952 @jit 953 def foo(x, y, param_a, param_b): 954 x = x + 2 955 if param_a > param_b: 956 if x > y: 957 x += 3 958 x = x + param_a 959 y = x + y 960 return y 961 962 x = Tensor(2, mstype.int32) 963 ret1 = foo(x, x, param_a, param_b) 964 ret2 = foo(x, x, param_a, param_b) 965 assert ret1 966 assert ret2 967 968 969@pytest.mark.level0 970@pytest.mark.platform_arm_ascend_training 971@pytest.mark.platform_x86_ascend_training 972@pytest.mark.env_onecard 973def test_lazy_inline(): 974 """ 975 Feature: Switch inline with lazy inline. 976 Description: All inline in single graph. 977 Expectation: Run successfully and the memory usage is reduced. 978 """ 979 class Grad(Cell): 980 def __init__(self, net): 981 super(Grad, self).__init__() 982 self.grad = ops.GradOperation() 983 self.net = net 984 985 def construct(self, x): 986 grad_net = self.grad(self.net) 987 return grad_net(x) 988 989 class Block(Cell): 990 def __init__(self): 991 super(Block, self).__init__() 992 self.batch_matmul = P.BatchMatMul() 993 self.expand_dims = P.ExpandDims() 994 self.y = Parameter(Tensor(np.ones((8)).astype(np.float32))) 995 996 def construct(self, x): 997 z1 = self.batch_matmul(x, x) 998 z2 = self.expand_dims(self.y, 1) 999 return z1 + z2 1000 1001 class BaseBlock(Cell): 1002 @lazy_inline 1003 def __init__(self): 1004 super(BaseBlock, self).__init__() 1005 self.block = Block() 1006 1007 def construct(self, x): 1008 return self.block(x) 1009 1010 class Net(Cell): 1011 def __init__(self): 1012 super(Net, self).__init__() 1013 self.blocks = nn.CellList() 1014 b = BaseBlock() 1015 self.blocks.append(b) 1016 1017 def construct(self, x): 1018 out = x 1019 for i in range(1): 1020 out = self.blocks[i](out) 1021 return out 1022 class GradNet(Cell): 1023 def __init__(self, net): 1024 super(GradNet, self).__init__() 1025 self.grad_net = Grad(net) 1026 self.a = Parameter(Tensor(np.ones((8)).astype(np.float32))) 1027 self.b = Parameter(Tensor(np.ones((8)).astype(np.float32))) 1028 1029 def construct(self, x, y): 1030 out = self.grad_net(x) 1031 if y > 3: 1032 return out * 2, self.a 1033 return out, self.b 1034 1035 x = Tensor(np.ones((8, 8)).astype(np.float32)) 1036 y = Tensor(6) 1037 net = Net() 1038 grad_net = GradNet(net) 1039 grad_net(x, y) 1040 grad_net(x, y) 1041 1042 1043class TupleParaNet(Cell): 1044 def __init__(self): 1045 super(TupleParaNet, self).__init__() 1046 self.add = ops.Add() 1047 def construct(self, paralist): 1048 length = len(list) 1049 if length >= 2: 1050 x1 = paralist[0] 1051 x2 = paralist[length - 1] 1052 return self.add(x1, x2) 1053 return paralist[0] 1054 1055 1056@pytest.mark.level1 1057@pytest.mark.platform_arm_ascend_training 1058@pytest.mark.platform_x86_ascend_training 1059@pytest.mark.env_onecard 1060def test_tuple_parameter(): 1061 """ 1062 Feature: Contrtol flow inline. 1063 Description: Tuple parameter. 1064 Expectation: Not throw exception. 1065 """ 1066 context.set_context(mode=context.GRAPH_MODE, jit_config={"jit_level": "O0"}) 1067 net = TupleParaNet() 1068 input_2_ele = mutable((2, 3), dynamic_len=True) 1069 export(net, input_2_ele, file_name="test.mindir", file_format="MINDIR") 1070 input_3_ele = mutable((2, 2, 3), dynamic_len=False) 1071 y = load("test.mindir") 1072 mindir_load = GraphCell(y) 1073 print(mindir_load(input_3_ele)) 1074 1075 1076@pytest.mark.level1 1077@pytest.mark.platform_arm_ascend_training 1078@pytest.mark.platform_x86_ascend_training 1079@pytest.mark.env_onecard 1080def test_call_same_graph(): 1081 """ 1082 Feature: Contrtol flow inline. 1083 Description: Two call node call same graph. 1084 Expectation: Not throw exception. 1085 """ 1086 param_a = Parameter(Tensor(5, mstype.float32), name='a') 1087 1088 @jit 1089 def foo(x, y, param_a): 1090 out = Tensor(1, mstype.float32) 1091 for i in range(0, 2): 1092 if x + i < y: 1093 out += param_a 1094 break 1095 return out 1096 1097 x = Tensor(2, mstype.int32) 1098 ret = foo(x, x, param_a) 1099 assert ret 1100