1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import re 16import pytest 17import numpy as np 18 19import mindspore as ms 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore import context 23from mindspore.common.api import _cell_graph_executor 24from mindspore.common.parameter import Parameter 25from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits 26from mindspore.nn.optim.momentum import Momentum 27from mindspore.ops import operations as P 28from mindspore.ops.operations.comm_ops import AlltoAll 29from mindspore.parallel._utils import _reset_op_id 30from mindspore.train import Model 31from mindspore.context import ParallelMode 32from mindspore.communication.management import GlobalComm, init 33from tests.dataset_mock import MindData 34 35context.set_context(device_target="Ascend") 36GlobalComm.CHECK_ENVS = False 37init("hccl") 38GlobalComm.CHECK_ENVS = True 39 40_x1 = Tensor(np.ones([64, 3, 224, 224]), dtype=ms.float32) 41 42 43class Dataset(MindData): 44 def __init__(self, predict, label, length=3): 45 super(Dataset, self).__init__(size=length) 46 self.predict = predict 47 self.label = label 48 self.index = 0 49 self.length = length 50 51 def __iter__(self): 52 return self 53 54 def __next__(self): 55 if self.index >= self.length: 56 raise StopIteration 57 self.index += 1 58 return self.predict, self.label 59 60 def reset(self): 61 self.index = 0 62 63 64class AllToAllNet(nn.Cell): 65 def __init__(self, strategy1): 66 super(AllToAllNet, self).__init__() 67 self.matmul = P.MatMul().shard(((1, 1), (1, 8))) 68 self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight") 69 self.transpose1 = P.Transpose().shard(strategy1) 70 71 def construct(self, x): 72 x = self.matmul(x, self.matmul_weight) 73 x = self.transpose1(x, (1, 0)) 74 return x 75 76 77def all_to_all_net(strategy1): 78 return AllToAllNet(strategy1=strategy1) 79 80 81def all_to_all_common(strategy1): 82 learning_rate = 0.1 83 momentum = 0.9 84 epoch_size = 2 85 86 context.reset_auto_parallel_context() 87 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8) 88 predict = Tensor(np.ones([32, 128]), dtype=ms.float32) 89 label = Tensor(np.ones([32]), dtype=ms.int32) 90 dataset = Dataset(predict, label, 2) 91 net = all_to_all_net(strategy1) 92 93 loss = SoftmaxCrossEntropyWithLogits(sparse=True) 94 loss.softmax_cross_entropy.shard(((8, 1), (8, 1))) 95 loss.one_hot.shard(((8, 1), (), ())) 96 opt = Momentum(net.trainable_params(), learning_rate, momentum) 97 model = Model(net, loss, opt) 98 99 model.train(epoch_size, dataset, dataset_sink_mode=False) 100 strategys = _cell_graph_executor._get_shard_strategy(model._train_network) 101 return strategys 102 103 104def test_all_to_all(): 105 strategy1 = ((8, 1),) 106 context.set_context(mode=context.GRAPH_MODE) 107 _reset_op_id() 108 strategys = all_to_all_common(strategy1) 109 print(strategys) 110 for (k, v) in strategys.items(): 111 if re.search('SoftmaxCrossEntropyWithLogits-op', k) is not None: 112 assert v == [[8, 1], [8, 1]] 113 elif re.search('OneHot-op', k) is not None: 114 assert v == [[8, 1], [], []] 115 elif re.search('Transpose-op', k) is not None: 116 assert v == [[8, 1]] 117 elif re.search('MatMul-op', k) is not None: 118 assert v == [[1, 1], [1, 8]] 119 120 121def test_all_to_all_success(): 122 """ 123 Feature: AlltoAll 124 Description: on 8p, a 4d tensor split at dim 2 and concat at dim 3 125 Expectation: success 126 """ 127 context.set_auto_parallel_context(device_num=8, global_rank=0) 128 129 class Net(nn.Cell): 130 def __init__(self): 131 super(Net, self).__init__() 132 self.alltoallv = AlltoAll(split_count=8, split_dim=2, concat_dim=3) 133 134 def construct(self, x1): 135 out = self.alltoallv(x1) 136 return out 137 138 net = Net() 139 _cell_graph_executor.compile(net, _x1) 140 141 142def test_all_to_all_invalid_split_count_value_failed(): 143 """ 144 Feature: AlltoAll 145 Description: split_count should be equal to rank size, but not 146 Expectation: throw ValueError 147 """ 148 context.set_auto_parallel_context(device_num=8, global_rank=0) 149 150 class Net(nn.Cell): 151 def __init__(self): 152 super(Net, self).__init__() 153 self.alltoallv = AlltoAll(split_count=7, split_dim=2, concat_dim=3) 154 155 def construct(self, x1): 156 out = self.alltoallv(x1) 157 return out 158 159 with pytest.raises(ValueError): 160 net = Net() 161 _cell_graph_executor.compile(net, _x1) 162 163 164def test_all_to_all_invalid_split_count_type_failed(): 165 """ 166 Feature: AlltoAll 167 Description: split_count should be int, but a list is given 168 Expectation: throw TypeError 169 """ 170 context.set_auto_parallel_context(device_num=8, global_rank=0) 171 172 class Net(nn.Cell): 173 def __init__(self): 174 super(Net, self).__init__() 175 self.alltoallv = AlltoAll(split_count=[8], split_dim=2, concat_dim=3) 176 177 def construct(self, x1): 178 out = self.alltoallv(x1) 179 return out 180 181 with pytest.raises(TypeError): 182 net = Net() 183 _cell_graph_executor.compile(net, _x1) 184 185 186def test_all_to_all_invalid_split_dim_value_failed(): 187 """ 188 Feature: AlltoAll 189 Description: split_dim over input shape 190 Expectation: throw IndexError 191 """ 192 context.set_auto_parallel_context(device_num=8, global_rank=0) 193 194 class Net(nn.Cell): 195 def __init__(self): 196 super(Net, self).__init__() 197 self.alltoallv = AlltoAll(split_count=8, split_dim=4, concat_dim=3) 198 199 def construct(self, x1): 200 out = self.alltoallv(x1) 201 return out 202 203 with pytest.raises(IndexError): 204 net = Net() 205 _cell_graph_executor.compile(net, _x1) 206 207 208def test_all_to_all_invalid_split_dim_type_failed(): 209 """ 210 Feature: AlltoAll 211 Description: split_dim should be int, but a tuple is given 212 Expectation: throw TypeError 213 """ 214 context.set_auto_parallel_context(device_num=8, global_rank=0) 215 216 class Net(nn.Cell): 217 def __init__(self): 218 super(Net, self).__init__() 219 self.alltoallv = AlltoAll(split_count=8, split_dim=(3,), concat_dim=3) 220 221 def construct(self, x1): 222 out = self.alltoallv(x1) 223 return out 224 225 with pytest.raises(TypeError): 226 net = Net() 227 _cell_graph_executor.compile(net, _x1) 228 229 230def test_all_to_all_invalid_concat_dim_value_failed(): 231 """ 232 Feature: AlltoAll 233 Description: concat_dim over input shape 234 Expectation: throw IndexError 235 """ 236 context.set_auto_parallel_context(device_num=8, global_rank=0) 237 238 class Net(nn.Cell): 239 def __init__(self): 240 super(Net, self).__init__() 241 self.alltoallv = AlltoAll(split_count=8, split_dim=3, concat_dim=4) 242 243 def construct(self, x1): 244 out = self.alltoallv(x1) 245 return out 246 247 with pytest.raises(IndexError): 248 net = Net() 249 _cell_graph_executor.compile(net, _x1) 250 251 252def test_all_to_all_invalid_concat_dim_type_failed(): 253 """ 254 Feature: AlltoAll 255 Description: concat_dim should be int, but a tuple is given 256 Expectation: throw TypeError 257 """ 258 context.set_auto_parallel_context(device_num=8, global_rank=0) 259 260 class Net(nn.Cell): 261 def __init__(self): 262 super(Net, self).__init__() 263 self.alltoallv = AlltoAll(split_count=8, split_dim=3, concat_dim=([3],)) 264 265 def construct(self, x1): 266 out = self.alltoallv(x1) 267 return out 268 269 with pytest.raises(TypeError): 270 net = Net() 271 _cell_graph_executor.compile(net, _x1) 272 273 274def test_all_to_all_invalid_split_count_cannot_be_divisible_failed(): 275 """ 276 Feature: AlltoAll 277 Description: shape at split_dim should be divisible by split_count, but not 278 Expectation: throw ValueError 279 """ 280 context.set_auto_parallel_context(device_num=3, global_rank=0) 281 282 class Net(nn.Cell): 283 def __init__(self): 284 super(Net, self).__init__() 285 self.alltoallv = AlltoAll(split_count=3, split_dim=3, concat_dim=3) 286 287 def construct(self, x1): 288 out = self.alltoallv(x1) 289 return out 290 291 with pytest.raises(ValueError): 292 net = Net() 293 _cell_graph_executor.compile(net, _x1) 294 295 296def test_all_to_all_invalid_group_type_failed(): 297 """ 298 Feature: AlltoAll 299 Description: group should be str, but a tuple is given 300 Expectation: throw TypeError 301 """ 302 context.set_auto_parallel_context(device_num=8, global_rank=0) 303 304 class Net(nn.Cell): 305 def __init__(self): 306 super(Net, self).__init__() 307 self.alltoallv = AlltoAll(split_count=8, split_dim=3, concat_dim=3, group=3) 308 309 def construct(self, x1): 310 out = self.alltoallv(x1) 311 return out 312 313 with pytest.raises(TypeError): 314 net = Net() 315 _cell_graph_executor.compile(net, _x1) 316 317 318if __name__ == '__main__': 319 test_all_to_all() 320