1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import re 16import numpy as np 17 18import mindspore as ms 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22import mindspore.common.dtype as mstype 23from mindspore.common.api import _cell_graph_executor 24from mindspore.common.parameter import Parameter 25from mindspore.nn.loss.loss import LossBase 26from mindspore.nn.optim.momentum import Momentum 27from mindspore.ops import operations as P 28from mindspore.ops import functional as F 29from mindspore.parallel._utils import _reset_op_id 30from mindspore.train import Model 31from mindspore.context import ParallelMode 32from tests.dataset_mock import MindData 33 34context.set_context(mode=context.GRAPH_MODE) 35 36 37class Dataset(MindData): 38 def __init__(self, predict, label, length=3): 39 super(Dataset, self).__init__(size=length) 40 self.predict = predict 41 self.label = label 42 self.index = 0 43 self.length = length 44 45 def __iter__(self): 46 return self 47 48 def __next__(self): 49 if self.index >= self.length: 50 raise StopIteration 51 self.index += 1 52 return self.predict, self.label 53 54 def reset(self): 55 self.index = 0 56 57 58class AllToAllNet(nn.Cell): 59 def __init__(self): 60 super(AllToAllNet, self).__init__() 61 self.matmul = P.MatMul() 62 self.matmul_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight") 63 self.transpose1 = P.Transpose() 64 65 def construct(self, x): 66 x = self.matmul(x, self.matmul_weight) 67 x = self.transpose1(x, (1, 0)) 68 return x 69 70 71class SoftmaxCrossEntropyWithLogits(LossBase): 72 def __init__(self, 73 sparse=False, 74 reduction='none'): 75 super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) 76 self.sparse = sparse 77 self.reduction = reduction 78 self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits() 79 self.one_hot = P.OneHot() 80 self.on_value = Tensor(1.0, mstype.float32) 81 self.off_value = Tensor(0., mstype.float32) 82 self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] 83 84 if self.is_cpugpu: 85 self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits() 86 87 def construct(self, logits, labels): 88 if self.is_cpugpu and self.sparse and self.reduction == 'mean': 89 x = self.sparse_softmax_cross_entropy(logits, labels) 90 return x 91 92 if self.sparse: 93 labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value) 94 x = self.softmax_cross_entropy(logits, labels)[0] 95 return self.get_loss(x) 96 97 98def all_to_all_net(): 99 return AllToAllNet() 100 101 102def all_to_all_common(): 103 learning_rate = 0.1 104 momentum = 0.9 105 epoch_size = 2 106 107 context.reset_auto_parallel_context() 108 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=1, global_rank=0) 109 predict = Tensor(np.ones([32, 128]), dtype=ms.float32) 110 label = Tensor(np.ones([32]), dtype=ms.int32) 111 dataset = Dataset(predict, label, 2) 112 net = all_to_all_net() 113 114 loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 115 opt = Momentum(net.trainable_params(), learning_rate, momentum) 116 model = Model(net, loss, opt) 117 118 model.train(epoch_size, dataset, dataset_sink_mode=False) 119 strategys = _cell_graph_executor._get_shard_strategy(model._train_network) 120 return strategys 121 122 123def test_one_dev(): 124 _reset_op_id() 125 strategies = all_to_all_common() 126 for (k, v) in strategies.items(): 127 if re.search('SoftmaxCrossEntropyWithLogits-op', k) is not None: 128 assert v == [[1, 1], [1, 1]] 129 elif re.search('Transpose-op', k) is not None: 130 assert v == [[1, 1]] 131 elif re.search('MatMul-op', k) is not None: 132 assert v == [[1, 1], [1, 1]] 133