1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17from mindspore.common import dtype as mstype 18from mindspore import nn 19from mindspore import Tensor 20from mindspore.ops import composite as C 21from mindspore.ops import operations as P 22from mindspore import context 23from mindspore.common.parameter import Parameter 24 25context.set_context(mode=context.GRAPH_MODE, save_graphs=False) 26grad_all = C.GradOperation(get_all=True) 27 28 29class Grad(nn.Cell): 30 def __init__(self, net): 31 super(Grad, self).__init__(auto_prefix=False) 32 self.forward_net = net 33 self.grad = C.GradOperation(get_all=True) 34 35 def construct(self, *inputs): 36 grads = self.grad(self.forward_net)(*inputs) 37 return grads 38 39 40class ForBreakForwardNet(nn.Cell): 41 def __init__(self, max_cycles=10): 42 super(ForBreakForwardNet, self).__init__() 43 self.max_cycles = max_cycles 44 self.zero = Tensor(np.array(0), mstype.int32) 45 46 def construct(self, x, y): 47 out = self.zero 48 for i in range(self.max_cycles): 49 if i % 2 == 0: 50 continue 51 out = x * y + out 52 if out == 20: 53 return out 54 if out > 20: 55 break 56 57 return out 58 59 60@pytest.mark.level1 61@pytest.mark.platform_x86_gpu_training 62@pytest.mark.platform_arm_ascend_training 63@pytest.mark.platform_x86_ascend_training 64@pytest.mark.env_onecard 65def test_for_break_forward(): 66 x = Tensor(np.array(1), mstype.int32) 67 y = Tensor(np.array(3), mstype.int32) 68 forward_net = ForBreakForwardNet(max_cycles=3) 69 graph_out = forward_net(x, y) 70 assert graph_out == Tensor(np.array(3), mstype.int32) 71 72 73@pytest.mark.level0 74@pytest.mark.platform_x86_gpu_training 75@pytest.mark.platform_arm_ascend_training 76@pytest.mark.platform_x86_ascend_training 77@pytest.mark.env_onecard 78def test_for_break_backward(): 79 x = Tensor(np.array(1), mstype.int32) 80 y = Tensor(np.array(3), mstype.int32) 81 forward_net = ForBreakForwardNet(max_cycles=3) 82 backward_net = Grad(forward_net) 83 graph_grads = backward_net(x, y) 84 assert graph_grads == (Tensor(np.array(3), mstype.int32), Tensor(np.array(1), mstype.int32)) 85 86 87class WhileBreakForwardNet(nn.Cell): 88 def __init__(self, max_cycles=10): 89 super(WhileBreakForwardNet, self).__init__() 90 self.max_cycles = max_cycles 91 self.i = Tensor(np.array(0), mstype.int32) 92 self.zero = Tensor(np.array(0), mstype.int32) 93 94 def construct(self, x, y): 95 i = self.i 96 out = self.zero 97 while i < self.max_cycles: 98 if i % 2 == 0: 99 i = i + 1 100 continue 101 out = x * y + out 102 if out > 20: 103 break 104 if out == 20: 105 return out 106 i = i + 1 107 return out 108 109 110@pytest.mark.level1 111@pytest.mark.platform_x86_gpu_training 112@pytest.mark.platform_arm_ascend_training 113@pytest.mark.platform_x86_ascend_training 114@pytest.mark.env_onecard 115def test_while_break_forward(): 116 x = Tensor(np.array(1), mstype.int32) 117 y = Tensor(np.array(3), mstype.int32) 118 forward_net = WhileBreakForwardNet(max_cycles=10) 119 graph_mode_out = forward_net(x, y) 120 assert graph_mode_out == Tensor(np.array(15)) 121 122 123@pytest.mark.level0 124@pytest.mark.platform_arm_ascend_training 125@pytest.mark.platform_x86_ascend_training 126@pytest.mark.env_onecard 127def test_while_break_backward(): 128 context.set_context(mode=context.GRAPH_MODE, save_graphs=True) 129 x = Tensor(np.array(1), mstype.int32) 130 y = Tensor(np.array(3), mstype.int32) 131 forward_net = WhileBreakForwardNet(max_cycles=10) 132 backward_net = Grad(forward_net) 133 graph_grads = backward_net(x, y) 134 assert graph_grads == (Tensor(np.array(15), mstype.int32), Tensor(np.array(5), mstype.int32)) 135 136 137class IfAfterIfInWhileBreakForwardNet(nn.Cell): 138 def __init__(self, max_cycles=10): 139 super(IfAfterIfInWhileBreakForwardNet, self).__init__() 140 self.max_cycles = max_cycles 141 self.i = Tensor(np.array(0), mstype.int32) 142 self.zero = Tensor(np.array(0), mstype.int32) 143 self.weight = Parameter(Tensor(np.array(0), mstype.int32)) 144 145 def construct(self, x, y): 146 i = self.i 147 out = self.zero 148 while i < self.max_cycles: 149 self.weight = i 150 if self.weight % 2 == 0: 151 i = i + 1 152 continue 153 if out <= 20: 154 self.weight = i 155 out = x * y + out 156 else: 157 break 158 i = i + 1 159 if out >= 30: 160 self.weight = out 161 out = out - 30 162 return out 163 out = out + 1 164 return out 165 166 167@pytest.mark.level1 168@pytest.mark.platform_x86_gpu_training 169@pytest.mark.platform_arm_ascend_training 170@pytest.mark.platform_x86_ascend_training 171@pytest.mark.env_onecard 172def test_if_after_if_in_while_break_forward(): 173 x = Tensor(np.array(1), mstype.int32) 174 y = Tensor(np.array(3), mstype.int32) 175 # Graph Mode 176 context.set_context(mode=context.GRAPH_MODE, save_graphs=False) 177 graph_forward_net = IfAfterIfInWhileBreakForwardNet(max_cycles=10) 178 graph_mode_out = graph_forward_net(x, y) 179 assert graph_mode_out == Tensor(np.array(16), mstype.int32) 180 181 182@pytest.mark.level1 183@pytest.mark.platform_x86_gpu_training 184@pytest.mark.platform_arm_ascend_training 185@pytest.mark.platform_x86_ascend_training 186@pytest.mark.env_onecard 187def test_if_after_if_in_while_break_backward(): 188 x = Tensor(np.array(1), mstype.int32) 189 y = Tensor(np.array(3), mstype.int32) 190 # Graph Mode 191 context.set_context(mode=context.GRAPH_MODE) 192 graph_forward_net = IfAfterIfInWhileBreakForwardNet(max_cycles=10) 193 graph_backward_net = Grad(graph_forward_net) 194 graph_mode_grads = graph_backward_net(x, y) 195 196 assert graph_mode_grads == (Tensor(np.array(15), mstype.int32), Tensor(np.array(5), mstype.int32)) 197 198 199@pytest.mark.level1 200@pytest.mark.platform_x86_gpu_training 201@pytest.mark.platform_arm_ascend_training 202@pytest.mark.platform_x86_ascend_training 203@pytest.mark.env_onecard 204def test_if_after_for_in_if_break(): 205 class IfAfterForInIfNet(nn.Cell): 206 def __init__(self): 207 super().__init__() 208 self.param_a = Parameter(Tensor(5, mstype.int32), name='a') 209 self.param_b = Parameter(Tensor(4, mstype.int32), name='b') 210 211 def construct(self, x): 212 out = x + self.param_a 213 if self.param_a > self.param_b: 214 for _ in range(4): 215 self.param_a += 1 216 if self.param_b < 0: 217 continue 218 self.param_b -= 3 219 if self.param_a > 6: 220 break 221 222 self.param_b += 15 223 if x < self.param_b: 224 out -= self.param_b 225 return out 226 out = self.param_b + out 227 return out 228 229 x = Tensor(2, mstype.int32) 230 231 # graph mode 232 233 forward_net = IfAfterForInIfNet() 234 graph_forward_res = forward_net(x) 235 236 context.set_context(mode=context.GRAPH_MODE) 237 if_after_for_in_if_net = IfAfterForInIfNet() 238 net = Grad(if_after_for_in_if_net) 239 graph_backward_res = net(x) 240 241 assert graph_forward_res == Tensor(-6, mstype.int32) 242 assert graph_backward_res == (Tensor(1, mstype.int32),) 243 244 245@pytest.mark.skip(reason="ME EvalCNode error.") 246@pytest.mark.level0 247@pytest.mark.platform_x86_gpu_training 248@pytest.mark.platform_arm_ascend_training 249@pytest.mark.platform_x86_ascend_training 250@pytest.mark.env_onecard 251def test_if_after_for_in_for_break(): 252 class IfAfterForInForNet(nn.Cell): 253 def __init__(self): 254 super().__init__() 255 self.param_a = Parameter(Tensor(5, mstype.int32), name='a') 256 self.param_b = Parameter(Tensor(2, mstype.int32), name='b') 257 258 def construct(self, x): 259 out = x + self.param_a 260 for _ in range(0, 10): 261 x *= 2 262 if self.param_a % 2 == 0: 263 self.param_a += 1 264 continue 265 for _ in range(0, 5): 266 self.param_a += 1 267 x += self.param_b 268 if x > 10: 269 break 270 if x > 100: 271 return x 272 if self.param_a > self.param_b: 273 out += x 274 return out 275 276 x = Tensor(2, mstype.int32) 277 278 # graph mode 279 forward_net = IfAfterForInForNet() 280 graph_forward_res = forward_net(x) 281 282 if_after_for_in_for_net = IfAfterForInForNet() 283 net = Grad(if_after_for_in_for_net) 284 graph_backward_res = net(x) 285 286 print("test_if_after_for_in_for_break graph_forward_res:", graph_forward_res) 287 print("test_if_after_for_in_for_break graph_backward_res:", graph_backward_res) 288 # assert graph_forward_res == Tensor(12285, mstype.int32) 289 # assert graph_backward_res == (Tensor(1025, mstype.int32),) 290 291 292class WhileAfterWhileInWhileBreakForwardNet(nn.Cell): 293 def __init__(self, max_cycles=10): 294 super(WhileAfterWhileInWhileBreakForwardNet, self).__init__() 295 self.max_cycles = max_cycles 296 self.zero = Tensor(np.array(0), mstype.int32) 297 self.i = Tensor(np.array(0), mstype.int32) 298 299 def construct(self, x, y): 300 out = self.zero 301 i = self.i 302 while i < self.max_cycles: 303 j = self.i 304 while j < self.max_cycles + 3: 305 out = x * y + out 306 j = j + 1 307 if j > 4: 308 break 309 i = i + 1 310 if i > 2: 311 break 312 i = self.i 313 while i < self.max_cycles: 314 out = x * y + out 315 i = i + 1 316 return out 317 318 319@pytest.mark.level1 320@pytest.mark.platform_x86_gpu_training 321@pytest.mark.platform_arm_ascend_training 322@pytest.mark.platform_x86_ascend_training 323@pytest.mark.env_onecard 324def test_while_after_while_in_while_break_forward(): 325 context.set_context(mode=context.GRAPH_MODE) 326 x = Tensor(np.array(1), mstype.int32) 327 y = Tensor(np.array(3), mstype.int32) 328 forward_net = WhileAfterWhileInWhileBreakForwardNet(max_cycles=3) 329 graph_out = forward_net(x, y) 330 331 assert graph_out == Tensor(np.array(54), mstype.int32) 332 333 334@pytest.mark.level1 335@pytest.mark.platform_x86_gpu_training 336@pytest.mark.platform_arm_ascend_training 337@pytest.mark.platform_x86_ascend_training 338@pytest.mark.env_onecard 339def test_while_after_while_in_while_break_backward(): 340 context.set_context(mode=context.GRAPH_MODE) 341 x = Tensor(np.array(1), mstype.int32) 342 y = Tensor(np.array(3), mstype.int32) 343 forward_net = WhileAfterWhileInWhileBreakForwardNet(max_cycles=3) 344 backward_net = Grad(forward_net) 345 graph_grads = backward_net(x, y) 346 347 assert graph_grads == (Tensor(np.array(54), mstype.int32), Tensor(np.array(18), mstype.int32)) 348 349 350class TwoBreakDeadForwardNet(nn.Cell): 351 def __init__(self): 352 super(TwoBreakDeadForwardNet, self).__init__() 353 self.zero = Tensor(np.array(0), mstype.int32) 354 355 def construct(self, x): 356 while x < 5: 357 if x > 3: 358 x -= 2 359 elif x == 3: 360 break 361 else: 362 break 363 x = x + 1 364 return x 365 366 367@pytest.mark.level1 368@pytest.mark.platform_x86_gpu_training 369@pytest.mark.platform_arm_ascend_training 370@pytest.mark.platform_x86_ascend_training 371@pytest.mark.env_onecard 372def test_2break_dead_block(): 373 x = Tensor(np.array(1), mstype.int32) 374 forward_net = TwoBreakDeadForwardNet() 375 graph_out = forward_net(x) 376 377 assert graph_out == Tensor(np.array(1), mstype.int32) 378 379 380class ForInFor2BreakForwardNet(nn.Cell): 381 def __init__(self): 382 super(ForInFor2BreakForwardNet, self).__init__() 383 self.relu = P.ReLU() 384 self.add = P.TensorAdd() 385 386 def construct(self, x, y, z): 387 out = z 388 for _ in range(2): 389 for _ in range(3): 390 if 2 * x < y: 391 out = self.add(out, out) 392 x = x + 1 393 if x + 6 == y: 394 break 395 out = self.relu(out) 396 return out 397 398 399@pytest.mark.skip(reason="Get wrong parent graph") 400def test_for_in_for_break(): 401 x = Tensor(np.array(7), mstype.float32) 402 y = Tensor(np.array(20), mstype.float32) 403 z = Tensor(np.array(2), mstype.float32) 404 forward_net = ForInFor2BreakForwardNet() 405 graph_out = forward_net(x, y, z) 406 print("test_for_in_for_break graph out:", graph_out) 407 408 409# raise a endless loop exception. 410@pytest.mark.skip(reason="Infer raise a endless loop exception") 411def test_while_true_break(): 412 context.set_context(save_graphs=True) 413 414 class WhileTrueBreakNet(nn.Cell): 415 def __init__(self, t): 416 super(WhileTrueBreakNet, self).__init__() 417 self.add = P.Add() 418 self.mul = P.Mul() 419 self.para = Parameter(Tensor(t, mstype.int32), name="a") 420 421 def construct(self, x, y): 422 out = self.mul(y, self.para) 423 while True: 424 if x == 5: 425 x = x - 3 426 continue 427 if x == 2: 428 break 429 out = self.add(out, out) 430 return out 431 432 t = np.array([1]).astype(np.int32) 433 y = Tensor([1], mstype.int32) 434 x = Tensor([5], mstype.int32) 435 net = WhileTrueBreakNet(t) 436 grad_net = Grad(net) 437 grad_out = grad_net(x, y) 438 print(grad_out) 439 440 441# stuck in vm backend 442@pytest.mark.skip(reason="Stuck in vm backend") 443def test_continue_stuck_in_vm(): 444 context.set_context(save_graphs=True) 445 446 class NetWork(nn.Cell): 447 def __init__(self, t): 448 super().__init__() 449 self.add = P.Add() 450 self.mul = P.Mul() 451 self.para = Parameter(Tensor(t, mstype.int32), name="a") 452 453 def construct(self, x, y): 454 out = self.mul(y, y) 455 while x != 3: 456 while self.para > 5: 457 # self.param -= 1 if set after if_switch, which is wrong 458 self.para -= 1 459 x += 1 460 if x > 3: 461 self.para -= x 462 return out 463 out = self.add(out, y) 464 continue 465 out = self.mul(out, y) 466 return out 467 468 x = Tensor(2, mstype.int32) 469 t = 8 470 y = Tensor(1, mstype.int32) 471 net = NetWork(t) 472 grad_net = Grad(net) 473 grad = grad_net(x, y) 474 print(grad) 475