• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import re
16import pytest
17import numpy as np
18
19import mindspore as ms
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore import context
23from mindspore.common.api import _cell_graph_executor
24from mindspore.common.parameter import Parameter
25from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
26from mindspore.nn.optim.momentum import Momentum
27from mindspore.ops import operations as P
28from mindspore.ops.operations.comm_ops import AlltoAll
29from mindspore.parallel._utils import _reset_op_id
30from mindspore.train import Model
31from mindspore.context import ParallelMode
32from mindspore.communication.management import GlobalComm, init
33from tests.dataset_mock import MindData
34
35context.set_context(device_target="Ascend")
36GlobalComm.CHECK_ENVS = False
37init("hccl")
38GlobalComm.CHECK_ENVS = True
39
40_x1 = Tensor(np.ones([64, 3, 224, 224]), dtype=ms.float32)
41
42
43class Dataset(MindData):
44    def __init__(self, predict, label, length=3):
45        super(Dataset, self).__init__(size=length)
46        self.predict = predict
47        self.label = label
48        self.index = 0
49        self.length = length
50
51    def __iter__(self):
52        return self
53
54    def __next__(self):
55        if self.index >= self.length:
56            raise StopIteration
57        self.index += 1
58        return self.predict, self.label
59
60    def reset(self):
61        self.index = 0
62
63
64class AllToAllNet(nn.Cell):
65    def __init__(self, strategy1):
66        super(AllToAllNet, self).__init__()
67        self.matmul = P.MatMul().shard(((1, 1), (1, 8)))
68        self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
69        self.transpose1 = P.Transpose().shard(strategy1)
70
71    def construct(self, x):
72        x = self.matmul(x, self.matmul_weight)
73        x = self.transpose1(x, (1, 0))
74        return x
75
76
77def all_to_all_net(strategy1):
78    return AllToAllNet(strategy1=strategy1)
79
80
81def all_to_all_common(strategy1):
82    learning_rate = 0.1
83    momentum = 0.9
84    epoch_size = 2
85
86    context.reset_auto_parallel_context()
87    context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8)
88    predict = Tensor(np.ones([32, 128]), dtype=ms.float32)
89    label = Tensor(np.ones([32]), dtype=ms.int32)
90    dataset = Dataset(predict, label, 2)
91    net = all_to_all_net(strategy1)
92
93    loss = SoftmaxCrossEntropyWithLogits(sparse=True)
94    loss.softmax_cross_entropy.shard(((8, 1), (8, 1)))
95    loss.one_hot.shard(((8, 1), (), ()))
96    opt = Momentum(net.trainable_params(), learning_rate, momentum)
97    model = Model(net, loss, opt)
98
99    model.train(epoch_size, dataset, dataset_sink_mode=False)
100    strategys = _cell_graph_executor._get_shard_strategy(model._train_network)
101    return strategys
102
103
104def test_all_to_all():
105    strategy1 = ((8, 1),)
106    context.set_context(mode=context.GRAPH_MODE)
107    _reset_op_id()
108    strategys = all_to_all_common(strategy1)
109    print(strategys)
110    for (k, v) in strategys.items():
111        if re.search('SoftmaxCrossEntropyWithLogits-op', k) is not None:
112            assert v == [[8, 1], [8, 1]]
113        elif re.search('OneHot-op', k) is not None:
114            assert v == [[8, 1], [], []]
115        elif re.search('Transpose-op', k) is not None:
116            assert v == [[8, 1]]
117        elif re.search('MatMul-op', k) is not None:
118            assert v == [[1, 1], [1, 8]]
119
120
121def test_all_to_all_success():
122    """
123    Feature: AlltoAll
124    Description: on 8p, a 4d tensor split at dim 2 and concat at dim 3
125    Expectation: success
126    """
127    context.set_auto_parallel_context(device_num=8, global_rank=0)
128
129    class Net(nn.Cell):
130        def __init__(self):
131            super(Net, self).__init__()
132            self.alltoallv = AlltoAll(split_count=8, split_dim=2, concat_dim=3)
133
134        def construct(self, x1):
135            out = self.alltoallv(x1)
136            return out
137
138    net = Net()
139    _cell_graph_executor.compile(net, _x1)
140
141
142def test_all_to_all_invalid_split_count_value_failed():
143    """
144    Feature: AlltoAll
145    Description: split_count should be equal to rank size, but not
146    Expectation: throw ValueError
147    """
148    context.set_auto_parallel_context(device_num=8, global_rank=0)
149
150    class Net(nn.Cell):
151        def __init__(self):
152            super(Net, self).__init__()
153            self.alltoallv = AlltoAll(split_count=7, split_dim=2, concat_dim=3)
154
155        def construct(self, x1):
156            out = self.alltoallv(x1)
157            return out
158
159    with pytest.raises(ValueError):
160        net = Net()
161        _cell_graph_executor.compile(net, _x1)
162
163
164def test_all_to_all_invalid_split_count_type_failed():
165    """
166    Feature: AlltoAll
167    Description: split_count should be int, but a list is given
168    Expectation: throw TypeError
169    """
170    context.set_auto_parallel_context(device_num=8, global_rank=0)
171
172    class Net(nn.Cell):
173        def __init__(self):
174            super(Net, self).__init__()
175            self.alltoallv = AlltoAll(split_count=[8], split_dim=2, concat_dim=3)
176
177        def construct(self, x1):
178            out = self.alltoallv(x1)
179            return out
180
181    with pytest.raises(TypeError):
182        net = Net()
183        _cell_graph_executor.compile(net, _x1)
184
185
186def test_all_to_all_invalid_split_dim_value_failed():
187    """
188    Feature: AlltoAll
189    Description: split_dim over input shape
190    Expectation: throw IndexError
191    """
192    context.set_auto_parallel_context(device_num=8, global_rank=0)
193
194    class Net(nn.Cell):
195        def __init__(self):
196            super(Net, self).__init__()
197            self.alltoallv = AlltoAll(split_count=8, split_dim=4, concat_dim=3)
198
199        def construct(self, x1):
200            out = self.alltoallv(x1)
201            return out
202
203    with pytest.raises(IndexError):
204        net = Net()
205        _cell_graph_executor.compile(net, _x1)
206
207
208def test_all_to_all_invalid_split_dim_type_failed():
209    """
210    Feature: AlltoAll
211    Description: split_dim should be int, but a tuple is given
212    Expectation: throw TypeError
213    """
214    context.set_auto_parallel_context(device_num=8, global_rank=0)
215
216    class Net(nn.Cell):
217        def __init__(self):
218            super(Net, self).__init__()
219            self.alltoallv = AlltoAll(split_count=8, split_dim=(3,), concat_dim=3)
220
221        def construct(self, x1):
222            out = self.alltoallv(x1)
223            return out
224
225    with pytest.raises(TypeError):
226        net = Net()
227        _cell_graph_executor.compile(net, _x1)
228
229
230def test_all_to_all_invalid_concat_dim_value_failed():
231    """
232    Feature: AlltoAll
233    Description: concat_dim over input shape
234    Expectation: throw IndexError
235    """
236    context.set_auto_parallel_context(device_num=8, global_rank=0)
237
238    class Net(nn.Cell):
239        def __init__(self):
240            super(Net, self).__init__()
241            self.alltoallv = AlltoAll(split_count=8, split_dim=3, concat_dim=4)
242
243        def construct(self, x1):
244            out = self.alltoallv(x1)
245            return out
246
247    with pytest.raises(IndexError):
248        net = Net()
249        _cell_graph_executor.compile(net, _x1)
250
251
252def test_all_to_all_invalid_concat_dim_type_failed():
253    """
254    Feature: AlltoAll
255    Description: concat_dim should be int, but a tuple is given
256    Expectation: throw TypeError
257    """
258    context.set_auto_parallel_context(device_num=8, global_rank=0)
259
260    class Net(nn.Cell):
261        def __init__(self):
262            super(Net, self).__init__()
263            self.alltoallv = AlltoAll(split_count=8, split_dim=3, concat_dim=([3],))
264
265        def construct(self, x1):
266            out = self.alltoallv(x1)
267            return out
268
269    with pytest.raises(TypeError):
270        net = Net()
271        _cell_graph_executor.compile(net, _x1)
272
273
274def test_all_to_all_invalid_split_count_cannot_be_divisible_failed():
275    """
276    Feature: AlltoAll
277    Description: shape at split_dim should be divisible by split_count, but not
278    Expectation: throw ValueError
279    """
280    context.set_auto_parallel_context(device_num=3, global_rank=0)
281
282    class Net(nn.Cell):
283        def __init__(self):
284            super(Net, self).__init__()
285            self.alltoallv = AlltoAll(split_count=3, split_dim=3, concat_dim=3)
286
287        def construct(self, x1):
288            out = self.alltoallv(x1)
289            return out
290
291    with pytest.raises(ValueError):
292        net = Net()
293        _cell_graph_executor.compile(net, _x1)
294
295
296def test_all_to_all_invalid_group_type_failed():
297    """
298    Feature: AlltoAll
299    Description: group should be str, but a tuple is given
300    Expectation: throw TypeError
301    """
302    context.set_auto_parallel_context(device_num=8, global_rank=0)
303
304    class Net(nn.Cell):
305        def __init__(self):
306            super(Net, self).__init__()
307            self.alltoallv = AlltoAll(split_count=8, split_dim=3, concat_dim=3, group=3)
308
309        def construct(self, x1):
310            out = self.alltoallv(x1)
311            return out
312
313    with pytest.raises(TypeError):
314        net = Net()
315        _cell_graph_executor.compile(net, _x1)
316
317
318if __name__ == '__main__':
319    test_all_to_all()
320