1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test control ops """ 16import os 17import numpy as np 18import pytest 19 20import mindspore as ms 21from mindspore import Tensor 22from mindspore import context 23from mindspore import nn 24from mindspore.common import dtype as mstype 25from mindspore.ops import composite as C 26from mindspore.ops import functional as F 27from mindspore.ops import operations as P 28from mindspore.common.parameter import Parameter, ParameterTuple 29from mindspore.common import ms_function 30 31context.set_context(mode=context.GRAPH_MODE) 32 33grad_by_list = C.GradOperation(get_by_list=True) 34grad_all = C.GradOperation(get_all=True) 35grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) 36 37 38def cond_data_test(x_init, y_init): 39 class Net(nn.Cell): 40 def __init__(self): 41 """""" 42 super(Net, self).__init__() 43 self.square = P.Square() 44 self.add = P.Add() 45 self.value = Tensor(3, dtype=ms.float32) 46 self.switch = P.GeSwitch() 47 self.merge = P.Merge() 48 self.less = P.Less() 49 50 def construct(self, x, y): 51 cond = self.less(x, y) 52 st1, _ = self.switch(x, cond) 53 st2, _ = self.switch(y, cond) 54 add_ret = self.add(st1, st2) 55 _, sf3 = self.switch(self.value, cond) 56 sq_ret = self.square(sf3) 57 ret = self.merge((add_ret, sq_ret)) 58 return ret[0] 59 60 x = Tensor(x_init, dtype=ms.float32) 61 y = Tensor(y_init, dtype=ms.float32) 62 net = Net() 63 output = net(x, y) 64 return output 65 66 67def test_cond_data_true(): 68 output = cond_data_test(3, 8) 69 print("test_cond_data_true:", output) 70 71 72def test_cond_data_false(): 73 output = cond_data_test(8, 3) 74 print("test_cond_data_false:", output) 75 76 77def if_compile_test(x_init, y_init): 78 class Net(nn.Cell): 79 def __init__(self): 80 """""" 81 super(Net, self).__init__() 82 self.square = P.Square() 83 self.add = P.Add() 84 self.value = Tensor(3, dtype=ms.float32) 85 self.switch = P.GeSwitch() 86 self.merge = P.Merge() 87 self.less = P.Less() 88 89 def construct(self, x, y): 90 cond = self.less(x, y) 91 ret = self.value 92 if cond: 93 ret = self.add(x, ret) 94 ret = self.add(y, ret) 95 else: 96 ret = self.square(self.value) 97 return ret 98 99 x = Tensor(x_init, dtype=ms.float32) 100 y = Tensor(y_init, dtype=ms.float32) 101 net = Net() 102 output = net(x, y) 103 return output 104 105 106def test_if_none(): 107 class Net(nn.Cell): 108 def __init__(self, z: None): 109 """""" 110 super(Net, self).__init__() 111 self.z = z 112 113 def construct(self, x, y): 114 if self.z: 115 ret = x 116 else: 117 ret = y 118 return ret 119 120 x = Tensor(np.ones([6, 8, 10], np.int32)) 121 y = Tensor(np.zeros([3, 4, 5], np.int32)) 122 z = None 123 net = Net(z) 124 assert np.all(net(x, y).asnumpy() == y.asnumpy()) 125 126 127def test_if_str_is_not_none_right(): 128 class Net(nn.Cell): 129 def __init__(self, z: str): 130 """""" 131 super(Net, self).__init__() 132 self.z = z 133 134 def construct(self, x, y): 135 if self.z is None: 136 ret = x 137 else: 138 ret = y 139 return ret 140 141 x = Tensor(np.ones([6, 8, 10], np.int32)) 142 y = Tensor(np.zeros([3, 4, 5], np.int32)) 143 z = "ok" 144 net = Net(z) 145 assert np.all(net(x, y).asnumpy() == y.asnumpy()) 146 147 148def test_if_str_is_not_none_left(): 149 class Net(nn.Cell): 150 def __init__(self, z: str): 151 """""" 152 super(Net, self).__init__() 153 self.z = z 154 155 def construct(self, x, y): 156 if self.z is None: 157 ret = x 158 else: 159 ret = y 160 return ret 161 162 x = Tensor(np.ones([6, 8, 10], np.int32)) 163 y = Tensor(np.zeros([3, 4, 5], np.int32)) 164 z = "ok" 165 net = Net(z) 166 assert np.all(net(x, y).asnumpy() == y.asnumpy()) 167 168 169def test_if_none_equal_none(): 170 class Net(nn.Cell): 171 def __init__(self, z: None): 172 """""" 173 super(Net, self).__init__() 174 self.z = z 175 176 def construct(self, x, y): 177 if self.z is None: 178 ret = x 179 else: 180 ret = y 181 return ret 182 183 x = Tensor(np.ones([6, 8, 10], np.int32)) 184 y = Tensor(np.zeros([3, 4, 5], np.int32)) 185 z = None 186 net = Net(z) 187 assert np.all(net(x, y).asnumpy() == x.asnumpy()) 188 189 190def test_if_str_is_null(): 191 class Net(nn.Cell): 192 def __init__(self, z: str): 193 """""" 194 super(Net, self).__init__() 195 self.z = z 196 197 def construct(self, x, y): 198 if self.z: 199 ret = x 200 else: 201 ret = y 202 return ret 203 204 x = Tensor(np.ones([6, 8, 10], np.int32)) 205 y = Tensor(np.zeros([3, 4, 5], np.int32)) 206 z = "" 207 net = Net(z) 208 assert np.all(net(x, y).asnumpy() == y.asnumpy()) 209 210 211def test_if_str_is_true(): 212 class Net(nn.Cell): 213 def __init__(self, z: str): 214 """""" 215 super(Net, self).__init__() 216 self.z = z 217 218 def construct(self, x, y): 219 if self.z: 220 ret = x 221 else: 222 ret = y 223 return ret 224 225 x = Tensor(np.ones([6, 9, 10], np.int32)) 226 y = Tensor(np.zeros([3, 4, 5], np.int32)) 227 z = "ok" 228 net = Net(z) 229 assert np.all(net(x, y).asnumpy() == x.asnumpy()) 230 231 232def test_if_str_equal(): 233 class Net(nn.Cell): 234 def __init__(self, z: str): 235 """""" 236 super(Net, self).__init__() 237 self.z = z 238 239 def construct(self, x, y): 240 if self.z == "ok": 241 ret = x 242 else: 243 ret = y 244 return ret 245 246 x = Tensor(np.ones([6, 8, 10], np.int32)) 247 y = Tensor(np.zeros([3, 4, 5], np.int32)) 248 z = "ok" 249 net = Net(z) 250 assert np.all(net(x, y).asnumpy() == x.asnumpy()) 251 252 253def test_if_tuple_is_null(): 254 class Net(nn.Cell): 255 def __init__(self, z: tuple): 256 """""" 257 super(Net, self).__init__() 258 self.z = z 259 260 def construct(self, x, y): 261 if self.z: 262 ret = x 263 else: 264 ret = y 265 return ret 266 267 x = Tensor(np.ones([6, 8, 10], np.int32)) 268 y = Tensor(np.zeros([3, 4, 5], np.int32)) 269 z = () 270 net = Net(z) 271 assert np.all(net(x, y).asnumpy() == y.asnumpy()) 272 273 274def test_if_tuple_is_not_null(): 275 class Net(nn.Cell): 276 def __init__(self, z: tuple): 277 """""" 278 super(Net, self).__init__() 279 self.z = z 280 281 def construct(self, x, y): 282 if self.z: 283 ret = x 284 else: 285 ret = y 286 return ret 287 288 x = Tensor(np.ones([6, 8, 10], np.int32)) 289 y = Tensor(np.zeros([3, 4, 5], np.int32)) 290 z = (1, 2, 3) 291 net = Net(z) 292 assert np.all(net(x, y).asnumpy() == x.asnumpy()) 293 294 295def test_if_dict_is_null(): 296 class Net(nn.Cell): 297 def __init__(self, z: dict): 298 """""" 299 super(Net, self).__init__() 300 self.z = z 301 302 def construct(self, x, y): 303 if self.z: 304 ret = x 305 else: 306 ret = y 307 return ret 308 309 x = Tensor(np.ones([6, 8, 10], np.int32)) 310 y = Tensor(np.zeros([3, 4, 5], np.int32)) 311 z = {} 312 net = Net(z) 313 assert np.all(net(x, y).asnumpy() == y.asnumpy()) 314 315 316def test_if_dict_is_not_null(): 317 class Net(nn.Cell): 318 def __init__(self, z: dict): 319 """""" 320 super(Net, self).__init__() 321 self.z = z 322 323 def construct(self, x, y): 324 if self.z: 325 ret = x 326 else: 327 ret = y 328 return ret 329 330 x = Tensor(np.ones([6, 8, 10], np.int32)) 331 y = Tensor(np.zeros([3, 4, 5], np.int32)) 332 z = {"one": 1, "two": 2} 333 net = Net(z) 334 assert np.all(net(x, y).asnumpy() == x.asnumpy()) 335 336 337def test_if_else_assign(): 338 class Net(nn.Cell): 339 def __init__(self, m: list): 340 """""" 341 super(Net, self).__init__() 342 self.m = m 343 self.n = [4, 5, 6] 344 345 def construct(self, x, y): 346 exp_1 = self.m if self.m else self.n 347 exp_2 = self.m if exp_1 == self.n else self.n 348 if exp_2 == self.m: 349 if self.m: 350 ret = x 351 else: 352 ret = y 353 else: 354 if self.m: 355 ret = x 356 else: 357 ret = y 358 return ret 359 360 x = Tensor(np.ones([6, 8, 10], np.int32)) 361 y = Tensor(np.zeros([3, 4, 5], np.int32)) 362 z = [1, 2] 363 net = Net(z) 364 assert np.all(net(x, y).asnumpy() == x.asnumpy()) 365 366 367def test_if_compile_true(): 368 output = if_compile_test(3, 8) 369 print("test_if_compile_true:", output) 370 371 372def test_if_compile_false(): 373 output = if_compile_test(8, 3) 374 print("test_if_compile_false:", output) 375 376 377def test_switch_layer(): 378 class Layer1(nn.Cell): 379 def __init__(self): 380 super(Layer1, self).__init__() 381 self.z1 = Parameter( 382 Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') 383 384 def construct(self, x): 385 return x * self.z1 386 387 class Layer2(nn.Cell): 388 def __init__(self): 389 super(Layer2, self).__init__() 390 self.z2 = Parameter( 391 Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') 392 393 def construct(self, x): 394 return x * self.z2 395 396 class SwitchLayerCell(nn.Cell): 397 def __init__(self): 398 super(SwitchLayerCell, self).__init__() 399 self.layers = (Layer1(), Layer2()) 400 self.z3 = Parameter( 401 Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') 402 403 def construct(self, index, x): 404 ret = F.switch_layer(index, self.layers)(x) * self.z3 405 return ret 406 407 index = Tensor(0, dtype=mstype.int32) 408 net = SwitchLayerCell() 409 net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) 410 grad_by_list(net, ParameterTuple(net.trainable_params()))(index, 411 Tensor(np.full([128, 96], 0.6, dtype=np.float32))) 412 grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) 413 414 415def test_index_to_switch_layer(): 416 class Layer1(nn.Cell): 417 def __init__(self): 418 super(Layer1, self).__init__() 419 self.z1 = Parameter( 420 Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') 421 422 def construct(self, x): 423 return x * self.z1 424 425 class Layer2(nn.Cell): 426 def __init__(self): 427 super(Layer2, self).__init__() 428 self.z2 = Parameter( 429 Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') 430 431 def construct(self, x): 432 return x * self.z2 433 434 class SwitchLayerCell(nn.Cell): 435 def __init__(self): 436 super(SwitchLayerCell, self).__init__() 437 self.layers = (Layer1(), Layer2()) 438 self.z3 = Parameter( 439 Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') 440 441 def construct(self, index, x): 442 ret = self.layers[index](x) * self.z3 443 return ret 444 445 index = Tensor(0, dtype=mstype.int32) 446 net = SwitchLayerCell() 447 net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) 448 grad_by_list(net, ParameterTuple(net.trainable_params()))(index, 449 Tensor(np.full([128, 96], 0.6, dtype=np.float32))) 450 grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) 451 452 453def test_parser_switch_layer_switch_in_bprop(): 454 class OneInputBprop(nn.Cell): 455 def __init__(self, funcs): 456 super(OneInputBprop, self).__init__() 457 self.op = P.ReLU() 458 self.funcs = funcs 459 460 def construct(self, i, x): 461 return self.op(x) 462 463 def bprop(self, i, x, out, dout): 464 return i, self.funcs[i](x, dout) 465 466 class Add(nn.Cell): 467 def __init__(self): 468 super().__init__() 469 self.op = P.Add() 470 471 def construct(self, x, y): 472 return self.op(x, y) 473 474 class Mul(nn.Cell): 475 def __init__(self): 476 super().__init__() 477 self.op = P.Mul() 478 479 def construct(self, x, y): 480 return self.op(x, y) 481 482 func1 = Add() 483 func2 = Mul() 484 funcs = (func1, func2) 485 net = OneInputBprop(funcs) 486 input1 = Tensor(np.ones([2, 2]).astype(np.float32)) 487 grad = Tensor(np.random.randn(2, 2).astype(np.float32)) 488 i = Tensor(1, mstype.int32) 489 grad_net = grad_all_with_sens(net) 490 grad_net(i, input1, grad) 491 492 493def test_parser_switch_layer_inputs_tuple(): 494 class TwoInputTupleFinalNet(nn.Cell): 495 def __init__(self, funcs): 496 super().__init__() 497 self.funcs = funcs 498 499 def construct(self, i, inputa, inputb): 500 inputs = (inputa, inputb) 501 x = self.funcs[i](inputs) 502 return x 503 504 class Add(nn.Cell): 505 def __init__(self): 506 super().__init__() 507 self.op = P.Add() 508 509 def construct(self, x): 510 y = self.op(x[0], x[1]) 511 return self.op(x[0], y) 512 513 class Mul(nn.Cell): 514 def __init__(self): 515 super().__init__() 516 self.op = P.Mul() 517 518 def construct(self, x): 519 y = self.op(x[0], x[1]) 520 return self.op(x[0], y) 521 522 func1 = Add() 523 func2 = Mul() 524 525 funcs = (func1, func2) 526 net = TwoInputTupleFinalNet(funcs) 527 528 input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) 529 input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) 530 i = Tensor(1, mstype.int32) 531 grad = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) 532 back_net = grad_all_with_sens(net) 533 back_out = back_net(i, input1, input2, grad) 534 535 536def test_switch_layer_with_single_prim(): 537 class SwitchLayerCell(nn.Cell): 538 def __init__(self): 539 super(SwitchLayerCell, self).__init__() 540 self.layers = (nn.ReLU(), nn.ReLU()) 541 self.z3 = Parameter( 542 Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') 543 544 def construct(self, index, x): 545 ret = self.layers[index](x) * self.z3 546 return ret 547 548 index = Tensor(0, dtype=mstype.int32) 549 net = SwitchLayerCell() 550 net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) 551 grad_by_list(net, ParameterTuple(net.trainable_params()))(index, 552 Tensor(np.full([128, 96], 0.6, dtype=np.float32))) 553 grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) 554 555 556def test_switch_layer_env_eliminate(): 557 class Net(nn.Cell): 558 def __init__(self): 559 super(Net, self).__init__() 560 self.conv = nn.Conv2d(1, 1, 3, pad_mode='same') 561 self.conv2 = nn.Conv2d(1, 1, 5, pad_mode='same') 562 self.funs = (self.conv, self.conv2) 563 564 def construct(self, x, index): 565 x = self.funs[index](x) 566 return x 567 568 class NetGrad(nn.Cell): 569 def __init__(self, net): 570 super(NetGrad, self).__init__() 571 self.grad_op = C.GradOperation(get_by_list=True, sens_param=False) 572 self.net = net 573 self.weights = ParameterTuple(self.net.trainable_params()) 574 575 def construct(self, x, index): 576 weights = self.weights 577 grad = self.grad_op(self.net, weights)(x, index) 578 return grad 579 580 net = Net() 581 net2 = NetGrad(net) 582 x = Tensor(np.ones((3, 1, 12, 12)), ms.float32) 583 i = Tensor(1, ms.int32) 584 net2(x, i) 585 586 587def test_switch_layer_single_layer(): 588 class Net(nn.Cell): 589 def __init__(self): 590 super(Net, self).__init__() 591 self.conv = nn.Conv2d(1, 1, 3, pad_mode='same') 592 self.funs = (self.conv,) 593 594 def construct(self, x, index): 595 x = self.funs[index](x) 596 return x 597 598 class NetGrad(nn.Cell): 599 def __init__(self, net): 600 super(NetGrad, self).__init__() 601 self.grad_op = C.GradOperation(get_by_list=True, sens_param=False) 602 self.net = net 603 self.weights = ParameterTuple(self.net.trainable_params()) 604 605 def construct(self, x, index): 606 weights = self.weights 607 grad = self.grad_op(self.net, weights)(x, index) 608 return grad 609 610 net = Net() 611 net2 = NetGrad(net) 612 x = Tensor(np.ones((3, 1, 12, 12)), ms.float32) 613 i = Tensor(1, ms.int32) 614 net2(x, i) 615 616 617def test_if_nested_compile(): 618 class Net(nn.Cell): 619 def __init__(self, auto_prefix=True): 620 super().__init__(auto_prefix=auto_prefix) 621 self.squre = P.Square() 622 self.value = Tensor(3, dtype=ms.float32) 623 624 def construct(self, x, y): 625 res = self.value 626 if x <= y: 627 res = x + res 628 res = y + res 629 else: 630 if x == y: 631 res = self.squre(self.value * y) 632 else: 633 res = self.squre(self.value) 634 return res 635 636 x = Tensor(1.0, dtype=ms.float32) 637 y = Tensor(2.0, dtype=ms.float32) 638 net = Net() 639 net(x, y) 640 641 642def test_if_inside_for(): 643 class Net(nn.Cell): 644 def __init__(self, auto_prefix=True): 645 super().__init__(auto_prefix=auto_prefix) 646 self.squre = P.Square() 647 self.value = Tensor(3, dtype=ms.float32) 648 self.count = 4 649 650 def construct(self, x, y): 651 res = 0 652 for i in range(self.count): 653 if i == x: 654 res = res + x 655 else: 656 res = res - y 657 return res 658 659 c1 = Tensor(1, dtype=ms.int32) 660 c2 = Tensor(1, dtype=ms.int32) 661 net = Net() 662 net(c1, c2) 663 664 665def test_while_in_while(): 666 c1 = Tensor(1, dtype=ms.int32) 667 c2 = Tensor(2, dtype=ms.int32) 668 c3 = Tensor(3, dtype=ms.int32) 669 c4 = Tensor(4, dtype=ms.int32) 670 671 @ms_function 672 def while_in_while(x, y, z, u): 673 out = c4 674 while x < y: 675 z = c4 + c4 676 while z < y: 677 z = z + 1 678 out = out + 1 679 x = x + 1 680 681 out = out + 3 682 return out 683 684 while_in_while(c1, c2, c3, c4) 685 686 687def test_tensor_cond(): 688 class Net(nn.Cell): 689 def __init__(self): 690 super(Net, self).__init__() 691 self.t = Tensor(np.array(0, np.bool)) 692 self.t1 = Tensor(np.array([True], np.bool)) 693 694 def construct(self, x, y): 695 t = 0 696 if self.t: 697 t = t - x * y 698 else: 699 t = t - x / y 700 if self.t1: 701 t = t + x / y 702 else: 703 t = t + x * y 704 return t 705 706 x = Tensor(np.ones([6, 8, 10], np.int32)) 707 y = Tensor(np.ones([6, 8, 10], np.int32)) 708 net = Net() 709 out = net(x, y) 710 711 712def test_tensor_cond_exception(): 713 class Net(nn.Cell): 714 def __init__(self): 715 super(Net, self).__init__() 716 self.t = Tensor(np.array([True, False], np.bool)) 717 718 def construct(self, x, y): 719 t = 0 720 if self.t: 721 t = t - x * y 722 else: 723 t = t - x / y 724 return t 725 726 x = Tensor(np.ones([6, 8, 10], np.int32)) 727 y = Tensor(np.ones([6, 8, 10], np.int32)) 728 net = Net() 729 with pytest.raises(ValueError): 730 out = net(x, y) 731 732 733def test_while_scalar(): 734 class Net(nn.Cell): 735 def __init__(self): 736 super(Net, self).__init__() 737 self.x = 10 738 739 def construct(self, x, y): 740 i = 0 741 t = 0 742 while (i < 10): 743 t = t + x + y 744 i = i + 1 745 return t 746 747 net = Net() 748 x = Tensor(np.ones([6, 8, 10], np.int32)) 749 y = Tensor(np.ones([6, 8, 10], np.int32)) 750 out = net(x, y) 751 752 753def test_while_with_weight_in_condition(): 754 class Net(nn.Cell): 755 def __init__(self): 756 super(Net, self).__init__() 757 self.loop = Parameter(Tensor(1, dtype=ms.float32), name="loop") 758 759 def construct(self, x): 760 while self.loop < 5: 761 self.loop += 1 762 x += 1 763 return x 764 765 net = Net() 766 x = Tensor(-1, dtype=ms.float32) 767 grad_all(net)(x) 768 769 770def test_mixed_precision_cast(): 771 x = Tensor(np.ones([2, 3], dtype=np.float32)) 772 z = F.mixed_precision_cast(mstype.float16, x) 773 assert z.dtype == mstype.float16 774 775 776def test_while_add(): 777 class Net(nn.Cell): 778 def __init__(self, data): 779 super(Net, self).__init__() 780 self.start = Tensor(0, dtype=mstype.int32) 781 self.end = Tensor(2, dtype=mstype.int32) 782 self.out = Tensor(np.zeros([2, 3], dtype=np.float32)) 783 self.add = P.Add() 784 785 def construct(self, inputs): 786 idx = self.start 787 end = self.end 788 out = self.out 789 while idx < end: 790 xi = inputs[idx, :, :] 791 out = self.add(out, xi) 792 idx = idx + 1 793 return out 794 795 x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32)) 796 net = Net(x) 797 net(x) 798 799 800def test_tensor_all_construct_lack_branch(): 801 class NetConditionLackBranch(nn.Cell): 802 def __init__(self): 803 super(NetConditionLackBranch, self).__init__() 804 self.logicaland = P.LogicalAnd() 805 self.logicalor = P.LogicalOr() 806 807 def construct(self, input1, input2): 808 if input1.all(): 809 return self.logicaland(input1, input2) 810 while input1.any(): 811 return self.logicalor(input1, input2) 812 # NOTICE: here missing return statement, default return None 813 814 input_np_1 = np.random.choice([True], size=(2, 3, 4, 5)) 815 input_tensor_1 = Tensor(input_np_1) 816 input_np_2 = np.random.choice([True, False], size=(2, 3, 4, 5)) 817 input_tensor_2 = Tensor(input_np_2) 818 net = NetConditionLackBranch() 819 with pytest.raises(Exception): 820 net(input_tensor_1, input_tensor_2) 821 822 823def test_parser_switch_layer_func_primitive(): 824 class FinalNet(nn.Cell): 825 def __init__(self, funcs): 826 super().__init__() 827 self.funcs = funcs 828 829 def construct(self, i, input1): 830 x = self.funcs[i](input1) 831 return x 832 833 func1 = P.ReLU() 834 func2 = P.Softmax() 835 funcs = (func1, func2) 836 net = FinalNet(funcs) 837 838 input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) 839 i = Tensor(1, mstype.int32) 840 841 with pytest.raises(ValueError): 842 net(i, input1) 843 844 845def test_switch_layer_shape_join_failed(): 846 class AddFuncNet(nn.Cell): 847 def __init__(self, funcs, new_func): 848 super(AddFuncNet, self).__init__() 849 self.funcs = funcs 850 self.new_func = new_func 851 852 def construct(self, i, inputs): 853 final_funcs = self.funcs + (self.new_func,) 854 x = final_funcs[i](inputs) 855 return x 856 857 class ReLUTuple(nn.Cell): 858 def __init__(self): 859 super(ReLUTuple, self).__init__() 860 self.op = nn.ReLU() 861 862 def construct(self, x): 863 return self.op(x[0]) 864 865 func1 = nn.Softmax() 866 func2 = nn.ReLU() 867 func3 = ReLUTuple() 868 869 funcs = (func1, func2) 870 871 net = AddFuncNet(funcs, func3) 872 873 inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) 874 i = Tensor(1, mstype.int32) 875 with pytest.raises(ValueError) as err: 876 net(i, inp) 877 878 879def test_switch_layer_dtype_join_failed(): 880 class Cast(nn.Cell): 881 def __init__(self, dtype): 882 super(Cast, self).__init__() 883 self.op = P.Cast() 884 self.dtype = dtype 885 886 def construct(self, x): 887 y = self.op(x, self.dtype) 888 return y + y 889 890 class SwitchNegNet(nn.Cell): 891 def __init__(self, funcs): 892 super(SwitchNegNet, self).__init__() 893 self.funcs = funcs 894 self.op = P.Neg() 895 896 def construct(self, i, inputs): 897 x = self.funcs[i](inputs) 898 x = self.op(x) 899 return x 900 901 func1 = nn.ReLU() 902 func2 = Cast(mstype.int32) 903 funcs = (func1, func2) 904 net = SwitchNegNet(funcs) 905 906 inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) 907 i = Tensor(0, mstype.int32) 908 909 with pytest.raises(TypeError) as err: 910 net(i, inp) 911 912 913def test_large_for_loop(): 914 class Net(nn.Cell): 915 def __init__(self): 916 super(Net, self).__init__() 917 self.flatten = P.ReLU() # nn.Flatten() 918 919 def construct(self, x): 920 for elem in range(1, 1900): 921 x = self.flatten(x + elem) 922 return x 923 924 t = Tensor(np.ones([2, 3], dtype=np.float32)) 925 net = Net() 926 os.environ['ENV_RECURSIVE_EVAL'] = '1' 927 old_max_call_depth = context.get_context('max_call_depth') 928 context.set_context(max_call_depth=60) 929 with pytest.raises(RuntimeError) as err: 930 net(t) 931 context.set_context(max_call_depth=old_max_call_depth) 932 os.environ['ENV_RECURSIVE_EVAL'] = '0' 933 assert 'Exceed function call depth limit 60' in str(err.value) 934 935 936def test_large_for_loop_case2(): 937 class Menet(nn.Cell): 938 def __init__(self, axis, flag_boottom, flag_top): 939 super(Menet, self).__init__() 940 self.squeeze = P.Squeeze(axis) 941 self.expanddims = P.ExpandDims() 942 self.flatten = nn.Flatten() 943 self.neg = P.Neg() 944 self.axis = axis 945 self.flag_boottom = flag_boottom 946 self.flag_top = flag_top 947 948 def construct(self, x): 949 if self.flag_boottom: 950 x = self.neg(x) 951 for i in range(0, 1500): 952 x = self.expanddims(x, self.axis) 953 x = self.squeeze(x) 954 x = self.flatten(x) 955 if self.flag_top: 956 x = self.neg(x) 957 return x 958 959 x = Tensor(np.ones([2, 3], dtype=np.float32)) 960 net = Menet(axis=0, flag_boottom=True, flag_top=True) 961 os.environ['ENV_RECURSIVE_EVAL'] = '1' 962 old_max_call_depth = context.get_context('max_call_depth') 963 context.set_context(max_call_depth=80) 964 with pytest.raises(RuntimeError) as err: 965 net(x) 966 os.environ['ENV_RECURSIVE_EVAL'] = '0' 967 context.set_context(max_call_depth=old_max_call_depth) 968 assert 'Exceed function call depth limit 80' in str(err.value) 969 970 971def test_large_for_loop_with_continue_break(): 972 class Net(nn.Cell): 973 def __init__(self): 974 super(Net, self).__init__() 975 self.flatten = P.ReLU() # nn.Flatten() 976 977 def construct(self, x): 978 idx = 0 979 for elem1 in range(200): 980 idx = idx + 1 981 if idx < 10: 982 x = x + 0.5 983 continue 984 if idx > 500: 985 break 986 x = self.flatten(x + elem1) 987 return x 988 989 os.environ['ENV_RECURSIVE_EVAL'] = '1' 990 old_max_call_depth = context.get_context('max_call_depth') 991 context.set_context(max_call_depth=2000) 992 t = Tensor(np.ones([2, 3], dtype=np.float32)) 993 net = Net() 994 net(t) 995 os.environ['ENV_RECURSIVE_EVAL'] = '0' 996 context.set_context(max_call_depth=old_max_call_depth) 997 998 999def test_recursive_call(): 1000 class Net(nn.Cell): 1001 """ Net definition """ 1002 1003 def __init__(self): 1004 super(Net, self).__init__() 1005 self.fc = nn.Dense(10, 10) # padding=0 1006 # self.net2 = Net2() 1007 1008 def construct(self, x): 1009 net2 = Net2() 1010 x = net2(x) 1011 out = self.fc(x) 1012 return out 1013 1014 class Net2(nn.Cell): 1015 def __init__(self): 1016 super(Net2, self).__init__() 1017 self.net = Net() 1018 self.fc = nn.Dense(10, 10) 1019 1020 def construct(self, x): 1021 x = self.net(x) 1022 out = self.fc(x) 1023 return out 1024 1025 context.set_context(mode=context.GRAPH_MODE) 1026 os.environ['ENV_RECURSIVE_EVAL'] = '1' 1027 old_max_call_depth = context.get_context('max_call_depth') 1028 context.set_context(max_call_depth=80) 1029 input_data = Tensor(np.identity(10).astype(np.float32)) 1030 net = Net2() 1031 with pytest.raises(RuntimeError): 1032 net(input_data) 1033 os.environ['ENV_RECURSIVE_EVAL'] = '0' 1034 context.set_context(max_call_depth=old_max_call_depth) 1035 1036 1037# grad for Tensor(Bool) input and eliminate AddN(MakeTuple(Xs, zeros_like(Bool))) 1038def test_grad_tensor_bool(): 1039 class Net(nn.Cell): 1040 def __init__(self): 1041 super(Net, self).__init__() 1042 1043 def construct(self, x, y, z): 1044 out = z 1045 while x: 1046 out = out + z 1047 x = y 1048 return out 1049 1050 x = Tensor(np.array(False).astype(np.bool)) 1051 y = Tensor(np.array(False).astype(np.bool)) 1052 z = Tensor(np.ones([2, 3], dtype=np.float32)) 1053 net = grad_all(Net()) 1054 net(x, y, z) 1055