1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16from mindspore.nn import Cell 17from mindspore.common import Tensor, Parameter 18import mindspore.ops.operations as P 19from mindspore import context, ops, lazy_inline, nn 20 21context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 22context.set_context(jit_level='O2') 23 24 25class Grad(Cell): 26 def __init__(self, net): 27 super(Grad, self).__init__() 28 self.grad = ops.GradOperation() 29 self.net = net 30 31 def construct(self, x): 32 grad_net = self.grad(self.net) 33 return grad_net(x) 34 35 36class Block(Cell): 37 def __init__(self): 38 super(Block, self).__init__() 39 self.transpose1 = P.Transpose() 40 self.transpose2 = P.Transpose() 41 self.transpose3 = P.Transpose() 42 self.transpose4 = P.Transpose() 43 self.real_div1 = P.RealDiv() 44 self.real_div2 = P.RealDiv() 45 self.batch_matmul1 = P.BatchMatMul() 46 self.batch_matmul2 = P.BatchMatMul() 47 self.add = P.Add() 48 self.softmax = P.Softmax(-1) 49 self.dropout = P.Dropout(0.9) 50 self.expand_dims = P.ExpandDims() 51 self.sub = P.Sub() 52 self.mul = P.Mul() 53 self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32))) 54 55 def construct(self, x): 56 transpose1 = self.transpose1(x, (0, 2, 1, 3)) 57 real_div1 = self.real_div1(transpose1, Tensor(2.37891)) 58 transpose2 = self.transpose2(x, (0, 2, 3, 1)) 59 real_div2 = self.real_div2(transpose2, Tensor(2.37891)) 60 batch_matmul1 = self.batch_matmul1(real_div1, real_div2) 61 expand_dims = self.expand_dims(self.y, 1) 62 sub = self.sub(Tensor([1.0]), expand_dims) 63 mul = self.mul(sub, Tensor([-0.0001])) 64 add = self.add(mul, batch_matmul1) 65 soft_max = self.softmax(add) 66 dropout = self.dropout(soft_max) 67 transpose3 = self.transpose3(x, (0, 2, 1, 3)) 68 batch_matmul2 = self.batch_matmul2(dropout[0], transpose3) 69 transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3)) 70 return transpose4 71 72 73class TestBlock(Cell): 74 def __init__(self): 75 super(TestBlock, self).__init__() 76 self.y = Parameter(Tensor(5)) 77 78 def construct(self, x): 79 x = x + self.y 80 x = x + self.y * 2 81 x = x - 9 82 return x 83 84 85class TestIfBlock(Cell): 86 def __init__(self): 87 super(TestIfBlock, self).__init__() 88 self.y = Parameter(Tensor(5)) 89 90 def construct(self, x): 91 if x > 10: 92 x = x + self.y * 2 93 else: 94 x = x + self.y 95 x = x - 9 96 return x 97 98 99def test_recompute_block_recompute(): 100 """ 101 Feature: Recompute with lazy inline. 102 Description: Each block is set recompute by the cell recompute api. 103 Expectation: Run successfully and the memory usage is reduced. 104 """ 105 106 class OuterBlock(Cell): 107 @lazy_inline 108 def __init__(self): 109 super(OuterBlock, self).__init__() 110 self.block = Block() 111 112 def construct(self, x): 113 return self.block(x) 114 115 class Net(Cell): 116 def __init__(self): 117 super(Net, self).__init__() 118 self.blocks = nn.CellList() 119 for _ in range(3): 120 b = OuterBlock() 121 b.recompute() 122 self.blocks.append(b) 123 124 def construct(self, x): 125 out = x 126 for i in range(3): 127 out = self.blocks[i](out) 128 return out 129 130 x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32)) 131 net = Net() 132 grad_net = Grad(net) 133 grad_net(x) 134 135 136def test_recompute_op_recompute1(): 137 """ 138 Feature: Recompute with lazy inline. 139 Description: Each block is set recompute by the primitive recompute api. 140 Expectation: Run successfully and the memory usage is reduced. 141 """ 142 143 class OuterBlock(Cell): 144 @lazy_inline 145 def __init__(self): 146 super(OuterBlock, self).__init__() 147 self.block = Block() 148 self.block.real_div1.recompute() 149 self.block.batch_matmul1.recompute() 150 self.block.add.recompute() 151 self.block.softmax.recompute() 152 153 def construct(self, x): 154 return self.block(x) 155 156 class Net(Cell): 157 def __init__(self): 158 super(Net, self).__init__() 159 self.blocks = nn.CellList() 160 for _ in range(3): 161 b = OuterBlock() 162 self.blocks.append(b) 163 164 def construct(self, x): 165 out = x 166 for i in range(3): 167 out = self.blocks[i](out) 168 return out 169 170 x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32)) 171 net = Net() 172 grad_net = Grad(net) 173 grad_net(x) 174 175 176def test_recompute_op_recompute2(): 177 """ 178 Feature: Recompute with lazy inline. 179 Description: Each block is set recompute by the primitive recompute api. 180 Expectation: Run successfully and the memory usage is reduced. 181 """ 182 183 class OuterBlock(Cell): 184 @lazy_inline 185 def __init__(self): 186 super(OuterBlock, self).__init__() 187 self.block = Block() 188 self.block.transpose1.recompute() 189 self.block.transpose2.recompute() 190 self.block.real_div1.recompute() 191 self.block.real_div2.recompute() 192 self.block.batch_matmul1.recompute() 193 self.block.add.recompute() 194 self.block.softmax.recompute() 195 self.block.dropout.recompute() 196 197 def construct(self, x): 198 return self.block(x) 199 200 class Net(Cell): 201 def __init__(self): 202 super(Net, self).__init__() 203 self.blocks = nn.CellList() 204 for _ in range(3): 205 b = OuterBlock() 206 self.blocks.append(b) 207 208 def construct(self, x): 209 out = x 210 for i in range(3): 211 out = self.blocks[i](out) 212 return out 213 214 x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32)) 215 net = Net() 216 grad_net = Grad(net) 217 grad_net(x) 218 219 220def test_recompute_op_recompute3(): 221 """ 222 Feature: Recompute with lazy inline. 223 Description: Each block is set recompute by the primitive recompute api. 224 Expectation: Run successfully and the memory usage is reduced. 225 """ 226 227 class Block1(Cell): 228 def __init__(self): 229 super(Block1, self).__init__() 230 self.transpose1 = P.Transpose() 231 self.transpose2 = P.Transpose() 232 self.transpose3 = P.Transpose() 233 self.transpose4 = P.Transpose() 234 self.real_div1 = P.RealDiv() 235 self.real_div2 = P.RealDiv() 236 self.batch_matmul1 = P.BatchMatMul() 237 self.batch_matmul2 = P.BatchMatMul() 238 self.add = P.Add() 239 self.softmax = P.Softmax(-1) 240 self.dropout = P.Dropout(0.9) 241 self.expand_dims = P.ExpandDims() 242 self.sub1 = P.Sub() 243 self.sub2 = P.Sub() 244 self.mul = P.Mul() 245 self.y = Parameter(Tensor(np.ones((8, 16, 128, 128)).astype(np.float32))) 246 247 def construct(self, x): 248 transpose1 = self.transpose1(x, (0, 2, 1, 3)) 249 real_div1 = self.real_div1(transpose1, Tensor(2.37891)) 250 sub1 = self.sub1(Tensor([1.0]), transpose1) 251 sub2 = self.sub2(Tensor([1.0]), sub1) 252 mul = self.mul(sub2, Tensor([-0.0001])) 253 add = self.add(mul, real_div1) 254 soft_max = self.softmax(add) 255 dropout = self.dropout(soft_max) 256 transpose3 = self.transpose3(x, (0, 2, 1, 3)) 257 batch_matmul2 = self.batch_matmul2(dropout[0], transpose3) 258 transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3)) 259 return transpose4 260 261 class OuterBlock(Cell): 262 @lazy_inline 263 def __init__(self): 264 super(OuterBlock, self).__init__() 265 self.block = Block1() 266 self.block.mul.recompute() 267 self.block.real_div1.recompute() 268 self.block.transpose1.recompute() 269 self.block.sub1.recompute() 270 self.block.add.recompute() 271 self.block.softmax.recompute() 272 273 def construct(self, x): 274 return self.block(x) 275 276 class Net(Cell): 277 def __init__(self): 278 super(Net, self).__init__() 279 self.blocks = nn.CellList() 280 for _ in range(3): 281 b = OuterBlock() 282 self.blocks.append(b) 283 284 def construct(self, x): 285 out = x 286 for i in range(3): 287 out = self.blocks[i](out) 288 return out 289 290 x = Tensor(np.ones((8, 128, 16, 128)).astype(np.float32)) 291 net = Net() 292 grad_net = Grad(net) 293 grad_net(x) 294 295 296def test_recompute_cell_and_op_recompute1(): 297 """ 298 Feature: Recompute with lazy inline. 299 Description: Each block is set recompute by both the primitive and cell recompute api. 300 Expectation: Run successfully and the memory usage is reduced. 301 """ 302 303 class Net1(Cell): 304 def __init__(self): 305 super(Net1, self).__init__() 306 self.transpose2 = P.Transpose() 307 self.real_div2 = P.RealDiv() 308 309 def construct(self, x): 310 transpose2 = self.transpose2(x, (0, 2, 3, 1)) 311 real_div2 = self.real_div2(transpose2, Tensor(2.37891)) 312 return real_div2 313 314 class Block1(Cell): 315 def __init__(self): 316 super(Block1, self).__init__() 317 self.transpose1 = P.Transpose() 318 self.transpose2 = P.Transpose() 319 self.transpose3 = P.Transpose() 320 self.transpose4 = P.Transpose() 321 self.real_div1 = P.RealDiv() 322 self.real_div1.recompute() 323 self.real_div2 = P.RealDiv() 324 self.batch_matmul1 = P.BatchMatMul() 325 self.batch_matmul1.recompute() 326 self.batch_matmul2 = P.BatchMatMul() 327 self.add = P.Add() 328 self.add.recompute() 329 self.softmax = P.Softmax(-1) 330 self.softmax.recompute() 331 self.dropout = P.Dropout(0.9) 332 self.expand_dims = P.ExpandDims() 333 self.sub = P.Sub() 334 self.mul = P.Mul() 335 self.net1 = Net1() 336 self.net1.recompute() 337 self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32))) 338 339 def construct(self, x): 340 transpose1 = self.transpose1(x, (0, 2, 1, 3)) 341 real_div1 = self.real_div1(transpose1, Tensor(2.37891)) 342 real_div2 = self.net1(x) 343 batch_matmul1 = self.batch_matmul1(real_div1, real_div2) 344 expand_dims = self.expand_dims(self.y, 1) 345 sub = self.sub(Tensor([1.0]), expand_dims) 346 mul = self.mul(sub, Tensor([-0.0001])) 347 add = self.add(mul, batch_matmul1) 348 soft_max = self.softmax(add) 349 dropout = self.dropout(soft_max) 350 transpose3 = self.transpose3(x, (0, 2, 1, 3)) 351 batch_matmul2 = self.batch_matmul2(dropout[0], transpose3) 352 transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3)) 353 return transpose4 354 355 class OuterBlock(Cell): 356 @lazy_inline 357 def __init__(self): 358 super(OuterBlock, self).__init__() 359 self.block = Block1() 360 361 def construct(self, x): 362 return self.block(x) 363 364 class Net(Cell): 365 def __init__(self): 366 super(Net, self).__init__() 367 self.blocks = nn.CellList() 368 for _ in range(3): 369 b = OuterBlock() 370 self.blocks.append(b) 371 372 def construct(self, x): 373 out = x 374 for i in range(3): 375 out = self.blocks[i](out) 376 return out 377 378 x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32)) 379 net = Net() 380 grad_net = Grad(net) 381 grad_net(x) 382 383 384def test_recompute_cell_and_op_recompute2(): 385 """ 386 Feature: Recompute with lazy inline. 387 Description: Each block is set recompute by both the primitive and cell recompute api. 388 Expectation: Run successfully and the memory usage is reduced. 389 """ 390 391 class Net1(Cell): 392 def __init__(self): 393 super(Net1, self).__init__() 394 self.transpose2 = P.Transpose() 395 self.real_div2 = P.RealDiv() 396 397 def construct(self, x): 398 transpose2 = self.transpose2(x, (0, 2, 3, 1)) 399 real_div2 = self.real_div2(transpose2, Tensor(2.37891)) 400 return real_div2 401 402 class Block1(Cell): 403 def __init__(self): 404 super(Block1, self).__init__() 405 self.transpose1 = P.Transpose() 406 self.transpose2 = P.Transpose() 407 self.transpose3 = P.Transpose() 408 self.transpose4 = P.Transpose() 409 self.real_div1 = P.RealDiv() 410 self.real_div1.recompute() 411 self.real_div2 = P.RealDiv() 412 self.batch_matmul1 = P.BatchMatMul() 413 self.batch_matmul1.recompute() 414 self.batch_matmul2 = P.BatchMatMul() 415 self.add = P.Add() 416 self.add.recompute() 417 self.softmax = P.Softmax(-1) 418 self.softmax.recompute() 419 self.dropout = P.Dropout(0.9) 420 self.expand_dims = P.ExpandDims() 421 self.sub = P.Sub() 422 self.mul = P.Mul() 423 self.depend = ops.Depend() 424 self.net1 = Net1() 425 self.net1.recompute() 426 self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32))) 427 428 def construct(self, x): 429 real_div2 = self.net1(x) 430 depend = self.depend(x, real_div2) 431 transpose1 = self.transpose1(depend, (0, 2, 1, 3)) 432 real_div1 = self.real_div1(transpose1, Tensor(2.37891)) 433 batch_matmul1 = self.batch_matmul1(real_div1, real_div2) 434 expand_dims = self.expand_dims(self.y, 1) 435 sub = self.sub(Tensor([1.0]), expand_dims) 436 mul = self.mul(sub, Tensor([-0.0001])) 437 add = self.add(mul, batch_matmul1) 438 soft_max = self.softmax(add) 439 dropout = self.dropout(soft_max) 440 transpose3 = self.transpose3(x, (0, 2, 1, 3)) 441 batch_matmul2 = self.batch_matmul2(dropout[0], transpose3) 442 transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3)) 443 return transpose4 444 445 class OuterBlock(Cell): 446 @lazy_inline 447 def __init__(self): 448 super(OuterBlock, self).__init__() 449 self.block = Block1() 450 451 def construct(self, x): 452 return self.block(x) 453 454 class Net(Cell): 455 def __init__(self): 456 super(Net, self).__init__() 457 self.blocks = nn.CellList() 458 for _ in range(3): 459 b = OuterBlock() 460 self.blocks.append(b) 461 462 def construct(self, x): 463 out = x 464 for i in range(3): 465 out = self.blocks[i](out) 466 return out 467 468 x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32)) 469 net = Net() 470 grad_net = Grad(net) 471 grad_net(x) 472 473 474def test_recompute_cell_and_op_recompute_with_tuple_outputs1(): 475 """ 476 Feature: Recompute with lazy inline. 477 Description: Each block is set recompute by both the primitive and cell recompute api and return a tuple. 478 Expectation: Run successfully and the memory usage is reduced. 479 """ 480 481 class Net1(Cell): 482 def __init__(self): 483 super(Net1, self).__init__() 484 self.transpose2 = P.Transpose() 485 self.real_div2 = P.RealDiv() 486 487 def construct(self, x): 488 transpose2 = self.transpose2(x, (0, 2, 3, 1)) 489 real_div2 = self.real_div2(transpose2, Tensor(2.37891)) 490 return real_div2 491 492 class Block1(Cell): 493 def __init__(self): 494 super(Block1, self).__init__() 495 self.transpose1 = P.Transpose() 496 self.transpose2 = P.Transpose() 497 self.transpose3 = P.Transpose() 498 self.transpose4 = P.Transpose() 499 self.transpose4.recompute() 500 self.real_div1 = P.RealDiv() 501 self.real_div1.recompute() 502 self.real_div2 = P.RealDiv() 503 self.batch_matmul1 = P.BatchMatMul() 504 self.batch_matmul1.recompute() 505 self.batch_matmul2 = P.BatchMatMul() 506 self.add = P.Add() 507 self.add.recompute() 508 self.add1 = P.Add() 509 self.softmax = P.Softmax(-1) 510 self.softmax.recompute() 511 self.dropout = P.Dropout(0.9) 512 self.expand_dims = P.ExpandDims() 513 self.sub = P.Sub() 514 self.mul = P.Mul() 515 self.net1 = Net1() 516 self.net1.recompute() 517 self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32))) 518 519 def construct(self, x, z): 520 transpose1 = self.transpose1(x, (0, 2, 1, 3)) 521 real_div1 = self.real_div1(transpose1, Tensor(2.37891)) 522 real_div2 = self.net1(x) 523 batch_matmul1 = self.batch_matmul1(real_div1, real_div2) 524 expand_dims = self.expand_dims(self.y, 1) 525 sub = self.sub(Tensor([1.0]), expand_dims) 526 mul = self.mul(sub, Tensor([-0.0001])) 527 add = self.add(mul, batch_matmul1) 528 soft_max = self.softmax(add) 529 dropout = self.dropout(soft_max) 530 transpose3 = self.transpose3(x, (0, 2, 1, 3)) 531 batch_matmul2 = self.batch_matmul2(dropout[0], transpose3) 532 transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3)) 533 add1 = self.add1(transpose4, z) 534 return add1, transpose4 535 536 class OuterBlock(Cell): 537 @lazy_inline 538 def __init__(self): 539 super(OuterBlock, self).__init__() 540 self.block = Block1() 541 542 def construct(self, x, z): 543 return self.block(x, z) 544 545 class Net(Cell): 546 def __init__(self): 547 super(Net, self).__init__() 548 self.blocks = nn.CellList() 549 for _ in range(3): 550 b = OuterBlock() 551 self.blocks.append(b) 552 553 def construct(self, x): 554 out1, out2 = x, x 555 for i in range(3): 556 out1, out2 = self.blocks[i](out1, out2) 557 return out1, out2 558 559 x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32)) 560 net = Net() 561 grad_net = Grad(net) 562 grad_net(x) 563 564 565def test_recompute_cell_and_op_recompute_with_tuple_outputs2(): 566 """ 567 Feature: Recompute with lazy inline. 568 Description: Each block is set recompute by both the primitive and cell recompute api and return a tuple. 569 Expectation: Run successfully and the memory usage is reduced. 570 """ 571 572 class Net1(Cell): 573 def __init__(self): 574 super(Net1, self).__init__() 575 self.transpose2 = P.Transpose() 576 self.real_div2 = P.RealDiv() 577 578 def construct(self, x): 579 transpose2 = self.transpose2(x, (0, 2, 3, 1)) 580 real_div2 = self.real_div2(transpose2, Tensor(2.37891)) 581 return real_div2 582 583 class Block1(Cell): 584 def __init__(self): 585 super(Block1, self).__init__() 586 self.transpose1 = P.Transpose() 587 self.transpose2 = P.Transpose() 588 self.transpose3 = P.Transpose() 589 self.transpose4 = P.Transpose() 590 self.transpose4.recompute() 591 self.real_div1 = P.RealDiv() 592 self.real_div1.recompute() 593 self.real_div2 = P.RealDiv() 594 self.batch_matmul1 = P.BatchMatMul() 595 self.batch_matmul1.recompute() 596 self.batch_matmul2 = P.BatchMatMul() 597 self.add = P.Add() 598 self.add.recompute() 599 self.add1 = P.Add() 600 self.add1.recompute() 601 self.softmax = P.Softmax(-1) 602 self.softmax.recompute() 603 self.dropout = P.Dropout(0.9) 604 self.expand_dims = P.ExpandDims() 605 self.sub = P.Sub() 606 self.mul = P.Mul() 607 self.net1 = Net1() 608 self.net1.recompute() 609 self.y = Parameter(Tensor(np.ones((8, 128, 128)).astype(np.float32))) 610 611 def construct(self, x, z): 612 transpose1 = self.transpose1(x, (0, 2, 1, 3)) 613 real_div1 = self.real_div1(transpose1, Tensor(2.37891)) 614 real_div2 = self.net1(x) 615 batch_matmul1 = self.batch_matmul1(real_div1, real_div2) 616 expand_dims = self.expand_dims(self.y, 1) 617 sub = self.sub(Tensor([1.0]), expand_dims) 618 mul = self.mul(sub, Tensor([-0.0001])) 619 add = self.add(mul, batch_matmul1) 620 soft_max = self.softmax(add) 621 dropout = self.dropout(soft_max) 622 transpose3 = self.transpose3(x, (0, 2, 1, 3)) 623 batch_matmul2 = self.batch_matmul2(dropout[0], transpose3) 624 transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3)) 625 add1 = self.add1(transpose4, z) 626 return add1, transpose4 627 628 class OuterBlock(Cell): 629 @lazy_inline 630 def __init__(self): 631 super(OuterBlock, self).__init__() 632 self.block = Block1() 633 634 def construct(self, x, z): 635 return self.block(x, z) 636 637 class Net(Cell): 638 def __init__(self): 639 super(Net, self).__init__() 640 self.blocks = nn.CellList() 641 for _ in range(3): 642 b = OuterBlock() 643 self.blocks.append(b) 644 645 def construct(self, x): 646 out1, out2 = x, x 647 for i in range(3): 648 out1, out2 = self.blocks[i](out1, out2) 649 return out1, out2 650 651 x = Tensor(np.ones((8, 128, 16, 32)).astype(np.float32)) 652 net = Net() 653 grad_net = Grad(net) 654 grad_net(x) 655