1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.common.dtype as mstype 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.common.api import _cell_graph_executor 23from mindspore.context import set_auto_parallel_context 24from mindspore.ops import composite as C 25from mindspore.ops import operations as P 26from tests.ut.python.ops.test_math_ops import VirtualLoss 27 28 29grad_all = C.GradOperation(get_all=True) 30 31 32class NetWithLoss(nn.Cell): 33 def __init__(self, network): 34 super(NetWithLoss, self).__init__() 35 self.loss = VirtualLoss() 36 self.network = network 37 38 def construct(self, x, y): 39 predict = self.network(x, y) 40 return self.loss(predict) 41 42 43class GradWrap(nn.Cell): 44 def __init__(self, network): 45 super(GradWrap, self).__init__() 46 self.network = network 47 48 def construct(self, x, y): 49 return grad_all(self.network)(x, y) 50 51 52def compile_net(net, x, y): 53 net.set_auto_parallel() 54 net.set_train() 55 _cell_graph_executor.compile(net, x, y) 56 57 58# model_parallel test 59def test_two_matmul(): 60 class Net(nn.Cell): 61 def __init__(self, strategy1, strategy2, strategy3): 62 super().__init__() 63 self.matmul1 = P.MatMul().shard(strategy1) 64 self.matmul2 = P.MatMul().shard(strategy2) 65 self.matmul3 = P.MatMul().shard(strategy3) 66 self.diag = P.Diag() 67 self.fill = P.Fill() 68 69 def construct(self, x, y): 70 fill = self.diag(self.fill(mstype.float32, (128,), 1.0)) 71 out1 = self.matmul1(fill, x) 72 out2 = self.matmul2(y, fill) 73 out = self.matmul3(out1, out2) 74 return out 75 76 set_auto_parallel_context(device_num=8, global_rank=0) 77 strategy1 = ((2, 2), (2, 2)) 78 strategy2 = ((1, 8), (8, 1)) 79 strategy3 = ((2, 4), (4, 1)) 80 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 81 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 82 83 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 84 y = Tensor(np.ones([32, 128]), dtype=ms.float32) 85 86 compile_net(net, x, y) 87 88 89def test_matmul_mul_broadcast2(): 90 class Net(nn.Cell): 91 def __init__(self, strategy1, strategy2): 92 super().__init__() 93 self.matmul = P.MatMul().shard(strategy1) 94 self.mul = P.Mul().shard(strategy2) 95 self.t = Tensor(0.9, ms.float32) 96 97 def construct(self, x, y): 98 out = self.matmul(x, y) 99 out = self.mul(out, self.t) 100 return out 101 102 context.set_auto_parallel_context(device_num=8, global_rank=0) 103 strategy1 = ((2, 4), (4, 1)) 104 strategy2 = ((4, 1), ()) 105 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 106 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 107 108 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 109 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 110 compile_net(net, x, y) 111 112 113def test_two_matmul1(): 114 class Net(nn.Cell): 115 def __init__(self, strategy1, strategy2, strategy3): 116 super().__init__() 117 self.matmul1 = P.MatMul().shard(strategy1) 118 self.matmul2 = P.MatMul().shard(strategy2) 119 self.matmul3 = P.MatMul().shard(strategy3) 120 self.diag = P.Diag() 121 self.fill = P.Fill() 122 123 def construct(self, x, y): 124 fill = self.diag(self.fill(mstype.float32, (128,), 1.0)) 125 out1 = self.matmul1(fill, x) 126 out2 = self.matmul2(fill, y) 127 out = self.matmul3(out1, out2) 128 return out 129 130 set_auto_parallel_context(device_num=8, global_rank=0) 131 strategy1 = ((2, 2), (2, 2)) 132 strategy2 = ((1, 8), (8, 1)) 133 strategy3 = ((2, 4), (4, 1)) 134 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 135 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 136 137 x = Tensor(np.ones([128, 128]), dtype=ms.float32) 138 y = Tensor(np.ones([128, 128]), dtype=ms.float32) 139 140 compile_net(net, x, y) 141 142 143def test_matmul_add_tensor(): 144 class Net(nn.Cell): 145 def __init__(self, strategy1, strategy2): 146 super().__init__() 147 self.matmul = P.MatMul().shard(strategy1) 148 self.add = P.Add().shard(strategy2) 149 self.b = Tensor(0.9, ms.float32) 150 151 def construct(self, x, y): 152 out = self.matmul(x, y) 153 out = self.add(out, self.b) 154 return out 155 156 context.set_auto_parallel_context(device_num=8, global_rank=0) 157 strategy1 = ((2, 2), (2, 2)) 158 strategy2 = ((4, 2), ()) 159 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 160 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 161 162 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 163 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 164 165 compile_net(net, x, y) 166