• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import re
16import numpy as np
17
18import mindspore as ms
19import mindspore.nn as nn
20from mindspore import Tensor
21from mindspore import context
22import mindspore.common.dtype as mstype
23from mindspore.common.api import _cell_graph_executor
24from mindspore.common.parameter import Parameter
25from mindspore.nn.loss.loss import LossBase
26from mindspore.nn.optim.momentum import Momentum
27from mindspore.ops import operations as P
28from mindspore.ops import functional as F
29from mindspore.parallel._utils import _reset_op_id
30from mindspore.train import Model
31from mindspore.context import ParallelMode
32from tests.dataset_mock import MindData
33
34context.set_context(mode=context.GRAPH_MODE)
35
36
37class Dataset(MindData):
38    def __init__(self, predict, label, length=3):
39        super(Dataset, self).__init__(size=length)
40        self.predict = predict
41        self.label = label
42        self.index = 0
43        self.length = length
44
45    def __iter__(self):
46        return self
47
48    def __next__(self):
49        if self.index >= self.length:
50            raise StopIteration
51        self.index += 1
52        return self.predict, self.label
53
54    def reset(self):
55        self.index = 0
56
57
58class AllToAllNet(nn.Cell):
59    def __init__(self):
60        super(AllToAllNet, self).__init__()
61        self.matmul = P.MatMul()
62        self.matmul_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight")
63        self.transpose1 = P.Transpose()
64
65    def construct(self, x):
66        x = self.matmul(x, self.matmul_weight)
67        x = self.transpose1(x, (1, 0))
68        return x
69
70
71class SoftmaxCrossEntropyWithLogits(LossBase):
72    def __init__(self,
73                 sparse=False,
74                 reduction='none'):
75        super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
76        self.sparse = sparse
77        self.reduction = reduction
78        self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
79        self.one_hot = P.OneHot()
80        self.on_value = Tensor(1.0, mstype.float32)
81        self.off_value = Tensor(0., mstype.float32)
82        self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"]
83
84        if self.is_cpugpu:
85            self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()
86
87    def construct(self, logits, labels):
88        if self.is_cpugpu and self.sparse and self.reduction == 'mean':
89            x = self.sparse_softmax_cross_entropy(logits, labels)
90            return x
91
92        if self.sparse:
93            labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
94        x = self.softmax_cross_entropy(logits, labels)[0]
95        return self.get_loss(x)
96
97
98def all_to_all_net():
99    return AllToAllNet()
100
101
102def all_to_all_common():
103    learning_rate = 0.1
104    momentum = 0.9
105    epoch_size = 2
106
107    context.reset_auto_parallel_context()
108    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=1, global_rank=0)
109    predict = Tensor(np.ones([32, 128]), dtype=ms.float32)
110    label = Tensor(np.ones([32]), dtype=ms.int32)
111    dataset = Dataset(predict, label, 2)
112    net = all_to_all_net()
113
114    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
115    opt = Momentum(net.trainable_params(), learning_rate, momentum)
116    model = Model(net, loss, opt)
117
118    model.train(epoch_size, dataset, dataset_sink_mode=False)
119    strategys = _cell_graph_executor._get_shard_strategy(model._train_network)
120    return strategys
121
122
123def test_one_dev():
124    _reset_op_id()
125    strategies = all_to_all_common()
126    for (k, v) in strategies.items():
127        if re.search('SoftmaxCrossEntropyWithLogits-op', k) is not None:
128            assert v == [[1, 1], [1, 1]]
129        elif re.search('Transpose-op', k) is not None:
130            assert v == [[1, 1]]
131        elif re.search('MatMul-op', k) is not None:
132            assert v == [[1, 1], [1, 1]]
133