• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/tensor_layout.h"
18 #include <iostream>
19 #include <utility>
20 #include "utils/ms_utils.h"
21 #include "ir/value.h"
22 #include "frontend/parallel/device_matrix.h"
23 #include "frontend/parallel/device_manager.h"
24 #include "frontend/parallel/status.h"
25 #include "include/common/utils/parallel_context.h"
26 #include "frontend/parallel/tensor_layout/shape_util.h"
27 #include "utils/log_adapter.h"
28 
29 namespace mindspore {
30 namespace parallel {
ToString() const31 std::string TensorLayout::ToString() const { return StandardToString() + OriginToString(); }
32 
StandardToString() const33 std::string TensorLayout::StandardToString() const {
34   std::ostringstream buffer;
35   buffer << std::endl << std::string("device arrangement = " + device_arrangement_.ToString());
36   buffer << std::endl << std::string("tensor map = " + tensor_map_.ToString());
37   buffer << std::endl << std::string("tensor shape = " + tensor_shape_.ToString());
38   return buffer.str();
39 }
40 
OriginToString() const41 std::string TensorLayout::OriginToString() const {
42   std::ostringstream buffer;
43   buffer << std::endl << std::string("device arrangement origin = " + device_arrangement_origin_.ToString());
44   buffer << std::endl << std::string("tensor map origin = " + tensor_map_origin_.ToString());
45   buffer << std::endl << std::string("tensor shape origin = " + tensor_shape_origin_.ToString());
46   return buffer.str();
47 }
48 
Init(const Arrangement & device_arrangement,const Map & tensor_map,const Arrangement & tensor_shape)49 Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map,
50                           const Arrangement &tensor_shape) {
51   device_arrangement_origin_ = device_arrangement;
52   tensor_map_origin_ = tensor_map;
53   tensor_shape_origin_ = tensor_shape;
54   device_arrangement_ = device_arrangement;
55   tensor_map_ = tensor_map;
56   tensor_shape_ = tensor_shape;
57   if (IsValidTensorLayout()) {
58     MS_LOG(DEBUG) << "valid origin tensor layout " << this->OriginToString();
59     RemoveElementEqualToOneInDeviceArrangement();
60     MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString();
61     return Status::SUCCESS;
62   } else {
63     if (layout_transfer_) {
64       MS_LOG(DEBUG) << "invalid origin tensor layout " << this->OriginToString();
65     } else {
66       MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString();
67     }
68     return Status::FAILED;
69   }
70 }
71 
InitFromVector(const Shape & device_arrangement,const Shape & tensor_map,const Shape & tensor_shape)72 Status TensorLayout::InitFromVector(const Shape &device_arrangement, const Shape &tensor_map,
73                                     const Shape &tensor_shape) {
74   if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) {
75     MS_LOG(ERROR) << "Init device_arrangement failed.";
76     return FAILED;
77   }
78   if (tensor_map_origin_.Init(tensor_map) != SUCCESS) {
79     MS_LOG(ERROR) << "Init tensor_map failed.";
80     return FAILED;
81   }
82   if (tensor_shape_origin_.Init(tensor_shape) != SUCCESS) {
83     MS_LOG(ERROR) << "Init tensor_shape failed.";
84     return FAILED;
85   }
86   if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) {
87     MS_LOG(ERROR) << "Init tensor_layout failed.";
88     return FAILED;
89   }
90   return SUCCESS;
91 }
92 
93 /*
94  *  example1:
95  *    in_device_arrangement = [8, 2, 4],
96  *    in_tensor_map = [[2], [1, 0]],
97  *    in_tensor_shape = [512, 1024],
98  *  =>
99  *    in_device_arrangement = [8, 2, 4],
100  *    in_tensor_map = [2, 1, 0],
101  *    in_tensor_shape = [512, 2, 512],
102  *  example2:
103  *    in_device_arrangement = [8, 2, 4],
104  *    in_tensor_map = [[1], [0, 2]],
105  *    in_tensor_shape = [512, 1024],
106  *  =>
107  *    in_device_arrangement = [8, 2, 4],
108  *    in_tensor_map = [1, 0, 2],
109  *    in_tensor_shape = [512, 4, 256],
110  */
InitFromExtendVector(const Shape & device_matrix,const std::vector<Shape> & tensor_map,const Shape & tensor_shape,bool interleaved_parallel,bool check_device_num)111 Status TensorLayout::InitFromExtendVector(const Shape &device_matrix, const std::vector<Shape> &tensor_map,
112                                           const Shape &tensor_shape, bool interleaved_parallel, bool check_device_num) {
113   auto device_arrangement = device_matrix;
114   if (interleaved_parallel) {
115     if (device_arrangement_interleaved_.Init(device_matrix) != SUCCESS) {
116       return FAILED;
117     }
118     if (parallel::ParallelContext::GetInstance()->fine_grained_micro_interleaved_size() == -1) {
119       parallel::ParallelContext::GetInstance()->set_fine_grained_micro_interleaved_size(
120         device_arrangement[device_arrangement.size() - 1]);
121     } else if (parallel::ParallelContext::GetInstance()->fine_grained_micro_interleaved_size() !=
122                device_arrangement[device_arrangement.size() - 1]) {
123       MS_LOG(EXCEPTION) << "The micro interleaved num should be configured be consistent for each operator's layout.";
124     }
125     device_arrangement[device_arrangement.size() - 1] = 1;
126   }
127 
128   if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) {
129     return FAILED;
130   }
131   CheckGlobalDeviceManager();
132   auto device_num = g_device_manager->stage_device_num();
133   int64_t device_total =
134     std::accumulate(device_arrangement.begin(), device_arrangement.end(), 1, std::multiplies<int64_t>());
135   if (device_num != device_total && check_device_num) {
136     MS_LOG(ERROR) << "The configured device_matrix " << device_arrangement << " accumulate value " << device_total
137                   << " dose not equal to the device number in one stage " << device_num;
138     return FAILED;
139   }
140   Shape extended_tensor_map;
141   Shape reshaped_tensor_shape;
142   if (tensor_shape.size() != tensor_map.size()) {
143     MS_LOG(ERROR) << "The tensor_shape " << tensor_shape << " dose not have the same size with tensor_map "
144                   << tensor_map;
145     return FAILED;
146   }
147 
148   size_t not_none_count = 0;
149   for (size_t i = 0; i < tensor_map.size(); ++i) {
150     for (size_t j = 0; j < tensor_map[i].size(); ++j) {
151       extended_tensor_map.push_back(tensor_map[i][j]);
152       if (tensor_map[i][j] > 0) {
153         ++not_none_count;
154       }
155     }
156   }
157 
158   if (not_none_count > device_arrangement.size()) {
159     MS_LOG(ERROR) << "The device_matrix " << device_arrangement
160                   << " length dose not greater equal than the not None size of extended_tensor_map "
161                   << extended_tensor_map;
162     return FAILED;
163   }
164   tensor_shape_before_.Init(tensor_shape);
165   for (size_t i = 0; i < tensor_map.size(); ++i) {
166     if (tensor_map[i].size() == 1) {
167       reshaped_tensor_shape.push_back(tensor_shape[i]);
168       continue;
169     }
170     int64_t accu_shp = 1;
171     for (size_t j = 0; j < tensor_map[i].size() - 1; ++j) {
172       size_t tensor_index = device_arrangement.size() - 1 - static_cast<size_t>(tensor_map[i][j]);
173       auto shard_size = device_arrangement[tensor_index];
174       accu_shp *= shard_size;
175       reshaped_tensor_shape.push_back(shard_size);
176     }
177     auto last_shp = tensor_shape[i] / accu_shp;
178     reshaped_tensor_shape.push_back(last_shp);
179   }
180   if (tensor_map_origin_.Init(extended_tensor_map) != SUCCESS) {
181     return FAILED;
182   }
183   if (tensor_shape_origin_.Init(reshaped_tensor_shape) != SUCCESS) {
184     return FAILED;
185   }
186   if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) {
187     return FAILED;
188   }
189   tensor_map_before_ = tensor_map;
190   return SUCCESS;
191 }
192 
GetVirtualRank() const193 std::vector<int64_t> TensorLayout::GetVirtualRank() const {
194   int64_t rank = g_device_manager->global_rank();
195   if (!IsInterleavedParallel()) {
196     return {rank};
197   }
198   auto interleaved_num = device_arrangement_interleaved_.array().back();
199   std::vector<int64_t> virtual_ranks;
200   for (int64_t i = 0; i < interleaved_num; ++i) {
201     virtual_ranks.push_back(rank * interleaved_num + i);
202   }
203   return virtual_ranks;
204 }
205 
LayoutForRedistribution() const206 TensorLayout TensorLayout::LayoutForRedistribution() const {
207   if (!IsInterleavedParallel()) {
208     return *this;
209   }
210   TensorLayout interleaved_layout;
211   if (interleaved_layout.InitFromExtendVector(device_arrangement_interleaved_.array(), tensor_map_before_,
212                                               tensor_shape_before_.array(), false, false) != SUCCESS) {
213     MS_LOG(EXCEPTION) << "Init layout for micro interleaved failed, device_matrix:"
214                       << device_arrangement_interleaved_.array() << ", tensor_map:" << tensor_map_before_;
215   }
216   return interleaved_layout;
217 }
218 
IsValidTensorLayout() const219 bool TensorLayout::IsValidTensorLayout() const {
220   int64_t max_tensor_map_item = tensor_map_origin_.GetMaxItem();
221   int64_t device_arr_size = SizeToLong(device_arrangement_origin_.GetDimSize());
222   if (max_tensor_map_item >= device_arr_size) {
223     MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size! "
224                   << "Max element in tensor_map_origin_ is " << max_tensor_map_item
225                   << ", device_arrangement_origin_ size is " << device_arr_size;
226     return false;
227   }
228   size_t tensor_map_size = tensor_map_origin_.GetDimSize();
229   size_t tensor_shape_size = tensor_shape_origin_.GetDimSize();
230   if (tensor_map_size != tensor_shape_size) {
231     MS_LOG(ERROR) << "tensor_map_origin_ size must be equal to tensor_shape_origin_ size! "
232                   << "tensor_map_origin_ size is " << tensor_map_size << ", tensor_shape_origin_ size is "
233                   << tensor_shape_size;
234     return false;
235   }
236   if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) {
237     if (layout_transfer_) {
238       MS_LOG(DEBUG) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
239     } else {
240       MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
241     }
242     return false;
243   }
244   return true;
245 }
246 
TensorShapeDimensionIsDividedBySplitDeviceDimension() const247 bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const {
248   for (uint64_t i = 0; i < tensor_map_.GetDimSize(); i++) {
249     if (tensor_map_.GetDimByIdx(i) != -1) {
250       int64_t divisor = GetSliceNumByTensorDimensionIndex(i);
251       if (divisor == 0) {
252         MS_LOG(ERROR) << "GetSliceNumByTensorDimensionIndex is 0";
253         return false;
254       }
255       if (tensor_shape_.GetDimByIdx(i) != -1 && tensor_shape_.GetDimByIdx(i) % divisor != 0) {
256         if (layout_transfer_) {
257           MS_LOG(DEBUG) << i << "th input shape is not divisible. The input shape is " << tensor_shape_.GetDimByIdx(i)
258                         << ", but the slice number is " << divisor;
259         } else {
260           MS_LOG(ERROR) << i << "th input shape is not divisible. The input shape is " << tensor_shape_.GetDimByIdx(i)
261                         << ", but the slice number is " << divisor;
262         }
263         return false;
264       }
265     }
266   }
267   return true;
268 }
269 
RemoveElementEqualToOneInDeviceArrangement()270 void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() {
271   Shape device_arrangement_shape;
272   Shape tensor_map_shape = tensor_map_origin_.array();
273   size_t dev_num = device_arrangement_origin_.GetDimSize();
274   size_t dev_num_left = device_arrangement_origin_.GetDimSize();
275   for (size_t i = 0; i < dev_num; i++) {
276     if (device_arrangement_origin_.GetDimByIdx(i) == 1) {
277       int64_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast<int64_t>(dev_num - 1 - i));
278       if (idx != -1) {
279         tensor_map_shape[static_cast<uint64_t>(idx)] = -1;
280       }
281       for (auto &value : tensor_map_shape) {
282         if (value >= SizeToLong(dev_num_left) - 1 - static_cast<int64_t>(i)) {
283           value--;
284         }
285       }
286       continue;
287     }
288     device_arrangement_shape.push_back(device_arrangement_origin_.GetDimByIdx(i));
289   }
290   (void)device_arrangement_.Init(device_arrangement_shape);
291   (void)tensor_map_.Init(tensor_map_shape);
292   tensor_shape_ = tensor_shape_origin_;
293 }
294 
295 // if idx is not in tensor_map, return -1
GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const296 int64_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const {
297   return tensor_map_.GetIndexByValue(idx);
298 }
299 
300 // tensor_map_.GetDimByIdx(idx) should not be -1
GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const301 int64_t TensorLayout::GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const {
302   return static_cast<int64_t>(device_arrangement_.GetDimSize()) - 1 - tensor_map_.GetDimByIdx(idx);
303 }
304 
305 // tensor_map_.GetDimByIdx(idx) should not be -1
GetSliceNumByTensorDimensionIndex(uint64_t idx) const306 int64_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint64_t idx) const {
307   return device_arrangement_.GetDimByIdx(static_cast<uint64_t>(GetSliceDeviceDimensionByTensorDimensionIndex(idx)));
308 }
309 
ExpandTensorShape(const Arrangement & expanded_shape) const310 std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const {
311   std::shared_ptr<Arrangement> expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape);
312   if (expanded_arrangement_ptr == nullptr) {
313     return nullptr;
314   }
315   std::shared_ptr<TensorLayout> temp_tensor_layout_ptr = ExpandDeviceArrangement(*expanded_arrangement_ptr);
316   if (temp_tensor_layout_ptr == nullptr) {
317     return nullptr;
318   }
319   return temp_tensor_layout_ptr->ExpandTensorShapeWithoutExtendDeviceArrangement(expanded_shape);
320 }
321 
322 /*
323  *  example1:
324  *    in_device_arrangement = [8, 4],
325  *    in_tensor_map = [1, 0],
326  *    in_tensor_shape = [512, 1024],
327  *    out_tensor_shape = [128, 4, 2, 512],
328  *  =>
329  *    out_device_arrangement = [8, 2, 2]
330  */
ComputeArrangementByExpandedShape(const Arrangement & tensor_shape) const331 std::shared_ptr<Arrangement> TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const {
332   std::shared_ptr<std::vector<Arrangement>> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape);
333   if (expand_list_ptr == nullptr) {
334     return nullptr;
335   }
336   std::vector<Arrangement> re_map_expand_list;
337   Arrangement empty_arrangement;
338   for (int64_t i = static_cast<int64_t>(device_arrangement_.GetDimSize()) - 1; i >= 0; i--) {
339     if (tensor_map_.GetIndexByValue(i) < 0) {
340       re_map_expand_list.push_back(empty_arrangement);
341     } else {
342       re_map_expand_list.push_back((*expand_list_ptr)[LongToUlong(tensor_map_.GetIndexByValue(i))]);
343     }
344   }
345   std::shared_ptr<Arrangement> new_arrangement_ptr =
346     device_arrangement_.GetExpandedShapeByExpandListRemoveLeft(re_map_expand_list);
347   return new_arrangement_ptr;
348 }
349 
350 /*
351  *  example1:
352  *    in_device_arrangement = [8, 4],
353  *    in_tensor_map = [1, 0],
354  *    in_tensor_shape = [512, 1024],
355  *    out_tensor_shape = [8, 64, 4, 256]
356  *  =>
357  *    out_device_arrangement = [8, 4],
358  *    out_tensor_map = [1, -1, 0, -1],
359  */
ExpandTensorShapeWithoutExtendDeviceArrangement(const Arrangement & expanded_shape) const360 std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement(
361   const Arrangement &expanded_shape) const {
362   std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> expand_list_pair_ptr =
363     tensor_shape_.GetExpandShapeListPair(expanded_shape);
364   if (expand_list_pair_ptr == nullptr) {
365     return nullptr;
366   }
367   std::shared_ptr<Map> tensor_map_new_ptr = tensor_map_.ExpandMapByNone(expand_list_pair_ptr->second);
368   if (tensor_map_new_ptr == nullptr) {
369     return nullptr;
370   }
371   TensorLayout tensor_layout_new;
372   tensor_layout_new.set_layout_transfer(true);
373   Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape);
374   if (status != Status::SUCCESS) {
375     return nullptr;
376   }
377   return std::make_shared<TensorLayout>(tensor_layout_new);
378 }
379 
380 /*
381  *  example1:
382  *    in_device_arrangement = [8, 4],
383  *    in_tensor_map = [1, 0],
384  *    in_tensor_shape = [512, 1024],
385  *    out_device_arrangement = [4, 2, 2, 2]
386  *  =>
387  *    out_tensor_map = [3, 2, 1, 0],
388  *    out_tensor_shape = [4, 128, 2, 512]
389  *
390  *  example2:
391  *    in_device_arrangement = [8, 4],
392  *    in_tensor_map = [0, 1],
393  *    in_tensor_shape = [512, 1024],
394  *    out_device_arrangement = [4, 2, 2, 2]
395  *  =>
396  *    out_tensor_map = [1, 0, 3, 2],
397  *    out_tensor_shape = [2, 256, 4, 256]
398  *
399  *  example3:
400  *    in_device_arrangement = [8, 4],
401  *    in_tensor_map = [1, -1],
402  *    in_tensor_shape = [512, 1024],
403  *    out_device_arrangement = [4, 2, 2, 2]
404  *  =>
405  *    out_tensor_map = [3, 2, -1],
406  *    out_tensor_shape = [4, 128, 1024]
407  *
408  *  example4:
409  *    in_device_arrangement = [8, 4],
410  *    in_tensor_map = [0, 1],
411  *    in_tensor_shape = [512, 1024],
412  *    out_device_arrangement = [4, 2, 4]
413  *  =>
414  *    out_tensor_map = [0, 2, 1],
415  *    out_tensor_shape = [512, 4, 256]
416  */
ExpandDeviceArrangement(const Arrangement & expanded_arrangement) const417 std::shared_ptr<TensorLayout> TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const {
418   std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> expand_list_pair_ptr =
419     device_arrangement_.GetExpandShapeListPair(expanded_arrangement);
420   if (expand_list_pair_ptr == nullptr) {
421     return nullptr;
422   }
423   std::shared_ptr<Map> tensor_map_new_ptr = tensor_map_.ExpandMapByDecreaseNumber(expand_list_pair_ptr->second);
424   if (tensor_map_new_ptr == nullptr) {
425     return nullptr;
426   }
427   std::shared_ptr<std::vector<Arrangement>> re_map_shape_list_ptr =
428     tensor_map_.ReMapVector(expand_list_pair_ptr->first);
429   if (re_map_shape_list_ptr == nullptr) {
430     return nullptr;
431   }
432   std::shared_ptr<Arrangement> tensor_shape_new_ptr =
433     tensor_shape_.GetExpandedShapeByExpandListReserveLeft(*re_map_shape_list_ptr);
434   if (tensor_shape_new_ptr == nullptr) {
435     return nullptr;
436   }
437   TensorLayout tensor_layout_new;
438   Status status = tensor_layout_new.Init(expanded_arrangement, *tensor_map_new_ptr, *tensor_shape_new_ptr);
439   if (status != Status::SUCCESS) {
440     return nullptr;
441   }
442   return std::make_shared<TensorLayout>(tensor_layout_new);
443 }
444 
TensorShapeCanBeExpanded(const Arrangement & expand_shape) const445 bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const {
446   Shape in_expand_shape_shape;
447   Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape);
448   if (status != Status::SUCCESS) {
449     return false;
450   }
451   return (in_expand_shape_shape == tensor_shape_.array());
452 }
453 
ComputeExpandedTensorShape(const Arrangement & expand_shape) const454 std::shared_ptr<Arrangement> TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const {
455   Shape in_expand_shape_shape;
456   Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape);
457   if (status != Status::SUCCESS) {
458     return nullptr;
459   }
460   Arrangement expanded_shape;
461   status = expanded_shape.Init(in_expand_shape_shape);
462   if (status != Status::SUCCESS) {
463     return nullptr;
464   }
465   return std::make_shared<Arrangement>(expanded_shape);
466 }
467 
slice_shape() const468 Arrangement TensorLayout::slice_shape() const {
469   Shape shape;
470   for (size_t index = 0; index < tensor_map_.GetDimSize(); index++) {
471     int64_t dim = tensor_map_.GetDimByIdx(index);
472     int64_t num = tensor_shape_.GetDimByIdx(index);
473     if (dim == -1 || num == -1) {
474       shape.push_back(num);  // num == -1 means dynamic shape
475     } else {
476       int64_t divisor = device_arrangement_.GetDimByReverseIdx(LongToUlong(dim));
477       shape.push_back(num / divisor);
478     }
479   }
480   Arrangement new_tensor_shape;
481   if (new_tensor_shape.Init(shape) == Status::FAILED) {
482     ValuePtr ptr = MakeValue(shape);
483     MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << ptr->ToString();
484   } else {
485     return new_tensor_shape;
486   }
487 }
488 
base_slice_shape() const489 Arrangement TensorLayout::base_slice_shape() const {
490   if (tensor_map_before_.empty()) {
491     return slice_shape();
492   }
493   Shape shape;
494   for (size_t index = 0; index < tensor_map_before_.size(); index++) {
495     auto dim_map = tensor_map_before_[index];
496     int64_t num = tensor_shape_before_.GetDimByIdx(index);
497     int64_t axis_shard = 1;
498     for (const auto &dim : dim_map) {
499       if (dim != -1) {
500         int64_t divisor = device_arrangement_origin_.GetDimByReverseIdx(LongToUlong(dim));
501         axis_shard *= divisor;
502       }
503     }
504     if (num == -1) {
505       shape.push_back(num);  // num == -1 means dynamic shape
506     } else {
507       shape.push_back(num / axis_shard);
508     }
509   }
510   Arrangement new_slice_shape;
511   if (new_slice_shape.Init(shape) == Status::FAILED) {
512     MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << shape;
513   } else {
514     return new_slice_shape;
515   }
516 }
517 
shard_strategy() const518 Shape TensorLayout::shard_strategy() const {
519   Shape ret;
520   for (size_t index = 0; index < tensor_map_.GetDimSize(); index++) {
521     int64_t dim = tensor_map_.GetDimByIdx(index);
522     if (dim == -1) {
523       ret.push_back(1);
524     } else {
525       int64_t divisor = device_arrangement_.GetDimByReverseIdx(LongToUlong(dim));
526       ret.push_back(divisor);
527     }
528   }
529   return ret;
530 }
531 
UpdateTensorMap(size_t index,int64_t value)532 Status TensorLayout::UpdateTensorMap(size_t index, int64_t value) {
533   if (index >= tensor_map_.GetDimSize()) {
534     MS_LOG(ERROR) << "Index is out of the size of the tensor map!";
535     return Status::FAILED;
536   }
537   auto shape = tensor_map_.array();
538   shape[index] = value;
539   if (tensor_map_.Init(shape) == Status::FAILED) {
540     MS_LOG(ERROR) << "Update tensor map failed!";
541     return Status::FAILED;
542   }
543   return Status::SUCCESS;
544 }
545 
operator ==(const TensorLayout & t1) const546 bool TensorLayout::operator==(const TensorLayout &t1) const {
547   return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1));
548 }
549 
operator !=(const TensorLayout & t1) const550 bool TensorLayout::operator!=(const TensorLayout &t1) const {
551   return !(IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1));
552 }
553 
IsSameWithoutSplit(const TensorLayout & t1) const554 bool TensorLayout::IsSameWithoutSplit(const TensorLayout &t1) const {
555   if (!IsSameTensorMap(t1) || !IsSameTensorShape(t1)) {
556     return false;
557   }
558   auto first_array = tensor_map().array();
559   auto second_array = t1.tensor_map().array();
560   auto first_pos = std::find_if(first_array.begin(), first_array.end(), [&](const int64_t ele) { return ele != -1; });
561   auto second_pos =
562     std::find_if(second_array.begin(), second_array.end(), [&](const int64_t ele) { return ele != -1; });
563   if (first_pos != first_array.end() || second_pos != second_array.end()) {
564     return false;
565   }
566   return true;
567 }
568 
569 // Check whether layout has interleaved dev mat and the tensor map use the interleaved parallel
IsInterleavedParallel() const570 bool TensorLayout::IsInterleavedParallel() const {
571   if (device_arrangement_interleaved_.array().empty()) {
572     return false;
573   }
574   bool is_interleaved_parallel = false;
575   for (size_t i = 0; i < origin_tensor_map().array().size(); ++i) {
576     if (origin_tensor_map().array()[i] == 0) {
577       is_interleaved_parallel = true;
578       break;
579     }
580   }
581   return is_interleaved_parallel;
582 }
583 
584 /*
585  * remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ]
586  * example 1:
587  *  original tensor layout:
588  *    device arrangement = [ 8 ]
589  *    tensor map = [ 0 -1 -1 -1 ]
590  *    tensor shape = [ 128 64 1 1 ]
591  *  return tensor layout:
592  *    device arrangement = [ 8 ]
593  *    tensor map = [ 0 -1 ]
594  *    tensor shape = [ 128 64 ]
595  *
596  * example 2:
597  *  original tensor layout:
598  *    device arrangement = [ 8 ]
599  *    tensor map = [ -1 -1 -1 -1 ]
600  *    tensor shape = [ 1 1 1 1 ]
601  *  return tensor layout:
602  *    device arrangement = [ 8 ]
603  *    tensor map = [ -1 ]
604  *    tensor shape = [ 1 ]
605  */
SqueezeShape() const606 TensorLayout TensorLayout::SqueezeShape() const {
607   TensorLayout out;
608   Map out_map;
609   Arrangement out_shape;
610   auto is_dynamic_func = [](const Shape &shape) -> bool {
611     return std::find(shape.begin(), shape.end(), -1) != shape.end();
612   };
613   // tensor_shape's size doesn't make sense in dynamic shape scene.
614   if (!is_dynamic_func(tensor_shape_.array()) && tensor_shape_.size() == 1) {
615     (void)out_map.Init({MAP_NONE});
616     (void)out_shape.Init({1});
617     (void)out.Init(device_arrangement_, out_map, out_shape);
618     return out;
619   }
620   std::vector<size_t> squeeze_list = tensor_shape_.GetSqueezeIdx();
621   if (!tensor_map_.CheckNoneByIdxList(squeeze_list)) {
622     MS_LOG(ERROR) << "CheckNoneByIdxList failed, this may not happen under current situation";
623     return *this;
624   }
625   out_shape = tensor_shape_.GetSqueezeArrangement();
626   out_map = tensor_map_.SqueezeMapByIdxList(squeeze_list);
627   (void)out.Init(device_arrangement_, out_map, out_shape);
628   return out;
629 }
630 
TransferRepeatLayout() const631 TensorLayout TensorLayout::TransferRepeatLayout() const {
632   Shape dev_mat(device_arrangement_origin_.array());
633   Shape tensor_map(tensor_map_origin_.GetDimSize(), -1);
634   Shape tensor_shape(tensor_shape_origin_.array());
635   TensorLayout repeat;
636   if (repeat.InitFromVector(dev_mat, tensor_map, tensor_shape) != SUCCESS) {
637     MS_LOG(EXCEPTION) << "Init from vector failed.";
638   }
639   return repeat;
640 }
641 
InferRepeatedGroup()642 RankList TensorLayout::InferRepeatedGroup() {
643   CheckGlobalDeviceManager();
644   int64_t rank = g_device_manager->global_rank();
645   DeviceMatrix dev_matrix(rank, g_device_manager->GetDeviceListInThisStage(), device_arrangement_origin_.array());
646   RankList group_devices;
647   if (dev_matrix.GetDevicesByTensorMap(tensor_map_origin_.array(), &group_devices) != SUCCESS) {
648     MS_LOG(EXCEPTION) << "Tensor layout:" << ToString() << " infer repeated group failed.";
649   }
650   return group_devices;
651 }
652 
653 // Generate a totally shard tensor slice shape for parallel optimizer
GenerateOptShardSliceShape()654 Status TensorLayout::GenerateOptShardSliceShape() {
655   MS_LOG(INFO) << "layout for GetOptShardSliceShape is " << StandardToString();
656   Shape dev_max = device_arrangement_.array();
657 
658   Shape repeated_dev;
659   for (size_t i = 0; i < dev_max.size(); i++) {
660     if (tensor_map_.GetIndexByValue(static_cast<int64_t>(i)) == MAP_NONE) {
661       repeated_dev.push_back(dev_max[dev_max.size() - 1 - i]);
662       dev_max[dev_max.size() - 1 - i] = 1;
663     }
664   }
665   if (repeated_dev.empty()) {
666     MS_LOG(INFO) << "Tensor is totally shard already.";
667     return Status::FAILED;
668   }
669   int64_t repeated_num =
670     std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
671   int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
672   if (optimizer_weight_shard_size != -1 && repeated_num >= optimizer_weight_shard_size) {
673     repeated_num = optimizer_weight_shard_size;
674   }
675 
676   Shape origin_slice_shape = base_slice_shape().array();
677   if (origin_slice_shape[0] % repeated_num != 0) {
678     MS_LOG(INFO) << "Tensor could not be shard on the first dimension.";
679     return Status::FAILED;
680   }
681   origin_slice_shape[0] = origin_slice_shape[0] / repeated_num;
682   opt_shard_slice_shape_ = origin_slice_shape;
683   return Status::SUCCESS;
684 }
685 }  // namespace parallel
686 }  // namespace mindspore
687