1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17 18import mindspore as ms 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.common.parameter import Parameter 23from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits 24from mindspore.nn.optim.momentum import Momentum 25from mindspore.ops import operations as P 26from mindspore.parallel._utils import _reset_op_id 27from mindspore.train import Model 28from mindspore.context import ParallelMode 29from tests.dataset_mock import MindData 30 31class Dataset(MindData): 32 def __init__(self, predict, label, length=3): 33 super(Dataset, self).__init__(size=length) 34 self.predict = predict 35 self.label = label 36 self.index = 0 37 self.length = length 38 39 def __iter__(self): 40 return self 41 42 def __next__(self): 43 if self.index >= self.length: 44 raise StopIteration 45 self.index += 1 46 return self.predict, self.label 47 48 def reset(self): 49 self.index = 0 50 51 52class AllToAllNet(nn.Cell): 53 def __init__(self, strategy1): 54 super(AllToAllNet, self).__init__() 55 self.matmul = P.MatMul().shard(((1, 1), (1, 8))) 56 self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight") 57 self.transpose1 = P.Transpose().shard(strategy1) 58 59 def construct(self, x): 60 x = self.matmul(x, self.matmul_weight) 61 x = self.transpose1(x, (1, 0)) 62 return x 63 64def all_to_all_net(strategy1): 65 return AllToAllNet(strategy1=strategy1) 66 67def all_to_all_common(strategy1): 68 learning_rate = 0.1 69 momentum = 0.9 70 epoch_size = 2 71 72 context.set_context(mode=context.GRAPH_MODE) 73 context.reset_auto_parallel_context() 74 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8, 75 dataset_strategy="full_batch") 76 predict = Tensor(np.ones([256, 128]), dtype=ms.float32) 77 label = Tensor(np.ones([256]), dtype=ms.int32) 78 dataset = Dataset(predict, label, 2) 79 net = all_to_all_net(strategy1) 80 81 loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 82 loss.softmax_cross_entropy.shard(((8, 1), (8, 1))) 83 loss.one_hot.shard(((8, 1), (), ())) 84 opt = Momentum(net.trainable_params(), learning_rate, momentum) 85 model = Model(net, loss, opt) 86 87 model.train(epoch_size, dataset, dataset_sink_mode=False) 88 89def test_all_to_all(): 90 strategy1 = ((8, 1),) 91 _reset_op_id() 92 all_to_all_common(strategy1) 93 94def test_data_parallel_mode(): 95 _reset_op_id() 96 learning_rate = 0.1 97 momentum = 0.9 98 epoch_size = 2 99 context.set_context(mode=context.GRAPH_MODE) 100 context.reset_auto_parallel_context() 101 context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, full_batch=True) 102 predict = Tensor(np.ones([256, 128]), dtype=ms.float32) 103 label = Tensor(np.ones([256]), dtype=ms.int32) 104 dataset = Dataset(predict, label, 2) 105 net = all_to_all_net(None) 106 loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 107 opt = Momentum(net.trainable_params(), learning_rate, momentum) 108 model = Model(net, loss, opt) 109 110 with pytest.raises(RuntimeError): 111 model.train(epoch_size, dataset, dataset_sink_mode=False) 112