1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import re 17import numpy as np 18import mindspore as ms 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.common.api import _cell_graph_executor 23from mindspore.ops import operations as P 24from mindspore.common.parameter import Parameter 25 26context.set_context(mode=context.GRAPH_MODE) 27 28class DenseMutMulNet(nn.Cell): 29 def __init__(self): 30 super(DenseMutMulNet, self).__init__() 31 self.fc1 = nn.Dense(128, 768) 32 self.fc2 = nn.Dense(128, 768) 33 self.fc3 = nn.Dense(128, 768) 34 self.fc4 = nn.Dense(768, 768, has_bias=False) 35 self.relu4 = nn.ReLU() 36 self.relu5 = nn.ReLU() 37 self.transpose = P.Transpose() 38 self.matmul1 = P.MatMul() 39 self.matmul2 = P.MatMul() 40 self.fc4.matmul.shard(((1, 1), (8, 1))) 41 42 def construct(self, x): 43 q = self.fc1(x) 44 k = self.fc2(x) 45 v = self.fc3(x) 46 k = self.transpose(k, (1, 0)) 47 c = self.relu4(self.matmul1(q, k)) 48 s = self.relu5(self.matmul2(c, v)) 49 s = self.fc4(s) 50 return s 51 52class MulNegTwoOutputNet(nn.Cell): 53 def __init__(self): 54 super().__init__() 55 self.mul = P.Mul().shard(((2, 4), (2, 4))) 56 self.neg = P.Neg().shard(((2, 4),)) 57 self.mul_weight = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight") 58 59 def construct(self, x): 60 out1 = self.mul(x, self.mul_weight) 61 out2 = self.neg(out1) 62 return out1, out2 63 64class ReshapeMatMulNet(nn.Cell): 65 def __init__(self, strategy1, strategy2): 66 super().__init__() 67 self.reshape = P.Reshape() 68 self.matmul = P.MatMul().shard(strategy2) 69 self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") 70 # x (64, 4, 7) 71 def construct(self, x): 72 out = self.reshape(x, (64, 28)) 73 out = self.matmul(out, self.matmul_weight) 74 return out 75 76class MatMulReshapeNet(nn.Cell): 77 def __init__(self, strategy1, strategy2): 78 super().__init__() 79 self.reshape = P.Reshape() 80 self.matmul = P.MatMul().shard(strategy1) 81 self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") 82 # x (128, 28) 83 def construct(self, x): 84 out = self.matmul(x, self.matmul_weight) 85 out = self.reshape(out, (64, -1)) 86 return out 87 88class ReshapeMulNet(nn.Cell): 89 def __init__(self): 90 super().__init__() 91 self.reshape = P.Reshape() 92 self.mul = P.Mul().shard(((1, 2, 4), (2, 4))) 93 self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") 94 95 def construct(self, x): 96 weight = self.reshape(self.mul_weight, (1, 128, 96)) 97 out = self.mul(weight, self.mul_weight) 98 return out 99 100class ParallelMulNet(nn.Cell): 101 def __init__(self, dense_in_channel=2048, dense_out_channel=250): 102 super().__init__() 103 weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32) 104 bias_np = np.full((dense_out_channel,), 0.01, dtype=np.float32) 105 self.flat = nn.Flatten() 106 self.dense = nn.Dense(in_channels=dense_in_channel, 107 out_channels=dense_out_channel, 108 weight_init=Tensor(weight_np), 109 bias_init=Tensor(bias_np), 110 has_bias=True) 111 self.mul = P.Mul() 112 def construct(self, inputs): 113 x = self.flat(inputs) 114 x = self.dense(x) 115 x = self.mul(x, x) 116 return x 117 118def compile_graph(x, net): 119 net.set_auto_parallel() 120 net.set_train(False) 121 _cell_graph_executor.compile(net, x, auto_parallel_mode=True) 122 strategies = _cell_graph_executor._get_shard_strategy(net) 123 return strategies 124 125def compile_graph_two_input(x, y, net): 126 net.set_auto_parallel() 127 net.set_train(False) 128 _cell_graph_executor.compile(net, x, y, auto_parallel_mode=True) 129 strategies = _cell_graph_executor._get_shard_strategy(net) 130 return strategies 131 132 133def test_dense_relu_semi_auto(): 134 context.reset_auto_parallel_context() 135 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", 136 dataset_strategy="data_parallel") 137 net = DenseMutMulNet() 138 x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) 139 strategies = compile_graph(x, net) 140 for (k, v) in strategies.items(): 141 if re.search('VirtualOutput-op', k) is not None: 142 assert v[0][0] == 8 143 144def test_dense_relu_semi_auto_full_batch(): 145 context.reset_auto_parallel_context() 146 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", 147 dataset_strategy="full_batch") 148 net = DenseMutMulNet() 149 x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) 150 strategies = compile_graph(x, net) 151 for (k, v) in strategies.items(): 152 if re.search('VirtualOutput-op', k) is not None: 153 assert v[0][0] == 1 154 155def test_dense_relu_auto(): 156 context.reset_auto_parallel_context() 157 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", 158 dataset_strategy="data_parallel") 159 net = DenseMutMulNet() 160 x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) 161 strategies = compile_graph(x, net) 162 for (k, v) in strategies.items(): 163 if re.search('VirtualOutput-op', k) is not None: 164 assert v[0][0] == 8 165 166def test_dense_relu_auto_full_batch(): 167 context.reset_auto_parallel_context() 168 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", 169 dataset_strategy="full_batch") 170 net = DenseMutMulNet() 171 x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) 172 strategies = compile_graph(x, net) 173 for (k, v) in strategies.items(): 174 if re.search('VirtualOutput-op', k) is not None: 175 assert v[0][0] == 1 176 177def test_mul_neg_two_output_semi_auto(): 178 context.reset_auto_parallel_context() 179 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", 180 dataset_strategy="data_parallel") 181 net = MulNegTwoOutputNet() 182 x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) 183 strategies = compile_graph(x, net) 184 count = 0 185 for (k, v) in strategies.items(): 186 if re.search('VirtualOutput-op', k) is not None: 187 count += 1 188 assert v[0][0] == 8 189 assert count == 2 190 191def test_mul_neg_two_output_semi_auto_full_batch(): 192 context.reset_auto_parallel_context() 193 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", 194 dataset_strategy="full_batch") 195 net = MulNegTwoOutputNet() 196 x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) 197 strategies = compile_graph(x, net) 198 count = 0 199 for (k, v) in strategies.items(): 200 if re.search('VirtualOutput-op', k) is not None: 201 count += 1 202 assert v[0][0] == 1 203 assert count == 2 204 205def test_mul_neg_two_output_auto(): 206 context.reset_auto_parallel_context() 207 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", 208 dataset_strategy="data_parallel") 209 net = MulNegTwoOutputNet() 210 x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) 211 strategies = compile_graph(x, net) 212 count = 0 213 for (k, v) in strategies.items(): 214 if re.search('VirtualOutput-op', k) is not None: 215 count += 1 216 assert v[0][0] == 8 217 assert count == 2 218 219def test_mul_neg_two_output_full_batch(): 220 context.reset_auto_parallel_context() 221 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", 222 dataset_strategy="full_batch") 223 net = MulNegTwoOutputNet() 224 x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) 225 strategies = compile_graph(x, net) 226 count = 0 227 for (k, v) in strategies.items(): 228 if re.search('VirtualOutput-op', k) is not None: 229 count += 1 230 assert v[0][0] == 1 231 assert count == 2 232 233def test_reshape_matmul_semi_auto(): 234 context.reset_auto_parallel_context() 235 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", 236 dataset_strategy="data_parallel") 237 strategy1 = None 238 strategy2 = ((1, 1), (1, 8)) 239 net = ReshapeMatMulNet(strategy1, strategy2) 240 x = Tensor(np.ones([64, 4, 7]), ms.float32) 241 strategies = compile_graph(x, net) 242 for (k, v) in strategies.items(): 243 if re.search('VirtualOutput-op', k) is not None: 244 assert v[0][0] == 8 245 246def test_reshape_matmul_auto(): 247 context.reset_auto_parallel_context() 248 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", 249 dataset_strategy="data_parallel") 250 strategy1 = None 251 strategy2 = ((1, 1), (1, 8)) 252 net = ReshapeMatMulNet(strategy1, strategy2) 253 x = Tensor(np.ones([64, 4, 7]), ms.float32) 254 strategies = compile_graph(x, net) 255 for (k, v) in strategies.items(): 256 if re.search('VirtualOutput-op', k) is not None: 257 assert v[0][0] == 8 258 259def test_matmul_reshape_semi_auto(): 260 context.reset_auto_parallel_context() 261 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", 262 dataset_strategy="data_parallel") 263 strategy2 = None 264 strategy1 = ((1, 1), (1, 8)) 265 net = MatMulReshapeNet(strategy1, strategy2) 266 x = Tensor(np.ones([128, 28]), ms.float32) 267 strategies = compile_graph(x, net) 268 for (k, v) in strategies.items(): 269 if re.search('VirtualOutput-op', k) is not None: 270 assert v[0][0] == 8 271 272def test_matmul_reshape_auto(): 273 context.reset_auto_parallel_context() 274 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", 275 dataset_strategy="data_parallel") 276 strategy2 = None 277 strategy1 = ((1, 1), (1, 8)) 278 net = MatMulReshapeNet(strategy1, strategy2) 279 x = Tensor(np.ones([128, 28]), ms.float32) 280 strategies = compile_graph(x, net) 281 for (k, v) in strategies.items(): 282 if re.search('VirtualOutput-op', k) is not None: 283 assert v[0][0] == 8 284 285def test_reshape_mul_semi_auto(): 286 context.reset_auto_parallel_context() 287 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", 288 dataset_strategy="full_batch") 289 net = ReshapeMulNet() 290 x = Tensor(np.ones([64, 4]), ms.float32) 291 strategies = compile_graph(x, net) 292 for (k, v) in strategies.items(): 293 if re.search('VirtualOutput-op', k) is not None: 294 assert v[0][0] == 1 295 296def test_reshape_mul_auto(): 297 context.reset_auto_parallel_context() 298 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", 299 dataset_strategy="full_batch") 300 net = ReshapeMulNet() 301 x = Tensor(np.ones([64, 4]), ms.float32) 302 strategies = compile_graph(x, net) 303 for (k, v) in strategies.items(): 304 if re.search('VirtualOutput-op', k) is not None: 305 assert v[0][0] == 1 306 307def test_scalar_output_semi_auto(): 308 context.reset_auto_parallel_context() 309 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", 310 dataset_strategy="data_parallel") 311 net = ParallelMulNet() 312 loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') 313 eval_net = nn.WithEvalCell(net, loss_fn) 314 x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01) 315 label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01) 316 strategies = compile_graph_two_input(x, label, eval_net) 317 count = 0 318 for (k, v) in strategies.items(): 319 if re.search('VirtualOutput-op', k) is not None: 320 assert v[0][0] == 8 321 count += 1 322 assert count == 1 323 324def test_scalar_output_auto(): 325 context.reset_auto_parallel_context() 326 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", 327 dataset_strategy="data_parallel") 328 net = ParallelMulNet() 329 loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') 330 eval_net = nn.WithEvalCell(net, loss_fn) 331 x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01) 332 label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01) 333 strategies = compile_graph_two_input(x, label, eval_net) 334 count = 0 335 for (k, v) in strategies.items(): 336 if re.search('VirtualOutput-op', k) is not None: 337 assert v[0][0] == 8 338 count += 1 339 assert count == 1 340