• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16from mindspore.parallel._tensor import _transform_tensor_by_layout, _get_needed_rank_list_by_layouts, \
17    _get_needed_rank_transform_operator_map_by_layouts, _generate_transform_operator_stack, \
18    _apply_tensor_transform_operators, _construct_from_to_tensor_layout, _construct_tensor_layout_for_opt_shard, \
19    _get_tensor_strategy
20
21
22def test_transform_tensor_by_layout_allconcat_axis_1():
23    """
24    Feature: transform tensor by layout.
25    Description: allconcat.
26    Expectation: assert no error.
27    """
28    from_layout = ((1, 8), (1, 0), (32, 256))
29    to_layout = ((1, 4, 2), (2, 1), (32, 256))
30    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
31    rank_id = 0
32    op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id)
33    assert op_list == [('Reshape', [32, 1, 32]), ('AllGather', [0, 1, 2]), ('Reshape', [32, 64])]
34
35
36def test_transform_tensor_by_layout_allconcat_axis_1_using_none_map():
37    """
38    Feature: transform tensor by layout.
39    Description: allconcat, tensor map contains -1
40    Expectation: assert no error.
41    """
42    from_layout = ((8,), (0, -1), (32, 256))
43    to_layout = ((4, 2), (1, -1), (32, 256))
44    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
45    rank_id = 0
46    op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id)
47    assert op_list == [('Reshape', [1, 4, 256]), ('AllGather', [0, 1, 1]), ('Reshape', [8, 256])]
48
49
50def test_transform_tensor_by_layout_allconcat_to_single():
51    """
52    Feature: transform tensor by layout.
53    Description: allconcat, transform to single device.
54    Expectation: assert no error.
55    """
56    from_layout = ((4, 2), (1, 0), (32, 256))
57    to_layout = ((1, 8), (-1, -1), (32, 256))
58    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
59    rank_id = 0
60    op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id)
61    assert op_list == [('AllGather', [0, 2, 4, 6, 0]), ('AllGather', [0, 1, 1])]
62
63
64def test_transform_tensor_by_layout_allconcat_axis_0():
65    """
66    Feature: transform tensor by layout.
67    Description: allconcat.
68    Expectation: assert no error.
69    """
70    from_layout = ((8, 1), (1, 0), (32, 256))
71    to_layout = ((2, 1, 4), (2, 1), (32, 256))
72    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
73    rank_id = 0
74    op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id)
75    assert op_list == [('Reshape', [1, 4, 256]), ('AllGather', [0, 1, 2, 3, 1]), ('Reshape', [16, 256])]
76
77
78def test_transform_tensor_by_layout_all_to_all():
79    """
80    Feature: transform tensor by layout.
81    Description: all to all.
82    Expectation: assert no error.
83    """
84    from_layout = ((8, 1), (1, -1), (32, 64))
85    to_layout = ((1, 8), (-1, 0), (32, 64))
86    device_list = list(range(0, 8))
87    rank_id = 0
88    op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id)
89    assert op_list == [('AllGather', [0, 1, 2, 3, 4, 5, 6, 7, 0]), ('StridedSlice', [0, 0, 32, 8, 1, 1])]
90
91
92def test_transform_tensor_by_layout_mix():
93    """
94    Feature: transform tensor by layout.
95    Description: allconcat + allsplit.
96    Expectation: assert no error.
97    """
98    from_layout = ((2, 2, 2), (2, 1, 0), (32, 64, 128))
99    to_layout = ((8, 1, 1), (2, 1, 0), (32, 64, 128))
100    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
101    rank_id = 1
102    op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id)
103    assert op_list == [('Reshape', [1, 2, 8, 32, 64]), ('AllGather', [1, 3, 3]),
104                       ('StridedSlice', [0, 0, 0, 0, 0, 1, 1, 8, 64, 64, 1, 1, 1, 1, 1]),
105                       ('AllGather', [0, 1, 4]), ('StridedSlice', [0, 0, 4, 0, 0, 1, 1, 8, 64, 128, 1, 1, 1, 1, 1]),
106                       ('Reshape', [4, 64, 128])]
107
108
109def test_needed_rank_list_by_layouts_1():
110    """
111    Feature: get needed rank list for transform tensor by layout.
112    Description: allconcat + allsplit.
113    Expectation: assert no error.
114    """
115    from_layout = ((2, 2, 2), (2, 1, 0), (32, 64, 128))
116    to_layout = ((8, 1, 1), (2, 1, 0), (32, 64, 128))
117    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
118    rank_id = 1
119    needed_rank_list = _get_needed_rank_list_by_layouts(from_layout, to_layout, device_list, rank_id)
120    assert needed_rank_list == [0, 1, 2, 3]
121
122
123def test_needed_rank_list_by_layouts_2():
124    """
125    Feature: get needed rank list for transform tensor by layout.
126    Description: allconcat + allsplit, 128p.
127    Expectation: assert no error.
128    """
129    from_layout = ((32, 1, 8), (2, -1, 0), (32, 64, 128))
130    to_layout = ((2, 64, 2), (2, 1, 0), (32, 64, 128))
131    device_list = list(range(0, 256))
132    rank_id = 1
133    needed_rank_list = _get_needed_rank_list_by_layouts(from_layout, to_layout, device_list, rank_id)
134    assert needed_rank_list == list(range(0, 128))
135
136
137def test_generate_transform_operator_stack_1():
138    """
139    Feature: generate transform operator stack.
140    Description: moe transform.
141    Expectation: assert no error.
142    """
143    from_layout = ((8, 1, 8), (2, 1, -1), (32, 64, 128))
144    to_layout = ((8, 8, 1), (2, 1, 0), (32, 64, 128))
145    device_list = list(range(0, 64))
146    param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout,
147                                                                        device_list, 0)
148    assert param_rank_map == {0: [('StridedSlice', [0, 0, 0, 4, 8, 128, 1, 1, 1])]}
149    transform_operator_stack = _generate_transform_operator_stack(param_rank_map, 0)
150    assert transform_operator_stack == [(0, 0, ('StridedSlice', [0, 0, 0, 4, 8, 128, 1, 1, 1]))]
151
152
153def test_generate_transform_operator_stack_2():
154    """
155    Feature: generate transform operator stack.
156    Description: all to all.
157    Expectation: assert no error.
158    """
159    from_layout = ((8, 1), (1, -1), (32, 64))
160    to_layout = ((1, 8), (-1, 0), (32, 64))
161    device_list = list(range(0, 8))
162    param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout,
163                                                                        device_list, 0)
164    assert len(param_rank_map) == 8
165    transform_operator_stack = _generate_transform_operator_stack(param_rank_map, 0)
166    assert transform_operator_stack == [(0, 1, ('StridedSlice', [0, 0, 32, 8, 1, 1])),
167                                        (0, 0, ('AllGather', [0, 1, 2, 3, 4, 5, 6, 7, 0]))]
168
169
170def test_generate_transform_operator_stack_3():
171    """
172    Feature: generate transform operator stack.
173    Description: mix.
174    Expectation: assert no error.
175    """
176    from_layout = ((8, 1, 8), (2, -1, 0), (32, 64, 128))
177    to_layout = ((2, 8, 4), (2, 1, 0), (32, 64, 128))
178    device_list = list(range(0, 64))
179    param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout,
180                                                                        device_list, 0)
181    transform_operator_stack = _generate_transform_operator_stack(param_rank_map, 0)
182    assert len(transform_operator_stack) == 79
183
184
185def test_generate_transform_operator_stack_4():
186    """
187    Feature: generate transform operator stack.
188    Description: multi allconcat and allsplit.
189    Expectation: assert no error.
190    """
191    from_layout = ((2, 2, 2), (2, 1, 0), (32, 64, 128))
192    to_layout = ((8, 1, 1), (2, 1, 0), (32, 64, 128))
193    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
194    rank_id = 1
195    param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout,
196                                                                        device_list, rank_id)
197    transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
198    assert transform_operator_stack == [(1, 5, ('Reshape', [4, 64, 128])),
199                                        (1, 4, ('StridedSlice', [0, 0, 4, 0, 0, 1, 1, 8, 64, 128, 1, 1, 1, 1, 1])),
200                                        (1, 3, ('AllGather', [0, 1, 4])),
201                                        (0, 2, ('StridedSlice', [0, 0, 0, 0, 0, 1, 1, 8, 64, 64, 1, 1, 1, 1, 1])),
202                                        (1, 2, ('StridedSlice', [0, 0, 0, 0, 0, 1, 1, 8, 64, 64, 1, 1, 1, 1, 1])),
203                                        (0, 1, ('AllGather', [0, 2, 3])), (1, 1, ('AllGather', [1, 3, 3])),
204                                        (0, 0, ('Reshape', [1, 2, 8, 32, 64])), (2, 0, ('Reshape', [1, 2, 8, 32, 64])),
205                                        (1, 0, ('Reshape', [1, 2, 8, 32, 64])), (3, 0, ('Reshape', [1, 2, 8, 32, 64]))]
206
207
208def test_apply_tensor_transform_operators_allconcat():
209    """
210    Feature: apply tensor transform operators.
211    Description: allconcat.
212    Expectation: assert no error.
213    """
214    device_num = 8
215    tensor_dict = {}
216    for rank in range(device_num):
217        tensor_dict[rank] = np.full((1, 8, 8), rank)
218    from_layout = ((8, 1, 1), (2, 1, 0), (8, 8, 8))
219    to_layout = ((1, 1, 1, 8), (-1, -1, -1), (8, 8, 8))
220    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
221    rank_id = 0
222    param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout,
223                                                                        device_list, rank_id)
224    transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
225    _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num)
226    assert tensor_dict.get(0).shape == (8, 8, 8)
227    for rank in range(8):
228        assert np.all(tensor_dict.get(0)[rank] == rank)
229
230
231def test_apply_tensor_transform_operators_allsplit():
232    """
233    Feature: apply tensor transform operators.
234    Description: allsplit.
235    Expectation: assert no error.
236    """
237    device_num = 8
238    tensor_dict = {}
239    for rank in range(device_num):
240        tensor_dict[rank] = np.array([np.full((8, 8), i) for i in range(device_num)])
241    from_layout = ((8,), (-1, -1, -1), (8, 8, 8))
242    to_layout = ((8,), (-1, -1, 0), (8, 8, 8))
243    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
244    rank_id = 0
245    param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout,
246                                                                        device_list, rank_id)
247    transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
248    _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num)
249    assert tensor_dict.get(0).shape == (8, 8, 1)
250    for rank in range(8):
251        assert np.all(tensor_dict.get(0)[rank] == rank)
252
253
254def test_apply_tensor_transform_operators_mix():
255    """
256    Feature: apply tensor transform operators.
257    Description: mulit allconcat, allsplit.
258    Expectation: assert no error.
259    """
260    device_num = 8
261    tensor_dict = {}
262    for rank in range(device_num):
263        tensor_dict[rank] = np.full((1, 8, 8), rank)
264    from_layout = ((8, 1, 1), (2, 1, 0), (8, 8, 8))
265    to_layout = ((2, 2, 2), (2, 1, 0), (8, 8, 8))
266    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
267    rank_id = 0
268    param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout,
269                                                                        device_list, rank_id)
270    transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
271    _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num)
272    assert tensor_dict.get(0).shape == (4, 4, 4)
273    for rank in range(4):
274        assert np.all(tensor_dict.get(0)[rank] == rank)
275
276
277def test_apply_tensor_transform_operators_no_need_transform():
278    """
279    Feature: apply tensor transform operators.
280    Description: no need transform.
281    Expectation: assert no error.
282    """
283    device_num = 8
284    tensor_dict = {}
285    for rank in range(device_num):
286        tensor_dict[rank] = np.full((1, 8, 8), rank)
287    from_layout = ((8, 1, 1), (2, -1, -1), (8, 8, 8))
288    to_layout = ((8, 1, 1), (2, -1, -1), (8, 8, 8))
289    device_list = [0, 1, 2, 3, 4, 5, 6, 7]
290    rank_id = 0
291    param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout,
292                                                                        device_list, rank_id)
293    transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
294    _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num)
295    assert tensor_dict.get(0).shape == (1, 8, 8)
296    assert np.all(tensor_dict.get(0) == rank_id)
297
298
299def test_construct_tensor_layout_for_opt_shard():
300    """
301    Feature: construct tensor layout for optimizer shard.
302    Description: construct tensor layout for optimizer shard.
303    Expectation: assert no error.
304    """
305    dev_matrix = (2, 2, 2)
306    tensor_map = (2, 1, -1)
307    opt_shard_step = 4
308    opt_shard_size = 2
309    origin_full_tensor_shape = (16, 16, 16)
310    new_dev_matrix, new_tensor_map, new_shape = _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map,
311                                                                                       opt_shard_step, opt_shard_size,
312                                                                                       origin_full_tensor_shape)
313    assert new_dev_matrix == [2, 2, 2, 1]
314    assert new_tensor_map == [2, 3, 1, -1]
315    assert new_shape == [2, 8, 16, 16]
316
317
318def test_construct_from_to_tensor_layout():
319    """
320    Feature: construct from and to tensor layout.
321    Description: construct from and to tensor layout.
322    Expectation: assert no error.
323    """
324    tensor_shape = (8, 1024)
325    from_dev_matrix = (8, 8)
326    from_tensor_map = (-1, 0)
327    to_dev_matrix = (16, 8)
328    to_tensor_map = (1, -1)
329    from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(tensor_shape, from_dev_matrix,
330                                                                            from_tensor_map, tensor_shape,
331                                                                            to_dev_matrix, to_tensor_map)
332    assert from_tensor_layout == ([2, 8, 8], [-1, 0], [8, 1024])
333    assert to_tensor_layout == ([16, 8], [1, -1], [8, 1024])
334
335
336def conver_tensor_by_layout(from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
337                            to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size,
338                            tensor_dict, rank_id):
339    device_num = np.prod(from_dev_matrix)
340    tensor_shape = tensor_dict[rank_id % device_num].shape
341    param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
342    origin_tensor_shape = ()
343    for i, item in enumerate(tensor_shape):
344        if i == 0 and from_opt_shard_size > 0:
345            origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
346            continue
347        origin_tensor_shape += (item * param_strategy[i],)
348
349    from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
350        from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
351    to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
352        to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
353    # Convert tensor layout to same device num
354    from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix,
355                                                                            from_tensor_map, to_full_tensor_shape,
356                                                                            to_dev_matrix, to_tensor_map)
357
358    # when the from_layout is less devices, the checkpoint_map for map[device_num] should using map[0]
359
360    device_list = list(range(0, np.prod(from_tensor_layout[0])))
361    if rank_id % device_num not in tensor_dict:
362        raise ValueError("The checkpoint of rank {} is missing.".format(rank_id % device_num))
363    param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
364                                                                        device_list, rank_id)
365    for param_rank, _ in param_rank_map.items():
366        if from_opt_shard_size > 0:
367            from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
368            from_slice_tensor_shape = ()
369            for i, item in enumerate(from_full_tensor_shape):
370                from_slice_tensor_shape += (item // from_tensor_strategy[i],)
371            param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape)))
372        if to_opt_shard_size > 0:
373            to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
374            to_slice_tensor_shape = ()
375            for i, item in enumerate(origin_tensor_shape):
376                if i == 0 and to_opt_shard_size > 0:
377                    to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
378                    continue
379                to_slice_tensor_shape += (item // to_tensor_strategy[i],)
380            param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
381
382    transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
383    _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num)
384
385    return tensor_dict[rank_id % device_num]
386
387
388def test_transform_parallel_checkpoint():
389    """
390    Feature: transform parallel checkpoint.
391    Description: device_num 16. None -> optimizer_shard 2, model_parallel 4
392        -> optimizer_shard 4, model_parallel 2 -> optimizer_shard 16, model_parallel 1
393    Expectation: assert no error.
394    """
395    import copy
396    device_num = 16
397    tensor_dict = {}
398    for rank in range(device_num):
399        tensor_dict[rank] = np.array([np.full((8,), i) for i in range(device_num)])
400    no_change_tensor_dict = copy.deepcopy(tensor_dict)
401    result_dict = {}
402    from_dev_matrix = (16,)
403    from_tensor_map = (-1, -1)
404    from_opt_shard_step = 0
405    from_opt_shard_size = 0
406    to_dev_matrix = (4, 4)
407    to_tensor_map = (0, -1)
408    to_opt_shard_step = 4
409    to_opt_shard_size = 2
410    for rank_id in range(device_num):
411        result = conver_tensor_by_layout(from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
412                                         to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size,
413                                         tensor_dict, rank_id)
414        result_dict[rank_id] = result
415        tensor_dict = copy.deepcopy(no_change_tensor_dict)
416        rank = rank_id % 8
417        first_value = (rank % 4) * 4 + (rank // 4) * 2
418        assert np.all(result[0] == first_value)
419        assert np.all(result[1] == first_value + 1)
420    to_dev_matrix1 = (8, 2)
421    to_tensor_map1 = (0, -1)
422    to_opt_shard_step1 = 2
423    to_opt_shard_size1 = 4
424    tensor_dict = copy.deepcopy(result_dict)
425    no_change_tensor_dict = copy.deepcopy(result_dict)
426    for rank_id in range(device_num):
427        result = conver_tensor_by_layout(to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size,
428                                         to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1,
429                                         tensor_dict, rank_id)
430        result_dict[rank_id] = result
431        tensor_dict = copy.deepcopy(no_change_tensor_dict)
432        rank = rank_id % 8
433        first_value = (rank % 2) * 8 + (rank // 2) * 2
434        assert np.all(result[0] == first_value)
435        assert np.all(result[1] == first_value + 1)
436    to_dev_matrix2 = (16,)
437    to_tensor_map2 = (-1, -1)
438    to_opt_shard_step2 = 1
439    to_opt_shard_size2 = 16
440    tensor_dict = copy.deepcopy(result_dict)
441    no_change_tensor_dict = copy.deepcopy(result_dict)
442    for rank_id in range(device_num):
443        result = conver_tensor_by_layout(to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1,
444                                         to_dev_matrix2, to_tensor_map2, to_opt_shard_step2, to_opt_shard_size2,
445                                         tensor_dict, rank_id)
446        result_dict[rank_id] = result
447        tensor_dict = copy.deepcopy(no_change_tensor_dict)
448        assert np.all(result == rank_id)
449
450
451def test_transform_parallel_checkpoint_1():
452    """
453    Feature: transform parallel checkpoint.
454    Description: model_parallel in last dim. device_num 16. None -> optimizer_shard 2, model_parallel 4
455        -> optimizer_shard 4, model_parallel 2 -> optimizer_shard 16, model_parallel 1
456    Expectation: assert no error.
457    """
458    import copy
459    device_num = 16
460    tensor_dict = {}
461    for rank in range(device_num):
462        tensor_dict[rank] = np.array([np.full((8,), i) for i in range(device_num)])
463    no_change_tensor_dict = copy.deepcopy(tensor_dict)
464    result_dict = {}
465    from_dev_matrix = (16,)
466    from_tensor_map = (-1, -1)
467    from_opt_shard_step = 0
468    from_opt_shard_size = 0
469    to_dev_matrix = (4, 4)
470    to_tensor_map = (-1, 0)
471    to_opt_shard_step = 4
472    to_opt_shard_size = 2
473    for rank_id in range(device_num):
474        result = conver_tensor_by_layout(from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
475                                         to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size,
476                                         tensor_dict, rank_id)
477        result_dict[rank_id] = result
478        tensor_dict = copy.deepcopy(no_change_tensor_dict)
479        rank = rank_id % 8
480        first_value = (rank // 4) * 8
481        assert np.all(result[0] == first_value)
482    to_dev_matrix1 = (8, 2)
483    to_tensor_map1 = (-1, 0)
484    to_opt_shard_step1 = 2
485    to_opt_shard_size1 = 4
486    tensor_dict = copy.deepcopy(result_dict)
487    no_change_tensor_dict = copy.deepcopy(result_dict)
488    for rank_id in range(device_num):
489        result = conver_tensor_by_layout(to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size,
490                                         to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1,
491                                         tensor_dict, rank_id)
492        result_dict[rank_id] = result
493        tensor_dict = copy.deepcopy(no_change_tensor_dict)
494        rank = rank_id % 8
495        first_value = (rank // 2) * 4
496        assert np.all(result[0] == first_value)
497    to_dev_matrix2 = (16,)
498    to_tensor_map2 = (-1, -1)
499    to_opt_shard_step2 = 1
500    to_opt_shard_size2 = 16
501    tensor_dict = copy.deepcopy(result_dict)
502    no_change_tensor_dict = copy.deepcopy(result_dict)
503    for rank_id in range(device_num):
504        result = conver_tensor_by_layout(to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1,
505                                         to_dev_matrix2, to_tensor_map2, to_opt_shard_step2, to_opt_shard_size2,
506                                         tensor_dict, rank_id)
507        result_dict[rank_id] = result
508        tensor_dict = copy.deepcopy(no_change_tensor_dict)
509        assert np.all(result == rank_id)
510
511
512def test_transform_parallel_checkpoint_2():
513    """
514    Feature: transform parallel checkpoint.
515    Description: model_parallel in last dim. device_num 16. None -> device_num 8, optimizer_shard 2, model_parallel 4
516        -> device_num 16, optimizer_shard 4, model_parallel 2.
517    Expectation: assert no error.
518    """
519    import copy
520    device_num = 16
521    tensor_dict = {}
522    for rank in range(device_num):
523        tensor_dict[rank] = np.array([np.full((8,), i) for i in range(device_num)])
524    no_change_tensor_dict = copy.deepcopy(tensor_dict)
525    result_dict = {}
526    from_dev_matrix = (16,)
527    from_tensor_map = (-1, -1)
528    from_opt_shard_step = 0
529    from_opt_shard_size = 0
530    to_dev_matrix = (2, 4)
531    to_tensor_map = (0, -1)
532    to_opt_shard_step = 4
533    to_opt_shard_size = 2
534    for rank_id in range(8):
535        result = conver_tensor_by_layout(from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
536                                         to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size,
537                                         tensor_dict, rank_id)
538        result_dict[rank_id] = result
539        tensor_dict = copy.deepcopy(no_change_tensor_dict)
540        rank = rank_id % 8
541        first_value = (rank % 4) * 4 + (rank // 4) * 2
542        assert np.all(result[0] == first_value)
543        assert np.all(result[1] == first_value + 1)
544    to_dev_matrix1 = (8, 2)
545    to_tensor_map1 = (0, -1)
546    to_opt_shard_step1 = 2
547    to_opt_shard_size1 = 4
548    tensor_dict = copy.deepcopy(result_dict)
549    no_change_tensor_dict = copy.deepcopy(result_dict)
550    for rank_id in range(device_num):
551        result = conver_tensor_by_layout(to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size,
552                                         to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1,
553                                         tensor_dict, rank_id)
554        result_dict[rank_id] = result
555        tensor_dict = copy.deepcopy(no_change_tensor_dict)
556        rank = rank_id % 8
557        first_value = (rank % 2) * 8 + (rank // 2) * 2
558        assert np.all(result[0] == first_value)
559        assert np.all(result[1] == first_value + 1)
560