1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/tensor_layout.h"
18 #include <iostream>
19 #include <utility>
20 #include "utils/ms_utils.h"
21 #include "ir/value.h"
22 #include "frontend/parallel/device_matrix.h"
23 #include "frontend/parallel/device_manager.h"
24 #include "frontend/parallel/status.h"
25 #include "include/common/utils/parallel_context.h"
26 #include "frontend/parallel/tensor_layout/shape_util.h"
27 #include "utils/log_adapter.h"
28
29 namespace mindspore {
30 namespace parallel {
ToString() const31 std::string TensorLayout::ToString() const { return StandardToString() + OriginToString(); }
32
StandardToString() const33 std::string TensorLayout::StandardToString() const {
34 std::ostringstream buffer;
35 buffer << std::endl << std::string("device arrangement = " + device_arrangement_.ToString());
36 buffer << std::endl << std::string("tensor map = " + tensor_map_.ToString());
37 buffer << std::endl << std::string("tensor shape = " + tensor_shape_.ToString());
38 return buffer.str();
39 }
40
OriginToString() const41 std::string TensorLayout::OriginToString() const {
42 std::ostringstream buffer;
43 buffer << std::endl << std::string("device arrangement origin = " + device_arrangement_origin_.ToString());
44 buffer << std::endl << std::string("tensor map origin = " + tensor_map_origin_.ToString());
45 buffer << std::endl << std::string("tensor shape origin = " + tensor_shape_origin_.ToString());
46 return buffer.str();
47 }
48
Init(const Arrangement & device_arrangement,const Map & tensor_map,const Arrangement & tensor_shape)49 Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map,
50 const Arrangement &tensor_shape) {
51 device_arrangement_origin_ = device_arrangement;
52 tensor_map_origin_ = tensor_map;
53 tensor_shape_origin_ = tensor_shape;
54 device_arrangement_ = device_arrangement;
55 tensor_map_ = tensor_map;
56 tensor_shape_ = tensor_shape;
57 if (IsValidTensorLayout()) {
58 MS_LOG(DEBUG) << "valid origin tensor layout " << this->OriginToString();
59 RemoveElementEqualToOneInDeviceArrangement();
60 MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString();
61 return Status::SUCCESS;
62 } else {
63 if (layout_transfer_) {
64 MS_LOG(DEBUG) << "invalid origin tensor layout " << this->OriginToString();
65 } else {
66 MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString();
67 }
68 return Status::FAILED;
69 }
70 }
71
InitFromVector(const Shape & device_arrangement,const Shape & tensor_map,const Shape & tensor_shape)72 Status TensorLayout::InitFromVector(const Shape &device_arrangement, const Shape &tensor_map,
73 const Shape &tensor_shape) {
74 if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) {
75 MS_LOG(ERROR) << "Init device_arrangement failed.";
76 return FAILED;
77 }
78 if (tensor_map_origin_.Init(tensor_map) != SUCCESS) {
79 MS_LOG(ERROR) << "Init tensor_map failed.";
80 return FAILED;
81 }
82 if (tensor_shape_origin_.Init(tensor_shape) != SUCCESS) {
83 MS_LOG(ERROR) << "Init tensor_shape failed.";
84 return FAILED;
85 }
86 if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) {
87 MS_LOG(ERROR) << "Init tensor_layout failed.";
88 return FAILED;
89 }
90 return SUCCESS;
91 }
92
93 /*
94 * example1:
95 * in_device_arrangement = [8, 2, 4],
96 * in_tensor_map = [[2], [1, 0]],
97 * in_tensor_shape = [512, 1024],
98 * =>
99 * in_device_arrangement = [8, 2, 4],
100 * in_tensor_map = [2, 1, 0],
101 * in_tensor_shape = [512, 2, 512],
102 * example2:
103 * in_device_arrangement = [8, 2, 4],
104 * in_tensor_map = [[1], [0, 2]],
105 * in_tensor_shape = [512, 1024],
106 * =>
107 * in_device_arrangement = [8, 2, 4],
108 * in_tensor_map = [1, 0, 2],
109 * in_tensor_shape = [512, 4, 256],
110 */
InitFromExtendVector(const Shape & device_matrix,const std::vector<Shape> & tensor_map,const Shape & tensor_shape,bool interleaved_parallel,bool check_device_num)111 Status TensorLayout::InitFromExtendVector(const Shape &device_matrix, const std::vector<Shape> &tensor_map,
112 const Shape &tensor_shape, bool interleaved_parallel, bool check_device_num) {
113 auto device_arrangement = device_matrix;
114 if (interleaved_parallel) {
115 if (device_arrangement_interleaved_.Init(device_matrix) != SUCCESS) {
116 return FAILED;
117 }
118 if (parallel::ParallelContext::GetInstance()->fine_grained_micro_interleaved_size() == -1) {
119 parallel::ParallelContext::GetInstance()->set_fine_grained_micro_interleaved_size(
120 device_arrangement[device_arrangement.size() - 1]);
121 } else if (parallel::ParallelContext::GetInstance()->fine_grained_micro_interleaved_size() !=
122 device_arrangement[device_arrangement.size() - 1]) {
123 MS_LOG(EXCEPTION) << "The micro interleaved num should be configured be consistent for each operator's layout.";
124 }
125 device_arrangement[device_arrangement.size() - 1] = 1;
126 }
127
128 if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) {
129 return FAILED;
130 }
131 CheckGlobalDeviceManager();
132 auto device_num = g_device_manager->stage_device_num();
133 int64_t device_total =
134 std::accumulate(device_arrangement.begin(), device_arrangement.end(), 1, std::multiplies<int64_t>());
135 if (device_num != device_total && check_device_num) {
136 MS_LOG(ERROR) << "The configured device_matrix " << device_arrangement << " accumulate value " << device_total
137 << " dose not equal to the device number in one stage " << device_num;
138 return FAILED;
139 }
140 Shape extended_tensor_map;
141 Shape reshaped_tensor_shape;
142 if (tensor_shape.size() != tensor_map.size()) {
143 MS_LOG(ERROR) << "The tensor_shape " << tensor_shape << " dose not have the same size with tensor_map "
144 << tensor_map;
145 return FAILED;
146 }
147
148 size_t not_none_count = 0;
149 for (size_t i = 0; i < tensor_map.size(); ++i) {
150 for (size_t j = 0; j < tensor_map[i].size(); ++j) {
151 extended_tensor_map.push_back(tensor_map[i][j]);
152 if (tensor_map[i][j] > 0) {
153 ++not_none_count;
154 }
155 }
156 }
157
158 if (not_none_count > device_arrangement.size()) {
159 MS_LOG(ERROR) << "The device_matrix " << device_arrangement
160 << " length dose not greater equal than the not None size of extended_tensor_map "
161 << extended_tensor_map;
162 return FAILED;
163 }
164 tensor_shape_before_.Init(tensor_shape);
165 for (size_t i = 0; i < tensor_map.size(); ++i) {
166 if (tensor_map[i].size() == 1) {
167 reshaped_tensor_shape.push_back(tensor_shape[i]);
168 continue;
169 }
170 int64_t accu_shp = 1;
171 for (size_t j = 0; j < tensor_map[i].size() - 1; ++j) {
172 size_t tensor_index = device_arrangement.size() - 1 - static_cast<size_t>(tensor_map[i][j]);
173 auto shard_size = device_arrangement[tensor_index];
174 accu_shp *= shard_size;
175 reshaped_tensor_shape.push_back(shard_size);
176 }
177 auto last_shp = tensor_shape[i] / accu_shp;
178 reshaped_tensor_shape.push_back(last_shp);
179 }
180 if (tensor_map_origin_.Init(extended_tensor_map) != SUCCESS) {
181 return FAILED;
182 }
183 if (tensor_shape_origin_.Init(reshaped_tensor_shape) != SUCCESS) {
184 return FAILED;
185 }
186 if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) {
187 return FAILED;
188 }
189 tensor_map_before_ = tensor_map;
190 return SUCCESS;
191 }
192
GetVirtualRank() const193 std::vector<int64_t> TensorLayout::GetVirtualRank() const {
194 int64_t rank = g_device_manager->global_rank();
195 if (!IsInterleavedParallel()) {
196 return {rank};
197 }
198 auto interleaved_num = device_arrangement_interleaved_.array().back();
199 std::vector<int64_t> virtual_ranks;
200 for (int64_t i = 0; i < interleaved_num; ++i) {
201 virtual_ranks.push_back(rank * interleaved_num + i);
202 }
203 return virtual_ranks;
204 }
205
LayoutForRedistribution() const206 TensorLayout TensorLayout::LayoutForRedistribution() const {
207 if (!IsInterleavedParallel()) {
208 return *this;
209 }
210 TensorLayout interleaved_layout;
211 if (interleaved_layout.InitFromExtendVector(device_arrangement_interleaved_.array(), tensor_map_before_,
212 tensor_shape_before_.array(), false, false) != SUCCESS) {
213 MS_LOG(EXCEPTION) << "Init layout for micro interleaved failed, device_matrix:"
214 << device_arrangement_interleaved_.array() << ", tensor_map:" << tensor_map_before_;
215 }
216 return interleaved_layout;
217 }
218
IsValidTensorLayout() const219 bool TensorLayout::IsValidTensorLayout() const {
220 int64_t max_tensor_map_item = tensor_map_origin_.GetMaxItem();
221 int64_t device_arr_size = SizeToLong(device_arrangement_origin_.GetDimSize());
222 if (max_tensor_map_item >= device_arr_size) {
223 MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size! "
224 << "Max element in tensor_map_origin_ is " << max_tensor_map_item
225 << ", device_arrangement_origin_ size is " << device_arr_size;
226 return false;
227 }
228 size_t tensor_map_size = tensor_map_origin_.GetDimSize();
229 size_t tensor_shape_size = tensor_shape_origin_.GetDimSize();
230 if (tensor_map_size != tensor_shape_size) {
231 MS_LOG(ERROR) << "tensor_map_origin_ size must be equal to tensor_shape_origin_ size! "
232 << "tensor_map_origin_ size is " << tensor_map_size << ", tensor_shape_origin_ size is "
233 << tensor_shape_size;
234 return false;
235 }
236 if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) {
237 if (layout_transfer_) {
238 MS_LOG(DEBUG) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
239 } else {
240 MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
241 }
242 return false;
243 }
244 return true;
245 }
246
TensorShapeDimensionIsDividedBySplitDeviceDimension() const247 bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const {
248 for (uint64_t i = 0; i < tensor_map_.GetDimSize(); i++) {
249 if (tensor_map_.GetDimByIdx(i) != -1) {
250 int64_t divisor = GetSliceNumByTensorDimensionIndex(i);
251 if (divisor == 0) {
252 MS_LOG(ERROR) << "GetSliceNumByTensorDimensionIndex is 0";
253 return false;
254 }
255 if (tensor_shape_.GetDimByIdx(i) != -1 && tensor_shape_.GetDimByIdx(i) % divisor != 0) {
256 if (layout_transfer_) {
257 MS_LOG(DEBUG) << i << "th input shape is not divisible. The input shape is " << tensor_shape_.GetDimByIdx(i)
258 << ", but the slice number is " << divisor;
259 } else {
260 MS_LOG(ERROR) << i << "th input shape is not divisible. The input shape is " << tensor_shape_.GetDimByIdx(i)
261 << ", but the slice number is " << divisor;
262 }
263 return false;
264 }
265 }
266 }
267 return true;
268 }
269
RemoveElementEqualToOneInDeviceArrangement()270 void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() {
271 Shape device_arrangement_shape;
272 Shape tensor_map_shape = tensor_map_origin_.array();
273 size_t dev_num = device_arrangement_origin_.GetDimSize();
274 size_t dev_num_left = device_arrangement_origin_.GetDimSize();
275 for (size_t i = 0; i < dev_num; i++) {
276 if (device_arrangement_origin_.GetDimByIdx(i) == 1) {
277 int64_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast<int64_t>(dev_num - 1 - i));
278 if (idx != -1) {
279 tensor_map_shape[static_cast<uint64_t>(idx)] = -1;
280 }
281 for (auto &value : tensor_map_shape) {
282 if (value >= SizeToLong(dev_num_left) - 1 - static_cast<int64_t>(i)) {
283 value--;
284 }
285 }
286 continue;
287 }
288 device_arrangement_shape.push_back(device_arrangement_origin_.GetDimByIdx(i));
289 }
290 (void)device_arrangement_.Init(device_arrangement_shape);
291 (void)tensor_map_.Init(tensor_map_shape);
292 tensor_shape_ = tensor_shape_origin_;
293 }
294
295 // if idx is not in tensor_map, return -1
GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const296 int64_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const {
297 return tensor_map_.GetIndexByValue(idx);
298 }
299
300 // tensor_map_.GetDimByIdx(idx) should not be -1
GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const301 int64_t TensorLayout::GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const {
302 return static_cast<int64_t>(device_arrangement_.GetDimSize()) - 1 - tensor_map_.GetDimByIdx(idx);
303 }
304
305 // tensor_map_.GetDimByIdx(idx) should not be -1
GetSliceNumByTensorDimensionIndex(uint64_t idx) const306 int64_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint64_t idx) const {
307 return device_arrangement_.GetDimByIdx(static_cast<uint64_t>(GetSliceDeviceDimensionByTensorDimensionIndex(idx)));
308 }
309
ExpandTensorShape(const Arrangement & expanded_shape) const310 std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const {
311 std::shared_ptr<Arrangement> expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape);
312 if (expanded_arrangement_ptr == nullptr) {
313 return nullptr;
314 }
315 std::shared_ptr<TensorLayout> temp_tensor_layout_ptr = ExpandDeviceArrangement(*expanded_arrangement_ptr);
316 if (temp_tensor_layout_ptr == nullptr) {
317 return nullptr;
318 }
319 return temp_tensor_layout_ptr->ExpandTensorShapeWithoutExtendDeviceArrangement(expanded_shape);
320 }
321
322 /*
323 * example1:
324 * in_device_arrangement = [8, 4],
325 * in_tensor_map = [1, 0],
326 * in_tensor_shape = [512, 1024],
327 * out_tensor_shape = [128, 4, 2, 512],
328 * =>
329 * out_device_arrangement = [8, 2, 2]
330 */
ComputeArrangementByExpandedShape(const Arrangement & tensor_shape) const331 std::shared_ptr<Arrangement> TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const {
332 std::shared_ptr<std::vector<Arrangement>> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape);
333 if (expand_list_ptr == nullptr) {
334 return nullptr;
335 }
336 std::vector<Arrangement> re_map_expand_list;
337 Arrangement empty_arrangement;
338 for (int64_t i = static_cast<int64_t>(device_arrangement_.GetDimSize()) - 1; i >= 0; i--) {
339 if (tensor_map_.GetIndexByValue(i) < 0) {
340 re_map_expand_list.push_back(empty_arrangement);
341 } else {
342 re_map_expand_list.push_back((*expand_list_ptr)[LongToUlong(tensor_map_.GetIndexByValue(i))]);
343 }
344 }
345 std::shared_ptr<Arrangement> new_arrangement_ptr =
346 device_arrangement_.GetExpandedShapeByExpandListRemoveLeft(re_map_expand_list);
347 return new_arrangement_ptr;
348 }
349
350 /*
351 * example1:
352 * in_device_arrangement = [8, 4],
353 * in_tensor_map = [1, 0],
354 * in_tensor_shape = [512, 1024],
355 * out_tensor_shape = [8, 64, 4, 256]
356 * =>
357 * out_device_arrangement = [8, 4],
358 * out_tensor_map = [1, -1, 0, -1],
359 */
ExpandTensorShapeWithoutExtendDeviceArrangement(const Arrangement & expanded_shape) const360 std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement(
361 const Arrangement &expanded_shape) const {
362 std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> expand_list_pair_ptr =
363 tensor_shape_.GetExpandShapeListPair(expanded_shape);
364 if (expand_list_pair_ptr == nullptr) {
365 return nullptr;
366 }
367 std::shared_ptr<Map> tensor_map_new_ptr = tensor_map_.ExpandMapByNone(expand_list_pair_ptr->second);
368 if (tensor_map_new_ptr == nullptr) {
369 return nullptr;
370 }
371 TensorLayout tensor_layout_new;
372 tensor_layout_new.set_layout_transfer(true);
373 Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape);
374 if (status != Status::SUCCESS) {
375 return nullptr;
376 }
377 return std::make_shared<TensorLayout>(tensor_layout_new);
378 }
379
380 /*
381 * example1:
382 * in_device_arrangement = [8, 4],
383 * in_tensor_map = [1, 0],
384 * in_tensor_shape = [512, 1024],
385 * out_device_arrangement = [4, 2, 2, 2]
386 * =>
387 * out_tensor_map = [3, 2, 1, 0],
388 * out_tensor_shape = [4, 128, 2, 512]
389 *
390 * example2:
391 * in_device_arrangement = [8, 4],
392 * in_tensor_map = [0, 1],
393 * in_tensor_shape = [512, 1024],
394 * out_device_arrangement = [4, 2, 2, 2]
395 * =>
396 * out_tensor_map = [1, 0, 3, 2],
397 * out_tensor_shape = [2, 256, 4, 256]
398 *
399 * example3:
400 * in_device_arrangement = [8, 4],
401 * in_tensor_map = [1, -1],
402 * in_tensor_shape = [512, 1024],
403 * out_device_arrangement = [4, 2, 2, 2]
404 * =>
405 * out_tensor_map = [3, 2, -1],
406 * out_tensor_shape = [4, 128, 1024]
407 *
408 * example4:
409 * in_device_arrangement = [8, 4],
410 * in_tensor_map = [0, 1],
411 * in_tensor_shape = [512, 1024],
412 * out_device_arrangement = [4, 2, 4]
413 * =>
414 * out_tensor_map = [0, 2, 1],
415 * out_tensor_shape = [512, 4, 256]
416 */
ExpandDeviceArrangement(const Arrangement & expanded_arrangement) const417 std::shared_ptr<TensorLayout> TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const {
418 std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> expand_list_pair_ptr =
419 device_arrangement_.GetExpandShapeListPair(expanded_arrangement);
420 if (expand_list_pair_ptr == nullptr) {
421 return nullptr;
422 }
423 std::shared_ptr<Map> tensor_map_new_ptr = tensor_map_.ExpandMapByDecreaseNumber(expand_list_pair_ptr->second);
424 if (tensor_map_new_ptr == nullptr) {
425 return nullptr;
426 }
427 std::shared_ptr<std::vector<Arrangement>> re_map_shape_list_ptr =
428 tensor_map_.ReMapVector(expand_list_pair_ptr->first);
429 if (re_map_shape_list_ptr == nullptr) {
430 return nullptr;
431 }
432 std::shared_ptr<Arrangement> tensor_shape_new_ptr =
433 tensor_shape_.GetExpandedShapeByExpandListReserveLeft(*re_map_shape_list_ptr);
434 if (tensor_shape_new_ptr == nullptr) {
435 return nullptr;
436 }
437 TensorLayout tensor_layout_new;
438 Status status = tensor_layout_new.Init(expanded_arrangement, *tensor_map_new_ptr, *tensor_shape_new_ptr);
439 if (status != Status::SUCCESS) {
440 return nullptr;
441 }
442 return std::make_shared<TensorLayout>(tensor_layout_new);
443 }
444
TensorShapeCanBeExpanded(const Arrangement & expand_shape) const445 bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const {
446 Shape in_expand_shape_shape;
447 Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape);
448 if (status != Status::SUCCESS) {
449 return false;
450 }
451 return (in_expand_shape_shape == tensor_shape_.array());
452 }
453
ComputeExpandedTensorShape(const Arrangement & expand_shape) const454 std::shared_ptr<Arrangement> TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const {
455 Shape in_expand_shape_shape;
456 Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape);
457 if (status != Status::SUCCESS) {
458 return nullptr;
459 }
460 Arrangement expanded_shape;
461 status = expanded_shape.Init(in_expand_shape_shape);
462 if (status != Status::SUCCESS) {
463 return nullptr;
464 }
465 return std::make_shared<Arrangement>(expanded_shape);
466 }
467
slice_shape() const468 Arrangement TensorLayout::slice_shape() const {
469 Shape shape;
470 for (size_t index = 0; index < tensor_map_.GetDimSize(); index++) {
471 int64_t dim = tensor_map_.GetDimByIdx(index);
472 int64_t num = tensor_shape_.GetDimByIdx(index);
473 if (dim == -1 || num == -1) {
474 shape.push_back(num); // num == -1 means dynamic shape
475 } else {
476 int64_t divisor = device_arrangement_.GetDimByReverseIdx(LongToUlong(dim));
477 shape.push_back(num / divisor);
478 }
479 }
480 Arrangement new_tensor_shape;
481 if (new_tensor_shape.Init(shape) == Status::FAILED) {
482 ValuePtr ptr = MakeValue(shape);
483 MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << ptr->ToString();
484 } else {
485 return new_tensor_shape;
486 }
487 }
488
base_slice_shape() const489 Arrangement TensorLayout::base_slice_shape() const {
490 if (tensor_map_before_.empty()) {
491 return slice_shape();
492 }
493 Shape shape;
494 for (size_t index = 0; index < tensor_map_before_.size(); index++) {
495 auto dim_map = tensor_map_before_[index];
496 int64_t num = tensor_shape_before_.GetDimByIdx(index);
497 int64_t axis_shard = 1;
498 for (const auto &dim : dim_map) {
499 if (dim != -1) {
500 int64_t divisor = device_arrangement_origin_.GetDimByReverseIdx(LongToUlong(dim));
501 axis_shard *= divisor;
502 }
503 }
504 if (num == -1) {
505 shape.push_back(num); // num == -1 means dynamic shape
506 } else {
507 shape.push_back(num / axis_shard);
508 }
509 }
510 Arrangement new_slice_shape;
511 if (new_slice_shape.Init(shape) == Status::FAILED) {
512 MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << shape;
513 } else {
514 return new_slice_shape;
515 }
516 }
517
shard_strategy() const518 Shape TensorLayout::shard_strategy() const {
519 Shape ret;
520 for (size_t index = 0; index < tensor_map_.GetDimSize(); index++) {
521 int64_t dim = tensor_map_.GetDimByIdx(index);
522 if (dim == -1) {
523 ret.push_back(1);
524 } else {
525 int64_t divisor = device_arrangement_.GetDimByReverseIdx(LongToUlong(dim));
526 ret.push_back(divisor);
527 }
528 }
529 return ret;
530 }
531
UpdateTensorMap(size_t index,int64_t value)532 Status TensorLayout::UpdateTensorMap(size_t index, int64_t value) {
533 if (index >= tensor_map_.GetDimSize()) {
534 MS_LOG(ERROR) << "Index is out of the size of the tensor map!";
535 return Status::FAILED;
536 }
537 auto shape = tensor_map_.array();
538 shape[index] = value;
539 if (tensor_map_.Init(shape) == Status::FAILED) {
540 MS_LOG(ERROR) << "Update tensor map failed!";
541 return Status::FAILED;
542 }
543 return Status::SUCCESS;
544 }
545
operator ==(const TensorLayout & t1) const546 bool TensorLayout::operator==(const TensorLayout &t1) const {
547 return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1));
548 }
549
operator !=(const TensorLayout & t1) const550 bool TensorLayout::operator!=(const TensorLayout &t1) const {
551 return !(IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1));
552 }
553
IsSameWithoutSplit(const TensorLayout & t1) const554 bool TensorLayout::IsSameWithoutSplit(const TensorLayout &t1) const {
555 if (!IsSameTensorMap(t1) || !IsSameTensorShape(t1)) {
556 return false;
557 }
558 auto first_array = tensor_map().array();
559 auto second_array = t1.tensor_map().array();
560 auto first_pos = std::find_if(first_array.begin(), first_array.end(), [&](const int64_t ele) { return ele != -1; });
561 auto second_pos =
562 std::find_if(second_array.begin(), second_array.end(), [&](const int64_t ele) { return ele != -1; });
563 if (first_pos != first_array.end() || second_pos != second_array.end()) {
564 return false;
565 }
566 return true;
567 }
568
569 // Check whether layout has interleaved dev mat and the tensor map use the interleaved parallel
IsInterleavedParallel() const570 bool TensorLayout::IsInterleavedParallel() const {
571 if (device_arrangement_interleaved_.array().empty()) {
572 return false;
573 }
574 bool is_interleaved_parallel = false;
575 for (size_t i = 0; i < origin_tensor_map().array().size(); ++i) {
576 if (origin_tensor_map().array()[i] == 0) {
577 is_interleaved_parallel = true;
578 break;
579 }
580 }
581 return is_interleaved_parallel;
582 }
583
584 /*
585 * remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ]
586 * example 1:
587 * original tensor layout:
588 * device arrangement = [ 8 ]
589 * tensor map = [ 0 -1 -1 -1 ]
590 * tensor shape = [ 128 64 1 1 ]
591 * return tensor layout:
592 * device arrangement = [ 8 ]
593 * tensor map = [ 0 -1 ]
594 * tensor shape = [ 128 64 ]
595 *
596 * example 2:
597 * original tensor layout:
598 * device arrangement = [ 8 ]
599 * tensor map = [ -1 -1 -1 -1 ]
600 * tensor shape = [ 1 1 1 1 ]
601 * return tensor layout:
602 * device arrangement = [ 8 ]
603 * tensor map = [ -1 ]
604 * tensor shape = [ 1 ]
605 */
SqueezeShape() const606 TensorLayout TensorLayout::SqueezeShape() const {
607 TensorLayout out;
608 Map out_map;
609 Arrangement out_shape;
610 auto is_dynamic_func = [](const Shape &shape) -> bool {
611 return std::find(shape.begin(), shape.end(), -1) != shape.end();
612 };
613 // tensor_shape's size doesn't make sense in dynamic shape scene.
614 if (!is_dynamic_func(tensor_shape_.array()) && tensor_shape_.size() == 1) {
615 (void)out_map.Init({MAP_NONE});
616 (void)out_shape.Init({1});
617 (void)out.Init(device_arrangement_, out_map, out_shape);
618 return out;
619 }
620 std::vector<size_t> squeeze_list = tensor_shape_.GetSqueezeIdx();
621 if (!tensor_map_.CheckNoneByIdxList(squeeze_list)) {
622 MS_LOG(ERROR) << "CheckNoneByIdxList failed, this may not happen under current situation";
623 return *this;
624 }
625 out_shape = tensor_shape_.GetSqueezeArrangement();
626 out_map = tensor_map_.SqueezeMapByIdxList(squeeze_list);
627 (void)out.Init(device_arrangement_, out_map, out_shape);
628 return out;
629 }
630
TransferRepeatLayout() const631 TensorLayout TensorLayout::TransferRepeatLayout() const {
632 Shape dev_mat(device_arrangement_origin_.array());
633 Shape tensor_map(tensor_map_origin_.GetDimSize(), -1);
634 Shape tensor_shape(tensor_shape_origin_.array());
635 TensorLayout repeat;
636 if (repeat.InitFromVector(dev_mat, tensor_map, tensor_shape) != SUCCESS) {
637 MS_LOG(EXCEPTION) << "Init from vector failed.";
638 }
639 return repeat;
640 }
641
InferRepeatedGroup()642 RankList TensorLayout::InferRepeatedGroup() {
643 CheckGlobalDeviceManager();
644 int64_t rank = g_device_manager->global_rank();
645 DeviceMatrix dev_matrix(rank, g_device_manager->GetDeviceListInThisStage(), device_arrangement_origin_.array());
646 RankList group_devices;
647 if (dev_matrix.GetDevicesByTensorMap(tensor_map_origin_.array(), &group_devices) != SUCCESS) {
648 MS_LOG(EXCEPTION) << "Tensor layout:" << ToString() << " infer repeated group failed.";
649 }
650 return group_devices;
651 }
652
653 // Generate a totally shard tensor slice shape for parallel optimizer
GenerateOptShardSliceShape()654 Status TensorLayout::GenerateOptShardSliceShape() {
655 MS_LOG(INFO) << "layout for GetOptShardSliceShape is " << StandardToString();
656 Shape dev_max = device_arrangement_.array();
657
658 Shape repeated_dev;
659 for (size_t i = 0; i < dev_max.size(); i++) {
660 if (tensor_map_.GetIndexByValue(static_cast<int64_t>(i)) == MAP_NONE) {
661 repeated_dev.push_back(dev_max[dev_max.size() - 1 - i]);
662 dev_max[dev_max.size() - 1 - i] = 1;
663 }
664 }
665 if (repeated_dev.empty()) {
666 MS_LOG(INFO) << "Tensor is totally shard already.";
667 return Status::FAILED;
668 }
669 int64_t repeated_num =
670 std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
671 int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
672 if (optimizer_weight_shard_size != -1 && repeated_num >= optimizer_weight_shard_size) {
673 repeated_num = optimizer_weight_shard_size;
674 }
675
676 Shape origin_slice_shape = base_slice_shape().array();
677 if (origin_slice_shape[0] % repeated_num != 0) {
678 MS_LOG(INFO) << "Tensor could not be shard on the first dimension.";
679 return Status::FAILED;
680 }
681 origin_slice_shape[0] = origin_slice_shape[0] / repeated_num;
682 opt_shard_slice_shape_ = origin_slice_shape;
683 return Status::SUCCESS;
684 }
685 } // namespace parallel
686 } // namespace mindspore
687