1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import math 16import numpy as np 17 18import mindspore as ms 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.common.api import _cell_graph_executor 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from tests.ut.python.ops.test_math_ops import VirtualLoss 26 27 28grad_all = C.GradOperation(get_all=True) 29 30 31class NetWithLoss(nn.Cell): 32 def __init__(self, network): 33 super(NetWithLoss, self).__init__() 34 self.loss = VirtualLoss() 35 self.network = network 36 37 def construct(self, x, y, b): 38 predict = self.network(x, y, b) 39 return self.loss(predict) 40 41 42class GradWrap(nn.Cell): 43 def __init__(self, network): 44 super(GradWrap, self).__init__() 45 self.network = network 46 47 def construct(self, x, y, b): 48 return grad_all(self.network)(x, y, b) 49 50 51def loop_config(size): 52 config_list = [] 53 num = 1 54 split_list = [num] 55 for _ in range(int(math.log2(size))): 56 num = num * 2 57 split_list.append(num) 58 59 for a in split_list: 60 for b in split_list: 61 if a * b > size: 62 continue 63 c = int(size / (a * b)) 64 config_list.append(((a, b), (b, c))) 65 66 return config_list 67 68 69# model_parallel test 70def test_two_matmul(): 71 class Net(nn.Cell): 72 def __init__(self, strategy1, strategy2): 73 super().__init__() 74 self.matmul1 = P.MatMul().shard(strategy1) 75 self.matmul2 = P.MatMul().shard(strategy2) 76 77 def construct(self, x, y, b): 78 out = self.matmul1(x, y) 79 out = self.matmul2(out, b) 80 return out 81 82 size = 4 83 context.set_auto_parallel_context(device_num=size, global_rank=0) 84 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 85 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 86 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 87 88 config_list = loop_config(size) 89 90 count = 0 91 for strategy1 in config_list: 92 for strategy2 in config_list: 93 print("=======current config {}=========".format(count)) 94 print(strategy1, strategy2) 95 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 96 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 97 net.set_auto_parallel() 98 net.set_train() 99 _cell_graph_executor.compile(net, x, y, b) 100 count = count + 1 101