1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test control ops """ 16import pytest 17import numpy as np 18from mindspore import dtype as ms 19from mindspore import Tensor 20from mindspore import context 21from mindspore import nn 22from mindspore import ms_function 23from mindspore.common.parameter import Parameter, ParameterTuple 24from mindspore.ops import composite as C 25from mindspore.ops import operations as P 26# from tests.vm_impl.math_ops_vm_impl import * 27# from tests.vm_impl.vm_interface import * 28# from tests.vm_impl import * 29 30grad_by_list = C.GradOperation(get_by_list=True) 31grad_all = C.GradOperation(get_all=True) 32 33 34@pytest.fixture(scope="module", autouse=True) 35def setup_teardown(): 36 context.set_context(mode=context.PYNATIVE_MODE, precompile_only=True) 37 yield 38 context.set_context(mode=context.GRAPH_MODE, precompile_only=False) 39 40 41def test_while_with_param_forward_with_const_branch(): 42 class MyWhileNet(nn.Cell): 43 def __init__(self): 44 super().__init__() 45 self.max = P.ReduceMax() 46 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 47 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 48 self.reduce = P.ReduceSum() 49 50 @ms_function 51 def construct(self, idx, end, x): 52 out = self.zero 53 while idx < end: 54 if 2 > 1: 55 out = out + self.param 56 else: 57 out = out + idx + self.param 58 idx = idx + 1 59 return out 60 61 while_net = MyWhileNet() 62 net = while_net 63 idx = Tensor(np.array(0), dtype=ms.int32) 64 end = Tensor(np.array(4), dtype=ms.int32) 65 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 66 net(idx, end, x) 67 68 69def test_while_opt_endless(): 70 """endless during optimization case""" 71 class MyWhileNet(nn.Cell): 72 def __init__(self): 73 super().__init__() 74 self.max = P.ReduceMax() 75 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 76 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 77 self.reduce = P.ReduceSum() 78 self.addn = P.AddN() 79 80 def construct(self, idx, end, x): 81 addn1 = self.addn((x, x, x)) 82 out = addn1 83 while idx < end: 84 out = self.addn((out, addn1)) 85 idx = idx + 1 86 out = self.addn((out, x)) 87 return out 88 89 class GradNet(nn.Cell): 90 def __init__(self, net): 91 super(GradNet, self).__init__() 92 self.net = net 93 94 @ms_function 95 def construct(self, *inputs): 96 return grad_all(self.net)(*inputs) 97 98 while_net = MyWhileNet() 99 net = GradNet(while_net) 100 idx = Tensor(np.array(0), dtype=ms.int32) 101 end = Tensor(np.array(4), dtype=ms.int32) 102 x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32) 103 net(idx, end, x) 104 105 106def test_no_while_call(): 107 class MyWhileNet(nn.Cell): 108 def __init__(self): 109 super().__init__() 110 self.max = P.ReduceMax() 111 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 112 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 113 self.reduce = P.ReduceSum() 114 115 @ms_function 116 def construct(self, idx, end, x): 117 out = self.zero 118 if 2 > 1: 119 out = out + self.param 120 else: 121 out = out + idx + self.param 122 return out 123 124 while_net = MyWhileNet() 125 net = while_net 126 idx = Tensor(np.array(0), dtype=ms.int32) 127 end = Tensor(np.array(4), dtype=ms.int32) 128 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 129 net(idx, end, x) 130 131 132def test_while_with_param_grad_with_const_branch(): 133 class MyWhileNet(nn.Cell): 134 def __init__(self): 135 super().__init__() 136 self.max = P.ReduceMax() 137 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 138 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 139 self.reduce = P.ReduceSum() 140 141 def construct(self, idx, end, x): 142 out = self.zero 143 while idx < end: 144 if 2 > 1: 145 out = out + self.param 146 else: 147 out = out + idx + self.param 148 idx = idx + 1 149 return out 150 151 class GradNet(nn.Cell): 152 def __init__(self, net): 153 super(GradNet, self).__init__() 154 self.net = net 155 self.weights = ParameterTuple(net.trainable_params()) 156 157 @ms_function 158 def construct(self, a, b, c): 159 return grad_by_list(self.net, self.weights)(a, b, c) 160 161 while_net = MyWhileNet() 162 net = GradNet(while_net) 163 idx = Tensor(np.array(0), dtype=ms.int32) 164 end = Tensor(np.array(4), dtype=ms.int32) 165 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 166 net(idx, end, x) 167 168 169def test_for_while_with_param_grad_with_const_branch(): 170 class MyWhileNet(nn.Cell): 171 def __init__(self): 172 super().__init__() 173 self.max = P.ReduceMax() 174 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 175 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 176 self.reduce = P.ReduceSum() 177 self.start = Tensor(np.array(0), dtype=ms.int32) 178 179 def construct(self, idx, end, x): 180 out = self.zero 181 for _ in range(0, 2): 182 idx = self.start 183 while idx < end: 184 if 2 > 1: 185 out = out + self.param 186 else: 187 out = out + idx + self.param 188 idx = idx + 1 189 return out 190 191 class GradNet(nn.Cell): 192 def __init__(self, net): 193 super(GradNet, self).__init__() 194 self.net = net 195 self.weights = ParameterTuple(net.trainable_params()) 196 197 @ms_function 198 def construct(self, a, b, c): 199 return grad_by_list(self.net, self.weights)(a, b, c) 200 201 while_net = MyWhileNet() 202 net = GradNet(while_net) 203 idx = Tensor(np.array(0), dtype=ms.int32) 204 end = Tensor(np.array(4), dtype=ms.int32) 205 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 206 net(idx, end, x) 207 208 209def test_for_while_with_param_grad_basic(): 210 class MyWhileNet(nn.Cell): 211 def __init__(self): 212 super().__init__() 213 self.max = P.ReduceMax() 214 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 215 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 216 self.reduce = P.ReduceSum() 217 self.start = Tensor(np.array(0), dtype=ms.int32) 218 219 def construct(self, idx, end, x): 220 out = self.zero 221 for _ in range(0, 2): 222 idx = self.start 223 while idx < end: 224 out = out + self.param 225 idx = idx + 1 226 return out 227 228 class GradNet(nn.Cell): 229 def __init__(self, net): 230 super(GradNet, self).__init__() 231 self.net = net 232 self.weights = ParameterTuple(net.trainable_params()) 233 234 @ms_function 235 def construct(self, a, b, c): 236 return grad_by_list(self.net, self.weights)(a, b, c) 237 238 while_net = MyWhileNet() 239 net = GradNet(while_net) 240 idx = Tensor(np.array(0), dtype=ms.int32) 241 end = Tensor(np.array(4), dtype=ms.int32) 242 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 243 net(idx, end, x) 244 245 246def test_for_while_with_param_grad_normal(): 247 class MyWhileNet(nn.Cell): 248 def __init__(self): 249 super().__init__() 250 self.max = P.ReduceMax() 251 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 252 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 253 self.reduce = P.ReduceSum() 254 self.start = Tensor(np.array(0), dtype=ms.int32) 255 256 def construct(self, idx, end, x): 257 out = x 258 for _ in range(0, 2): 259 idx = self.start 260 while idx < end: 261 out = out + self.param 262 idx = idx + 1 263 return out 264 265 class GradNet(nn.Cell): 266 def __init__(self, net): 267 super(GradNet, self).__init__() 268 self.net = net 269 self.weights = ParameterTuple(net.trainable_params()) 270 271 @ms_function 272 def construct(self, a, b, c): 273 return grad_by_list(self.net, self.weights)(a, b, c) 274 275 while_net = MyWhileNet() 276 net = GradNet(while_net) 277 idx = Tensor(np.array(0), dtype=ms.int32) 278 end = Tensor(np.array(4), dtype=ms.int32) 279 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 280 net(idx, end, x) 281 282 283def test_while_with_param_basic_grad(): 284 class MyWhileNet(nn.Cell): 285 def __init__(self): 286 super().__init__() 287 self.max = P.ReduceMax() 288 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 289 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 290 self.t2 = Tensor(np.array(2), dtype=ms.float32) 291 292 def construct(self, idx, end, x): 293 out = self.zero 294 while idx < end: 295 out = out + self.param 296 idx = idx + 1 297 return out + self.param 298 299 class GradNet(nn.Cell): 300 def __init__(self, net): 301 super(GradNet, self).__init__() 302 self.net = net 303 self.weights = ParameterTuple(net.trainable_params()) 304 305 @ms_function 306 def construct(self, a, b, c): 307 return grad_by_list(self.net, self.weights)(a, b, c) 308 309 while_net = MyWhileNet() 310 net = GradNet(while_net) 311 idx = Tensor(np.array(0), dtype=ms.int32) 312 end = Tensor(np.array(3), dtype=ms.int32) 313 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 314 net(idx, end, x) 315 316 317def test_while_with_param_basic_grad_mul(): 318 class MyWhileNet(nn.Cell): 319 def __init__(self): 320 super().__init__() 321 self.max = P.ReduceMax() 322 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 323 self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32) 324 self.t2 = Tensor(np.array(2), dtype=ms.float32) 325 326 def construct(self, idx, end, x): 327 out = self.zero 328 while idx < end: 329 out = out * self.param 330 idx = idx + 1 331 return out + self.param 332 333 class GradNet(nn.Cell): 334 def __init__(self, net): 335 super(GradNet, self).__init__() 336 self.net = net 337 self.weights = ParameterTuple(net.trainable_params()) 338 339 @ms_function 340 def construct(self, a, b, c): 341 return grad_by_list(self.net, self.weights)(a, b, c) 342 343 while_net = MyWhileNet() 344 net = GradNet(while_net) 345 idx = Tensor(np.array(0), dtype=ms.int32) 346 end = Tensor(np.array(3), dtype=ms.int32) 347 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 348 net(idx, end, x) 349 350 351def test_while_with_param_basic_grad_two(): 352 class MyWhileNet(nn.Cell): 353 def __init__(self): 354 super().__init__() 355 self.max = P.ReduceMax() 356 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 357 self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") 358 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 359 self.t2 = Tensor(np.array(2), dtype=ms.float32) 360 361 def construct(self, idx, end, x): 362 out = self.zero 363 while idx < end: 364 out = out + self.param + self.weight 365 idx = idx + 1 366 return out + self.param 367 368 class GradNet(nn.Cell): 369 def __init__(self, net): 370 super(GradNet, self).__init__() 371 self.net = net 372 self.weights = ParameterTuple(net.trainable_params()) 373 374 @ms_function 375 def construct(self, a, b, c): 376 return grad_by_list(self.net, self.weights)(a, b, c) 377 378 while_net = MyWhileNet() 379 net = GradNet(while_net) 380 idx = Tensor(np.array(0), dtype=ms.int32) 381 end = Tensor(np.array(3), dtype=ms.int32) 382 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 383 net(idx, end, x) 384 385 386def test_while_with_param_basic_grad_three(): 387 class MyWhileNet(nn.Cell): 388 def __init__(self): 389 super().__init__() 390 self.max = P.ReduceMax() 391 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 392 self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") 393 self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key") 394 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 395 self.t2 = Tensor(np.array(2), dtype=ms.float32) 396 397 def construct(self, idx, end, x): 398 out = self.zero 399 while idx < end: 400 out = out + self.param + self.weight + self.key 401 idx = idx + 1 402 return out + self.param 403 404 class GradNet(nn.Cell): 405 def __init__(self, net): 406 super(GradNet, self).__init__() 407 self.net = net 408 self.weights = ParameterTuple(net.trainable_params()) 409 410 @ms_function 411 def construct(self, a, b, c): 412 return grad_by_list(self.net, self.weights)(a, b, c) 413 414 while_net = MyWhileNet() 415 net = GradNet(while_net) 416 idx = Tensor(np.array(0), dtype=ms.int32) 417 end = Tensor(np.array(3), dtype=ms.int32) 418 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 419 net(idx, end, x) 420 421 422def test_while_if_with_param_grad(): 423 class MyWhileNet(nn.Cell): 424 def __init__(self): 425 super().__init__() 426 self.max = P.ReduceMax() 427 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 428 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 429 self.t2 = Tensor(np.array(2), dtype=ms.float32) 430 431 def construct(self, idx, end, x): 432 out = self.zero 433 while idx < end: 434 if self.max(out) < self.max(x): 435 out = out + self.param * 2 436 else: 437 out = out + self.param 438 idx = idx + 1 439 return out + self.param 440 441 class GradNet(nn.Cell): 442 def __init__(self, net): 443 super(GradNet, self).__init__() 444 self.net = net 445 self.weights = ParameterTuple(net.trainable_params()) 446 447 @ms_function 448 def construct(self, a, b, c): 449 return grad_by_list(self.net, self.weights)(a, b, c) 450 451 while_net = MyWhileNet() 452 net = GradNet(while_net) 453 idx = Tensor(np.array(0), dtype=ms.int32) 454 end = Tensor(np.array(3), dtype=ms.int32) 455 x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) 456 net(idx, end, x) 457 458 459def test_while_with_param_grad_not_enter_while(): 460 class MyWhileNet(nn.Cell): 461 def __init__(self): 462 super().__init__() 463 self.max = P.ReduceMax() 464 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 465 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 466 467 def construct(self, idx, end, x): 468 out = self.zero 469 while idx < end: 470 out = out + self.param * 3 471 idx = idx + 1 472 return out + self.param 473 474 class GradNet(nn.Cell): 475 def __init__(self, net): 476 super(GradNet, self).__init__() 477 self.net = net 478 self.weights = ParameterTuple(net.trainable_params()) 479 480 @ms_function 481 def construct(self, a, b, c): 482 return grad_by_list(self.net, self.weights)(a, b, c) 483 484 while_net = MyWhileNet() 485 net = GradNet(while_net) 486 idx = Tensor(np.array(3), dtype=ms.int32) 487 end = Tensor(np.array(0), dtype=ms.int32) 488 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 489 net(idx, end, x) 490 491 492def test_with_param_if_by_if_forward(): 493 class MyIfByIfNet(nn.Cell): 494 def __init__(self): 495 super().__init__() 496 self.max = P.ReduceMax() 497 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 498 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 499 500 @ms_function 501 def construct(self, a, b, x): 502 out = self.zero 503 if a < b: 504 out = out + x + self.param 505 else: 506 out = out + x 507 if a == b: 508 out = out + x*3 + self.param 509 else: 510 out = out + x*2 511 return out 512 513 if_net = MyIfByIfNet() 514 net = if_net 515 idx = Tensor(np.array(0), dtype=ms.int32) 516 end = Tensor(np.array(4), dtype=ms.int32) 517 x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) 518 net(idx, end, x) 519 520 521def test_with_param_if_by_if_grad_inputs(): 522 class MyIfByIfNet(nn.Cell): 523 def __init__(self): 524 super().__init__() 525 self.max = P.ReduceMax() 526 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 527 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 528 529 def construct(self, a, b, x): 530 out = self.zero 531 if a < b: 532 out = out + x + self.param * 4 533 if a == b: 534 out = out + x*3 + self.param * 3 535 return out 536 537 class GradNet(nn.Cell): 538 def __init__(self, net): 539 super(GradNet, self).__init__() 540 self.net = net 541 542 @ms_function 543 def construct(self, *inputs): 544 return grad_all(self.net)(*inputs) 545 546 if_net = MyIfByIfNet() 547 net = GradNet(if_net) 548 idx = Tensor(np.array(0), dtype=ms.int32) 549 end = Tensor(np.array(0), dtype=ms.int32) 550 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 551 net(idx, end, x) 552 553 554def test_with_param_if_by_if_grad_parameter(): 555 class MyIfByIfNet(nn.Cell): 556 def __init__(self): 557 super().__init__() 558 self.max = P.ReduceMax() 559 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 560 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 561 562 def construct(self, a, b, x): 563 out = self.zero 564 if a < b: 565 out = out + x + self.param * 2 566 if a == b: 567 out = out + x*3 + self.param 568 return out 569 570 class GradNet(nn.Cell): 571 def __init__(self, net): 572 super(GradNet, self).__init__() 573 self.net = net 574 self.weights = ParameterTuple(net.trainable_params()) 575 576 @ms_function 577 def construct(self, *inputs): 578 return grad_by_list(self.net, self.weights)(*inputs) 579 580 if_net = MyIfByIfNet() 581 net = GradNet(if_net) 582 idx = Tensor(np.array(0), dtype=ms.int32) 583 end = Tensor(np.array(2), dtype=ms.int32) 584 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 585 net(idx, end, x) 586 587 588def test_with_param_if_by_if_grad_param_excute_null(): 589 class MyIfByIfNet(nn.Cell): 590 def __init__(self): 591 super().__init__() 592 self.max = P.ReduceMax() 593 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 594 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 595 596 def construct(self, a, b, x): 597 out = self.zero 598 if a < b: 599 out = out + x + self.param * 2 600 return out 601 602 class GradNet(nn.Cell): 603 def __init__(self, net): 604 super(GradNet, self).__init__() 605 self.net = net 606 self.weights = ParameterTuple(net.trainable_params()) 607 608 @ms_function 609 def construct(self, *inputs): 610 return grad_by_list(self.net, self.weights)(*inputs) 611 612 if_net = MyIfByIfNet() 613 net = GradNet(if_net) 614 idx = Tensor(np.array(4), dtype=ms.int32) 615 end = Tensor(np.array(0), dtype=ms.int32) 616 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 617 net(idx, end, x) 618 619 620def test_if_by_if_return_inside_grad(): 621 class MyIfByIfNet(nn.Cell): 622 def __init__(self): 623 super().__init__() 624 self.max = P.ReduceMax() 625 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 626 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 627 628 def construct(self, a, b, x): 629 out = self.zero 630 if a < b: 631 return out + x + self.param 632 if a == b: 633 return out + self.param * 2 634 return out + self.param * 3 635 636 class GradNet(nn.Cell): 637 def __init__(self, net): 638 super(GradNet, self).__init__() 639 self.net = net 640 self.weights = ParameterTuple(net.trainable_params()) 641 642 @ms_function 643 def construct(self, *inputs): 644 return grad_by_list(self.net, self.weights)(*inputs) 645 646 if_net = MyIfByIfNet() 647 net = GradNet(if_net) 648 idx = Tensor(np.array(1), dtype=ms.int32) 649 end = Tensor(np.array(0), dtype=ms.int32) 650 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 651 net(idx, end, x) 652 653 654def test_if_by_if_forward(): 655 class MyIfByIfNet(nn.Cell): 656 def __init__(self): 657 super().__init__() 658 self.add = P.Add() 659 self.sub = P.Sub() 660 self.mul = P.Mul() 661 self.div = P.RealDiv() 662 663 @ms_function 664 def construct(self, a, b, x): 665 if a < b: 666 a = self.add(a, b) 667 else: 668 a = self.sub(a, b) 669 if a == x: 670 a = self.mul(a, b) 671 else: 672 a = self.div(a, b) 673 if b == x: 674 b = self.add(a, b) 675 else: 676 b = self.add(a, x) 677 a = a * b 678 out = a + b + x 679 return out 680 681 if_net = MyIfByIfNet() 682 net = if_net 683 idx = Tensor(np.array(2), dtype=ms.float32) 684 end = Tensor(np.array(3), dtype=ms.float32) 685 x = Tensor(np.array(4), dtype=ms.float32) 686 net(idx, end, x) 687 688 689def test_if_by_if_forward_control_tuple_switch(): 690 """tuple_get from switch op will generate new switch inside to eliminate tuple_get""" 691 class Branch3Net(nn.Cell): 692 def __init__(self): 693 super().__init__() 694 self.add = P.Add() 695 self.sub = P.Sub() 696 self.mul = P.Mul() 697 self.div = P.RealDiv() 698 699 def construct(self, a, b, x): 700 if b == x: 701 b = self.add(a, b) 702 else: 703 b = self.add(a, x) 704 return a, b, x 705 706 class Branch2Net(nn.Cell): 707 def __init__(self): 708 super().__init__() 709 self.add = P.Add() 710 self.sub = P.Sub() 711 self.mul = P.Mul() 712 self.div = P.RealDiv() 713 self.net = Branch3Net() 714 715 def construct(self, a, b, x): 716 if a == x: 717 a = self.mul(a, b) 718 else: 719 a = self.div(a, b) 720 return self.net(a, b, x) 721 722 class MyIfByIfNet(nn.Cell): 723 def __init__(self): 724 super().__init__() 725 self.add = P.Add() 726 self.sub = P.Sub() 727 self.mul = P.Mul() 728 self.div = P.RealDiv() 729 self.net = Branch2Net() 730 731 @ms_function 732 def construct(self, a, b, x): 733 if a < b: 734 a = self.add(a, b) 735 else: 736 a = self.sub(a, b) 737 a, b, x = self.net(a, b, x) 738 a = a * b 739 out = a + b + x 740 return out 741 742 if_net = MyIfByIfNet() 743 net = if_net 744 idx = Tensor(np.array(2), dtype=ms.float32) 745 end = Tensor(np.array(3), dtype=ms.float32) 746 x = Tensor(np.array(0), dtype=ms.float32) 747 net(idx, end, x) 748 749 750def test_if_by_if_forward_control_inside_net(): 751 class Branch3Net(nn.Cell): 752 def __init__(self): 753 super().__init__() 754 self.add = P.Add() 755 self.sub = P.Sub() 756 self.mul = P.Mul() 757 self.div = P.RealDiv() 758 759 def construct(self, a, b, x): 760 if b == x: 761 b = self.add(a, b) 762 else: 763 b = self.add(a, x) 764 a = a * b 765 out = a + b + x 766 return out 767 768 class Branch2Net(nn.Cell): 769 def __init__(self): 770 super().__init__() 771 self.add = P.Add() 772 self.sub = P.Sub() 773 self.mul = P.Mul() 774 self.div = P.RealDiv() 775 self.net = Branch3Net() 776 777 def construct(self, a, b, x): 778 if a == x: 779 a = self.mul(a, b) 780 else: 781 a = self.div(a, b) 782 return self.net(a, b, x) 783 784 class MyIfByIfNet(nn.Cell): 785 def __init__(self): 786 super().__init__() 787 self.add = P.Add() 788 self.sub = P.Sub() 789 self.mul = P.Mul() 790 self.div = P.RealDiv() 791 self.net = Branch2Net() 792 793 @ms_function 794 def construct(self, a, b, x): 795 if a < b: 796 a = self.add(a, b) 797 else: 798 a = self.sub(a, b) 799 out = self.net(a, b, x) 800 return out 801 802 if_net = MyIfByIfNet() 803 net = if_net 804 idx = Tensor(np.array(2), dtype=ms.float32) 805 end = Tensor(np.array(3), dtype=ms.float32) 806 x = Tensor(np.array(0), dtype=ms.float32) 807 net(idx, end, x) 808 809 810def test_if_by_if_forward_use_namespace(): 811 class MyIfByIfNet(nn.Cell): 812 def __init__(self): 813 super().__init__() 814 self.add = P.Add() 815 self.sub = P.Sub() 816 self.mul = P.Mul() 817 self.div = P.RealDiv() 818 819 @ms_function 820 def construct(self, a, b, x): 821 if a < b: 822 a = P.Add()(a, b) 823 else: 824 a = P.Sub()(a, b) 825 if a == x: 826 a = P.Mul()(a, b) 827 else: 828 a = P.RealDiv()(a, b) 829 if b == x: 830 b = P.Add()(a, b) 831 else: 832 b = P.Add()(a, x) 833 a = a * b 834 out = a + b + x 835 return out 836 837 if_net = MyIfByIfNet() 838 net = if_net 839 idx = Tensor(np.array(2), dtype=ms.float32) 840 end = Tensor(np.array(3), dtype=ms.float32) 841 x = Tensor(np.array(0), dtype=ms.float32) 842 net(idx, end, x) 843 844 845def test_if_by_if_forward_use_global_op(): 846 class MyIfByIfNet(nn.Cell): 847 def __init__(self): 848 super().__init__() 849 self.add = P.Add() 850 self.sub = P.Sub() 851 self.mul = P.Mul() 852 self.div = P.RealDiv() 853 854 @ms_function 855 def construct(self, a, b, x): 856 add = P.Add() 857 sub = P.Sub() 858 mul = P.Mul() 859 div = P.RealDiv() 860 if a < b: 861 a = add(a, b) 862 else: 863 a = sub(a, b) 864 if a == x: 865 a = mul(a, b) 866 else: 867 a = div(a, b) 868 if b == x: 869 b = add(a, b) 870 else: 871 b = add(a, x) 872 a = a * b 873 out = a + b + x 874 return out 875 876 if_net = MyIfByIfNet() 877 net = if_net 878 idx = Tensor(np.array(2), dtype=ms.float32) 879 end = Tensor(np.array(3), dtype=ms.float32) 880 x = Tensor(np.array(0), dtype=ms.float32) 881 net(idx, end, x) 882 883 884def test_for_with_if_by_if_forward(): 885 class MyIfByIfNet(nn.Cell): 886 def __init__(self): 887 super().__init__() 888 self.add = P.Add() 889 self.sub = P.Sub() 890 891 @ms_function 892 def construct(self, a, b, x): 893 for _ in range(0, 4): 894 if a < b: 895 a = self.add(a, b) 896 else: 897 b = self.sub(b, x) 898 a = a * b 899 out = a + b + x 900 return out 901 902 if_net = MyIfByIfNet() 903 net = if_net 904 idx = Tensor(np.array(2), dtype=ms.float32) 905 end = Tensor(np.array(3), dtype=ms.float32) 906 x = Tensor(np.array(0), dtype=ms.float32) 907 net(idx, end, x) 908 909 910def test_for_with_if_by_if_forward_namespace(): 911 class MyIfByIfNet(nn.Cell): 912 def __init__(self): 913 super().__init__() 914 self.add = P.Add() 915 self.sub = P.Sub() 916 self.mul = P.Mul() 917 self.div = P.RealDiv() 918 919 @ms_function 920 def construct(self, a, b, x): 921 for _ in range(0, 6): 922 if a < b: 923 a = P.Add()(a, b) 924 else: 925 b = P.Sub()(b, x) 926 a = a * b 927 out = a + b + x 928 return out 929 930 if_net = MyIfByIfNet() 931 net = if_net 932 idx = Tensor(np.array(2), dtype=ms.float32) 933 end = Tensor(np.array(3), dtype=ms.float32) 934 x = Tensor(np.array(0), dtype=ms.float32) 935 net(idx, end, x) 936 937 938def test_if_by_if_forward_const_branch_inner(): 939 class MyIfByIfNet(nn.Cell): 940 def __init__(self): 941 super().__init__() 942 self.add = P.Add() 943 self.sub = P.Sub() 944 self.mul = P.Mul() 945 self.div = P.RealDiv() 946 947 @ms_function 948 def construct(self, a, b, x): 949 add = P.Add() 950 sub = P.Sub() 951 mul = P.Mul() 952 div = P.RealDiv() 953 if a < b: 954 a = add(a, b) 955 else: 956 a = sub(a, b) 957 if 2 > 1: 958 a = mul(a, b) 959 else: 960 a = div(a, b) 961 if b == x: 962 b = add(a, b) 963 else: 964 b = add(a, x) 965 a = a * b 966 out = a + b + x 967 return out 968 969 if_net = MyIfByIfNet() 970 net = if_net 971 idx = Tensor(np.array(2), dtype=ms.float32) 972 end = Tensor(np.array(3), dtype=ms.float32) 973 x = Tensor(np.array(0), dtype=ms.float32) 974 net(idx, end, x) 975 976 977def test_if_by_if_forward_all_const_branch(): 978 class MyIfByIfNet(nn.Cell): 979 def __init__(self): 980 super().__init__() 981 self.add = P.Add() 982 self.sub = P.Sub() 983 self.mul = P.Mul() 984 self.div = P.RealDiv() 985 986 @ms_function 987 def construct(self, a, b, x): 988 add = P.Add() 989 sub = P.Sub() 990 mul = P.Mul() 991 div = P.RealDiv() 992 if 2 < 12: 993 a = add(a, b) 994 else: 995 a = sub(a, b) 996 if 2 > 1: 997 a = mul(a, b) 998 else: 999 a = div(a, b) 1000 if 2 == 1: 1001 b = add(a, b) 1002 else: 1003 b = add(a, x) 1004 a = a * b 1005 out = a + b + x 1006 return out 1007 1008 if_net = MyIfByIfNet() 1009 net = if_net 1010 idx = Tensor(np.array(2), dtype=ms.float32) 1011 end = Tensor(np.array(3), dtype=ms.float32) 1012 x = Tensor(np.array(0), dtype=ms.float32) 1013 net(idx, end, x) 1014