1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Parameter, Tensor, context 20from mindspore.common.api import _cell_graph_executor 21from mindspore.ops import composite as C 22from mindspore.ops import operations as P 23from tests.ut.python.ops.test_math_ops import VirtualLoss 24 25 26grad_all = C.GradOperation(get_all=True) 27 28 29class NetWithLoss(nn.Cell): 30 def __init__(self, network): 31 super(NetWithLoss, self).__init__() 32 self.loss = VirtualLoss() 33 self.network = network 34 35 def construct(self, x, y, b): 36 predict = self.network(x, y, b) 37 return self.loss(predict) 38 39 40class GradWrap(nn.Cell): 41 def __init__(self, network): 42 super(GradWrap, self).__init__() 43 self.network = network 44 45 def construct(self, x, y, b): 46 return grad_all(self.network)(x, y, b) 47 48 49def compile_net(net, x, y, b): 50 net.set_auto_parallel() 51 net.set_train() 52 _cell_graph_executor.compile(net, x, y, b) 53 54 55def test_matmul_sub(): 56 class Net(nn.Cell): 57 def __init__(self, strategy1, strategy2): 58 super().__init__() 59 self.matmul = P.MatMul().shard(strategy1) 60 self.sub = P.Sub().shard(strategy2) 61 62 def construct(self, x, y, b): 63 out = self.matmul(x, y) 64 out = self.sub(out, b) 65 return out 66 67 context.set_auto_parallel_context(device_num=8, global_rank=0) 68 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 69 strategy1 = ((2, 2), (2, 2)) 70 strategy2 = ((4, 2), (4, 2)) 71 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 72 73 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 74 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 75 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 76 compile_net(net, x, y, b) 77 78 79def test_matmul_add(): 80 class Net(nn.Cell): 81 def __init__(self, strategy1, strategy2): 82 super().__init__() 83 self.matmul = P.MatMul().shard(strategy1) 84 self.add = P.Add().shard(strategy2) 85 86 def construct(self, x, y, b): 87 out = self.matmul(x, y) 88 out = self.add(out, b) 89 return out 90 91 context.set_auto_parallel_context(device_num=8, global_rank=0) 92 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 93 strategy1 = ((2, 2), (2, 2)) 94 strategy2 = ((4, 2), (4, 2)) 95 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 96 97 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 98 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 99 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 100 compile_net(net, x, y, b) 101 102 103def test_matmul_mul(): 104 class Net(nn.Cell): 105 def __init__(self, strategy1, strategy2): 106 super().__init__() 107 self.matmul = P.MatMul().shard(strategy1) 108 self.mul = P.Mul().shard(strategy2) 109 110 def construct(self, x, y, b): 111 out = self.matmul(x, y) 112 out = self.mul(out, b) 113 return out 114 115 context.set_auto_parallel_context(device_num=8, global_rank=0) 116 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 117 strategy1 = ((2, 2), (2, 2)) 118 strategy2 = ((4, 2), (4, 2)) 119 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 120 121 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 122 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 123 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 124 compile_net(net, x, y, b) 125 126def test_matmul_mod(): 127 class Net(nn.Cell): 128 def __init__(self, strategy1, strategy2): 129 super().__init__() 130 self.matmul = P.MatMul().shard(strategy1) 131 self.mod = P.Mod().shard(strategy2) 132 133 def construct(self, x, y, b): 134 out = self.matmul(x, y) 135 out = self.mod(out, b) 136 return out 137 138 context.set_auto_parallel_context(device_num=8, global_rank=0) 139 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 140 strategy1 = ((2, 2), (2, 2)) 141 strategy2 = ((4, 2), (4, 2)) 142 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 143 144 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 145 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 146 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 147 compile_net(net, x, y, b) 148 149def test_matmul_floormod(): 150 class Net(nn.Cell): 151 def __init__(self, strategy1, strategy2): 152 super().__init__() 153 self.matmul = P.MatMul().shard(strategy1) 154 self.floormod = P.FloorMod().shard(strategy2) 155 156 def construct(self, x, y, b): 157 out = self.matmul(x, y) 158 out = self.floormod(out, b) 159 return out 160 161 context.set_auto_parallel_context(device_num=8, global_rank=0) 162 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 163 strategy1 = ((2, 2), (2, 2)) 164 strategy2 = ((4, 2), (4, 2)) 165 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 166 167 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 168 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 169 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 170 compile_net(net, x, y, b) 171 172 173def test_matmul_atan2(): 174 class Net(nn.Cell): 175 def __init__(self, strategy1, strategy2): 176 super().__init__() 177 self.matmul = P.MatMul().shard(strategy1) 178 self.atan2 = P.Atan2().shard(strategy2) 179 180 def construct(self, x, y, b): 181 out = self.matmul(x, y) 182 out = self.atan2(out, b) 183 return out 184 185 context.set_auto_parallel_context(device_num=8, global_rank=0) 186 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 187 strategy1 = ((2, 2), (2, 2)) 188 strategy2 = ((4, 2), (4, 2)) 189 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 190 191 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 192 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 193 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 194 compile_net(net, x, y, b) 195 196 197def test_matmul_divNoNan(): 198 class Net(nn.Cell): 199 def __init__(self, strategy1, strategy2): 200 super().__init__() 201 self.matmul = P.MatMul().shard(strategy1) 202 self.divNoNan = P.DivNoNan().shard(strategy2) 203 204 def construct(self, x, y, b): 205 out = self.matmul(x, y) 206 out = self.divNoNan(out, b) 207 return out 208 209 context.set_auto_parallel_context(device_num=8, global_rank=0) 210 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 211 strategy1 = ((2, 2), (2, 2)) 212 strategy2 = ((4, 2), (4, 2)) 213 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 214 215 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 216 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 217 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 218 compile_net(net, x, y, b) 219 220 221def test_matmul_logicaland(): 222 class Net(nn.Cell): 223 def __init__(self, strategy1, strategy2): 224 super().__init__() 225 self.matmul = P.MatMul().shard(strategy1) 226 self.equal = P.Equal().shard(strategy2) 227 self.notequal = P.NotEqual().shard(strategy2) 228 self.logical = P.LogicalAnd().shard(strategy2) 229 230 def construct(self, x, y, b): 231 out = self.matmul(x, y) 232 out1 = self.equal(out, b) 233 out = self.matmul(x, y) 234 out2 = self.notequal(out, b) 235 out = self.logical(out1, out2) 236 return out 237 238 context.set_auto_parallel_context(device_num=8, global_rank=0) 239 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 240 strategy1 = ((2, 2), (2, 2)) 241 strategy2 = ((4, 2), (4, 2)) 242 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 243 244 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 245 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 246 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 247 compile_net(net, x, y, b) 248 249 250def test_matmul_logicalor(): 251 class Net(nn.Cell): 252 def __init__(self, strategy1, strategy2): 253 super().__init__() 254 self.matmul = P.MatMul().shard(strategy1) 255 self.equal = P.Equal().shard(strategy2) 256 self.notequal = P.NotEqual().shard(strategy2) 257 self.logical = P.LogicalOr().shard(strategy2) 258 259 def construct(self, x, y, b): 260 out = self.matmul(x, y) 261 out1 = self.equal(out, b) 262 out = self.matmul(x, y) 263 out2 = self.notequal(out, b) 264 out = self.logical(out1, out2) 265 return out 266 267 context.set_auto_parallel_context(device_num=8, global_rank=0) 268 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 269 strategy1 = ((2, 2), (2, 2)) 270 strategy2 = ((4, 2), (4, 2)) 271 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 272 273 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 274 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 275 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 276 compile_net(net, x, y, b) 277 278 279def test_matmul_div(): 280 class Net(nn.Cell): 281 def __init__(self, strategy1, strategy2): 282 super().__init__() 283 self.matmul = P.MatMul().shard(strategy1) 284 self.div = P.Div().shard(strategy2) 285 286 def construct(self, x, y, b): 287 out = self.matmul(x, y) 288 out = self.div(out, b) 289 return out 290 291 context.set_auto_parallel_context(device_num=8, global_rank=0) 292 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 293 strategy1 = ((2, 2), (2, 2)) 294 strategy2 = ((4, 2), (4, 2)) 295 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 296 297 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 298 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 299 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 300 compile_net(net, x, y, b) 301 302 303def test_matmul_add_broadcast(): 304 class Net(nn.Cell): 305 def __init__(self, strategy1, strategy2): 306 super().__init__() 307 self.matmul = P.MatMul().shard(strategy1) 308 self.add = P.Add().shard(strategy2) 309 310 def construct(self, x, y, b): 311 out = self.matmul(x, y) 312 out = self.add(out, b) 313 return out 314 315 context.set_auto_parallel_context(device_num=8, global_rank=0) 316 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 317 strategy1 = ((2, 2), (2, 2)) 318 strategy2 = ((4, 2), (2,)) 319 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 320 321 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 322 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 323 b = Tensor(np.ones([64]), dtype=ms.float32) 324 compile_net(net, x, y, b) 325 326 327def test_matmul_add_broadcast2(): 328 class Net(nn.Cell): 329 def __init__(self, strategy1, strategy2): 330 super().__init__() 331 self.matmul = P.MatMul().shard(strategy1) 332 self.add = P.Add().shard(strategy2) 333 334 def construct(self, x, y, b): 335 out = self.matmul(x, y) 336 out = self.add(out, b) 337 return out 338 339 context.set_auto_parallel_context(device_num=8, global_rank=0) 340 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 341 strategy1 = ((2, 4), (4, 1)) 342 strategy2 = ((4, 1), (1, 2)) 343 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 344 345 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 346 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 347 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 348 compile_net(net, x, y, b) 349 350 351def test_matmul_sub_broadcast(): 352 class Net(nn.Cell): 353 def __init__(self, strategy1, strategy2): 354 super().__init__() 355 self.matmul = P.MatMul().shard(strategy1) 356 self.sub = P.Sub().shard(strategy2) 357 358 def construct(self, x, y, b): 359 out = self.matmul(x, y) 360 out = self.sub(out, b) 361 return out 362 363 context.set_auto_parallel_context(device_num=8, global_rank=0) 364 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 365 strategy1 = ((2, 2), (2, 2)) 366 strategy2 = ((4, 2), (2,)) 367 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 368 369 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 370 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 371 b = Tensor(np.ones([64]), dtype=ms.float32) 372 compile_net(net, x, y, b) 373 374 375def test_matmul_sub_broadcast2(): 376 class Net(nn.Cell): 377 def __init__(self, strategy1, strategy2): 378 super().__init__() 379 self.matmul = P.MatMul().shard(strategy1) 380 self.sub = P.Sub().shard(strategy2) 381 382 def construct(self, x, y, b): 383 out = self.matmul(x, y) 384 out = self.sub(out, b) 385 return out 386 387 context.set_auto_parallel_context(device_num=8, global_rank=0) 388 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 389 strategy1 = ((2, 4), (4, 1)) 390 strategy2 = ((4, 1), (1, 2)) 391 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 392 393 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 394 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 395 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 396 compile_net(net, x, y, b) 397 398 399def test_matmul_mul_broadcast(): 400 class Net(nn.Cell): 401 def __init__(self, strategy1, strategy2): 402 super().__init__() 403 self.matmul = P.MatMul().shard(strategy1) 404 self.mul = P.Mul().shard(strategy2) 405 406 def construct(self, x, y, b): 407 out = self.matmul(x, y) 408 out = self.mul(out, b) 409 return out 410 411 context.set_auto_parallel_context(device_num=8, global_rank=0) 412 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 413 strategy1 = ((2, 2), (2, 2)) 414 strategy2 = ((4, 2), (2,)) 415 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 416 417 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 418 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 419 b = Tensor(np.ones([64]), dtype=ms.float32) 420 compile_net(net, x, y, b) 421 422 423def test_matmul_mul_broadcast2(): 424 class Net(nn.Cell): 425 def __init__(self, strategy1, strategy2): 426 super().__init__() 427 self.matmul = P.MatMul().shard(strategy1) 428 self.mul = P.Mul().shard(strategy2) 429 430 def construct(self, x, y, b): 431 out = self.matmul(x, y) 432 out = self.mul(out, b) 433 return out 434 435 context.set_auto_parallel_context(device_num=8, global_rank=0) 436 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 437 strategy1 = ((2, 4), (4, 1)) 438 strategy2 = ((4, 1), (1, 2)) 439 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 440 441 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 442 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 443 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 444 compile_net(net, x, y, b) 445 446 447def test_matmul_div_broadcast(): 448 class Net(nn.Cell): 449 def __init__(self, strategy1, strategy2): 450 super().__init__() 451 self.matmul = P.MatMul().shard(strategy1) 452 self.div = P.Div().shard(strategy2) 453 454 def construct(self, x, y, b): 455 out = self.matmul(x, y) 456 out = self.div(out, b) 457 return out 458 459 context.set_auto_parallel_context(device_num=8, global_rank=0) 460 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 461 strategy1 = ((2, 2), (2, 2)) 462 strategy2 = ((4, 2), (2,)) 463 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 464 465 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 466 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 467 b = Tensor(np.ones([64]), dtype=ms.float32) 468 compile_net(net, x, y, b) 469 470 471def test_matmul_div_broadcast2(): 472 class Net(nn.Cell): 473 def __init__(self, strategy1, strategy2): 474 super().__init__() 475 self.matmul = P.MatMul().shard(strategy1) 476 self.div = P.Div().shard(strategy2) 477 478 def construct(self, x, y, b): 479 out = self.matmul(x, y) 480 out = self.div(out, b) 481 return out 482 483 context.set_auto_parallel_context(device_num=8, global_rank=0) 484 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 485 strategy1 = ((2, 4), (4, 1)) 486 strategy2 = ((4, 1), (1, 2)) 487 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 488 489 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 490 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 491 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 492 compile_net(net, x, y, b) 493 494 495def test_matmul_greater_broadcast(): 496 class Net(nn.Cell): 497 def __init__(self, strategy1, strategy2): 498 super().__init__() 499 self.matmul = P.MatMul().shard(strategy1) 500 self.greater = P.Greater().shard(strategy2) 501 502 def construct(self, x, y, b): 503 out = self.matmul(x, y) 504 out = self.greater(out, b) 505 return out 506 507 context.set_auto_parallel_context(device_num=8, global_rank=0) 508 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 509 strategy1 = ((2, 2), (2, 2)) 510 strategy2 = ((4, 2), (2,)) 511 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 512 513 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 514 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 515 b = Tensor(np.ones([64]), dtype=ms.float32) 516 compile_net(net, x, y, b) 517 518 519def test_matmul_greater_broadcast2(): 520 class Net(nn.Cell): 521 def __init__(self, strategy1, strategy2): 522 super().__init__() 523 self.matmul = P.MatMul().shard(strategy1) 524 self.greater = P.Greater().shard(strategy2) 525 526 def construct(self, x, y, b): 527 out = self.matmul(x, y) 528 out = self.greater(out, b) 529 return out 530 531 context.set_auto_parallel_context(device_num=8, global_rank=0) 532 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 533 strategy1 = ((2, 4), (4, 1)) 534 strategy2 = ((4, 1), (1, 2)) 535 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 536 537 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 538 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 539 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 540 compile_net(net, x, y, b) 541 542 543def test_matmul_floordiv(): 544 class Net(nn.Cell): 545 def __init__(self, strategy1, strategy2): 546 super().__init__() 547 self.matmul = P.MatMul().shard(strategy1) 548 self.floordiv = P.FloorDiv().shard(strategy2) 549 550 def construct(self, x, y, b): 551 out = self.matmul(x, y) 552 out = self.floordiv(out, b) 553 return out 554 555 context.set_auto_parallel_context(device_num=8, global_rank=0) 556 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 557 strategy1 = ((2, 2), (2, 2)) 558 strategy2 = ((4, 2), (4, 2)) 559 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 560 561 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 562 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 563 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 564 compile_net(net, x, y, b) 565 566 567def test_matmul_floordiv_broadcast(): 568 class Net(nn.Cell): 569 def __init__(self, strategy1, strategy2): 570 super().__init__() 571 self.matmul = P.MatMul().shard(strategy1) 572 self.floordiv = P.FloorDiv().shard(strategy2) 573 574 def construct(self, x, y, b): 575 out = self.matmul(x, y) 576 out = self.floordiv(out, b) 577 return out 578 579 context.set_auto_parallel_context(device_num=8, global_rank=0) 580 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 581 strategy1 = ((2, 2), (2, 2)) 582 strategy2 = ((4, 2), (2,)) 583 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 584 585 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 586 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 587 b = Tensor(np.ones([64]), dtype=ms.float32) 588 compile_net(net, x, y, b) 589 590 591def test_matmul_floordiv_broadcast2(): 592 class Net(nn.Cell): 593 def __init__(self, strategy1, strategy2): 594 super().__init__() 595 self.matmul = P.MatMul().shard(strategy1) 596 self.floordiv = P.FloorDiv().shard(strategy2) 597 598 def construct(self, x, y, b): 599 out = self.matmul(x, y) 600 out = self.floordiv(out, b) 601 return out 602 603 context.set_auto_parallel_context(device_num=8, global_rank=0) 604 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 605 strategy1 = ((2, 4), (4, 1)) 606 strategy2 = ((4, 1), (1, 2)) 607 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 608 609 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 610 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 611 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 612 compile_net(net, x, y, b) 613 614 615def test_assign_sub(): 616 class Net(nn.Cell): 617 def __init__(self): 618 super().__init__() 619 self.assign_sub = P.AssignSub() 620 self.mul = P.Mul() 621 self.mul_weight = Parameter(Tensor(np.full([128, 32], 622 0.5, dtype=np.float32)), 623 name="mul_weight") 624 self.assignsub_weight = Parameter(Tensor(np.full([128, 32], 625 1.1, dtype=np.float32)), 626 name="assignsub_weight") 627 628 def construct(self, x): 629 out = self.mul(x, self.mul_weight) 630 out = self.assign_sub(self.assignsub_weight, out) 631 return out 632 633 class SubNetWithLoss(nn.Cell): 634 def __init__(self, network): 635 super(SubNetWithLoss, self).__init__() 636 self.loss = VirtualLoss() 637 self.network = network 638 639 def construct(self, x): 640 predict = self.network(x,) 641 return self.loss(predict) 642 643 class SubGradWrap(nn.Cell): 644 def __init__(self, network): 645 super(SubGradWrap, self).__init__() 646 self.network = network 647 648 def construct(self, x): 649 return grad_all(self.network)(x) 650 651 def compile_sub_net(net, x): 652 net.set_auto_parallel() 653 net.set_train() 654 _cell_graph_executor.compile(net, x) 655 656 context.set_auto_parallel_context(device_num=64, global_rank=15) 657 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 658 net = SubGradWrap(SubNetWithLoss(Net())) 659 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 660 compile_sub_net(net, x) 661 662 663def test_assign_add(): 664 class Net(nn.Cell): 665 def __init__(self): 666 super().__init__() 667 self.assign_sub = P.AssignAdd() 668 self.mul = P.Mul() 669 self.mul_weight = Parameter(Tensor(np.full([128, 32], 670 0.5, dtype=np.float32)), 671 name="mul_weight") 672 self.assignsub_weight = Parameter(Tensor(np.full([128, 32], 673 1.1, dtype=np.float32)), 674 name="assignsub_weight") 675 676 def construct(self, x): 677 out = self.mul(x, self.mul_weight) 678 out = self.assign_sub(self.assignsub_weight, out) 679 return out 680 681 class SubNetWithLoss(nn.Cell): 682 def __init__(self, network): 683 super(SubNetWithLoss, self).__init__() 684 self.loss = VirtualLoss() 685 self.network = network 686 687 def construct(self, x): 688 predict = self.network(x,) 689 return self.loss(predict) 690 691 class SubGradWrap(nn.Cell): 692 def __init__(self, network): 693 super(SubGradWrap, self).__init__() 694 self.network = network 695 696 def construct(self, x): 697 return grad_all(self.network)(x) 698 699 def compile_sub_net(net, x): 700 net.set_auto_parallel() 701 net.set_train() 702 _cell_graph_executor.compile(net, x) 703 704 context.set_auto_parallel_context(device_num=64, global_rank=15) 705 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 706 net = SubGradWrap(SubNetWithLoss(Net())) 707 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 708 compile_sub_net(net, x) 709 710 711def test_assign(): 712 class Net(nn.Cell): 713 def __init__(self): 714 super().__init__() 715 self.assign_sub = P.Assign() 716 self.mul = P.Mul() 717 self.mul_weight = Parameter(Tensor(np.full([128, 32], 718 0.5, dtype=np.float32)), 719 name="mul_weight") 720 self.assignsub_weight = Parameter(Tensor(np.full([128, 32], 721 1.1, dtype=np.float32)), 722 name="assignsub_weight") 723 724 def construct(self, x): 725 out = self.mul(x, self.mul_weight) 726 out = self.assign_sub(self.assignsub_weight, out) 727 return out 728 729 class SubNetWithLoss(nn.Cell): 730 def __init__(self, network): 731 super(SubNetWithLoss, self).__init__() 732 self.loss = VirtualLoss() 733 self.network = network 734 735 def construct(self, x): 736 predict = self.network(x,) 737 return self.loss(predict) 738 739 class SubGradWrap(nn.Cell): 740 def __init__(self, network): 741 super(SubGradWrap, self).__init__() 742 self.network = network 743 744 def construct(self, x): 745 return grad_all(self.network)(x) 746 747 def compile_sub_net(net, x): 748 net.set_auto_parallel() 749 net.set_train() 750 _cell_graph_executor.compile(net, x) 751 752 context.set_auto_parallel_context(device_num=64, global_rank=15) 753 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 754 net = SubGradWrap(SubNetWithLoss(Net())) 755 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 756 compile_sub_net(net, x) 757