1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.ops import composite as C 23from mindspore.ops import operations as P 24from tests.ut.python.ops.test_math_ops import VirtualLoss 25 26 27grad_all = C.GradOperation(get_all=True) 28 29 30class NetWithLossNoBias(nn.Cell): 31 def __init__(self, network): 32 super(NetWithLossNoBias, self).__init__() 33 self.loss = VirtualLoss() 34 self.network = network 35 36 def construct(self, x, y): 37 predict = self.network(x, y) 38 return self.loss(predict) 39 40 41class NetWithLoss(nn.Cell): 42 def __init__(self, network): 43 super(NetWithLoss, self).__init__() 44 self.loss = VirtualLoss() 45 self.network = network 46 47 def construct(self, x, y, b): 48 predict = self.network(x, y, b) 49 return self.loss(predict) 50 51 52class GradWrapNoBias(nn.Cell): 53 def __init__(self, network): 54 super(GradWrapNoBias, self).__init__() 55 self.network = network 56 57 def construct(self, x, y): 58 return grad_all(self.network)(x, y) 59 60 61class GradWrap(nn.Cell): 62 def __init__(self, network): 63 super(GradWrap, self).__init__() 64 self.network = network 65 66 def construct(self, x, y, b): 67 return grad_all(self.network)(x, y, b) 68 69 70def compile_net_no_bias(net, x, y): 71 net.set_auto_parallel() 72 net.set_train() 73 _cell_graph_executor.compile(net, x, y) 74 75 76def compile_net(net, x, y, b): 77 net.set_auto_parallel() 78 net.set_train() 79 _cell_graph_executor.compile(net, x, y, b) 80 81 82# model_parallel test 83def test_sum_mul(): 84 class Net(nn.Cell): 85 def __init__(self, strategy1, strategy2, strategy3): 86 super().__init__() 87 self.mul1 = P.Mul().shard(strategy1) 88 self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2) 89 self.mul2 = P.Mul().shard(strategy3) 90 91 def construct(self, x, y, b): 92 out = self.mul1(x, y) 93 out = self.reduce_sum(out, (1,)) 94 out = self.mul2(out, b) 95 return out 96 97 context.set_auto_parallel_context(device_num=8, global_rank=0) 98 strategy1 = ((1, 1, 8), (1, 1, 8)) 99 strategy2 = ((4, 1, 2),) 100 strategy3 = ((2, 4), (2, 4)) 101 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 102 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 103 104 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 105 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 106 b = Tensor(np.ones([128, 64]), dtype=ms.float32) 107 compile_net(net, x, y, b) 108 109 110def test_sum_mul2(): 111 class Net(nn.Cell): 112 def __init__(self, strategy1, strategy2, strategy3): 113 super().__init__() 114 self.mul1 = P.Mul().shard(strategy1) 115 self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2) 116 self.mul2 = P.Mul().shard(strategy3) 117 118 def construct(self, x, y, b): 119 out = self.mul1(x, y) 120 out = self.reduce_sum(out, (0, 1)) 121 out = self.mul2(out, b) 122 return out 123 124 context.set_auto_parallel_context(device_num=8, global_rank=0) 125 strategy1 = ((1, 1, 4, 2), (1, 1, 4, 2)) 126 strategy2 = ((2, 4, 1, 1),) 127 strategy3 = ((2, 4), (2, 4)) 128 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 129 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 130 131 x = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32) 132 y = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32) 133 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 134 compile_net(net, x, y, b) 135 136 137def test_sum_mul3(): 138 class Net(nn.Cell): 139 def __init__(self, strategy1, strategy2, strategy3): 140 super().__init__() 141 self.mul1 = P.Mul().shard(strategy1) 142 self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2) 143 self.mul2 = P.Mul().shard(strategy3) 144 145 def construct(self, x, y, b): 146 out = self.mul1(x, y) 147 out = self.reduce_sum(out, -1) 148 out = self.mul2(out, b) 149 return out 150 151 context.set_auto_parallel_context(device_num=8, global_rank=0) 152 strategy1 = ((1, 4, 2), (1, 4, 2)) 153 strategy2 = ((4, 2, 1),) 154 strategy3 = ((2, 4), (2, 4)) 155 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 156 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 157 158 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 159 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 160 b = Tensor(np.ones([128, 32]), dtype=ms.float32) 161 compile_net(net, x, y, b) 162 163 164def test_sum_mul4(): 165 class Net(nn.Cell): 166 def __init__(self, strategy1, strategy2, strategy3): 167 super().__init__() 168 self.mul1 = P.Mul().shard(strategy1) 169 self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2) 170 self.mul2 = P.Mul().shard(strategy3) 171 172 def construct(self, x, y, b): 173 out = self.mul1(x, y) 174 out = self.reduce_sum(out, -1) 175 out = self.mul2(out, b) 176 return out 177 178 context.set_auto_parallel_context(device_num=8, global_rank=0) 179 strategy1 = ((1, 4, 2), (1, 4, 2)) 180 strategy2 = ((2, 2, 2),) 181 strategy3 = ((4, 2, 1), (4, 2, 1)) 182 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 183 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 184 185 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 186 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 187 b = Tensor(np.ones([128, 32, 1]), dtype=ms.float32) 188 compile_net(net, x, y, b) 189 190 191def test_sum_mul5(): 192 class Net(nn.Cell): 193 def __init__(self, strategy1, strategy2): 194 super().__init__() 195 self.mul1 = P.Mul().shard(strategy1) 196 self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2) 197 198 def construct(self, x, y): 199 out = self.mul1(x, y) 200 out = self.reduce_sum(out, 0) 201 return out 202 203 context.set_auto_parallel_context(device_num=64, global_rank=0) 204 strategy1 = ((1, 8, 8), (1, 8, 8)) 205 strategy2 = ((2, 4, 1),) 206 net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2))) 207 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 208 209 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 210 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 211 compile_net_no_bias(net, x, y) 212 213 214def test_sum_mul6(): 215 class Net(nn.Cell): 216 def __init__(self, strategy1, strategy2): 217 super().__init__() 218 self.mul1 = P.Mul().shard(strategy1) 219 self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2) 220 221 def construct(self, x, y): 222 out = self.mul1(x, y) 223 out = self.reduce_sum(out, 1) 224 return out 225 226 context.set_auto_parallel_context(device_num=64, global_rank=0) 227 strategy1 = ((1, 8, 8), (1, 8, 8)) 228 strategy2 = ((2, 1, 4),) 229 net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2))) 230 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 231 232 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 233 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 234 compile_net_no_bias(net, x, y) 235 236 237def test_sum_mul7(): 238 class Net(nn.Cell): 239 def __init__(self, strategy1, strategy2): 240 super().__init__() 241 self.mul1 = P.Mul().shard(strategy1) 242 self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2) 243 244 def construct(self, x, y): 245 out = self.mul1(x, y) 246 out = self.reduce_sum(out, (0, 1)) 247 return out 248 249 context.set_auto_parallel_context(device_num=64, global_rank=0) 250 strategy1 = ((1, 8, 8), (1, 8, 8)) 251 strategy2 = ((2, 4, 1),) 252 net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2))) 253 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 254 255 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 256 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 257 compile_net_no_bias(net, x, y) 258 259 260def test_max_mul(): 261 class Net(nn.Cell): 262 def __init__(self, strategy1, strategy2, strategy3): 263 super().__init__() 264 self.mul1 = P.Mul().shard(strategy1) 265 self.reduce_max = P.ReduceMax(keep_dims=False).shard(strategy2) 266 self.mul2 = P.Mul().shard(strategy3) 267 268 def construct(self, x, y, b): 269 out = self.mul1(x, y) 270 out = self.reduce_max(out, -1) 271 out = self.mul2(out, b) 272 return out 273 274 context.set_auto_parallel_context(device_num=8, global_rank=0) 275 strategy1 = ((1, 4, 2), (1, 4, 2)) 276 strategy2 = ((4, 1, 2),) 277 strategy3 = ((2, 4), (2, 4)) 278 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 279 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 280 281 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 282 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 283 b = Tensor(np.ones([128, 32]), dtype=ms.float32) 284 compile_net(net, x, y, b) 285 286 287def test_min_mul(): 288 class Net(nn.Cell): 289 def __init__(self, strategy1, strategy2, strategy3): 290 super().__init__() 291 self.mul1 = P.Mul().shard(strategy1) 292 self.reduce_min = P.ReduceMin(keep_dims=False).shard(strategy2) 293 self.mul2 = P.Mul().shard(strategy3) 294 295 def construct(self, x, y, b): 296 out = self.mul1(x, y) 297 out = self.reduce_min(out, 0) 298 out = self.mul2(out, b) 299 return out 300 301 context.set_auto_parallel_context(device_num=8, global_rank=0) 302 strategy1 = ((1, 4, 2), (1, 4, 2)) 303 strategy2 = ((4, 1, 2),) 304 strategy3 = ((2, 4), (2, 4)) 305 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 306 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 307 308 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 309 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 310 b = Tensor(np.ones([32, 64]), dtype=ms.float32) 311 compile_net(net, x, y, b) 312 313 314def test_reduce_mean_mul_float32(): 315 class Net(nn.Cell): 316 def __init__(self, strategy1, strategy2, strategy3): 317 super().__init__() 318 self.mul1 = P.Mul().shard(strategy1) 319 self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy2) 320 self.mul2 = P.Mul().shard(strategy3) 321 322 def construct(self, x, y, b): 323 out = self.mul1(x, y) 324 out = self.reduce_mean(out, 0) 325 out = self.mul2(out, b) 326 return out 327 328 context.set_auto_parallel_context(device_num=8, global_rank=0) 329 strategy1 = ((1, 4, 2), (1, 4, 2)) 330 strategy2 = ((4, 1, 2),) 331 strategy3 = ((2, 4), (2, 4)) 332 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 333 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 334 335 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 336 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 337 b = Tensor(np.ones([32, 64]), dtype=ms.float32) 338 339 compile_net(net, x, y, b) 340 341 342class ArgMaxWithValueNet(nn.Cell): 343 def __init__(self, strategy1, strategy2, strategy3): 344 super().__init__() 345 self.mul1 = P.Mul().shard(strategy1) 346 self.arg_max_with_value = P.ArgMaxWithValue(keep_dims=False, axis=-1).shard(strategy2) 347 self.mul2 = P.Mul().shard(strategy3) 348 349 def construct(self, x, y, b): 350 out = self.mul1(x, y) 351 _, out = self.arg_max_with_value(out) 352 out = self.mul2(out, b) 353 return out 354 355 356class ArgMinWithValueNet(nn.Cell): 357 def __init__(self, strategy1, strategy2, strategy3): 358 super().__init__() 359 self.mul1 = P.Mul().shard(strategy1) 360 self.arg_min_with_value = P.ArgMinWithValue(keep_dims=False, axis=-1).shard(strategy2) 361 self.mul2 = P.Mul().shard(strategy3) 362 363 def construct(self, x, y, b): 364 out = self.mul1(x, y) 365 _, out = self.arg_min_with_value(out) 366 out = self.mul2(out, b) 367 return out 368 369 370def gen_inputs_and_compile_net(net): 371 x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32) 372 y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32) 373 b = Tensor(np.ones([128, 64]), dtype=ms.float32) 374 compile_net(net, x, y, b) 375 376 377def gen_inputs_and_compile_net_no_bias(net): 378 x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32) 379 y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32) 380 compile_net_no_bias(net, x, y) 381 382 383def tobefixed_test_arg_max_with_value_mul_semi_axis_parallel(): 384 context.set_auto_parallel_context(device_num=8, global_rank=0) 385 strategy1 = ((1, 4, 2), (1, 4, 2)) 386 strategy2 = ((4, 1, 2),) 387 strategy3 = ((2, 4), (2, 4)) 388 net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3))) 389 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 390 gen_inputs_and_compile_net(net) 391 392 393def test_arg_max_with_value_mul_semi(): 394 context.set_auto_parallel_context(device_num=8, global_rank=0) 395 strategy1 = ((1, 4, 2), (1, 4, 2)) 396 strategy2 = ((4, 1, 1),) 397 strategy3 = ((2, 4), (2, 4)) 398 net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3))) 399 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 400 gen_inputs_and_compile_net(net) 401 402 403def test_arg_max_with_value_mul_auto(): 404 context.set_auto_parallel_context(device_num=8, global_rank=0) 405 strategy1 = None 406 strategy2 = None 407 strategy3 = None 408 net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3))) 409 context.set_auto_parallel_context(parallel_mode="auto_parallel") 410 gen_inputs_and_compile_net(net) 411 412 413def test_arg_min_with_value_mul_semi_axis_parallel(): 414 context.set_auto_parallel_context(device_num=8, global_rank=0) 415 strategy1 = ((1, 4, 2), (1, 4, 2)) 416 strategy2 = ((4, 1, 2),) 417 strategy3 = ((2, 4), (2, 4)) 418 net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3))) 419 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 420 gen_inputs_and_compile_net(net) 421 422 423def test_arg_min_with_value_mul_semi(): 424 context.set_auto_parallel_context(device_num=8, global_rank=0) 425 strategy1 = ((1, 4, 2), (1, 4, 2)) 426 strategy2 = ((4, 1, 1),) 427 strategy3 = ((2, 4), (2, 4)) 428 net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3))) 429 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 430 gen_inputs_and_compile_net(net) 431 432 433def test_arg_min_with_value_mul_auto(): 434 context.set_auto_parallel_context(device_num=8, global_rank=0) 435 strategy1 = None 436 strategy2 = None 437 strategy3 = None 438 net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3))) 439 context.set_auto_parallel_context(parallel_mode="auto_parallel") 440 gen_inputs_and_compile_net(net) 441 442 443class ArgMinWithValueNet2(nn.Cell): 444 def __init__(self, strategy1, strategy2, strategy3): 445 super().__init__() 446 self.mul1 = P.Mul().shard(strategy1) 447 self.arg_min_with_value = P.ArgMinWithValue(keep_dims=True, axis=-1).shard(strategy2) 448 self.relu = P.ReLU().shard(strategy3) 449 450 def construct(self, x, y): 451 out = self.mul1(x, y) 452 _, out = self.arg_min_with_value(out) 453 out = self.relu(out) 454 return out 455 456 457def tobefixed_test_arg_min_with_value_mul_semi_axis_parallel2(): 458 context.set_auto_parallel_context(device_num=8, global_rank=0) 459 strategy1 = ((1, 4, 2), (1, 4, 2)) 460 strategy2 = ((4, 1, 2),) 461 strategy3 = ((2, 4, 1),) 462 net = GradWrapNoBias(NetWithLossNoBias(ArgMinWithValueNet2(strategy1, strategy2, strategy3))) 463 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 464 gen_inputs_and_compile_net_no_bias(net) 465 466 467def test_arg_min_with_value_mul_semi2(): 468 context.set_auto_parallel_context(device_num=8, global_rank=0) 469 strategy1 = ((1, 4, 2), (1, 4, 2)) 470 strategy2 = ((4, 1, 1),) 471 strategy3 = ((2, 4, 1),) 472 net = GradWrapNoBias(NetWithLossNoBias(ArgMinWithValueNet2(strategy1, strategy2, strategy3))) 473 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 474 gen_inputs_and_compile_net_no_bias(net) 475 476 477def test_arg_min_with_value_mul_auto2(): 478 context.set_auto_parallel_context(device_num=8, global_rank=0) 479 strategy1 = None 480 strategy2 = None 481 strategy3 = None 482 net = GradWrapNoBias(NetWithLossNoBias(ArgMinWithValueNet2(strategy1, strategy2, strategy3))) 483 context.set_auto_parallel_context(parallel_mode="auto_parallel") 484 gen_inputs_and_compile_net_no_bias(net) 485 486 487def test_cross_batch(): 488 class Net(nn.Cell): 489 def __init__(self, strategy1, strategy2, strategy3): 490 super().__init__() 491 self.mul1 = P.Mul().shard(strategy1) 492 self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2) 493 self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy3).add_prim_attr("cross_batch", True) 494 495 def construct(self, x, y): 496 out = self.mul1(x, y) 497 out = self.reduce_sum(out, -1) 498 out = self.reduce_mean(out, 0) 499 return out 500 501 context.set_auto_parallel_context(device_num=8, global_rank=0) 502 strategy1 = ((4, 2), (4, 2)) 503 strategy2 = ((2, 1),) 504 strategy3 = ((8,),) 505 net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2, strategy3))) 506 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 507 508 x = Tensor(np.ones([32, 64]), dtype=ms.float32) 509 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 510 compile_net_no_bias(net, x, y) 511 512 513def test_cross_batch2(): 514 class Net(nn.Cell): 515 def __init__(self, strategy1, strategy2, strategy3): 516 super().__init__() 517 self.mul1 = P.Mul().shard(strategy1) 518 self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy2) 519 self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy3).add_prim_attr("cross_batch", True) 520 521 def construct(self, x, y): 522 out = self.mul1(x, y) 523 out = self.reduce_mean(out, -1) 524 out = self.reduce_sum(out, 0) 525 return out 526 527 context.set_auto_parallel_context(device_num=8, global_rank=0) 528 strategy1 = ((4, 2), (4, 2)) 529 strategy2 = ((2, 1),) 530 strategy3 = ((8,),) 531 net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2, strategy3))) 532 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 533 534 x = Tensor(np.ones([32, 64]), dtype=ms.float32) 535 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 536 compile_net_no_bias(net, x, y) 537 538 539def test_cross_batch_auto(): 540 class Net(nn.Cell): 541 def __init__(self): 542 super().__init__() 543 self.mul1 = P.Mul() 544 self.reduce_mean = P.ReduceMean(keep_dims=False) 545 self.reduce_sum = P.ReduceSum(keep_dims=False).add_prim_attr("cross_batch", True) 546 547 def construct(self, x, y): 548 out = self.mul1(x, y) 549 out = self.reduce_mean(out, -1) 550 out = self.reduce_sum(out, 0) 551 return out 552 553 context.set_auto_parallel_context(device_num=8, global_rank=0) 554 net = GradWrapNoBias(NetWithLossNoBias(Net())) 555 context.set_auto_parallel_context(parallel_mode="auto_parallel") 556 557 x = Tensor(np.ones([32, 64]), dtype=ms.float32) 558 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 559 compile_net_no_bias(net, x, y) 560 561 562def test_max_empty_tuple(): 563 class Net(nn.Cell): 564 def __init__(self, strategy1, strategy2, strategy3): 565 super().__init__() 566 self.mul = P.Mul().shard(strategy1) 567 self.reduce_max = P.ReduceMax(keep_dims=False).shard(strategy2) 568 self.add = P.Add().shard(strategy3) 569 570 def construct(self, x, y, b): 571 out = self.mul(x, y) 572 out = self.reduce_max(out) 573 out = self.add(out, b) 574 return out 575 576 context.set_auto_parallel_context(device_num=8, global_rank=0) 577 strategy1 = ((1, 4, 2), (1, 4, 2)) 578 strategy2 = ((4, 1, 2),) 579 strategy3 = ((), (1, 1)) 580 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 581 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 582 583 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 584 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 585 b = Tensor(np.ones([128, 32]), dtype=ms.float32) 586 587 compile_net(net, x, y, b) 588 589 590def test_any_mul(): 591 class Net(nn.Cell): 592 def __init__(self, strategy1, strategy2): 593 super().__init__() 594 self.mul1 = P.Mul().shard(strategy1) 595 self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2) 596 self.cast = P.Cast() 597 598 def construct(self, x, y): 599 out = self.mul1(x, y) 600 out = self.cast(out, ms.bool_) 601 out = self.reduce_any(out, 1) 602 return out 603 604 context.set_auto_parallel_context(device_num=64, global_rank=0) 605 strategy1 = ((1, 8, 1), (1, 8, 1)) 606 strategy2 = ((1, 8, 1),) 607 net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2))) 608 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 609 610 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 611 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 612 with pytest.raises(RuntimeError): 613 compile_net_no_bias(net, x, y) 614 615 616def test_any_mul2(): 617 class Net(nn.Cell): 618 def __init__(self, strategy1, strategy2): 619 super().__init__() 620 self.mul1 = P.Mul().shard(strategy1) 621 self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2) 622 self.cast = P.Cast() 623 624 def construct(self, x, y): 625 out = self.mul1(x, y) 626 out = self.cast(out, ms.bool_) 627 out = self.reduce_any(out, -1) 628 return out 629 630 context.set_auto_parallel_context(device_num=64, global_rank=0) 631 strategy1 = ((8, 1, 1), (8, 1, 1)) 632 strategy2 = ((8, 1, 1),) 633 net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2))) 634 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 635 636 x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 637 y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) 638 compile_net_no_bias(net, x, y) 639