• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16import numpy as np
17import mindspore as ms
18import mindspore.context as context
19from mindspore import Tensor, Parameter
20import mindspore.nn as nn
21from mindspore.common.api import _cell_graph_executor
22from mindspore.nn import TrainOneStepCell, Momentum
23from mindspore.ops import operations as P
24from mindspore.ops.operations.comm_ops import NeighborExchange
25
26_w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
27_x1 = Tensor(np.ones([32, 16]), dtype=ms.float32)
28_x2 = Tensor(np.ones([16, 32]), dtype=ms.float32)
29
30
31def compile_net(net):
32    context.set_context(mode=context.GRAPH_MODE)
33    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
34    train_net = TrainOneStepCell(net, optimizer)
35    train_net.set_train()
36    _cell_graph_executor.compile(train_net, _x1, _x2)
37
38
39def test_NeighborExchange_two_inputs_success():
40    """
41    Feature: NeighborExchange
42    Description: two inputs and two outputs, with valid arguments
43    Expectation: success
44    """
45    context.set_auto_parallel_context(device_num=8, global_rank=0)
46
47    class MatMulNet(nn.Cell):
48        def __init__(self, weight1):
49            super(MatMulNet, self).__init__()
50            self.matmul = P.MatMul()
51            self.mul = P.Mul()
52            self.alltoallv = NeighborExchange(send_rank_ids=[0, 1], recv_rank_ids=[1, 2],
53                                              recv_shapes=([32, 32], [32, 64]),
54                                              send_shapes=([32, 32], [32, 16]), recv_type=ms.float32)
55            self.weight1 = Parameter(weight1, "w1")
56
57        def construct(self, x1, x2):
58            out = self.matmul(x1, x2)
59            out = self.mul(out, self.weight1)
60            out = self.alltoallv((out, x1))
61            return out[0]
62
63    net = MatMulNet(_w1)
64    compile_net(net)
65
66
67def test_NeighborExchange_single_input_success():
68    """
69    Feature: NeighborExchange
70    Description: one inputs and two outputs, with valid arguments
71    Expectation: success
72    """
73    context.set_auto_parallel_context(device_num=8, global_rank=0)
74
75    class MatMulNet2(nn.Cell):
76        def __init__(self, weight1):
77            super(MatMulNet2, self).__init__()
78            self.matmul = P.MatMul()
79            self.mul = P.Mul()
80            self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
81                                              send_shapes=([32, 32],), recv_type=ms.float32)
82            self.weight1 = Parameter(weight1, "w1")
83
84        def construct(self, x1, x2):
85            out = self.matmul(x1, x2)
86            out = self.mul(out, self.weight1)
87            out = self.alltoallv((out,))
88            return out[0]
89
90    net = MatMulNet2(_w1)
91    compile_net(net)
92
93
94def test_NeighborExchange_empty_send_success():
95    """
96    Feature: NeighborExchange
97    Description: empty inputs, with valid arguments
98    Expectation: success
99    """
100    context.set_auto_parallel_context(device_num=8, global_rank=0)
101
102    class Net(nn.Cell):
103        def __init__(self):
104            super(Net, self).__init__()
105            self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[1], recv_shapes=([1],),
106                                              send_shapes=(), recv_type=ms.float32)
107
108        def construct(self, x1):
109            self.alltoallv()
110            return x1
111
112    net = Net()
113    _cell_graph_executor.compile(net, _x1)
114
115
116def test_NeighborExchange_empty_recv_success():
117    """
118    Feature: NeighborExchange
119    Description: empty outputs, with valid arguments
120    Expectation: success
121    """
122    context.set_auto_parallel_context(device_num=8, global_rank=0)
123
124    class Net(nn.Cell):
125        def __init__(self):
126            super(Net, self).__init__()
127            self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[], recv_shapes=(),
128                                              send_shapes=([32, 16],), recv_type=ms.float32)
129
130        def construct(self, x1):
131            self.alltoallv((x1,))
132            return x1
133
134    net = Net()
135    _cell_graph_executor.compile(net, _x1)
136
137
138def test_NeighborExchange_empty_send_empty_recv_success():
139    """
140    Feature: NeighborExchange
141    Description: empty inputs and empty outputs, with valid arguments
142    Expectation: success
143    """
144    context.set_auto_parallel_context(device_num=8, global_rank=0)
145
146    class Net(nn.Cell):
147        def __init__(self):
148            super(Net, self).__init__()
149            self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[], recv_shapes=(),
150                                              send_shapes=(), recv_type=ms.float32)
151
152        def construct(self, x1):
153            self.alltoallv()
154            return x1
155
156    net = Net()
157    _cell_graph_executor.compile(net, _x1)
158
159
160def test_NeighborExchange_recv_shape_num_diff_with_recv_rank_size_failed():
161    """
162    Feature: NeighborExchange
163    Description: send_rank_ids and send_shapes are set as 1 input, but gives 2
164    Expectation: throw ValueError
165    """
166    context.set_auto_parallel_context(device_num=8, global_rank=0)
167
168    class Net(nn.Cell):
169        def __init__(self, weight1):
170            super(Net, self).__init__()
171            self.matmul = P.MatMul()
172            self.mul = P.Mul()
173            self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32],),
174                                              send_shapes=([32, 32],), recv_type=ms.float32)
175            self.weight1 = Parameter(weight1, "w1")
176
177        def construct(self, x1, x2):
178            out = self.matmul(x1, x2)
179            out = self.mul(out, self.weight1)
180            out = self.alltoallv((out,))
181            return out[0]
182
183    net = Net(_w1)
184    with pytest.raises(ValueError):
185        compile_net(net)
186
187
188def test_NeighborExchange_send_shape_num_diff_with_send_rank_size_failed():
189    """
190    Feature: NeighborExchange
191    Description: send_rank_ids is set as 2 inputs, but send_shapes are set as 1 input
192    Expectation: throw ValueError
193    """
194    context.set_auto_parallel_context(device_num=8, global_rank=0)
195
196    class Net(nn.Cell):
197        def __init__(self, weight1):
198            super(Net, self).__init__()
199            self.matmul = P.MatMul()
200            self.mul = P.Mul()
201            self.alltoallv = NeighborExchange(send_rank_ids=[0, 1], recv_rank_ids=[1, 2],
202                                              recv_shapes=([32, 32], [32, 32]),
203                                              send_shapes=([32, 32],), recv_type=ms.float32)
204            self.weight1 = Parameter(weight1, "w1")
205
206        def construct(self, x1, x2):
207            out = self.matmul(x1, x2)
208            out = self.mul(out, self.weight1)
209            out = self.alltoallv((out,))
210            return out[0]
211
212    net = Net(_w1)
213    with pytest.raises(ValueError):
214        compile_net(net)
215
216
217def test_NeighborExchange_send_shape_num_diff_with_input_num_failed():
218    """
219    Feature: NeighborExchange
220    Description: send_rank_ids and send_shapes are set as 2 inputs, but has only 1 input
221    Expectation: throw Exception
222    """
223    context.set_auto_parallel_context(device_num=8, global_rank=0)
224
225    class Net(nn.Cell):
226        def __init__(self, weight1):
227            super(Net, self).__init__()
228            self.matmul = P.MatMul()
229            self.mul = P.Mul()
230            self.alltoallv = NeighborExchange(send_rank_ids=[0, 1], recv_rank_ids=[1, 2],
231                                              recv_shapes=([32, 32], [32, 32]),
232                                              send_shapes=([32, 32], [32, 32]), recv_type=ms.float32)
233            self.weight1 = Parameter(weight1, "w1")
234
235        def construct(self, x1, x2):
236            out = self.matmul(x1, x2)
237            out = self.mul(out, self.weight1)
238            out = self.alltoallv((out,))
239            return out[0]
240
241    net = Net(_w1)
242    with pytest.raises(Exception):
243        compile_net(net)
244
245
246def test_NeighborExchange_send_shape_diff_with_input_shape_failed():
247    """
248    Feature: NeighborExchange
249    Description: send_shapes is set as [16, 16], but input is [32, 32]
250    Expectation: throw Exception
251    """
252    context.set_auto_parallel_context(device_num=8, global_rank=0)
253
254    class Net(nn.Cell):
255        def __init__(self, weight1):
256            super(Net, self).__init__()
257            self.matmul = P.MatMul()
258            self.mul = P.Mul()
259            self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
260                                              send_shapes=([16, 16],), recv_type=ms.float32)
261            self.weight1 = Parameter(weight1, "w1")
262
263        def construct(self, x1, x2):
264            out = self.matmul(x1, x2)
265            out = self.mul(out, self.weight1)
266            out = self.alltoallv((out,))
267            return out[0]
268
269    net = Net(_w1)
270    with pytest.raises(Exception):
271        compile_net(net)
272
273
274def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_failed():
275    """
276    Feature: NeighborExchange
277    Description: send_rank_ids should be list, but a tuple is given
278    Expectation: throw TypeError
279    """
280    context.set_auto_parallel_context(device_num=8, global_rank=0)
281
282    class Net(nn.Cell):
283        def __init__(self):
284            super(Net, self).__init__()
285            self.alltoallv = NeighborExchange(send_rank_ids=(0), recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
286                                              send_shapes=([32, 16],), recv_type=ms.float32)
287
288        def construct(self, x1):
289            out = self.alltoallv((x1,))
290            return out[0]
291
292    net = Net()
293    with pytest.raises(TypeError):
294        _cell_graph_executor.compile(net, _x1)
295
296
297def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_2_failed():
298    """
299    Feature: NeighborExchange
300    Description: send_rank_ids should be list, but a tuple is given
301    Expectation: throw TypeError
302    """
303    context.set_auto_parallel_context(device_num=8, global_rank=0)
304
305    class Net(nn.Cell):
306        def __init__(self):
307            super(Net, self).__init__()
308            self.alltoallv = NeighborExchange(send_rank_ids=(0,), recv_rank_ids=[1, 2],
309                                              recv_shapes=([32, 32], [32, 64]),
310                                              send_shapes=([32, 16],), recv_type=ms.float32)
311
312        def construct(self, x1):
313            out = self.alltoallv((x1,))
314            return out[0]
315
316    net = Net()
317    with pytest.raises(TypeError):
318        _cell_graph_executor.compile(net, _x1)
319
320
321def test_NeighborExchange_attr_check_send_rank_ids_is_float_failed():
322    """
323    Feature: NeighborExchange
324    Description: send_rank_ids should be int, but a float is given
325    Expectation: throw TypeError
326    """
327    context.set_auto_parallel_context(device_num=8, global_rank=0)
328
329    class Net(nn.Cell):
330        def __init__(self):
331            super(Net, self).__init__()
332            self.alltoallv = NeighborExchange(send_rank_ids=[1.0], recv_rank_ids=[1, 2],
333                                              recv_shapes=([32, 32], [32, 64]),
334                                              send_shapes=([32, 16],), recv_type=ms.float32)
335
336        def construct(self, x1):
337            out = self.alltoallv((x1,))
338            return out[0]
339
340    net = Net()
341    with pytest.raises(TypeError):
342        _cell_graph_executor.compile(net, _x1)
343
344
345def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_failed():
346    """
347    Feature: NeighborExchange
348    Description: recv_rank_ids should be list, but a tuple is given
349    Expectation: throw TypeError
350    """
351    context.set_auto_parallel_context(device_num=8, global_rank=0)
352
353    class Net(nn.Cell):
354        def __init__(self):
355            super(Net, self).__init__()
356            self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=([1, 2],),
357                                              recv_shapes=([32, 32], [32, 64]),
358                                              send_shapes=([32, 16],), recv_type=ms.float32)
359
360        def construct(self, x1):
361            out = self.alltoallv((x1,))
362            return out[0]
363
364    net = Net()
365    with pytest.raises(TypeError):
366        _cell_graph_executor.compile(net, _x1)
367
368
369def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_2_failed():
370    """
371    Feature: NeighborExchange
372    Description: recv_rank_ids should be list, but a tuple is given
373    Expectation: throw TypeError
374    """
375    context.set_auto_parallel_context(device_num=8, global_rank=0)
376
377    class Net(nn.Cell):
378        def __init__(self):
379            super(Net, self).__init__()
380            self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=(1, 2,),
381                                              recv_shapes=([32, 32], [32, 64]),
382                                              send_shapes=([32, 16],), recv_type=ms.float32)
383
384        def construct(self, x1):
385            out = self.alltoallv((x1,))
386            return out[0]
387
388    net = Net()
389    with pytest.raises(TypeError):
390        _cell_graph_executor.compile(net, _x1)
391
392
393def test_NeighborExchange_attr_check_recv_rank_ids_is_float_failed():
394    """
395    Feature: NeighborExchange
396    Description: recv_rank_ids should be int, but a float is given
397    Expectation: throw TypeError
398    """
399    context.set_auto_parallel_context(device_num=8, global_rank=0)
400
401    class Net(nn.Cell):
402        def __init__(self):
403            super(Net, self).__init__()
404            self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2.0],
405                                              recv_shapes=([32, 32], [32, 64]),
406                                              send_shapes=([32, 16],), recv_type=ms.float32)
407
408        def construct(self, x1):
409            out = self.alltoallv((x1,))
410            return out[0]
411
412    net = Net()
413    with pytest.raises(TypeError):
414        _cell_graph_executor.compile(net, _x1)
415
416
417def test_NeighborExchange_attr_check_send_shape_not_tuple_failed():
418    """
419    Feature: NeighborExchange
420    Description: send_shapes should be tuple(list), but a list is given
421    Expectation: throw TypeError
422    """
423    context.set_auto_parallel_context(device_num=8, global_rank=0)
424
425    class Net(nn.Cell):
426        def __init__(self):
427            super(Net, self).__init__()
428            self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
429                                              recv_shapes=([32, 32], [32, 64]),
430                                              send_shapes=([32, 16]), recv_type=ms.float32)
431
432        def construct(self, x1):
433            out = self.alltoallv((x1,))
434            return out[0]
435
436    net = Net()
437    with pytest.raises(TypeError):
438        _cell_graph_executor.compile(net, _x1)
439
440
441def test_NeighborExchange_attr_check_send_shape_list_failed():
442    """
443    Feature: NeighborExchange
444    Description: send_shapes should be tuple(list), but a list(list) is given
445    Expectation: throw TypeError
446    """
447    context.set_auto_parallel_context(device_num=8, global_rank=0)
448
449    class Net(nn.Cell):
450        def __init__(self):
451            super(Net, self).__init__()
452            self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
453                                              recv_shapes=([32, 32], [32, 64]),
454                                              send_shapes=[[32, 16]], recv_type=ms.float32)
455
456        def construct(self, x1):
457            out = self.alltoallv((x1,))
458            return out[0]
459
460    net = Net()
461    with pytest.raises(TypeError):
462        _cell_graph_executor.compile(net, _x1)
463
464
465def test_NeighborExchange_attr_check_recv_type_numpy_failed():
466    """
467    Feature: NeighborExchange
468    Description: recv_type should be mindspore type, but a numpy type is given
469    Expectation: throw TypeError
470    """
471    context.set_auto_parallel_context(device_num=8, global_rank=0)
472
473    class Net(nn.Cell):
474        def __init__(self):
475            super(Net, self).__init__()
476            self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
477                                              recv_shapes=([32, 32], [32, 64]),
478                                              send_shapes=([32, 16],), recv_type=np.float32)
479
480        def construct(self, x1):
481            out = self.alltoallv((x1,))
482            return out[0]
483
484    net = Net()
485    with pytest.raises(TypeError):
486        _cell_graph_executor.compile(net, _x1)
487
488
489def test_NeighborExchange_attr_invalid_grpup_failed():
490    """
491    Feature: NeighborExchange
492    Description: group should be str, but a tuple is given
493    Expectation: throw TypeError
494    """
495    context.set_auto_parallel_context(device_num=8, global_rank=0)
496
497    class Net(nn.Cell):
498        def __init__(self):
499            super(Net, self).__init__()
500            self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
501                                              recv_shapes=([32, 32], [32, 64]),
502                                              send_shapes=([32, 16],), recv_type=ms.float32, group=("str",))
503
504        def construct(self, x1):
505            out = self.alltoallv((x1,))
506            return out[0]
507
508    net = Net()
509    with pytest.raises(TypeError):
510        _cell_graph_executor.compile(net, _x1)
511