1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16from mindspore.parallel._tensor import _transform_tensor_by_layout, _get_needed_rank_list_by_layouts, \ 17 _get_needed_rank_transform_operator_map_by_layouts, _generate_transform_operator_stack, \ 18 _apply_tensor_transform_operators, _construct_from_to_tensor_layout, _construct_tensor_layout_for_opt_shard, \ 19 _get_tensor_strategy 20 21 22def test_transform_tensor_by_layout_allconcat_axis_1(): 23 """ 24 Feature: transform tensor by layout. 25 Description: allconcat. 26 Expectation: assert no error. 27 """ 28 from_layout = ((1, 8), (1, 0), (32, 256)) 29 to_layout = ((1, 4, 2), (2, 1), (32, 256)) 30 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 31 rank_id = 0 32 op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id) 33 assert op_list == [('Reshape', [32, 1, 32]), ('AllGather', [0, 1, 2]), ('Reshape', [32, 64])] 34 35 36def test_transform_tensor_by_layout_allconcat_axis_1_using_none_map(): 37 """ 38 Feature: transform tensor by layout. 39 Description: allconcat, tensor map contains -1 40 Expectation: assert no error. 41 """ 42 from_layout = ((8,), (0, -1), (32, 256)) 43 to_layout = ((4, 2), (1, -1), (32, 256)) 44 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 45 rank_id = 0 46 op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id) 47 assert op_list == [('Reshape', [1, 4, 256]), ('AllGather', [0, 1, 1]), ('Reshape', [8, 256])] 48 49 50def test_transform_tensor_by_layout_allconcat_to_single(): 51 """ 52 Feature: transform tensor by layout. 53 Description: allconcat, transform to single device. 54 Expectation: assert no error. 55 """ 56 from_layout = ((4, 2), (1, 0), (32, 256)) 57 to_layout = ((1, 8), (-1, -1), (32, 256)) 58 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 59 rank_id = 0 60 op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id) 61 assert op_list == [('AllGather', [0, 2, 4, 6, 0]), ('AllGather', [0, 1, 1])] 62 63 64def test_transform_tensor_by_layout_allconcat_axis_0(): 65 """ 66 Feature: transform tensor by layout. 67 Description: allconcat. 68 Expectation: assert no error. 69 """ 70 from_layout = ((8, 1), (1, 0), (32, 256)) 71 to_layout = ((2, 1, 4), (2, 1), (32, 256)) 72 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 73 rank_id = 0 74 op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id) 75 assert op_list == [('Reshape', [1, 4, 256]), ('AllGather', [0, 1, 2, 3, 1]), ('Reshape', [16, 256])] 76 77 78def test_transform_tensor_by_layout_all_to_all(): 79 """ 80 Feature: transform tensor by layout. 81 Description: all to all. 82 Expectation: assert no error. 83 """ 84 from_layout = ((8, 1), (1, -1), (32, 64)) 85 to_layout = ((1, 8), (-1, 0), (32, 64)) 86 device_list = list(range(0, 8)) 87 rank_id = 0 88 op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id) 89 assert op_list == [('AllGather', [0, 1, 2, 3, 4, 5, 6, 7, 0]), ('StridedSlice', [0, 0, 32, 8, 1, 1])] 90 91 92def test_transform_tensor_by_layout_mix(): 93 """ 94 Feature: transform tensor by layout. 95 Description: allconcat + allsplit. 96 Expectation: assert no error. 97 """ 98 from_layout = ((2, 2, 2), (2, 1, 0), (32, 64, 128)) 99 to_layout = ((8, 1, 1), (2, 1, 0), (32, 64, 128)) 100 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 101 rank_id = 1 102 op_list = _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id) 103 assert op_list == [('Reshape', [1, 2, 8, 32, 64]), ('AllGather', [1, 3, 3]), 104 ('StridedSlice', [0, 0, 0, 0, 0, 1, 1, 8, 64, 64, 1, 1, 1, 1, 1]), 105 ('AllGather', [0, 1, 4]), ('StridedSlice', [0, 0, 4, 0, 0, 1, 1, 8, 64, 128, 1, 1, 1, 1, 1]), 106 ('Reshape', [4, 64, 128])] 107 108 109def test_needed_rank_list_by_layouts_1(): 110 """ 111 Feature: get needed rank list for transform tensor by layout. 112 Description: allconcat + allsplit. 113 Expectation: assert no error. 114 """ 115 from_layout = ((2, 2, 2), (2, 1, 0), (32, 64, 128)) 116 to_layout = ((8, 1, 1), (2, 1, 0), (32, 64, 128)) 117 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 118 rank_id = 1 119 needed_rank_list = _get_needed_rank_list_by_layouts(from_layout, to_layout, device_list, rank_id) 120 assert needed_rank_list == [0, 1, 2, 3] 121 122 123def test_needed_rank_list_by_layouts_2(): 124 """ 125 Feature: get needed rank list for transform tensor by layout. 126 Description: allconcat + allsplit, 128p. 127 Expectation: assert no error. 128 """ 129 from_layout = ((32, 1, 8), (2, -1, 0), (32, 64, 128)) 130 to_layout = ((2, 64, 2), (2, 1, 0), (32, 64, 128)) 131 device_list = list(range(0, 256)) 132 rank_id = 1 133 needed_rank_list = _get_needed_rank_list_by_layouts(from_layout, to_layout, device_list, rank_id) 134 assert needed_rank_list == list(range(0, 128)) 135 136 137def test_generate_transform_operator_stack_1(): 138 """ 139 Feature: generate transform operator stack. 140 Description: moe transform. 141 Expectation: assert no error. 142 """ 143 from_layout = ((8, 1, 8), (2, 1, -1), (32, 64, 128)) 144 to_layout = ((8, 8, 1), (2, 1, 0), (32, 64, 128)) 145 device_list = list(range(0, 64)) 146 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout, 147 device_list, 0) 148 assert param_rank_map == {0: [('StridedSlice', [0, 0, 0, 4, 8, 128, 1, 1, 1])]} 149 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, 0) 150 assert transform_operator_stack == [(0, 0, ('StridedSlice', [0, 0, 0, 4, 8, 128, 1, 1, 1]))] 151 152 153def test_generate_transform_operator_stack_2(): 154 """ 155 Feature: generate transform operator stack. 156 Description: all to all. 157 Expectation: assert no error. 158 """ 159 from_layout = ((8, 1), (1, -1), (32, 64)) 160 to_layout = ((1, 8), (-1, 0), (32, 64)) 161 device_list = list(range(0, 8)) 162 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout, 163 device_list, 0) 164 assert len(param_rank_map) == 8 165 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, 0) 166 assert transform_operator_stack == [(0, 1, ('StridedSlice', [0, 0, 32, 8, 1, 1])), 167 (0, 0, ('AllGather', [0, 1, 2, 3, 4, 5, 6, 7, 0]))] 168 169 170def test_generate_transform_operator_stack_3(): 171 """ 172 Feature: generate transform operator stack. 173 Description: mix. 174 Expectation: assert no error. 175 """ 176 from_layout = ((8, 1, 8), (2, -1, 0), (32, 64, 128)) 177 to_layout = ((2, 8, 4), (2, 1, 0), (32, 64, 128)) 178 device_list = list(range(0, 64)) 179 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout, 180 device_list, 0) 181 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, 0) 182 assert len(transform_operator_stack) == 79 183 184 185def test_generate_transform_operator_stack_4(): 186 """ 187 Feature: generate transform operator stack. 188 Description: multi allconcat and allsplit. 189 Expectation: assert no error. 190 """ 191 from_layout = ((2, 2, 2), (2, 1, 0), (32, 64, 128)) 192 to_layout = ((8, 1, 1), (2, 1, 0), (32, 64, 128)) 193 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 194 rank_id = 1 195 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout, 196 device_list, rank_id) 197 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id) 198 assert transform_operator_stack == [(1, 5, ('Reshape', [4, 64, 128])), 199 (1, 4, ('StridedSlice', [0, 0, 4, 0, 0, 1, 1, 8, 64, 128, 1, 1, 1, 1, 1])), 200 (1, 3, ('AllGather', [0, 1, 4])), 201 (0, 2, ('StridedSlice', [0, 0, 0, 0, 0, 1, 1, 8, 64, 64, 1, 1, 1, 1, 1])), 202 (1, 2, ('StridedSlice', [0, 0, 0, 0, 0, 1, 1, 8, 64, 64, 1, 1, 1, 1, 1])), 203 (0, 1, ('AllGather', [0, 2, 3])), (1, 1, ('AllGather', [1, 3, 3])), 204 (0, 0, ('Reshape', [1, 2, 8, 32, 64])), (2, 0, ('Reshape', [1, 2, 8, 32, 64])), 205 (1, 0, ('Reshape', [1, 2, 8, 32, 64])), (3, 0, ('Reshape', [1, 2, 8, 32, 64]))] 206 207 208def test_apply_tensor_transform_operators_allconcat(): 209 """ 210 Feature: apply tensor transform operators. 211 Description: allconcat. 212 Expectation: assert no error. 213 """ 214 device_num = 8 215 tensor_dict = {} 216 for rank in range(device_num): 217 tensor_dict[rank] = np.full((1, 8, 8), rank) 218 from_layout = ((8, 1, 1), (2, 1, 0), (8, 8, 8)) 219 to_layout = ((1, 1, 1, 8), (-1, -1, -1), (8, 8, 8)) 220 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 221 rank_id = 0 222 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout, 223 device_list, rank_id) 224 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id) 225 _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num) 226 assert tensor_dict.get(0).shape == (8, 8, 8) 227 for rank in range(8): 228 assert np.all(tensor_dict.get(0)[rank] == rank) 229 230 231def test_apply_tensor_transform_operators_allsplit(): 232 """ 233 Feature: apply tensor transform operators. 234 Description: allsplit. 235 Expectation: assert no error. 236 """ 237 device_num = 8 238 tensor_dict = {} 239 for rank in range(device_num): 240 tensor_dict[rank] = np.array([np.full((8, 8), i) for i in range(device_num)]) 241 from_layout = ((8,), (-1, -1, -1), (8, 8, 8)) 242 to_layout = ((8,), (-1, -1, 0), (8, 8, 8)) 243 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 244 rank_id = 0 245 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout, 246 device_list, rank_id) 247 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id) 248 _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num) 249 assert tensor_dict.get(0).shape == (8, 8, 1) 250 for rank in range(8): 251 assert np.all(tensor_dict.get(0)[rank] == rank) 252 253 254def test_apply_tensor_transform_operators_mix(): 255 """ 256 Feature: apply tensor transform operators. 257 Description: mulit allconcat, allsplit. 258 Expectation: assert no error. 259 """ 260 device_num = 8 261 tensor_dict = {} 262 for rank in range(device_num): 263 tensor_dict[rank] = np.full((1, 8, 8), rank) 264 from_layout = ((8, 1, 1), (2, 1, 0), (8, 8, 8)) 265 to_layout = ((2, 2, 2), (2, 1, 0), (8, 8, 8)) 266 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 267 rank_id = 0 268 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout, 269 device_list, rank_id) 270 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id) 271 _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num) 272 assert tensor_dict.get(0).shape == (4, 4, 4) 273 for rank in range(4): 274 assert np.all(tensor_dict.get(0)[rank] == rank) 275 276 277def test_apply_tensor_transform_operators_no_need_transform(): 278 """ 279 Feature: apply tensor transform operators. 280 Description: no need transform. 281 Expectation: assert no error. 282 """ 283 device_num = 8 284 tensor_dict = {} 285 for rank in range(device_num): 286 tensor_dict[rank] = np.full((1, 8, 8), rank) 287 from_layout = ((8, 1, 1), (2, -1, -1), (8, 8, 8)) 288 to_layout = ((8, 1, 1), (2, -1, -1), (8, 8, 8)) 289 device_list = [0, 1, 2, 3, 4, 5, 6, 7] 290 rank_id = 0 291 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_layout, to_layout, 292 device_list, rank_id) 293 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id) 294 _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num) 295 assert tensor_dict.get(0).shape == (1, 8, 8) 296 assert np.all(tensor_dict.get(0) == rank_id) 297 298 299def test_construct_tensor_layout_for_opt_shard(): 300 """ 301 Feature: construct tensor layout for optimizer shard. 302 Description: construct tensor layout for optimizer shard. 303 Expectation: assert no error. 304 """ 305 dev_matrix = (2, 2, 2) 306 tensor_map = (2, 1, -1) 307 opt_shard_step = 4 308 opt_shard_size = 2 309 origin_full_tensor_shape = (16, 16, 16) 310 new_dev_matrix, new_tensor_map, new_shape = _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, 311 opt_shard_step, opt_shard_size, 312 origin_full_tensor_shape) 313 assert new_dev_matrix == [2, 2, 2, 1] 314 assert new_tensor_map == [2, 3, 1, -1] 315 assert new_shape == [2, 8, 16, 16] 316 317 318def test_construct_from_to_tensor_layout(): 319 """ 320 Feature: construct from and to tensor layout. 321 Description: construct from and to tensor layout. 322 Expectation: assert no error. 323 """ 324 tensor_shape = (8, 1024) 325 from_dev_matrix = (8, 8) 326 from_tensor_map = (-1, 0) 327 to_dev_matrix = (16, 8) 328 to_tensor_map = (1, -1) 329 from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(tensor_shape, from_dev_matrix, 330 from_tensor_map, tensor_shape, 331 to_dev_matrix, to_tensor_map) 332 assert from_tensor_layout == ([2, 8, 8], [-1, 0], [8, 1024]) 333 assert to_tensor_layout == ([16, 8], [1, -1], [8, 1024]) 334 335 336def conver_tensor_by_layout(from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, 337 to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, 338 tensor_dict, rank_id): 339 device_num = np.prod(from_dev_matrix) 340 tensor_shape = tensor_dict[rank_id % device_num].shape 341 param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map) 342 origin_tensor_shape = () 343 for i, item in enumerate(tensor_shape): 344 if i == 0 and from_opt_shard_size > 0: 345 origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,) 346 continue 347 origin_tensor_shape += (item * param_strategy[i],) 348 349 from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard( 350 from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape) 351 to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard( 352 to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape) 353 # Convert tensor layout to same device num 354 from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix, 355 from_tensor_map, to_full_tensor_shape, 356 to_dev_matrix, to_tensor_map) 357 358 # when the from_layout is less devices, the checkpoint_map for map[device_num] should using map[0] 359 360 device_list = list(range(0, np.prod(from_tensor_layout[0]))) 361 if rank_id % device_num not in tensor_dict: 362 raise ValueError("The checkpoint of rank {} is missing.".format(rank_id % device_num)) 363 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout, 364 device_list, rank_id) 365 for param_rank, _ in param_rank_map.items(): 366 if from_opt_shard_size > 0: 367 from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map) 368 from_slice_tensor_shape = () 369 for i, item in enumerate(from_full_tensor_shape): 370 from_slice_tensor_shape += (item // from_tensor_strategy[i],) 371 param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape))) 372 if to_opt_shard_size > 0: 373 to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin) 374 to_slice_tensor_shape = () 375 for i, item in enumerate(origin_tensor_shape): 376 if i == 0 and to_opt_shard_size > 0: 377 to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),) 378 continue 379 to_slice_tensor_shape += (item // to_tensor_strategy[i],) 380 param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape))) 381 382 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id) 383 _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num) 384 385 return tensor_dict[rank_id % device_num] 386 387 388def test_transform_parallel_checkpoint(): 389 """ 390 Feature: transform parallel checkpoint. 391 Description: device_num 16. None -> optimizer_shard 2, model_parallel 4 392 -> optimizer_shard 4, model_parallel 2 -> optimizer_shard 16, model_parallel 1 393 Expectation: assert no error. 394 """ 395 import copy 396 device_num = 16 397 tensor_dict = {} 398 for rank in range(device_num): 399 tensor_dict[rank] = np.array([np.full((8,), i) for i in range(device_num)]) 400 no_change_tensor_dict = copy.deepcopy(tensor_dict) 401 result_dict = {} 402 from_dev_matrix = (16,) 403 from_tensor_map = (-1, -1) 404 from_opt_shard_step = 0 405 from_opt_shard_size = 0 406 to_dev_matrix = (4, 4) 407 to_tensor_map = (0, -1) 408 to_opt_shard_step = 4 409 to_opt_shard_size = 2 410 for rank_id in range(device_num): 411 result = conver_tensor_by_layout(from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, 412 to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size, 413 tensor_dict, rank_id) 414 result_dict[rank_id] = result 415 tensor_dict = copy.deepcopy(no_change_tensor_dict) 416 rank = rank_id % 8 417 first_value = (rank % 4) * 4 + (rank // 4) * 2 418 assert np.all(result[0] == first_value) 419 assert np.all(result[1] == first_value + 1) 420 to_dev_matrix1 = (8, 2) 421 to_tensor_map1 = (0, -1) 422 to_opt_shard_step1 = 2 423 to_opt_shard_size1 = 4 424 tensor_dict = copy.deepcopy(result_dict) 425 no_change_tensor_dict = copy.deepcopy(result_dict) 426 for rank_id in range(device_num): 427 result = conver_tensor_by_layout(to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size, 428 to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1, 429 tensor_dict, rank_id) 430 result_dict[rank_id] = result 431 tensor_dict = copy.deepcopy(no_change_tensor_dict) 432 rank = rank_id % 8 433 first_value = (rank % 2) * 8 + (rank // 2) * 2 434 assert np.all(result[0] == first_value) 435 assert np.all(result[1] == first_value + 1) 436 to_dev_matrix2 = (16,) 437 to_tensor_map2 = (-1, -1) 438 to_opt_shard_step2 = 1 439 to_opt_shard_size2 = 16 440 tensor_dict = copy.deepcopy(result_dict) 441 no_change_tensor_dict = copy.deepcopy(result_dict) 442 for rank_id in range(device_num): 443 result = conver_tensor_by_layout(to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1, 444 to_dev_matrix2, to_tensor_map2, to_opt_shard_step2, to_opt_shard_size2, 445 tensor_dict, rank_id) 446 result_dict[rank_id] = result 447 tensor_dict = copy.deepcopy(no_change_tensor_dict) 448 assert np.all(result == rank_id) 449 450 451def test_transform_parallel_checkpoint_1(): 452 """ 453 Feature: transform parallel checkpoint. 454 Description: model_parallel in last dim. device_num 16. None -> optimizer_shard 2, model_parallel 4 455 -> optimizer_shard 4, model_parallel 2 -> optimizer_shard 16, model_parallel 1 456 Expectation: assert no error. 457 """ 458 import copy 459 device_num = 16 460 tensor_dict = {} 461 for rank in range(device_num): 462 tensor_dict[rank] = np.array([np.full((8,), i) for i in range(device_num)]) 463 no_change_tensor_dict = copy.deepcopy(tensor_dict) 464 result_dict = {} 465 from_dev_matrix = (16,) 466 from_tensor_map = (-1, -1) 467 from_opt_shard_step = 0 468 from_opt_shard_size = 0 469 to_dev_matrix = (4, 4) 470 to_tensor_map = (-1, 0) 471 to_opt_shard_step = 4 472 to_opt_shard_size = 2 473 for rank_id in range(device_num): 474 result = conver_tensor_by_layout(from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, 475 to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size, 476 tensor_dict, rank_id) 477 result_dict[rank_id] = result 478 tensor_dict = copy.deepcopy(no_change_tensor_dict) 479 rank = rank_id % 8 480 first_value = (rank // 4) * 8 481 assert np.all(result[0] == first_value) 482 to_dev_matrix1 = (8, 2) 483 to_tensor_map1 = (-1, 0) 484 to_opt_shard_step1 = 2 485 to_opt_shard_size1 = 4 486 tensor_dict = copy.deepcopy(result_dict) 487 no_change_tensor_dict = copy.deepcopy(result_dict) 488 for rank_id in range(device_num): 489 result = conver_tensor_by_layout(to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size, 490 to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1, 491 tensor_dict, rank_id) 492 result_dict[rank_id] = result 493 tensor_dict = copy.deepcopy(no_change_tensor_dict) 494 rank = rank_id % 8 495 first_value = (rank // 2) * 4 496 assert np.all(result[0] == first_value) 497 to_dev_matrix2 = (16,) 498 to_tensor_map2 = (-1, -1) 499 to_opt_shard_step2 = 1 500 to_opt_shard_size2 = 16 501 tensor_dict = copy.deepcopy(result_dict) 502 no_change_tensor_dict = copy.deepcopy(result_dict) 503 for rank_id in range(device_num): 504 result = conver_tensor_by_layout(to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1, 505 to_dev_matrix2, to_tensor_map2, to_opt_shard_step2, to_opt_shard_size2, 506 tensor_dict, rank_id) 507 result_dict[rank_id] = result 508 tensor_dict = copy.deepcopy(no_change_tensor_dict) 509 assert np.all(result == rank_id) 510 511 512def test_transform_parallel_checkpoint_2(): 513 """ 514 Feature: transform parallel checkpoint. 515 Description: model_parallel in last dim. device_num 16. None -> device_num 8, optimizer_shard 2, model_parallel 4 516 -> device_num 16, optimizer_shard 4, model_parallel 2. 517 Expectation: assert no error. 518 """ 519 import copy 520 device_num = 16 521 tensor_dict = {} 522 for rank in range(device_num): 523 tensor_dict[rank] = np.array([np.full((8,), i) for i in range(device_num)]) 524 no_change_tensor_dict = copy.deepcopy(tensor_dict) 525 result_dict = {} 526 from_dev_matrix = (16,) 527 from_tensor_map = (-1, -1) 528 from_opt_shard_step = 0 529 from_opt_shard_size = 0 530 to_dev_matrix = (2, 4) 531 to_tensor_map = (0, -1) 532 to_opt_shard_step = 4 533 to_opt_shard_size = 2 534 for rank_id in range(8): 535 result = conver_tensor_by_layout(from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, 536 to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size, 537 tensor_dict, rank_id) 538 result_dict[rank_id] = result 539 tensor_dict = copy.deepcopy(no_change_tensor_dict) 540 rank = rank_id % 8 541 first_value = (rank % 4) * 4 + (rank // 4) * 2 542 assert np.all(result[0] == first_value) 543 assert np.all(result[1] == first_value + 1) 544 to_dev_matrix1 = (8, 2) 545 to_tensor_map1 = (0, -1) 546 to_opt_shard_step1 = 2 547 to_opt_shard_size1 = 4 548 tensor_dict = copy.deepcopy(result_dict) 549 no_change_tensor_dict = copy.deepcopy(result_dict) 550 for rank_id in range(device_num): 551 result = conver_tensor_by_layout(to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size, 552 to_dev_matrix1, to_tensor_map1, to_opt_shard_step1, to_opt_shard_size1, 553 tensor_dict, rank_id) 554 result_dict[rank_id] = result 555 tensor_dict = copy.deepcopy(no_change_tensor_dict) 556 rank = rank_id % 8 557 first_value = (rank % 2) * 8 + (rank // 2) * 2 558 assert np.all(result[0] == first_value) 559 assert np.all(result[1] == first_value + 1) 560