1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test control ops """ 16import numpy as np 17import pytest 18 19from mindspore import dtype as ms 20from mindspore import Tensor 21from mindspore import context 22from mindspore import nn 23from mindspore.common.parameter import Parameter, ParameterTuple 24from mindspore.ops import composite as C 25from mindspore.ops import operations as P 26 27 28grad_by_list = C.GradOperation(get_by_list=True) 29grad_all = C.GradOperation(get_all=True) 30 31 32def test_while_grad(): 33 class MyWhileNet(nn.Cell): 34 def __init__(self): 35 super().__init__() 36 self.max = P.ReduceMax() 37 38 def construct(self, idx, end, x): 39 while idx < end: 40 part = x[idx, :, :] 41 max_num = self.max(part) 42 x[idx, :, 0:2] = max_num 43 idx = idx + 1 44 return x 45 46 class GradNet(nn.Cell): 47 def __init__(self, net): 48 super(GradNet, self).__init__() 49 self.net = net 50 51 def construct(self, *inputs): 52 return grad_all(self.net)(*inputs) 53 54 idx = Tensor(np.array(0), dtype=ms.int32) 55 end = Tensor(np.array(2), dtype=ms.int32) 56 x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) 57 # graph mode 58 context.set_context(mode=context.GRAPH_MODE) 59 while_net = MyWhileNet() 60 net = GradNet(while_net) 61 graph_output = net(idx, end, x) 62 63 assert graph_output == 0 64 65 66@pytest.mark.level0 67@pytest.mark.platform_arm_ascend_training 68@pytest.mark.platform_x86_gpu_training 69@pytest.mark.env_onecard 70def test_while_with_const_param_grad(): 71 class MyWhileNet(nn.Cell): 72 def __init__(self): 73 super().__init__() 74 self.mul = P.Mul() 75 self.add = P.Add() 76 77 def construct(self, x, y): 78 while x < y: 79 z = self.mul(x, x) 80 x = self.add(z, 1) 81 return x 82 83 class GradNet(nn.Cell): 84 def __init__(self, net): 85 super(GradNet, self).__init__() 86 self.net = net 87 88 def construct(self, *inputs): 89 return grad_all(self.net)(*inputs) 90 91 context.set_context(mode=context.GRAPH_MODE) 92 while_net = MyWhileNet() 93 net = GradNet(while_net) 94 idx = Tensor([1.1], dtype=ms.float32) 95 end = Tensor([8.0], dtype=ms.float32) 96 graph_output = net(idx, end) 97 expect_one = np.array([1.14433983e+02], dtype=np.float32) 98 expect_two = np.array([0], dtype=np.float32) 99 assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001) 100 assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001) 101 102 103@pytest.mark.level0 104@pytest.mark.platform_arm_ascend_training 105@pytest.mark.platform_x86_gpu_training 106@pytest.mark.env_onecard 107def test_while_with_variable_grad(): 108 class MyWhileNet(nn.Cell): 109 def __init__(self): 110 super().__init__() 111 self.mul = P.Mul() 112 self.add = P.Add() 113 114 def construct(self, x, y): 115 while x < y: 116 z = self.mul(x, x) 117 x = self.add(z, y) 118 return x 119 120 class GradNet(nn.Cell): 121 def __init__(self, net): 122 super(GradNet, self).__init__() 123 self.net = net 124 125 def construct(self, *inputs): 126 return grad_all(self.net)(*inputs) 127 128 context.set_context(mode=context.GRAPH_MODE) 129 while_net = MyWhileNet() 130 net = GradNet(while_net) 131 idx = Tensor([1.1], dtype=ms.float32) 132 end = Tensor([8.0], dtype=ms.float32) 133 graph_output = net(idx, end) 134 expect_one = np.array([2.20000005e+00], dtype=np.float32) 135 expect_two = np.array([1.00000000e+00], dtype=np.float32) 136 assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001) 137 assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001) 138 139 140@pytest.mark.level1 141@pytest.mark.platform_arm_ascend_training 142@pytest.mark.platform_x86_gpu_training 143@pytest.mark.env_onecard 144def test_while_with_param_forward(): 145 class MyWhileNet(nn.Cell): 146 def __init__(self): 147 super().__init__() 148 self.max = P.ReduceMax() 149 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 150 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 151 152 def construct(self, idx, end, x): 153 out = self.zero 154 while idx < end: 155 part = x[idx, :, :] 156 max_num = self.max(part) 157 x[idx, :, 0:2] = max_num 158 out = out + x + self.param 159 idx = idx + 1 160 return out 161 162 # graph mode 163 context.set_context(mode=context.GRAPH_MODE) 164 net = MyWhileNet() 165 idx = Tensor(np.array(0), dtype=ms.int32) 166 end = Tensor(np.array(2), dtype=ms.int32) 167 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 168 graph_output = net(idx, end, x) 169 expect = np.array([[[6, 8], [10, 12]], [[19, 22], [25, 28]]], dtype=np.int32) 170 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 171 172 173@pytest.mark.level0 174@pytest.mark.platform_arm_ascend_training 175@pytest.mark.platform_x86_gpu_training 176@pytest.mark.env_onecard 177def test_while_endless_case(): 178 """endless case when optimization""" 179 180 class MyWhileNet(nn.Cell): 181 def __init__(self): 182 super().__init__() 183 self.max = P.ReduceMax() 184 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 185 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 186 187 def construct(self, idx, end, x): 188 out = self.zero 189 while idx < end: 190 part = x[idx, :, :] 191 out = out + part 192 idx = idx + 1 193 return out 194 195 idx = Tensor(np.array(0), dtype=ms.int32) 196 end = Tensor(np.array(2), dtype=ms.int32) 197 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 198 # graph mode 199 context.set_context(mode=context.GRAPH_MODE) 200 net = MyWhileNet() 201 graph_output = net(idx, end, x) 202 expect = np.array([[[4, 6], [8, 10]], 203 [[4, 6], [8, 10]]]).astype(np.float32) 204 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 205 206 207@pytest.mark.level1 208@pytest.mark.platform_arm_ascend_training 209@pytest.mark.platform_x86_gpu_training 210@pytest.mark.env_onecard 211def test_while_with_param_grad(): 212 class MyWhileNet(nn.Cell): 213 def __init__(self): 214 super().__init__() 215 self.max = P.ReduceMax() 216 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 217 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 218 219 def construct(self, idx, end, x): 220 out = self.zero 221 while idx < end: 222 part = x[idx, :, :] 223 max_num = self.max(part) 224 x[idx, :, 0:2] = max_num 225 out = out + x + self.param 226 idx = idx + 1 227 return out 228 229 class GradNet(nn.Cell): 230 def __init__(self, net): 231 super(GradNet, self).__init__() 232 self.net = net 233 self.weights = ParameterTuple(net.trainable_params()) 234 235 def construct(self, a, b, c): 236 return grad_by_list(self.net, self.weights)(a, b, c) 237 238 context.set_context(mode=context.GRAPH_MODE) 239 while_net = MyWhileNet() 240 net = GradNet(while_net) 241 idx = Tensor(np.array(0), dtype=ms.int32) 242 end = Tensor(np.array(2), dtype=ms.int32) 243 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 244 graph_output = net(idx, end, x) 245 expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32) 246 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 247 248 249@pytest.mark.level0 250@pytest.mark.platform_arm_ascend_training 251@pytest.mark.platform_x86_gpu_training 252@pytest.mark.env_onecard 253def test_while_with_param_forward_with_const_branch(): 254 class MyWhileNet(nn.Cell): 255 def __init__(self): 256 super().__init__() 257 self.max = P.ReduceMax() 258 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 259 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 260 self.reduce = P.ReduceSum() 261 262 def construct(self, idx, end, x): 263 out = self.zero 264 while idx < end: 265 if 2 > 1: 266 out = out + self.param 267 else: 268 out = out + idx + self.param 269 idx = idx + 1 270 return out 271 272 idx = Tensor(np.array(0), dtype=ms.int32) 273 end = Tensor(np.array(4), dtype=ms.int32) 274 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 275 # graph mode 276 context.set_context(mode=context.GRAPH_MODE) 277 while_net = MyWhileNet() 278 net = while_net 279 graph_output = net(idx, end, x) 280 281 expect = np.array([[[0, 4], [8, 12]], 282 [[16, 20], [24, 28]]]).astype(np.float32) 283 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 284 285 286@pytest.mark.level1 287@pytest.mark.platform_arm_ascend_training 288@pytest.mark.platform_x86_gpu_training 289@pytest.mark.env_onecard 290def test_while_opt_endless(): 291 """endless during optimization case""" 292 293 class MyWhileNet(nn.Cell): 294 def __init__(self): 295 super().__init__() 296 self.max = P.ReduceMax() 297 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 298 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 299 self.reduce = P.ReduceSum() 300 self.addn = P.AddN() 301 302 def construct(self, idx, end, x): 303 addn1 = self.addn((x, x, x)) 304 out = addn1 305 while idx < end: 306 out = self.addn((out, addn1)) 307 idx = idx + 1 308 out = self.addn((out, x)) 309 return out 310 311 class GradNet(nn.Cell): 312 def __init__(self, net): 313 super(GradNet, self).__init__() 314 self.net = net 315 316 def construct(self, *inputs): 317 return grad_all(self.net)(*inputs) 318 319 idx = Tensor(np.array(0), dtype=ms.int32) 320 end = Tensor(np.array(4), dtype=ms.int32) 321 x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32) 322 # graph mode 323 context.set_context(mode=context.GRAPH_MODE) 324 while_net = MyWhileNet() 325 net = GradNet(while_net) 326 graph_output = net(idx, end, x) 327 328 expect1 = 0 329 expect2 = 0 330 expect3 = np.array([[[16, 16], [16, 16]], 331 [[16, 16], [16, 16]]]).astype(np.float32) 332 assert np.allclose(graph_output[0].asnumpy(), expect1, 0.0001, 0.0001) 333 assert np.allclose(graph_output[1].asnumpy(), expect2, 0.0001, 0.0001) 334 assert np.allclose(graph_output[2].asnumpy(), expect3, 0.0001, 0.0001) 335 336 337@pytest.mark.level0 338@pytest.mark.platform_arm_ascend_training 339@pytest.mark.platform_x86_ascend_training 340@pytest.mark.env_onecard 341def test_no_while_call(): 342 class MyWhileNet(nn.Cell): 343 def __init__(self): 344 super().__init__() 345 self.max = P.ReduceMax() 346 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 347 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 348 self.reduce = P.ReduceSum() 349 350 def construct(self, idx, end, x): 351 out = self.zero 352 if 2 > 1: 353 out = out + self.param 354 else: 355 out = out + idx + self.param 356 return out 357 358 idx = Tensor(np.array(0), dtype=ms.int32) 359 end = Tensor(np.array(4), dtype=ms.int32) 360 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 361 # graph mode 362 context.set_context(mode=context.GRAPH_MODE) 363 while_net = MyWhileNet() 364 net = while_net 365 graph_output = net(idx, end, x) 366 367 expect = np.array([[[0, 1], [2, 3]], 368 [[4, 5], [6, 7]]]).astype(np.float32) 369 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 370 371 372@pytest.mark.level0 373@pytest.mark.platform_arm_ascend_training 374@pytest.mark.platform_x86_gpu_training 375@pytest.mark.env_onecard 376def test_while_with_param_grad_with_const_branch(): 377 class MyWhileNet(nn.Cell): 378 def __init__(self): 379 super().__init__() 380 self.max = P.ReduceMax() 381 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 382 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 383 self.reduce = P.ReduceSum() 384 385 def construct(self, idx, end, x): 386 out = self.zero 387 while idx < end: 388 if 2 > 1: 389 out = out + self.param 390 else: 391 out = out + idx + self.param 392 idx = idx + 1 393 return out 394 395 class GradNet(nn.Cell): 396 def __init__(self, net): 397 super(GradNet, self).__init__() 398 self.net = net 399 self.weights = ParameterTuple(net.trainable_params()) 400 401 def construct(self, a, b, c): 402 return grad_by_list(self.net, self.weights)(a, b, c) 403 404 idx = Tensor(np.array(0), dtype=ms.int32) 405 end = Tensor(np.array(4), dtype=ms.int32) 406 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 407 # graph mode 408 context.set_context(mode=context.GRAPH_MODE) 409 while_net = MyWhileNet() 410 net = GradNet(while_net) 411 graph_output = net(idx, end, x) 412 413 expect = np.array([[[4, 4], [4, 4]], 414 [[4, 4], [4, 4]]]).astype(np.float32) 415 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 416 417 418@pytest.mark.level0 419@pytest.mark.platform_arm_ascend_training 420@pytest.mark.platform_x86_ascend_training 421@pytest.mark.env_onecard 422def test_for_while_with_param_grad_with_const_branch(): 423 class MyWhileNet(nn.Cell): 424 def __init__(self): 425 super().__init__() 426 self.max = P.ReduceMax() 427 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 428 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 429 self.reduce = P.ReduceSum() 430 self.start = Tensor(np.array(0), dtype=ms.int32) 431 432 def construct(self, idx, end, x): 433 out = self.zero 434 for _ in range(0, 2): 435 idx = self.start 436 while idx < end: 437 if 2 > 1: 438 out = out + self.param 439 else: 440 out = out + idx + self.param 441 idx = idx + 1 442 return out 443 444 class GradNet(nn.Cell): 445 def __init__(self, net): 446 super(GradNet, self).__init__() 447 self.net = net 448 self.weights = ParameterTuple(net.trainable_params()) 449 450 def construct(self, a, b, c): 451 return grad_by_list(self.net, self.weights)(a, b, c) 452 453 idx = Tensor(np.array(0), dtype=ms.int32) 454 end = Tensor(np.array(4), dtype=ms.int32) 455 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 456 # graph mode 457 context.set_context(mode=context.GRAPH_MODE) 458 while_net = MyWhileNet() 459 net = GradNet(while_net) 460 graph_output = net(idx, end, x) 461 462 expect = np.array([[[8, 8], [8, 8]], 463 [[8, 8], [8, 8]]]).astype(np.float32) 464 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 465 466 467@pytest.mark.level0 468@pytest.mark.platform_arm_ascend_training 469@pytest.mark.platform_x86_gpu_training 470@pytest.mark.env_onecard 471def test_for_while_with_param_grad_basic(): 472 class MyWhileNet(nn.Cell): 473 def __init__(self): 474 super().__init__() 475 self.max = P.ReduceMax() 476 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 477 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 478 self.reduce = P.ReduceSum() 479 self.start = Tensor(np.array(0), dtype=ms.int32) 480 481 def construct(self, idx, end, x): 482 out = self.zero 483 for _ in range(0, 2): 484 idx = self.start 485 while idx < end: 486 out = out + self.param 487 idx = idx + 1 488 return out 489 490 class GradNet(nn.Cell): 491 def __init__(self, net): 492 super(GradNet, self).__init__() 493 self.net = net 494 self.weights = ParameterTuple(net.trainable_params()) 495 496 def construct(self, a, b, c): 497 return grad_by_list(self.net, self.weights)(a, b, c) 498 499 idx = Tensor(np.array(0), dtype=ms.int32) 500 end = Tensor(np.array(4), dtype=ms.int32) 501 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 502 # graph mode 503 context.set_context(mode=context.GRAPH_MODE) 504 while_net = MyWhileNet() 505 net = GradNet(while_net) 506 graph_output = net(idx, end, x) 507 expect = np.array([[[8, 8], [8, 8]], 508 [[8, 8], [8, 8]]]).astype(np.float32) 509 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 510 511 512@pytest.mark.level0 513@pytest.mark.platform_arm_ascend_training 514@pytest.mark.platform_x86_gpu_training 515@pytest.mark.env_onecard 516def test_for_while_with_param_grad_normal(): 517 class MyWhileNet(nn.Cell): 518 def __init__(self): 519 super().__init__() 520 self.max = P.ReduceMax() 521 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 522 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 523 self.reduce = P.ReduceSum() 524 self.start = Tensor(np.array(0), dtype=ms.int32) 525 526 def construct(self, idx, end, x): 527 out = x 528 for _ in range(0, 2): 529 idx = self.start 530 while idx < end: 531 out = out + self.param 532 idx = idx + 1 533 return out 534 535 class GradNet(nn.Cell): 536 def __init__(self, net): 537 super(GradNet, self).__init__() 538 self.net = net 539 self.weights = ParameterTuple(net.trainable_params()) 540 541 def construct(self, a, b, c): 542 return grad_by_list(self.net, self.weights)(a, b, c) 543 544 idx = Tensor(np.array(0), dtype=ms.int32) 545 end = Tensor(np.array(4), dtype=ms.int32) 546 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 547 # graph mode 548 context.set_context(mode=context.GRAPH_MODE) 549 while_net = MyWhileNet() 550 net = GradNet(while_net) 551 graph_output = net(idx, end, x) 552 expect = np.array([[[8, 8], [8, 8]], 553 [[8, 8], [8, 8]]]).astype(np.float32) 554 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 555 556 557@pytest.mark.level0 558@pytest.mark.platform_arm_ascend_training 559@pytest.mark.platform_x86_gpu_training 560@pytest.mark.env_onecard 561def test_while_with_param_basic_grad(): 562 class MyWhileNet(nn.Cell): 563 def __init__(self): 564 super().__init__() 565 self.max = P.ReduceMax() 566 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 567 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 568 self.t2 = Tensor(np.array(2), dtype=ms.float32) 569 570 def construct(self, idx, end, x): 571 out = self.zero 572 while idx < end: 573 out = out + self.param 574 idx = idx + 1 575 return out + self.param 576 577 class GradNet(nn.Cell): 578 def __init__(self, net): 579 super(GradNet, self).__init__() 580 self.net = net 581 self.weights = ParameterTuple(net.trainable_params()) 582 583 def construct(self, a, b, c): 584 return grad_by_list(self.net, self.weights)(a, b, c) 585 586 idx = Tensor(np.array(0), dtype=ms.int32) 587 end = Tensor(np.array(3), dtype=ms.int32) 588 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 589 # graph mode 590 context.set_context(mode=context.GRAPH_MODE) 591 while_net = MyWhileNet() 592 net = GradNet(while_net) 593 graph_output = net(idx, end, x) 594 expect = np.array([[[4, 4], [4, 4]], 595 [[4, 4], [4, 4]]]).astype(np.float32) 596 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 597 598 599@pytest.mark.level1 600@pytest.mark.platform_arm_ascend_training 601@pytest.mark.platform_x86_gpu_training 602@pytest.mark.env_onecard 603def test_while_with_param_basic_grad_mul(): 604 class MyWhileNet(nn.Cell): 605 def __init__(self): 606 super().__init__() 607 self.max = P.ReduceMax() 608 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 609 self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32) 610 self.t2 = Tensor(np.array(2), dtype=ms.float32) 611 612 def construct(self, idx, end, x): 613 out = self.zero 614 while idx < end: 615 out = out * self.param 616 idx = idx + 1 617 return out + self.param 618 619 class GradNet(nn.Cell): 620 def __init__(self, net): 621 super(GradNet, self).__init__() 622 self.net = net 623 self.weights = ParameterTuple(net.trainable_params()) 624 625 def construct(self, a, b, c): 626 return grad_by_list(self.net, self.weights)(a, b, c) 627 628 idx = Tensor(np.array(0), dtype=ms.int32) 629 end = Tensor(np.array(3), dtype=ms.int32) 630 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 631 # graph mode 632 context.set_context(mode=context.GRAPH_MODE) 633 while_net = MyWhileNet() 634 net = GradNet(while_net) 635 graph_output = net(idx, end, x) 636 expect = np.array([[[1, 4], [13, 28]], 637 [[49, 76], [109, 148]]]).astype(np.float32) 638 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 639 640 641@pytest.mark.level0 642@pytest.mark.platform_arm_ascend_training 643@pytest.mark.platform_x86_gpu_training 644@pytest.mark.env_onecard 645def test_while_with_param_basic_grad_two(): 646 class MyWhileNet(nn.Cell): 647 def __init__(self): 648 super().__init__() 649 self.max = P.ReduceMax() 650 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 651 self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") 652 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 653 self.t2 = Tensor(np.array(2), dtype=ms.float32) 654 655 def construct(self, idx, end, x): 656 out = self.zero 657 while idx < end: 658 out = out + self.param + self.weight 659 idx = idx + 1 660 return out + self.param 661 662 class GradNet(nn.Cell): 663 def __init__(self, net): 664 super(GradNet, self).__init__() 665 self.net = net 666 self.weights = ParameterTuple(net.trainable_params()) 667 668 def construct(self, a, b, c): 669 return grad_by_list(self.net, self.weights)(a, b, c) 670 671 idx = Tensor(np.array(0), dtype=ms.int32) 672 end = Tensor(np.array(3), dtype=ms.int32) 673 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 674 # graph mode 675 context.set_context(mode=context.GRAPH_MODE) 676 while_net = MyWhileNet() 677 net = GradNet(while_net) 678 graph_output = net(idx, end, x) 679 680 expect1 = np.array([[[4, 4], [4, 4]], 681 [[4, 4], [4, 4]]]).astype(np.float32) 682 expect2 = np.array([[[3, 3], [3, 3]], 683 [[3, 3], [3, 3]]]).astype(np.float32) 684 assert np.allclose(graph_output[0].asnumpy(), expect1, 0.0001, 0.0001) 685 assert np.allclose(graph_output[1].asnumpy(), expect2, 0.0001, 0.0001) 686 687 688@pytest.mark.level0 689@pytest.mark.platform_arm_ascend_training 690@pytest.mark.platform_x86_gpu_training 691@pytest.mark.env_onecard 692def test_while_with_param_basic_grad_three(): 693 class MyWhileNet(nn.Cell): 694 def __init__(self): 695 super().__init__() 696 self.max = P.ReduceMax() 697 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 698 self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") 699 self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key") 700 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 701 self.t2 = Tensor(np.array(2), dtype=ms.float32) 702 703 def construct(self, idx, end, x): 704 out = self.zero 705 while idx < end: 706 out = out + self.param + self.weight + self.key 707 idx = idx + 1 708 return out + self.param 709 710 class GradNet(nn.Cell): 711 def __init__(self, net): 712 super(GradNet, self).__init__() 713 self.net = net 714 self.weights = ParameterTuple(net.trainable_params()) 715 716 def construct(self, a, b, c): 717 return grad_by_list(self.net, self.weights)(a, b, c) 718 719 idx = Tensor(np.array(0), dtype=ms.int32) 720 end = Tensor(np.array(3), dtype=ms.int32) 721 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 722 # graph mode 723 context.set_context(mode=context.GRAPH_MODE) 724 while_net = MyWhileNet() 725 net = GradNet(while_net) 726 graph_output = net(idx, end, x) 727 expect1 = np.array([[[4, 4], [4, 4]], 728 [[4, 4], [4, 4]]]).astype(np.float32) 729 expect2 = np.array([[[3, 3], [3, 3]], 730 [[3, 3], [3, 3]]]).astype(np.float32) 731 expect3 = np.array([[[3, 3], [3, 3]], 732 [[3, 3], [3, 3]]]).astype(np.float32) 733 assert np.allclose(graph_output[0].asnumpy(), expect1, 0.0001, 0.0001) 734 assert np.allclose(graph_output[1].asnumpy(), expect2, 0.0001, 0.0001) 735 assert np.allclose(graph_output[2].asnumpy(), expect3, 0.0001, 0.0001) 736 737 738@pytest.mark.level0 739@pytest.mark.platform_arm_ascend_training 740@pytest.mark.platform_x86_gpu_training 741@pytest.mark.env_onecard 742def test_while_if_with_param_grad(): 743 class MyWhileNet(nn.Cell): 744 def __init__(self): 745 super().__init__() 746 self.max = P.ReduceMax() 747 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 748 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 749 self.t2 = Tensor(np.array(2), dtype=ms.float32) 750 751 def construct(self, idx, end, x): 752 out = self.zero 753 while idx < end: 754 if self.max(out) < self.max(x): 755 out = out + self.param * 2 756 else: 757 out = out + self.param 758 idx = idx + 1 759 return out + self.param 760 761 class GradNet(nn.Cell): 762 def __init__(self, net): 763 super(GradNet, self).__init__() 764 self.net = net 765 self.weights = ParameterTuple(net.trainable_params()) 766 767 def construct(self, a, b, c): 768 return grad_by_list(self.net, self.weights)(a, b, c) 769 770 idx = Tensor(np.array(0), dtype=ms.int32) 771 end = Tensor(np.array(3), dtype=ms.int32) 772 x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) 773 context.set_context(mode=context.GRAPH_MODE) 774 while_net = MyWhileNet() 775 net = GradNet(while_net) 776 graph_output = net(idx, end, x) 777 expect = np.array([[[5, 5], [5, 5]], 778 [[5, 5], [5, 5]]]).astype(np.float32) 779 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 780 781 782@pytest.mark.level0 783@pytest.mark.platform_arm_ascend_training 784@pytest.mark.platform_x86_ascend_training 785@pytest.mark.env_onecard 786def test_while_with_param_grad_not_enter_while(): 787 class MyWhileNet(nn.Cell): 788 def __init__(self): 789 super().__init__() 790 self.max = P.ReduceMax() 791 self.param = Parameter(Tensor(2, ms.float32), name="weight") 792 self.zero = Tensor(0, ms.float32) 793 794 def construct(self, idx, end, x): 795 out = self.zero 796 while idx < end: 797 out = out + self.param * 3 798 idx = idx + 1 799 return out + self.param 800 801 class GradNet(nn.Cell): 802 def __init__(self, net): 803 super(GradNet, self).__init__() 804 self.net = net 805 self.weights = ParameterTuple(net.trainable_params()) 806 807 def construct(self, a, b, c): 808 return grad_by_list(self.net, self.weights)(a, b, c) 809 810 idx = Tensor(np.array(3), dtype=ms.int32) 811 end = Tensor(np.array(0), dtype=ms.int32) 812 x = Tensor(2, dtype=ms.float32) 813 # graph mode 814 context.set_context(mode=context.GRAPH_MODE) 815 while_net = MyWhileNet() 816 net = GradNet(while_net) 817 graph_output = net(idx, end, x) 818 819 assert np.allclose(graph_output[0].asnumpy(), 1, 0.0001, 0.0001) 820 821 822@pytest.mark.level0 823@pytest.mark.platform_arm_ascend_training 824@pytest.mark.platform_x86_gpu_training 825@pytest.mark.env_onecard 826def test_with_param_if_by_if_forward(): 827 class MyIfByIfNet(nn.Cell): 828 def __init__(self): 829 super().__init__() 830 self.max = P.ReduceMax() 831 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 832 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 833 834 def construct(self, a, b, x): 835 out = self.zero 836 if a < b: 837 out = out + x + self.param 838 else: 839 out = out + x 840 if a == b: 841 out = out + x * 3 + self.param 842 else: 843 out = out + x * 2 844 return out 845 846 idx = Tensor(np.array(0), dtype=ms.int32) 847 end = Tensor(np.array(4), dtype=ms.int32) 848 x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) 849 # graph mode 850 context.set_context(mode=context.GRAPH_MODE) 851 if_net = MyIfByIfNet() 852 net = if_net 853 graph_output = net(idx, end, x) 854 expect = np.array([[[3, 4], [5, 6]], 855 [[7, 8], [9, 10]]]).astype(np.float32) 856 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 857 858 859@pytest.mark.level0 860@pytest.mark.platform_arm_ascend_training 861@pytest.mark.platform_x86_gpu_training 862@pytest.mark.env_onecard 863def test_with_param_if_by_if_grad_inputs(): 864 class MyIfByIfNet(nn.Cell): 865 def __init__(self): 866 super().__init__() 867 self.max = P.ReduceMax() 868 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 869 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 870 871 def construct(self, a, b, x): 872 out = self.zero 873 if a < b: 874 out = out + x + self.param * 4 875 if a == b: 876 out = out + x * 3 + self.param * 3 877 return out 878 879 class GradNet(nn.Cell): 880 def __init__(self, net): 881 super(GradNet, self).__init__() 882 self.net = net 883 884 def construct(self, *inputs): 885 return grad_all(self.net)(*inputs) 886 887 idx = Tensor(np.array(0), dtype=ms.int32) 888 end = Tensor(np.array(0), dtype=ms.int32) 889 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 890 # graph mode 891 context.set_context(mode=context.GRAPH_MODE) 892 if_net = MyIfByIfNet() 893 net = GradNet(if_net) 894 graph_output = net(idx, end, x) 895 expect1 = Tensor(np.array(0), dtype=ms.int32) 896 expect2 = Tensor(np.array(0), dtype=ms.int32) 897 expect3 = np.array([[[3, 3], [3, 3]], 898 [[3, 3], [3, 3]]]).astype(np.float32) 899 assert np.allclose(graph_output[0].asnumpy(), expect1.asnumpy(), 0.0001, 0.0001) 900 assert np.allclose(graph_output[1].asnumpy(), expect2.asnumpy(), 0.0001, 0.0001) 901 assert np.allclose(graph_output[2].asnumpy(), expect3, 0.0001, 0.0001) 902 903 904@pytest.mark.level0 905@pytest.mark.platform_arm_ascend_training 906@pytest.mark.platform_x86_gpu_training 907@pytest.mark.env_onecard 908def test_with_param_if_by_if_grad_parameter(): 909 class MyIfByIfNet(nn.Cell): 910 def __init__(self): 911 super().__init__() 912 self.max = P.ReduceMax() 913 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 914 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 915 916 def construct(self, a, b, x): 917 out = self.zero 918 if a < b: 919 out = out + x + self.param * 2 920 if a == b: 921 out = out + x * 3 + self.param 922 return out 923 924 class GradNet(nn.Cell): 925 def __init__(self, net): 926 super(GradNet, self).__init__() 927 self.net = net 928 self.weights = ParameterTuple(net.trainable_params()) 929 930 def construct(self, *inputs): 931 return grad_by_list(self.net, self.weights)(*inputs) 932 933 idx = Tensor(np.array(0), dtype=ms.int32) 934 end = Tensor(np.array(2), dtype=ms.int32) 935 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 936 # graph mode 937 context.set_context(mode=context.GRAPH_MODE) 938 if_net = MyIfByIfNet() 939 net = GradNet(if_net) 940 graph_output = net(idx, end, x) 941 942 expect = np.array([[[2, 2], [2, 2]], 943 [[2, 2], [2, 2]]]).astype(np.float32) 944 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 945 946 947@pytest.mark.level0 948@pytest.mark.platform_arm_ascend_training 949@pytest.mark.platform_x86_gpu_training 950@pytest.mark.env_onecard 951def test_with_param_if_by_if_grad_param_excute_null(): 952 class MyIfByIfNet(nn.Cell): 953 def __init__(self): 954 super().__init__() 955 self.max = P.ReduceMax() 956 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 957 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 958 959 def construct(self, a, b, x): 960 out = self.zero 961 if a < b: 962 out = out + x + self.param * 2 963 return out 964 965 class GradNet(nn.Cell): 966 def __init__(self, net): 967 super(GradNet, self).__init__() 968 self.net = net 969 self.weights = ParameterTuple(net.trainable_params()) 970 971 def construct(self, *inputs): 972 return grad_by_list(self.net, self.weights)(*inputs) 973 974 idx = Tensor(np.array(4), dtype=ms.int32) 975 end = Tensor(np.array(0), dtype=ms.int32) 976 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 977 # graph mode 978 context.set_context(mode=context.GRAPH_MODE) 979 if_net = MyIfByIfNet() 980 net = GradNet(if_net) 981 graph_output = net(idx, end, x) 982 983 expect = np.array([[[0, 0], [0, 0]], 984 [[0, 0], [0, 0]]]).astype(np.float32) 985 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 986 987 988@pytest.mark.level1 989@pytest.mark.platform_arm_ascend_training 990@pytest.mark.platform_x86_gpu_training 991@pytest.mark.env_onecard 992def test_if_by_if_return_inside_grad(): 993 class MyIfByIfNet(nn.Cell): 994 def __init__(self): 995 super().__init__() 996 self.max = P.ReduceMax() 997 self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") 998 self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) 999 1000 def construct(self, a, b, x): 1001 out = self.zero 1002 if a < b: 1003 return out + x + self.param 1004 if a == b: 1005 return out + self.param * 2 1006 return out + self.param * 3 1007 1008 class GradNet(nn.Cell): 1009 def __init__(self, net): 1010 super(GradNet, self).__init__() 1011 self.net = net 1012 self.weights = ParameterTuple(net.trainable_params()) 1013 1014 def construct(self, *inputs): 1015 return grad_by_list(self.net, self.weights)(*inputs) 1016 1017 idx = Tensor(np.array(1), dtype=ms.int32) 1018 end = Tensor(np.array(0), dtype=ms.int32) 1019 x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) 1020 # graph mode 1021 context.set_context(mode=context.GRAPH_MODE) 1022 if_net = MyIfByIfNet() 1023 net = GradNet(if_net) 1024 graph_output = net(idx, end, x) 1025 1026 expect = np.array([[[3, 3], [3, 3]], 1027 [[3, 3], [3, 3]]]).astype(np.float32) 1028 assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) 1029 1030 1031@pytest.mark.level1 1032@pytest.mark.platform_arm_ascend_training 1033@pytest.mark.platform_x86_gpu_training 1034@pytest.mark.env_onecard 1035def test_if_by_if_forward(): 1036 class MyIfByIfNet(nn.Cell): 1037 def __init__(self): 1038 super().__init__() 1039 self.add = P.Add() 1040 self.sub = P.Sub() 1041 self.mul = P.Mul() 1042 self.div = P.RealDiv() 1043 1044 def construct(self, a, b, x): 1045 if a < b: 1046 a = self.add(a, b) 1047 else: 1048 a = self.sub(a, b) 1049 if a == x: 1050 a = self.mul(a, b) 1051 else: 1052 a = self.div(a, b) 1053 if b == x: 1054 b = self.add(a, b) 1055 else: 1056 b = self.add(a, x) 1057 a = a * b 1058 out = a + b + x 1059 return out 1060 1061 idx = Tensor(np.array(2), dtype=ms.float32) 1062 end = Tensor(np.array(3), dtype=ms.float32) 1063 x = Tensor(np.array(4), dtype=ms.float32) 1064 # graph mode 1065 context.set_context(mode=context.GRAPH_MODE) 1066 if_net = MyIfByIfNet() 1067 net = if_net 1068 graph_output = net(idx, end, x) 1069 expect = 19.11111 1070 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 1071 1072 1073@pytest.mark.level0 1074@pytest.mark.platform_arm_ascend_training 1075@pytest.mark.platform_x86_gpu_training 1076@pytest.mark.env_onecard 1077def test_if_by_if_forward_control_tuple_switch(): 1078 """tuple_get from switch op will generate new switch inside to eliminate tuple_get""" 1079 1080 class Branch3Net(nn.Cell): 1081 def __init__(self): 1082 super().__init__() 1083 self.add = P.Add() 1084 self.sub = P.Sub() 1085 self.mul = P.Mul() 1086 self.div = P.RealDiv() 1087 1088 def construct(self, a, b, x): 1089 if b == x: 1090 b = self.add(a, b) 1091 else: 1092 b = self.add(a, x) 1093 return a, b, x 1094 1095 class Branch2Net(nn.Cell): 1096 def __init__(self): 1097 super().__init__() 1098 self.add = P.Add() 1099 self.sub = P.Sub() 1100 self.mul = P.Mul() 1101 self.div = P.RealDiv() 1102 self.net = Branch3Net() 1103 1104 def construct(self, a, b, x): 1105 if a == x: 1106 a = self.mul(a, b) 1107 else: 1108 a = self.div(a, b) 1109 return self.net(a, b, x) 1110 1111 class MyIfByIfNet(nn.Cell): 1112 def __init__(self): 1113 super().__init__() 1114 self.add = P.Add() 1115 self.sub = P.Sub() 1116 self.mul = P.Mul() 1117 self.div = P.RealDiv() 1118 self.net = Branch2Net() 1119 1120 def construct(self, a, b, x): 1121 if a < b: 1122 a = self.add(a, b) 1123 else: 1124 a = self.sub(a, b) 1125 a, b, x = self.net(a, b, x) 1126 a = a * b 1127 out = a + b + x 1128 return out 1129 1130 idx = Tensor(np.array(2), dtype=ms.float32) 1131 end = Tensor(np.array(3), dtype=ms.float32) 1132 x = Tensor(np.array(0), dtype=ms.float32) 1133 # graph mode 1134 context.set_context(mode=context.GRAPH_MODE) 1135 if_net = MyIfByIfNet() 1136 net = if_net 1137 graph_output = net(idx, end, x) 1138 expect = 4.444444 1139 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 1140 1141 1142@pytest.mark.level0 1143@pytest.mark.platform_arm_ascend_training 1144@pytest.mark.platform_x86_gpu_training 1145@pytest.mark.env_onecard 1146def test_if_by_if_forward_control_inside_net(): 1147 class Branch3Net(nn.Cell): 1148 def __init__(self): 1149 super().__init__() 1150 self.add = P.Add() 1151 self.sub = P.Sub() 1152 self.mul = P.Mul() 1153 self.div = P.RealDiv() 1154 1155 def construct(self, a, b, x): 1156 if b == x: 1157 b = self.add(a, b) 1158 else: 1159 b = self.add(a, x) 1160 a = a * b 1161 out = a + b + x 1162 return out 1163 1164 class Branch2Net(nn.Cell): 1165 def __init__(self): 1166 super().__init__() 1167 self.add = P.Add() 1168 self.sub = P.Sub() 1169 self.mul = P.Mul() 1170 self.div = P.RealDiv() 1171 self.net = Branch3Net() 1172 1173 def construct(self, a, b, x): 1174 if a == x: 1175 a = self.mul(a, b) 1176 else: 1177 a = self.div(a, b) 1178 return self.net(a, b, x) 1179 1180 class MyIfByIfNet(nn.Cell): 1181 def __init__(self): 1182 super().__init__() 1183 self.add = P.Add() 1184 self.sub = P.Sub() 1185 self.mul = P.Mul() 1186 self.div = P.RealDiv() 1187 self.net = Branch2Net() 1188 1189 def construct(self, a, b, x): 1190 if a < b: 1191 a = self.add(a, b) 1192 else: 1193 a = self.sub(a, b) 1194 out = self.net(a, b, x) 1195 return out 1196 1197 idx = Tensor(np.array(2), dtype=ms.float32) 1198 end = Tensor(np.array(3), dtype=ms.float32) 1199 x = Tensor(np.array(0), dtype=ms.float32) 1200 # graph mode 1201 context.set_context(mode=context.GRAPH_MODE) 1202 if_net = MyIfByIfNet() 1203 net = if_net 1204 graph_output = net(idx, end, x) 1205 expect = 4.444444 1206 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 1207 1208 1209@pytest.mark.level1 1210@pytest.mark.platform_arm_ascend_training 1211@pytest.mark.platform_x86_ascend_training 1212@pytest.mark.env_onecard 1213def test_if_by_if_forward_use_namespace(): 1214 class MyIfByIfNet(nn.Cell): 1215 def __init__(self): 1216 super().__init__() 1217 self.add = P.Add() 1218 self.sub = P.Sub() 1219 self.mul = P.Mul() 1220 self.div = P.RealDiv() 1221 1222 def construct(self, a, b, x): 1223 if a < b: 1224 a = P.Add()(a, b) 1225 else: 1226 a = P.Sub()(a, b) 1227 if a == x: 1228 a = P.Mul()(a, b) 1229 else: 1230 a = P.RealDiv()(a, b) 1231 if b == x: 1232 b = P.Add()(a, b) 1233 else: 1234 b = P.Add()(a, x) 1235 a = a * b 1236 out = a + b + x 1237 return out 1238 1239 idx = Tensor(np.array(2), dtype=ms.float32) 1240 end = Tensor(np.array(3), dtype=ms.float32) 1241 x = Tensor(np.array(0), dtype=ms.float32) 1242 # graph mode 1243 context.set_context(mode=context.GRAPH_MODE) 1244 if_net = MyIfByIfNet() 1245 net = if_net 1246 graph_output = net(idx, end, x) 1247 expect = 4.444444 1248 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 1249 1250 1251@pytest.mark.level1 1252@pytest.mark.platform_arm_ascend_training 1253@pytest.mark.platform_x86_ascend_training 1254@pytest.mark.env_onecard 1255def test_if_by_if_forward_use_global_op(): 1256 class MyIfByIfNet(nn.Cell): 1257 def __init__(self): 1258 super().__init__() 1259 self.add = P.Add() 1260 self.sub = P.Sub() 1261 self.mul = P.Mul() 1262 self.div = P.RealDiv() 1263 1264 def construct(self, a, b, x): 1265 add = P.Add() 1266 sub = P.Sub() 1267 mul = P.Mul() 1268 div = P.RealDiv() 1269 if a < b: 1270 a = add(a, b) 1271 else: 1272 a = sub(a, b) 1273 if a == x: 1274 a = mul(a, b) 1275 else: 1276 a = div(a, b) 1277 if b == x: 1278 b = add(a, b) 1279 else: 1280 b = add(a, x) 1281 a = a * b 1282 out = a + b + x 1283 return out 1284 1285 idx = Tensor(np.array(2), dtype=ms.float32) 1286 end = Tensor(np.array(3), dtype=ms.float32) 1287 x = Tensor(np.array(0), dtype=ms.float32) 1288 # graph mode 1289 context.set_context(mode=context.GRAPH_MODE) 1290 if_net = MyIfByIfNet() 1291 net = if_net 1292 graph_output = net(idx, end, x) 1293 1294 expect = 4.444444 1295 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 1296 1297 1298@pytest.mark.level1 1299@pytest.mark.platform_arm_ascend_training 1300@pytest.mark.platform_x86_ascend_training 1301@pytest.mark.env_onecard 1302def test_for_with_if_by_if_forward(): 1303 class MyIfByIfNet(nn.Cell): 1304 def __init__(self): 1305 super().__init__() 1306 self.add = P.Add() 1307 self.sub = P.Sub() 1308 1309 def construct(self, a, b, x): 1310 for _ in range(0, 4): 1311 if a < b: 1312 a = self.add(a, b) 1313 else: 1314 b = self.sub(b, x) 1315 a = a * b 1316 out = a + b + x 1317 return out 1318 1319 idx = Tensor(np.array(2), dtype=ms.float32) 1320 end = Tensor(np.array(3), dtype=ms.float32) 1321 x = Tensor(np.array(0), dtype=ms.float32) 1322 # graph mode 1323 context.set_context(mode=context.GRAPH_MODE) 1324 if_net = MyIfByIfNet() 1325 net = if_net 1326 graph_output = net(idx, end, x) 1327 1328 expect = 18.0 1329 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 1330 1331 1332@pytest.mark.level1 1333@pytest.mark.platform_arm_ascend_training 1334@pytest.mark.platform_x86_ascend_training 1335@pytest.mark.env_onecard 1336def test_for_with_if_by_if_forward_namespace(): 1337 class MyIfByIfNet(nn.Cell): 1338 def __init__(self): 1339 super().__init__() 1340 self.add = P.Add() 1341 self.sub = P.Sub() 1342 self.mul = P.Mul() 1343 self.div = P.RealDiv() 1344 1345 def construct(self, a, b, x): 1346 for _ in range(0, 6): 1347 if a < b: 1348 a = P.Add()(a, b) 1349 else: 1350 b = P.Sub()(b, x) 1351 a = a * b 1352 out = a + b + x 1353 return out 1354 1355 idx = Tensor(np.array(2), dtype=ms.float32) 1356 end = Tensor(np.array(3), dtype=ms.float32) 1357 x = Tensor(np.array(0), dtype=ms.float32) 1358 # graph mode 1359 context.set_context(mode=context.GRAPH_MODE) 1360 if_net = MyIfByIfNet() 1361 net = if_net 1362 graph_output = net(idx, end, x) 1363 1364 expect = 18.0 1365 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 1366 1367 1368@pytest.mark.level1 1369@pytest.mark.platform_arm_ascend_training 1370@pytest.mark.platform_x86_ascend_training 1371@pytest.mark.env_onecard 1372def test_if_by_if_forward_const_branch_inner(): 1373 class MyIfByIfNet(nn.Cell): 1374 def __init__(self): 1375 super().__init__() 1376 self.add = P.Add() 1377 self.sub = P.Sub() 1378 self.mul = P.Mul() 1379 self.div = P.RealDiv() 1380 1381 def construct(self, a, b, x): 1382 add = P.Add() 1383 sub = P.Sub() 1384 mul = P.Mul() 1385 div = P.RealDiv() 1386 if a < b: 1387 a = add(a, b) 1388 else: 1389 a = sub(a, b) 1390 if 2 > 1: 1391 a = mul(a, b) 1392 else: 1393 a = div(a, b) 1394 if b == x: 1395 b = add(a, b) 1396 else: 1397 b = add(a, x) 1398 a = a * b 1399 out = a + b + x 1400 return out 1401 1402 idx = Tensor(np.array(2), dtype=ms.float32) 1403 end = Tensor(np.array(3), dtype=ms.float32) 1404 x = Tensor(np.array(0), dtype=ms.float32) 1405 # graph mode 1406 context.set_context(mode=context.GRAPH_MODE) 1407 if_net = MyIfByIfNet() 1408 net = if_net 1409 graph_output = net(idx, end, x) 1410 1411 expect = 240.0 1412 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 1413 1414 1415@pytest.mark.level1 1416@pytest.mark.platform_arm_ascend_training 1417@pytest.mark.platform_x86_ascend_training 1418@pytest.mark.env_onecard 1419def test_if_by_if_forward_all_const_branch(): 1420 class MyIfByIfNet(nn.Cell): 1421 def __init__(self): 1422 super().__init__() 1423 self.add = P.Add() 1424 self.sub = P.Sub() 1425 self.mul = P.Mul() 1426 self.div = P.RealDiv() 1427 1428 def construct(self, a, b, x): 1429 add = P.Add() 1430 sub = P.Sub() 1431 mul = P.Mul() 1432 div = P.RealDiv() 1433 if 2 < 12: 1434 a = add(a, b) 1435 else: 1436 a = sub(a, b) 1437 if 2 > 1: 1438 a = mul(a, b) 1439 else: 1440 a = div(a, b) 1441 if 2 == 1: 1442 b = add(a, b) 1443 else: 1444 b = add(a, x) 1445 a = a * b 1446 out = a + b + x 1447 return out 1448 1449 idx = Tensor(np.array(2), dtype=ms.float32) 1450 end = Tensor(np.array(3), dtype=ms.float32) 1451 x = Tensor(np.array(0), dtype=ms.float32) 1452 # graph mode 1453 context.set_context(mode=context.GRAPH_MODE) 1454 if_net = MyIfByIfNet() 1455 net = if_net 1456 graph_output = net(idx, end, x) 1457 1458 expect = 240.0 1459 assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) 1460 1461 1462@pytest.mark.level1 1463@pytest.mark.platform_x86_cpu 1464@pytest.mark.platform_x86_gpu_training 1465@pytest.mark.env_onecard 1466def test_if_const_grad(): 1467 class MyNet(nn.Cell): 1468 def __init__(self): 1469 super().__init__() 1470 self.add = P.Add() 1471 1472 def construct(self, *inputs): 1473 out = self.add(*inputs) 1474 return out 1475 1476 class GradNet(nn.Cell): 1477 def __init__(self, net): 1478 super(GradNet, self).__init__() 1479 self.net = net 1480 self.weights = ParameterTuple(net.trainable_params()) 1481 1482 def construct(self, *inputs): 1483 a = 1 1484 b = 2 1485 if a > 0: 1486 b = 1 1487 a += b 1488 return grad_by_list(self.net, self.weights)(*inputs) 1489 1490 context.set_context(mode=context.GRAPH_MODE) 1491 my_net = MyNet() 1492 net = GradNet(my_net) 1493 a = Tensor(np.array(0), dtype=ms.int32) 1494 b = Tensor(np.array(1), dtype=ms.int32) 1495 net(a, b) 1496 1497 1498@pytest.mark.level1 1499@pytest.mark.platform_x86_cpu 1500@pytest.mark.platform_x86_gpu_training 1501@pytest.mark.env_onecard 1502def test_if_by_if_const_grad(): 1503 class MyNet(nn.Cell): 1504 def __init__(self): 1505 super().__init__() 1506 self.add = P.Add() 1507 1508 def construct(self, *inputs): 1509 out = self.add(*inputs) 1510 return out 1511 1512 class GradNet(nn.Cell): 1513 def __init__(self, net): 1514 super(GradNet, self).__init__() 1515 self.net = net 1516 self.weights = ParameterTuple(net.trainable_params()) 1517 1518 def construct(self, *inputs): 1519 a = 1 1520 b = 2 1521 if a > 0: 1522 b = 1 1523 if a < 0: 1524 b = 0 1525 if a == 0: 1526 b = 3 1527 a += b 1528 return grad_by_list(self.net, self.weights)(*inputs) 1529 1530 context.set_context(mode=context.GRAPH_MODE) 1531 my_net = MyNet() 1532 net = GradNet(my_net) 1533 a = Tensor(np.array(0), dtype=ms.int32) 1534 b = Tensor(np.array(1), dtype=ms.int32) 1535 net(a, b) 1536 1537 1538@pytest.mark.level1 1539@pytest.mark.platform_x86_cpu 1540@pytest.mark.platform_x86_gpu_training 1541@pytest.mark.env_onecard 1542def test_while_const_grad(): 1543 class MyNet(nn.Cell): 1544 def __init__(self): 1545 super().__init__() 1546 self.add = P.Add() 1547 1548 def construct(self, *inputs): 1549 out = self.add(*inputs) 1550 return out 1551 1552 class GradNet(nn.Cell): 1553 def __init__(self, net): 1554 super(GradNet, self).__init__() 1555 self.net = net 1556 self.weights = ParameterTuple(net.trainable_params()) 1557 1558 def construct(self, *inputs): 1559 a = 1 1560 while a > 1: 1561 a = a - 1 1562 return grad_by_list(self.net, self.weights)(*inputs) 1563 1564 context.set_context(mode=context.GRAPH_MODE) 1565 my_net = MyNet() 1566 net = GradNet(my_net) 1567 a = Tensor(np.array(0), dtype=ms.int32) 1568 b = Tensor(np.array(1), dtype=ms.int32) 1569 net(a, b) 1570 1571 1572@pytest.mark.level1 1573@pytest.mark.platform_x86_cpu 1574@pytest.mark.platform_x86_gpu_training 1575@pytest.mark.env_onecard 1576def test_if_by_while_const_grad(): 1577 class MyNet(nn.Cell): 1578 def __init__(self): 1579 super().__init__() 1580 self.add = P.Add() 1581 1582 def construct(self, *inputs): 1583 out = self.add(*inputs) 1584 return out 1585 1586 class GradNet(nn.Cell): 1587 def __init__(self, net): 1588 super(GradNet, self).__init__() 1589 self.net = net 1590 self.weights = ParameterTuple(net.trainable_params()) 1591 1592 def construct(self, *inputs): 1593 a = 1 1594 b = 2 1595 if a > 0: 1596 b = 0 1597 while a > 1: 1598 a = a - 1 1599 a += b 1600 return grad_by_list(self.net, self.weights)(*inputs) 1601 1602 context.set_context(mode=context.GRAPH_MODE) 1603 my_net = MyNet() 1604 net = GradNet(my_net) 1605 a = Tensor(np.array(0), dtype=ms.int32) 1606 b = Tensor(np.array(1), dtype=ms.int32) 1607 net(a, b) 1608