1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
18 #include <functional>
19 #include <numeric>
20 #include <memory>
21 #include <set>
22 #include <utility>
23 #include <algorithm>
24 #include <string>
25 #include "frontend/parallel/status.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "frontend/parallel/graph_util/graph_utils.h"
28 #include "frontend/parallel/tensor_layout/shape_util.h"
29 #include "frontend/parallel/step_parallel_utils.h"
30 #include "frontend/parallel/tensor_layout/prime_generator.h"
31 #include "frontend/parallel/tensor_layout/layout_utils.h"
32
33 namespace mindspore {
34 namespace parallel {
MakeFromToLayout(const TensorLayout & from,const TensorLayout & to)35 Status TensorRedistribution::MakeFromToLayout(const TensorLayout &from, const TensorLayout &to) {
36 auto from_layout = from.LayoutForRedistribution();
37 auto to_layout = to.LayoutForRedistribution();
38 if (virtual_rank_ >= 0) {
39 from_origin_ = from_layout;
40 to_origin_ = to_layout;
41 virtual_rank_list_ = {virtual_rank_};
42 return SUCCESS;
43 }
44 if (from.GetVirtualRank().size() == to.GetVirtualRank().size()) {
45 from_origin_ = from_layout;
46 to_origin_ = to_layout;
47 virtual_rank_list_ = from.GetVirtualRank();
48 return SUCCESS;
49 }
50 if (from.GetVirtualRank().size() == 1) {
51 auto device_matrix = from_layout.device_arrangement_origin().array();
52 device_matrix.push_back(to.GetVirtualRank().size());
53 virtual_rank_list_ = to.GetVirtualRank();
54 to_origin_ = to_layout;
55 if (!from_layout.tensor_map_before().empty()) {
56 auto new_tensor_map = from_layout.tensor_map_before();
57 std::for_each(new_tensor_map.begin(), new_tensor_map.end(), [](auto &inner_vec) {
58 std::for_each(inner_vec.begin(), inner_vec.end(), [](auto &val) {
59 if (val >= 0) {
60 val++;
61 }
62 });
63 });
64 return from_origin_.InitFromExtendVector(device_matrix, new_tensor_map, from_layout.tensor_shape_before().array(),
65 false, false);
66 }
67 auto new_map = from_layout.origin_tensor_map().array();
68 std::transform(new_map.begin(), new_map.end(), new_map.begin(),
69 [](const auto &val) { return val >= 0 ? val + 1 : val; });
70 return from_origin_.InitFromVector(device_matrix, new_map, from_layout.tensor_shape().array());
71 }
72 if (to.GetVirtualRank().size() == 1) {
73 auto device_matrix = to_layout.device_arrangement_origin().array();
74 device_matrix.push_back(from.GetVirtualRank().size());
75 virtual_rank_list_ = from.GetVirtualRank();
76 from_origin_ = from_layout;
77 if (!to_layout.tensor_map_before().empty()) {
78 auto new_tensor_map = to_layout.tensor_map_before();
79 std::for_each(new_tensor_map.begin(), new_tensor_map.end(), [](auto &inner_vec) {
80 std::for_each(inner_vec.begin(), inner_vec.end(), [](auto &val) {
81 if (val >= 0) {
82 val++;
83 }
84 });
85 });
86 return to_origin_.InitFromExtendVector(device_matrix, new_tensor_map, to_layout.tensor_shape_before().array(),
87 false, false);
88 }
89 auto new_map = to_layout.origin_tensor_map().array();
90 std::transform(new_map.begin(), new_map.end(), new_map.begin(),
91 [](const auto &val) { return val >= 0 ? val + 1 : val; });
92 return to_origin_.InitFromVector(device_matrix, new_map, to_layout.tensor_shape().array());
93 }
94 MS_LOG(ERROR) << "The from layout sharding micro interleaved num:" << from.GetVirtualRank().size()
95 << " dose not match the to layout sharding micro interleaved num:" << to.GetVirtualRank().size();
96 return FAILED;
97 }
98
Init(const TensorLayout & from,const TensorLayout & to,const RankList & dev_list)99 Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) {
100 if (MakeFromToLayout(from, to) != SUCCESS) {
101 MS_LOG(ERROR) << "Make from_layout and to_layout failed.";
102 return FAILED;
103 }
104 this->is_dynamic_shape_ = CheckDynamicShape(from, to);
105 if (this->is_dynamic_shape_) {
106 // Dynamic info of func_graph should be considered.
107 MS_LOG(INFO) << "LayoutTransfer inited with dynamic shape.";
108 this->from_origin_no_assembled_ = this->from_origin_;
109 this->to_origin_no_assembled_ = this->to_origin_;
110 Status ret = this->AssembleStaticTensorShape(this->from_origin_no_assembled_, this->to_origin_no_assembled_,
111 &this->from_origin_, &this->to_origin_);
112 if (ret != Status::SUCCESS) {
113 return ret;
114 }
115 this->is_assembled_static_shape_ = true;
116 }
117 const Shape from_origin_shape = from_origin_.tensor_shape().array();
118 const Shape to_origin_shape = to_origin_.tensor_shape().array();
119 bool is_from_dyn = std::find(from_origin_shape.begin(), from_origin_shape.end(), -1) != from_origin_shape.end();
120 bool is_to_dyn = std::find(to_origin_shape.begin(), to_origin_shape.end(), -1) != to_origin_shape.end();
121 if (!is_from_dyn && !is_to_dyn && from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) {
122 MS_LOG(ERROR) << "from shape size must be equal to to shape size! from shape size is "
123 << from_origin_.tensor_shape().size() << ", to shape size is " << to_origin_.tensor_shape().size();
124 MS_LOG(ERROR) << "reshape from_origin_ " << from_origin_.ToString();
125 MS_LOG(ERROR) << "reshape to_origin_ " << to_origin_.ToString();
126 return Status::FAILED;
127 }
128
129 if (virtual_rank_list_.size() == 1) {
130 dev_list_ = dev_list;
131 } else {
132 for (const auto &rank : dev_list) {
133 for (size_t i = 0; i < virtual_rank_list_.size(); ++i) {
134 dev_list_.push_back(int64_t(rank * virtual_rank_list_.size() + i));
135 }
136 }
137 }
138 from_ = from_origin_.SqueezeShape();
139 to_ = to_origin_.SqueezeShape();
140
141 this->is_inited_ = true;
142 return Status::SUCCESS;
143 }
144
CalculateFromTensorShape(Shape * from_shape,const Array & from_factors,const Shape & to_shape,const Array & to_factors)145 Status TensorRedistribution::CalculateFromTensorShape(Shape *from_shape, const Array &from_factors,
146 const Shape &to_shape, const Array &to_factors) {
147 if (from_shape->size() != from_factors.GetDimSize() || to_shape.size() != to_factors.GetDimSize()) {
148 MS_LOG(ERROR) << "Shape size is not equal to factor size.";
149 return Status::FAILED;
150 }
151 int64_t to_layout_added_factor = GetLeastFactorWithoutConstDims(to_shape, to_factors);
152 int64_t to_layout_const_size = GetTensorSize(to_shape);
153 int64_t from_layout_const_size = GetTensorSize(*from_shape);
154 if (to_layout_const_size > from_layout_const_size && to_layout_const_size % from_layout_const_size == 0) {
155 to_layout_added_factor *= (to_layout_const_size / from_layout_const_size);
156 }
157 MS_LOG(INFO) << "from_shape=" << (*from_shape) << ", from_factors=" << from_factors.array()
158 << ", to_shape=" << to_shape << ", to_factors=" << to_factors.array()
159 << ", to_layout_added_factor=" << to_layout_added_factor;
160 if (from_layout_const_size > to_layout_const_size && from_layout_const_size % to_layout_const_size == 0) {
161 int64_t merged_const_factor = from_layout_const_size / to_layout_const_size;
162 // Existed dim in from_layout already satisfy to_layout_added_factor.
163 if (to_layout_added_factor > merged_const_factor && to_layout_added_factor % merged_const_factor == 0) {
164 to_layout_added_factor /= merged_const_factor;
165 }
166 if (to_layout_added_factor == 1) {
167 to_layout_added_factor = -1;
168 }
169 }
170 bool strict_mode = UseStrictMode(*from_shape, to_shape);
171 std::vector<int64_t> known_dims;
172 (void)std::copy_if(from_shape->begin(), from_shape->end(), std::back_inserter(known_dims),
173 [](int64_t dim) -> bool { return dim != -1; });
174 constexpr size_t INVALID_TENSOR_RANK = 9999;
175 size_t last_dyn_dim = INVALID_TENSOR_RANK;
176 auto last_dyn_dim_iter = std::find(from_shape->rbegin(), from_shape->rend(), -1);
177 if (last_dyn_dim_iter != from_shape->rend()) {
178 last_dyn_dim = from_shape->size() - (last_dyn_dim_iter - from_shape->rbegin()) - 1;
179 }
180 for (size_t i = 0; i < from_shape->size(); ++i) {
181 if (from_shape->at(i) != -1) {
182 continue;
183 }
184 int64_t prime_num = PrimeGenerator::GetInstance()->GetCoprimeNum(known_dims);
185 if (prime_num == -1) {
186 return Status::FAILED;
187 }
188 (*from_shape)[i] = prime_num * from_factors.GetDimByIdx(i);
189 if (strict_mode && from_shape->at(i) < to_factors.GetDimByIdx(i) &&
190 from_factors.GetDimByIdx(i) < to_factors.GetDimByIdx(i)) {
191 int64_t common_factor = std::gcd(from_factors.GetDimByIdx(i), to_factors.GetDimByIdx(i));
192 int64_t left_factor = to_factors.GetDimByIdx(i) / common_factor;
193 (*from_shape)[i] *= left_factor;
194 if (to_layout_added_factor >= left_factor && to_layout_added_factor % left_factor == 0) {
195 to_layout_added_factor /= left_factor;
196 }
197 if (to_layout_added_factor < left_factor) {
198 to_layout_added_factor = -1;
199 }
200 }
201 if (strict_mode && from_shape->at(i) >= to_factors.GetDimByIdx(i) &&
202 from_shape->at(i) % to_factors.GetDimByIdx(i) != 0) {
203 (*from_shape)[i] *= to_factors.GetDimByIdx(i);
204 if (to_layout_added_factor >= to_factors.GetDimByIdx(i) &&
205 to_layout_added_factor % to_factors.GetDimByIdx(i) == 0) {
206 to_layout_added_factor /= to_factors.GetDimByIdx(i);
207 }
208 }
209 if (i == last_dyn_dim && to_layout_added_factor > 0) {
210 if (from_shape->at(i) % to_layout_added_factor != 0) {
211 (*from_shape)[i] *= to_layout_added_factor;
212 }
213 to_layout_added_factor = -1;
214 }
215 known_dims.emplace_back(from_shape->at(i));
216 MS_LOG(DEBUG) << "Replace " << i << " with value " << from_shape->at(i) << " prime " << prime_num;
217 if (!RecordDimsChange(i, from_shape->at(i), &this->from_dims_replace_memo_)) {
218 MS_LOG(ERROR) << "Index " << i << " conflicts.";
219 return Status::FAILED;
220 }
221 }
222 return Status::SUCCESS;
223 }
224
EnumerateArray(int64_t base_n,size_t length=100)225 static std::vector<int64_t> EnumerateArray(int64_t base_n, size_t length = 100) {
226 static std::map<int64_t, std::vector<int64_t>> enum_numbers;
227 if (enum_numbers.find(base_n) != enum_numbers.end()) {
228 return enum_numbers.at(base_n);
229 }
230 std::vector<int64_t> array(length);
231 for (size_t i = 1; i < length + 1; ++i) {
232 array[i - 1] = base_n * SizeToLong(i);
233 }
234 return array;
235 }
236
CalculateToTensorShapeUsingEnumeration(const Shape & from_tsr_shape,Shape * to_tsr_shape,const Array & factors)237 Status TensorRedistribution::CalculateToTensorShapeUsingEnumeration(const Shape &from_tsr_shape, Shape *to_tsr_shape,
238 const Array &factors) {
239 int64_t src_element_size = GetTensorSize(from_tsr_shape);
240 int64_t dst_element_size = GetTensorSize(*to_tsr_shape);
241 if (src_element_size % dst_element_size != 0) {
242 MS_LOG(ERROR) << "Calculate to tensor shape failed. Tensor shape size is not matched.";
243 return Status::FAILED;
244 }
245 const int64_t dyn_dim_val = -1;
246 int64_t dyn_axis_cnt = std::count(to_tsr_shape->begin(), to_tsr_shape->end(), dyn_dim_val);
247 int64_t left_size = src_element_size / dst_element_size;
248
249 if (dyn_axis_cnt == 0) {
250 if (left_size != 1) {
251 MS_LOG(ERROR) << "Calculate to tensor shape failed. Tensor shape size is not matched.";
252 return Status::FAILED;
253 }
254 return Status::SUCCESS;
255 }
256
257 if (dyn_axis_cnt == 1) {
258 /**
259 * Case1:
260 * from: c1, -1(32), c3, c4; to: c1/2, -1(32)*c3, c4
261 */
262 auto iter = std::find(to_tsr_shape->begin(), to_tsr_shape->end(), dyn_dim_val);
263 size_t index = static_cast<size_t>(iter - to_tsr_shape->begin());
264 if (left_size % factors.GetDimByIdx(index) != 0) {
265 MS_LOG(ERROR) << "Generate static shape failed, the shape cannot be divided by factor. dim=" << left_size
266 << ", factor=" << factors.GetDimByIdx(index);
267 return Status::FAILED;
268 }
269 (*iter) = left_size;
270 if (!RecordDimsChange(index, left_size, &this->to_dims_replace_memo_)) {
271 MS_LOG(ERROR) << "Index " << iter - to_tsr_shape->begin() << " conflicts.";
272 return Status::FAILED;
273 }
274 return Status::SUCCESS;
275 } else {
276 /**
277 * Case2:
278 * from: -1(16), c1, c2; to: -1(2), c1*c2/2, 2*-1(8)
279 * Solution:
280 * -1(16), c1*c2/2, 2
281 * A, B, c1*c2/2, 2
282 * A, c1*c2/2, 2* B
283 *
284 * A*B=3*16 && A%2=0 && B%8=0
285 */
286 std::vector<std::vector<int64_t>> enum_numbers;
287 for (size_t i = 0; i < to_tsr_shape->size(); ++i) {
288 if (to_tsr_shape->at(i) == -1) {
289 std::vector<int64_t> array = EnumerateArray(factors.GetDimByIdx(i));
290 enum_numbers.emplace_back(array);
291 }
292 }
293 std::vector<int64_t> candidates(enum_numbers.size());
294 if (!SolveCombination(from_tsr_shape, 0, enum_numbers, 0, left_size, &candidates)) {
295 MS_LOG(ERROR) << "Not supported for now.";
296 return Status::FAILED;
297 }
298 size_t cnt = 0;
299 for (size_t i = 0; i < to_tsr_shape->size(); ++i) {
300 if (to_tsr_shape->at(i) == -1) {
301 (*to_tsr_shape)[i] = candidates[cnt++];
302 if (!RecordDimsChange(i, to_tsr_shape->at(i), &this->to_dims_replace_memo_)) {
303 MS_LOG(ERROR) << "Index " << i << " conflicts.";
304 return Status::FAILED;
305 }
306 }
307 }
308 return Status::SUCCESS;
309 }
310 }
311
CalculateToTensorShapeForOneDynamicAxis(const Shape & from_shape,const Shape & origin_to_shape,Shape * to_shape)312 void CalculateToTensorShapeForOneDynamicAxis(const Shape &from_shape, const Shape &origin_to_shape, Shape *to_shape) {
313 Shape from_shape_divisor(from_shape);
314 size_t dynamic_axis = 0;
315 for (size_t i = 0; i < origin_to_shape.size(); ++i) {
316 int64_t dim_val = origin_to_shape[i];
317 (*to_shape)[i] = dim_val;
318 if (dim_val == -1) {
319 dynamic_axis = i;
320 continue;
321 }
322 for (int64_t &from_dim_val : from_shape_divisor) {
323 if (dim_val == 1) {
324 break;
325 }
326 int64_t f = std::gcd(dim_val, from_dim_val);
327 from_dim_val /= f;
328 dim_val /= f;
329 }
330 }
331 (*to_shape)[dynamic_axis] = GetTensorSize(from_shape_divisor);
332 MS_LOG(INFO) << "to_shape=" << (*to_shape) << ", from_shape_divisor=" << from_shape_divisor;
333 }
334
CalculateToTensorShape(const Shape & from_shape,const Shape & origin_to_shape,const Array & to_in_factors,Shape * to_shape)335 Status TensorRedistribution::CalculateToTensorShape(const Shape &from_shape, const Shape &origin_to_shape,
336 const Array &to_in_factors, Shape *to_shape) {
337 MS_LOG(INFO) << "from_shape=" << from_shape << ", origin_to_shape=" << origin_to_shape
338 << ", to_in_factors=" << to_in_factors.array();
339 // Use forward and backward matching first, if failed, turn to enumeration.
340 if (std::count(origin_to_shape.begin(), origin_to_shape.end(), -1) == 1) {
341 CalculateToTensorShapeForOneDynamicAxis(from_shape, origin_to_shape, to_shape);
342 return Status::SUCCESS;
343 }
344 bool flag_forward_match = ForwardMatching(from_shape, origin_to_shape, to_shape, to_in_factors);
345 if (!flag_forward_match && !BackwardMatching(origin_to_shape, to_shape, to_in_factors)) {
346 MS_LOG(DEBUG) << "Backward matching failed.";
347 if (CalculateToTensorShapeUsingEnumeration(from_shape, to_shape, to_in_factors) != Status::SUCCESS) {
348 MS_LOG(ERROR) << "Calculate to tensor shape failed trying to use enumeration method.";
349 return Status::FAILED;
350 }
351 }
352 return Status::SUCCESS;
353 }
354
AssembleStaticTensorShape(const TensorLayout & from_in,const TensorLayout & to_in,TensorLayout * new_from_layout,TensorLayout * new_to_layout)355 Status TensorRedistribution::AssembleStaticTensorShape(const TensorLayout &from_in, const TensorLayout &to_in,
356 TensorLayout *new_from_layout, TensorLayout *new_to_layout) {
357 Shape new_from_shape(from_in.tensor_shape().array());
358 Shape original_to_shape = to_in.tensor_shape().array();
359 Array from_in_factors;
360 if (GetFactors(from_in, &from_in_factors) != Status::SUCCESS) {
361 MS_LOG(ERROR) << "Get from_in factors failed.";
362 return Status::FAILED;
363 }
364 Array to_in_factors;
365 if (GetFactors(to_in, &to_in_factors) != Status::SUCCESS) {
366 MS_LOG(ERROR) << "Get to_in factors failed.";
367 return Status::FAILED;
368 }
369 if (CalculateFromTensorShape(&new_from_shape, from_in_factors, original_to_shape, to_in_factors) != Status::SUCCESS) {
370 MS_LOG(ERROR) << "Failed to generate static shape for from_tensor layout: " << from_in.ToString();
371 return Status::FAILED;
372 }
373 Shape new_to_shape(to_in_factors.GetDimSize(), 1);
374 if (CalculateToTensorShape(new_from_shape, original_to_shape, to_in_factors, &new_to_shape)) {
375 MS_LOG(ERROR) << "Failed to generate static shape for to_tensor layout: " << to_in.ToString() << std::endl
376 << "from_in layout: " << from_in.ToString() << std::endl
377 << "Already generate from_in shape: " << new_from_shape;
378 return Status::FAILED;
379 }
380 size_t size = std::min(new_from_shape.size(), new_to_shape.size());
381 if (GetTensorSize(new_from_shape) != GetTensorSize(new_to_shape)) {
382 int64_t acc_scalar = 1;
383 for (size_t i = 0; i < size; ++i) {
384 if (new_from_shape.at(i) > new_to_shape.at(i) && new_from_shape.at(i) % new_to_shape.at(i) == 0) {
385 int64_t scalar = new_from_shape.at(i) / new_to_shape.at(i);
386 new_to_shape[i] = new_to_shape[i] * scalar;
387 acc_scalar *= scalar;
388 }
389 }
390 const Shape &f_in_tensor_shape = from_in.tensor_shape().array();
391 auto last_dyn_dim_iter = std::find(f_in_tensor_shape.rbegin(), f_in_tensor_shape.rend(), -1);
392 if (last_dyn_dim_iter != f_in_tensor_shape.rend()) {
393 size_t last_dyn_dim =
394 f_in_tensor_shape.size() - static_cast<size_t>(last_dyn_dim_iter - f_in_tensor_shape.rbegin()) - 1;
395 new_from_shape[static_cast<size_t>(last_dyn_dim)] *= acc_scalar;
396 }
397 }
398
399 // Unify shape from begin to end.
400 UnifyFromAndToShape(&new_from_shape, &new_to_shape, from_in, to_in, &this->from_dims_replace_memo_);
401
402 MS_LOG(INFO) << "new_from_shape=" << new_from_shape << ", new_to_shape=" << new_to_shape;
403 if (new_from_layout->InitFromVector(from_in.device_arrangement().array(), from_in.tensor_map().array(),
404 new_from_shape) != Status::SUCCESS) {
405 MS_LOG(ERROR) << "Failed to init new from_tensor layout.";
406 return Status::FAILED;
407 }
408 MS_LOG(DEBUG) << "Init new_from_tensor layout, origin:" << from_in.ToString()
409 << ", new:" << new_from_layout->ToString();
410
411 if (new_to_layout->InitFromVector(to_in.device_arrangement().array(), to_in.tensor_map().array(), new_to_shape) !=
412 Status::SUCCESS) {
413 MS_LOG(ERROR) << "Failed to init new to_tensor layout.";
414 return Status::FAILED;
415 }
416 MS_LOG(DEBUG) << "Init new_to_layout layout, origin:" << to_in.ToString() << ", new:" << new_to_layout->ToString();
417
418 return Status::SUCCESS;
419 }
420
IsVirtualDatasetNextInput(const CNodePtr & cnode,const CNodePtr & dst_cnode,size_t depth=0)421 bool IsVirtualDatasetNextInput(const CNodePtr &cnode, const CNodePtr &dst_cnode, size_t depth = 0) {
422 if (depth >= MAX_RECURSIVE_DEPTH) {
423 return false;
424 }
425 for (size_t j = 1; j < cnode->inputs().size(); ++j) {
426 auto cur_cnode = cnode->input(j)->cast<CNodePtr>();
427 if (cur_cnode == nullptr) {
428 continue;
429 }
430 if (cur_cnode->UniqueId() == dst_cnode->UniqueId()) {
431 return true;
432 }
433 if (IsVirtualDatasetNextInput(cur_cnode, dst_cnode, depth + 1)) {
434 return true;
435 }
436 }
437 return false;
438 }
439
UpdateShapeNodeInput(const CNodePtr & current_cnode,const CNodePtr & dst_cnode,size_t redistribution_index)440 CNodePtr UpdateShapeNodeInput(const CNodePtr ¤t_cnode, const CNodePtr &dst_cnode, size_t redistribution_index) {
441 for (size_t i = redistribution_index; i < current_cnode->inputs().size(); ++i) {
442 auto prev_cnode = current_cnode->input(i)->cast<CNodePtr>();
443 if (prev_cnode == nullptr) {
444 continue;
445 }
446 bool found = IsVirtualDatasetNextInput(prev_cnode, dst_cnode);
447 if (found) {
448 MS_LOG(INFO) << "change input to " << current_cnode->input(1)->fullname_with_scope();
449 return prev_cnode;
450 }
451 }
452 return nullptr;
453 }
454
GetDimMapping(const AssembledDynamicDimsMapping & mapping,int64_t index)455 std::pair<int64_t, AnfNodePtr> GetDimMapping(const AssembledDynamicDimsMapping &mapping, int64_t index) {
456 for (const auto &iter : mapping) {
457 if (SizeToLong(iter.second.first) == index) {
458 return std::make_pair(iter.first, iter.second.second);
459 }
460 }
461 MS_LOG(EXCEPTION) << "Cannot find index " << index << " in AssembledDynamicDimsMapping.";
462 }
463
UnifyAssembledMappingWithSqueezedFromShape()464 void TensorRedistribution::UnifyAssembledMappingWithSqueezedFromShape() {
465 AssembledDynamicDimsMapping new_mapping;
466 for (const auto &iter : this->dynamic_dim_mapping_) {
467 auto origin_tuple_get_item = iter.second.second;
468 auto origin_tuple_get_item_cnode = origin_tuple_get_item->cast<CNodePtr>();
469 MS_EXCEPTION_IF_NULL(origin_tuple_get_item_cnode);
470 auto func_graph = origin_tuple_get_item->func_graph();
471 MS_EXCEPTION_IF_NULL(func_graph);
472 auto prim_tuple_get_item = std::make_shared<Primitive>(TUPLE_GETITEM_OP);
473 int64_t index = SizeToLong(iter.second.first) + 1;
474 AnfNodePtrList inputs{NewValueNode(prim_tuple_get_item), origin_tuple_get_item_cnode->input(1),
475 NewValueNode(MakeValue(index))};
476 auto tuple_get_item_cnode = func_graph->NewCNode(inputs);
477 tuple_get_item_cnode->set_fullname_with_scope(iter.second.second->fullname_with_scope());
478 prim_tuple_get_item->set_instance_name("tuple_getitem_for_value_" + std::to_string(iter.first));
479 if (iter.second.second->isa<CNode>()) {
480 auto raw_cnode = iter.second.second->cast<CNodePtr>();
481 if (IsValueNode<Primitive>(raw_cnode->input(0))) {
482 auto prim_node = raw_cnode->input(0)->cast<ValueNodePtr>();
483 auto prim = GetValueNode<PrimitivePtr>(prim_node);
484 prim_tuple_get_item->set_instance_name(prim->instance_name());
485 }
486 }
487 new_mapping.insert({iter.first, {iter.second.first, tuple_get_item_cnode}});
488 MS_LOG(WARNING) << "Adjust TupleGetItem for dim=" << iter.second.first << " to " << iter.second.first + 1
489 << " to replace value=" << iter.first;
490 }
491 this->dynamic_dim_mapping_ = new_mapping;
492 }
493
UnifyAssembledMappingWithSameSize(const std::set<int64_t> & index_mapping)494 void TensorRedistribution::UnifyAssembledMappingWithSameSize(const std::set<int64_t> &index_mapping) {
495 Shape from_shape = this->assembled_static_origin_from_.tensor_shape().array();
496 Shape origin_slice_shape = this->assembled_static_origin_from_.slice_shape().array();
497 AssembledDynamicDimsMapping new_mapping;
498 for (int64_t i = SizeToLong(from_shape.size()) - 1; i >= 0; --i) {
499 if (index_mapping.find(i) == index_mapping.end()) {
500 continue;
501 }
502 auto dyn_dim = GetDimMapping(this->dynamic_dim_mapping_, i);
503 int64_t real_dim_value = origin_slice_shape[i];
504 new_mapping.insert({real_dim_value, {i, dyn_dim.second}});
505 MS_LOG(INFO) << "insert at " << i << " with " << real_dim_value;
506 }
507 this->dynamic_dim_mapping_ = new_mapping;
508 }
509
UnifyAssembledMappingWithDiffSize(const std::set<int64_t> & index_mapping)510 void TensorRedistribution::UnifyAssembledMappingWithDiffSize(const std::set<int64_t> &index_mapping) {
511 auto func_graph = this->next_cnode_->func_graph();
512 MS_EXCEPTION_IF_NULL(func_graph);
513
514 Shape from_shape = this->assembled_static_origin_from_.tensor_shape().array();
515 Shape origin_slice_shape = this->assembled_static_origin_from_.slice_shape().array();
516 Shape unified_from_shape = this->layout_transfer_.from_in().tensor_shape().array();
517 Shape unified_slice_shape = this->layout_transfer_.from_in().slice_shape().array();
518
519 AssembledDynamicDimsMapping new_mapping;
520 // Assume length of unified_from_shape must be greater than from_shape.
521 int64_t unified_offset = SizeToLong(unified_from_shape.size()) - 1;
522 for (int64_t i = SizeToLong(from_shape.size()) - 1; i >= 0 && unified_offset >= 0; --i) {
523 int64_t real_dim_value = origin_slice_shape[i];
524 // It means it's a const dim.
525 if (index_mapping.find(i) == index_mapping.end()) {
526 MS_EXCEPTION_IF_CHECK_FAIL(real_dim_value >= unified_slice_shape[unified_offset] &&
527 real_dim_value % unified_slice_shape[unified_offset] == 0,
528 "Tensor layout tensor shape is illegal.");
529 int64_t left_size = real_dim_value / unified_slice_shape[unified_offset];
530 --unified_offset;
531 if (left_size == 1) {
532 continue;
533 }
534 while (left_size != 1 && unified_offset >= 0) {
535 MS_EXCEPTION_IF_CHECK_FAIL(left_size % unified_slice_shape[unified_offset] == 0,
536 "Tensor layout tensor shape is illegal, left_size is " + std::to_string(left_size) +
537 ", factor is " + std::to_string(unified_slice_shape[unified_offset]));
538 left_size = left_size / unified_slice_shape[unified_offset];
539 --unified_offset;
540 }
541 continue;
542 }
543 auto dyn_dim = GetDimMapping(this->dynamic_dim_mapping_, i);
544 // It means it's a dynamic dim.
545 if (from_shape[i] == unified_from_shape[unified_offset]) {
546 new_mapping.insert({real_dim_value, {unified_offset, dyn_dim.second}});
547 MS_LOG(INFO) << "insert at " << unified_offset << " with " << real_dim_value;
548 --unified_offset;
549 } else if (from_shape[i] > unified_slice_shape[unified_offset] &&
550 from_shape[i] % unified_slice_shape[unified_offset] == 0) {
551 // left_size must be greater than 1.
552 int64_t left_size = real_dim_value / unified_slice_shape[unified_offset];
553 MS_EXCEPTION_IF_CHECK_FAIL(left_size >= 1, "left_size must be greater than or equal to 1.");
554 int64_t divisor = real_dim_value / unified_slice_shape[unified_offset];
555 if (GetPrimeFactor(unified_slice_shape[unified_offset]) != -1) {
556 AnfNodePtr new_dim_node = CreateDiv(dyn_dim.second, divisor, func_graph, true, "assemble_dynamic_shape_op");
557 new_mapping.insert({unified_slice_shape[unified_offset], {unified_offset, new_dim_node}});
558 MS_LOG(INFO) << "insert at " << unified_offset << " with " << unified_slice_shape[unified_offset];
559 } else {
560 new_mapping.insert({unified_slice_shape[unified_offset], {unified_offset, dyn_dim.second}});
561 MS_LOG(INFO) << "insert at " << unified_offset << " with " << unified_slice_shape[unified_offset];
562 }
563 --unified_offset;
564 while (left_size != 1 && unified_offset >= 0) {
565 left_size = left_size / unified_slice_shape[unified_offset];
566 // If it's prime then add it to mapping.
567 if (GetPrimeFactor(unified_slice_shape[unified_offset]) != -1) {
568 new_mapping.insert({unified_slice_shape[unified_offset], {unified_offset, dyn_dim.second}});
569 MS_LOG(INFO) << "insert at " << unified_offset << " with " << unified_slice_shape[unified_offset];
570 } else {
571 MS_LOG(INFO) << "skip at " << unified_offset << " for " << unified_slice_shape[unified_offset]
572 << ", because it's not a prime.";
573 }
574 --unified_offset;
575 }
576 if (left_size != 1 && unified_offset < 0) {
577 MS_LOG(EXCEPTION) << "Tensor shape cannot be unified.";
578 }
579 } else {
580 MS_LOG(EXCEPTION) << "Tensor shape cannot be unified.";
581 }
582 }
583 this->dynamic_dim_mapping_ = new_mapping;
584 }
585
UnifyAssembledMapping()586 void TensorRedistribution::UnifyAssembledMapping() {
587 // 12,10,2,2 -> 2,6,10,2,2, 12 and 10 are all dynamic.
588 // 4, 6,2,2 -> 2,2, 6,2,2, 4 is static and 6 is dynamic.
589 // After refactor, from_origin_ and layer_transfer_.from_in are both in static shape.
590 // 1. If origin_from_shape.size > before_unified_from_shape, it means the shape is squeezed.
591 // Squeezed could be in head and also be in tail.
592 // 2. If before_unified_from_shape < unified_from_shape, it means the shape is expanded.
593 Shape origin_from_shape = this->from_origin_.tensor_shape().array();
594 Shape origin_from_slice_shape = this->from_origin_.slice_shape().array();
595 Shape before_unified_from_shape = this->assembled_static_origin_from_.tensor_shape().array();
596 Shape before_unified_from_slice_shape = this->assembled_static_origin_from_.slice_shape().array();
597 Shape unified_from_shape = this->layout_transfer_.from_in().tensor_shape().array();
598 Shape unified_from_slice_shape = this->layout_transfer_.from_in().slice_shape().array();
599
600 std::set<int64_t> index_mapping;
601 for (const auto &iter : this->dynamic_dim_mapping_) {
602 index_mapping.insert(iter.second.first);
603 }
604 MS_LOG(INFO) << "\norigin_from_shape=" << origin_from_shape << ", origin_from_slice_shape=" << origin_from_slice_shape
605 << ", \nbefore_unified_from_shape=" << before_unified_from_shape
606 << ", before_unified_from_slice_shape=" << before_unified_from_slice_shape
607 << ", \nunified_from_shape=" << unified_from_shape
608 << ", unified_from_slice_shape=" << unified_from_slice_shape;
609 if (before_unified_from_shape.size() == origin_from_shape.size() - 1 &&
610 (origin_from_shape.front() == 1 || origin_from_shape.back() == 1)) {
611 // It means unified_from_shape and before_unified_from_shape are squeezed,
612 // origin_from_shape has no squeezed info.
613 MS_LOG(WARNING) << "before_unified_from_shape == origin_from_shape - 1.";
614 this->UnifyAssembledMappingWithSqueezedFromShape();
615 return;
616 }
617 if (unified_from_shape.size() == origin_from_shape.size()) {
618 MS_LOG(WARNING) << "unified_from_shape == origin_from_shape.";
619 this->UnifyAssembledMappingWithSameSize(index_mapping);
620 return;
621 }
622 if (unified_from_shape.size() > before_unified_from_shape.size()) {
623 // In this branch, it means the unified_from_shape is expanded,
624 // or it's reshaped to another shape.
625 MS_LOG(WARNING) << "unified_from_shape > before_unified_from_shape.";
626 if (before_unified_from_shape.size() == origin_from_shape.size() - 1 &&
627 (origin_from_shape.front() == 1 || origin_from_shape.back() == 1)) {
628 // It means shape has been squeezed, so add one to index in mapping.
629 this->UnifyAssembledMappingWithSqueezedFromShape();
630 }
631 this->UnifyAssembledMappingWithDiffSize(index_mapping);
632 return;
633 }
634 MS_LOG(EXCEPTION) << "unified_from_shape.size() must be greater than before_unified_from_shape.size().";
635 }
636
CreateAssembledDynamicMapping(const CNodePtr & cur_cnode,const AnfNodePtr & pre_cnode,const FuncGraphPtr & func_graph,int64_t redistribution_index)637 void TensorRedistribution::CreateAssembledDynamicMapping(const CNodePtr &cur_cnode, const AnfNodePtr &pre_cnode,
638 const FuncGraphPtr &func_graph, int64_t redistribution_index) {
639 MS_EXCEPTION_IF_NULL(func_graph);
640 if (!this->IsAssembledStaticShape()) {
641 return;
642 }
643 MS_LOG(INFO) << "Start to create assembled dynamic shape mapping for " << pre_cnode->fullname_with_scope() << "->"
644 << cur_cnode->fullname_with_scope();
645 this->dynamic_dim_mapping_.clear();
646
647 AnfNodePtr shape_root = pre_cnode;
648 if (pre_cnode->isa<CNode>() && IsPrimitiveCNode(pre_cnode, std::make_shared<Primitive>(VIRTUAL_DATA_SET))) {
649 // Find VirtualDataset successor.
650 auto shape_input = UpdateShapeNodeInput(cur_cnode, pre_cnode->cast<CNodePtr>(), redistribution_index);
651 if (shape_input == nullptr) {
652 MS_LOG(WARNING) << "Cannot find real input of shape node.";
653 } else {
654 shape_root = shape_input;
655 }
656 }
657 const std::set<std::string> multi_output_op = {ARGMAXWITHVALUE, LAYER_NORM};
658 if (pre_cnode->isa<CNode>() && IsSomePrimitiveList(pre_cnode->cast<CNodePtr>(), multi_output_op)) {
659 shape_root = cur_cnode->input(redistribution_index);
660 MS_LOG(INFO) << "Change shape_root to " << shape_root->fullname_with_scope();
661 }
662
663 ReplacementMemo from_layout_memo = this->from_dims_replace_memo_;
664 Shape assembled_origin_slice_shape = this->from_origin_.slice_shape().array();
665 MS_LOG(INFO) << "Start to create assembled dynamic shape mapping: " << pre_cnode->fullname_with_scope() << "->"
666 << cur_cnode->fullname_with_scope() << ", shape_root=" << shape_root->fullname_with_scope()
667 << ", assembled_origin_slice_shape=" << assembled_origin_slice_shape;
668 // 1. New shape and set pre_cnode to its inputs.
669 std::string instance_name = std::string(REDISTRIBUTION_OP) + "_" + pre_cnode->fullname_with_scope();
670 auto shape_cnode = CreateShape(shape_root, func_graph, instance_name + "_get_shape");
671 // 2. Create TupleGetItem node to get dim value and insert to mapping.
672 for (const auto &iter : from_layout_memo) {
673 int64_t dim = SizeToLong(iter.first);
674 int64_t replacement = iter.second;
675 MS_EXCEPTION_IF_CHECK_FAIL(replacement % assembled_origin_slice_shape[LongToSize(dim)] == 0,
676 "Slice shape is not matched.");
677 MS_EXCEPTION_IF_CHECK_FAIL(LongToSize(dim) < assembled_origin_slice_shape.size(), "Slice shape is not matched.");
678 replacement = assembled_origin_slice_shape[dim];
679 auto prim_tuple_get_item = std::make_shared<Primitive>(TUPLE_GETITEM_OP);
680 AnfNodePtrList inputs{NewValueNode(prim_tuple_get_item), shape_cnode, NewValueNode(MakeValue(dim))};
681 auto tuple_get_item_cnode = func_graph->NewCNode(inputs);
682 tuple_get_item_cnode->set_fullname_with_scope(std::string(REDISTRIBUTION_OP) + "_getitem");
683 prim_tuple_get_item->set_instance_name(instance_name + "_getitem");
684 this->dynamic_dim_mapping_.insert({replacement, {iter.first, tuple_get_item_cnode}});
685 MS_LOG(INFO) << "Create TupleGetItem for dim=" << dim << " to replace value=" << replacement;
686 }
687 }
688
AppendOperatorVecStr(const OperatorVector & vec,std::string * res)689 void AppendOperatorVecStr(const OperatorVector &vec, std::string *res) {
690 for (size_t i = 0; i < vec.size(); ++i) {
691 res->append(vec.at(i).first);
692 if (i != vec.size() - 1) {
693 res->append(", ");
694 }
695 }
696 }
697
InferTensorRedistributionOperatorListUnExpand(bool is_cost_model)698 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorListUnExpand(bool is_cost_model) {
699 MS_LOG(INFO) << "Start to infer tensor redistribution with unexpanded.";
700 TensorLayout from_origin = this->from_origin_;
701 TensorLayout to_origin = this->to_origin_;
702 TensorLayout from_repeat = from_origin.TransferRepeatLayout();
703 TensorLayout to_repeat = to_origin.TransferRepeatLayout();
704 MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin.ToString();
705 MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin.ToString();
706 MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString();
707 MS_LOG(DEBUG) << "reshape to_repeat " << to_repeat.ToString();
708
709 OperatorVector operator_vector;
710 OutPutInfoVector output_info_vector;
711 if (InferRedistribution(from_origin, from_repeat, &operator_vector, &output_info_vector, is_cost_model) ==
712 Status::FAILED) {
713 return nullptr;
714 }
715 std::string operator_vec_str;
716 AppendOperatorVecStr(operator_vector, &operator_vec_str);
717 MS_LOG(INFO) << "After InferRedistribution, operator_vector size: " << operator_vector.size()
718 << ", operator_vector: " << operator_vec_str;
719 if (from_repeat.slice_shape().array() != to_repeat.slice_shape().array()) {
720 reshape_flag_ = true;
721 ConstructOperator constructor;
722 constructor.UpdateTensorShape(from_repeat.slice_shape().array());
723 Arrangement shape = to_repeat.slice_shape();
724 MS_LOG(INFO) << "from_repeat.slice_shape is not same with to_repeat.slice_shape: "
725 << "from_repeat.slice_shape=" << from_repeat.slice_shape().array()
726 << ", to_repeat.slice_shape=" << to_repeat.slice_shape().array() << ", reshape to "
727 << shape.ToString();
728 if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
729 return nullptr;
730 } else {
731 operator_vector.push_back(constructor.GetOperator());
732 output_info_vector.emplace_back(std::make_pair(false, 0));
733 }
734 }
735 if (InferRedistribution(to_repeat, to_origin, &operator_vector, &output_info_vector, is_cost_model) ==
736 Status::FAILED) {
737 return nullptr;
738 }
739 operator_vec_str.clear();
740 AppendOperatorVecStr(operator_vector, &operator_vec_str);
741 MS_LOG(INFO) << "After InferRedistribution, operator_vector size: " << operator_vector.size()
742 << ", operator_vector: " << operator_vec_str;
743 return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
744 std::make_pair(operator_vector, output_info_vector));
745 }
746
GetRedistributionOperators(const RedistributionOperatorInfer & operator_infer,OperatorVector * operator_vector,OutPutInfoVector * output_info_vector,OperatorList * operator_list)747 void GetRedistributionOperators(const RedistributionOperatorInfer &operator_infer, OperatorVector *operator_vector,
748 OutPutInfoVector *output_info_vector, OperatorList *operator_list) {
749 for (const auto &op : operator_infer.operator_vector()) {
750 (void)operator_vector->emplace_back(op);
751 }
752 for (auto info : operator_infer.output_info_vector()) {
753 (void)output_info_vector->emplace_back(info);
754 }
755 for (const auto &opc : operator_infer.operator_list()) {
756 (void)operator_list->emplace_back(opc);
757 }
758 }
759
InferTensorRedistributionOperatorListForMultiDynamicReshape(bool is_cost_model)760 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorListForMultiDynamicReshape(
761 bool is_cost_model) {
762 MS_LOG(INFO) << "Start to infer tensor redistribution for multi dynamic axis reshape.";
763 if (this->pre_cnode_ != nullptr && this->next_cnode_ != nullptr) {
764 MS_LOG(DEBUG) << this->PrintRedistribution();
765 }
766 OperatorVector operator_vector;
767 OutPutInfoVector output_info_vector;
768 RedistributionOperatorInfer allgather_infer(this->construct_op_flag_);
769 if (allgather_infer.Init(this->from_origin_no_assembled_, this->to_origin_no_assembled_.tensor_map(), this->dev_list_,
770 is_cost_model, this->is_dynamic_shape_) == Status::FAILED) {
771 MS_LOG(EXCEPTION) << "Init operatorInfer failed.";
772 }
773 // 1. Do AllGather on dynamic axis, skip const axis?
774 if (allgather_infer.MergePartialToFullForReshapeHasMultiDynamicAxis() != Status::SUCCESS) {
775 MS_LOG(EXCEPTION) << "Insert AllGather for Reshape which has multi dynamic axis failed.";
776 }
777 GetRedistributionOperators(allgather_infer, &operator_vector, &output_info_vector, &this->operator_list_);
778 // 2. Do Reshape. Const axis value should be divided later?
779 ConstructOperator constructor;
780 // Actually, no need to create virtual shape, store the original inputs and replace it later in replace op.
781 Shape full_shape = this->to_origin_no_assembled_.tensor_shape().array();
782 MS_LOG(INFO) << "before ReshapeOP, full_shape:" << full_shape;
783 if (constructor.ReshapeOP(full_shape, true) == Status::FAILED) {
784 MS_LOG(EXCEPTION) << "Cannot construct Reshape op for shape " << full_shape;
785 }
786 (void)operator_vector.emplace_back(constructor.GetOperator());
787 (void)output_info_vector.emplace_back(std::make_pair(false, 0));
788 // 3. Do Split, skip const axis?
789 RedistributionOperatorInfer allsplit_infer(this->construct_op_flag_);
790 if (allsplit_infer.Init(this->to_origin_no_assembled_, this->to_origin_no_assembled_.tensor_map(), this->dev_list_,
791 is_cost_model, this->is_dynamic_shape_) == Status::FAILED) {
792 MS_LOG(ERROR) << "Init operatorInfer failed";
793 return nullptr;
794 }
795 if (allsplit_infer.SegmentFullShapeToPartial() != Status::SUCCESS) {
796 MS_LOG(EXCEPTION) << "Insert AllSplit for Reshape which has multi dynamic axis failed.";
797 }
798 GetRedistributionOperators(allsplit_infer, &operator_vector, &output_info_vector, &this->operator_list_);
799 std::string operator_vec_str;
800 AppendOperatorVecStr(operator_vector, &operator_vec_str);
801 MS_LOG(INFO) << "After InferAllSplit, operator_vector size: " << operator_vector.size()
802 << ", operator_vector: " << operator_vec_str;
803 return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
804 std::make_pair(operator_vector, output_info_vector));
805 }
806
InferTensorRedistributionOperatorList(bool is_cost_model)807 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) {
808 MS_LOG(INFO) << "Start to infer tensor redistribution.";
809 if (this->pre_cnode_ != nullptr && this->next_cnode_ != nullptr) {
810 MS_LOG(DEBUG) << this->PrintRedistribution();
811 }
812 // Step 1: Match device arrangement between from_ and to_
813 // RedistributionLayoutTransfer layout_transfer;
814 // Step 0: Do dynamic shape to static shape conversion.
815 // TensorRedistribution::Init() only save from and to tensor layout, and squeezed from and to layout.
816 // We can change from_ and to_ in RedistributionLayoutTransfer object directly.
817 // RedistributionLayoutTransfer::Init() will check whether is dynamic shape,
818 // if the static shape cannot be created, reuse early process.
819 Status status = this->layout_transfer_.Init(from_, to_);
820 if (status != Status::SUCCESS) {
821 return nullptr;
822 }
823 TensorLayout from_layout;
824 TensorLayout to_layout;
825 if (this->is_dynamic_shape_ && !this->is_assembled_static_shape_) {
826 from_layout = this->layout_transfer_.from_in();
827 to_layout = this->layout_transfer_.to_in();
828 } else {
829 // init a new layout_transfer
830 // The function of assembled_static_origin_from_ is used to record layout before unify.
831 // When device matrix or tensor shape is needed to unified, it could insert 1 in front of tensor shape
832 // or split a dim into multi dim.
833 this->assembled_static_origin_from_ = this->layout_transfer_.from_in();
834 std::shared_ptr<ReshapeLayoutTransfer> ptr = this->layout_transfer_.UnifyDeviceArrangementAndTensorShape();
835 if (ptr == nullptr) {
836 MS_LOG(ERROR) << "Infer tensor layout return nullptr!";
837 return nullptr;
838 }
839 this->layout_transfer_.Init(ptr->from_in(), ptr->to_in());
840 if (!ptr->ExpandAble()) {
841 expand_able_ = false;
842 return InferTensorRedistributionOperatorListUnExpand(is_cost_model);
843 }
844 from_layout = ptr->from_in();
845 to_layout = ptr->to_in();
846 }
847 MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString();
848 MS_LOG(DEBUG) << "reshape to_layout " << to_layout.ToString();
849 MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString();
850 MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString();
851 MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
852 MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
853
854 // Step 2: Infer redistribution and insert operators
855 OperatorVector operator_vector;
856 OutPutInfoVector output_info_vector;
857 if (InferRedistribution(from_layout, to_layout, &operator_vector, &output_info_vector, is_cost_model) !=
858 Status::SUCCESS) {
859 return nullptr;
860 }
861 // Step 3: Infer reshape and insert operators
862 if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) {
863 MS_LOG(ERROR) << "Construct Reshape operator failed!";
864 return nullptr;
865 }
866 std::string operator_vec_str;
867 AppendOperatorVecStr(operator_vector, &operator_vec_str);
868 MS_LOG(INFO) << "After InferRedistribution, operator_vector size: " << operator_vector.size()
869 << ", operator_vector: " << operator_vec_str;
870 return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
871 std::make_pair(operator_vector, output_info_vector));
872 }
873
InferTensorRedistributionOperatorVirtualGraphs()874 std::vector<RedistributionOpListPtr> TensorRedistribution::InferTensorRedistributionOperatorVirtualGraphs() {
875 std::vector<RedistributionOpListPtr> redis_list_vector;
876 for (const auto &virtual_rank : virtual_rank_list_) {
877 this->SetVirtualRank(virtual_rank);
878 auto redis_list = this->InferTensorRedistributionOperatorList();
879 if (!redis_list) {
880 MS_LOG(INTERNAL_EXCEPTION) << "Infer tensor redistribution failed. from_layout:" << from_origin_.ToString()
881 << ", to_layout:" << to_origin_.ToString();
882 }
883 redis_list_vector.push_back(redis_list);
884 }
885 return redis_list_vector;
886 }
887
IsSameShape(const Shape & src,const Shape & tgt)888 bool IsSameShape(const Shape &src, const Shape &tgt) {
889 if (src.size() != tgt.size()) {
890 return false;
891 }
892 for (size_t i = 0; i < src.size(); ++i) {
893 if (src[i] == -1 || tgt[i] == -1) {
894 continue;
895 }
896 if (src[i] != tgt[i]) {
897 return false;
898 }
899 }
900 return true;
901 }
902
AlignToLayoutShape(const Shape & to_origin_shape,const Shape & to_layout_shape)903 Shape AlignToLayoutShape(const Shape &to_origin_shape, const Shape &to_layout_shape) {
904 Shape target_shape(to_origin_shape);
905 auto cnt = std::count(target_shape.begin(), target_shape.end(), -1);
906 if (cnt < SizeToInt(SIZE_TWO) || to_layout_shape[0] != 1 || to_layout_shape.size() - 1 != target_shape.size()) {
907 return target_shape;
908 }
909 for (size_t i = 0; i < target_shape.size(); ++i) {
910 if (target_shape[i] != -1) {
911 continue;
912 }
913 target_shape[i] = to_layout_shape[i + 1];
914 }
915 return target_shape;
916 }
917
OperatorListIsEmpty(ConstructOperator * constructor,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector)918 Status TensorRedistribution::OperatorListIsEmpty(ConstructOperator *constructor, OperatorVector *const operator_vector,
919 OutPutInfoVector *const output_info_vector) {
920 if (from_origin_.base_slice_shape().array() != to_origin_.base_slice_shape().array() || keep_reshape_) {
921 reshape_flag_ = true;
922 constructor->UpdateTensorShape(from_origin_.base_slice_shape().array());
923 Arrangement shape = to_origin_.base_slice_shape();
924 MS_LOG(INFO) << "from_origin_.base_slice_shape is not same with to_origin_.base_slice_shape: "
925 << "from_origin_.base_slice_shape=" << from_origin_.base_slice_shape().array()
926 << ", to_origin_.base_slice_shape=" << to_origin_.base_slice_shape().array() << ", reshape to "
927 << shape.ToString();
928 auto reshape_mode = ReshapeMode::FROM_ORIGIN_BASE_SLICE_TO_TO_ORIGIN_BASE_SLICE;
929 reshape_mode = this->is_dynamic_shape_ ? reshape_mode : ReshapeMode::NO_RESHAPE;
930 if (constructor->ReshapeOP(shape.array(), false, reshape_mode) == Status::FAILED) {
931 return Status::FAILED;
932 } else {
933 (void)operator_vector->insert(operator_vector->cbegin(), constructor->GetOperator());
934 (void)output_info_vector->insert(output_info_vector->cbegin(), std::make_pair(false, 0));
935 }
936 }
937 return Status::SUCCESS;
938 }
939
InferReshape(const TensorLayout & from_layout,const TensorLayout & to_layout,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector)940 Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout,
941 OperatorVector *const operator_vector,
942 OutPutInfoVector *const output_info_vector) {
943 MS_EXCEPTION_IF_NULL(operator_vector);
944 MS_EXCEPTION_IF_NULL(output_info_vector);
945 ConstructOperator constructor;
946 if (operator_list_.empty()) {
947 return OperatorListIsEmpty(&constructor, operator_vector, output_info_vector);
948 }
949 // 1. 需要知道哪个轴是动态的,哪个轴是常量,只比较常量轴,但是是否能保证from_origin_和from_layout的rank是一样的?
950 // from_origin_是静态,那from_layout也一定是静态,如果from_origin_是动态,那from_layout也一定是动态
951 // 先支持from_origin_和from_layout的rank一样的场景
952 if (!IsSameShape(from_origin_.slice_shape().array(), from_layout.slice_shape().array())) {
953 reshape_flag_ = true;
954 constructor.UpdateTensorShape(from_origin_.slice_shape().array());
955 Arrangement shape = from_layout.slice_shape();
956 MS_LOG(INFO) << "from_origin.slice_shape is not same with from_layout.slice_shape: "
957 << "from_origin_.slice_shape=" << from_origin_.slice_shape().array()
958 << ", from_layout.slice_shape=" << from_layout.slice_shape().array() << ", reshape to "
959 << shape.ToString();
960 auto reshape_mode = ReshapeMode::FROM_ORIGIN_SLICE_TO_FROM_LAYOUT_SLICE;
961 reshape_mode = this->is_dynamic_shape_ ? reshape_mode : ReshapeMode::NO_RESHAPE;
962 if (constructor.ReshapeOP(shape.array(), false, reshape_mode) == Status::FAILED) {
963 return Status::FAILED;
964 } else {
965 // Before all-gather.
966 (void)operator_vector->insert(operator_vector->cbegin(), constructor.GetOperator());
967 (void)output_info_vector->insert(output_info_vector->cbegin(), std::make_pair(false, 0));
968 }
969 }
970
971 if (from_origin_.base_slice_shape().array() != from_origin_.slice_shape().array()) {
972 reshape_flag_ = true;
973 constructor.UpdateTensorShape(from_origin_.base_slice_shape().array());
974 Arrangement shape = from_origin_.slice_shape();
975 MS_LOG(INFO) << "from_origin_.base_slice_shape is not same with from_origin_.slice_shape: "
976 << "from_origin_.base_slice_shape=" << from_origin_.base_slice_shape().array()
977 << ", from_origin_.slice_shape=" << from_origin_.slice_shape().array() << ", reshape to "
978 << shape.ToString();
979 if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
980 return Status::FAILED;
981 } else {
982 // Before all-gather.
983 (void)operator_vector->insert(operator_vector->cbegin(), constructor.GetOperator());
984 (void)output_info_vector->insert(output_info_vector->cbegin(), std::make_pair(false, 0));
985 }
986 }
987
988 if (!IsSameShape(to_origin_.slice_shape().array(), to_layout.slice_shape().array())) {
989 reshape_flag_ = true;
990 constructor.UpdateTensorShape(to_layout.slice_shape().array());
991 // If to_origin_ is all -1, it can not be reshape.
992 Shape target_shape = to_origin_.slice_shape().array();
993 size_t cnt = std::count(target_shape.begin(), target_shape.end(), -1);
994 if (this->IsAssembledStaticShape() && cnt >= SIZE_TWO) {
995 target_shape = AlignToLayoutShape(to_origin_.slice_shape().array(), to_layout.slice_shape().array());
996 MS_LOG(INFO) << "update reshape target shape.";
997 }
998 MS_LOG(INFO) << "to_origin_.slice_shape is not same with to_layout.slice_shape: "
999 << "to_origin_.slice_shape=" << to_origin_.slice_shape().array()
1000 << ", to_layout.slice_shape=" << to_layout.slice_shape().array() << ", reshape to " << target_shape;
1001 auto reshape_mode = ReshapeMode::TO_ORIGIN_SLICE_TO_TO_LAYOUT_SLICE;
1002 reshape_mode = this->is_dynamic_shape_ ? reshape_mode : ReshapeMode::NO_RESHAPE;
1003 if (constructor.ReshapeOP(target_shape, false, reshape_mode) == Status::FAILED) {
1004 return Status::FAILED;
1005 } else {
1006 // After all-gather.
1007 (void)operator_vector->insert(operator_vector->cend(), constructor.GetOperator());
1008 (void)output_info_vector->insert(output_info_vector->cend(), std::make_pair(false, 0));
1009 }
1010 }
1011
1012 if (to_origin_.slice_shape().array() != to_origin_.base_slice_shape().array()) {
1013 reshape_flag_ = true;
1014 constructor.UpdateTensorShape(to_origin_.slice_shape().array());
1015 Arrangement shape = to_origin_.base_slice_shape();
1016 MS_LOG(INFO) << "to_origin_.slice_shape is not same with to_origin_.base_slice_shape: "
1017 << "to_origin_.slice_shape=" << to_origin_.slice_shape().array()
1018 << ", to_origin_.base_slice_shape=" << to_origin_.base_slice_shape().array() << ", reshape to "
1019 << shape.ToString();
1020 if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
1021 return Status::FAILED;
1022 } else {
1023 // After all-gather.
1024 (void)operator_vector->insert(operator_vector->cend(), constructor.GetOperator());
1025 (void)output_info_vector->insert(output_info_vector->cend(), std::make_pair(false, 0));
1026 }
1027 }
1028 return Status::SUCCESS;
1029 }
1030
InferRedistribution(const TensorLayout & from_layout,const TensorLayout & to_layout,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector,bool is_cost_model)1031 Status TensorRedistribution::InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout,
1032 OperatorVector *const operator_vector,
1033 OutPutInfoVector *const output_info_vector, bool is_cost_model) {
1034 MS_EXCEPTION_IF_NULL(operator_vector);
1035 MS_EXCEPTION_IF_NULL(output_info_vector);
1036 MS_LOG(DEBUG) << "Start to infer redistribution.";
1037 RedistributionOperatorInfer operator_infer(construct_op_flag_);
1038 if (virtual_rank_ >= 0) {
1039 operator_infer.SetVirtualRank(virtual_rank_);
1040 }
1041 if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model, this->is_dynamic_shape_) ==
1042 Status::FAILED) {
1043 MS_LOG(ERROR) << "Init operatorInfer failed";
1044 return Status::FAILED;
1045 }
1046 if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) {
1047 MS_LOG(ERROR) << "Infer redistribution failed";
1048 return Status::FAILED;
1049 } else {
1050 for (auto op : operator_infer.operator_vector()) {
1051 (void)operator_vector->insert(operator_vector->cend(), op);
1052 }
1053 for (auto info : operator_infer.output_info_vector()) {
1054 (void)output_info_vector->insert(output_info_vector->cend(), info);
1055 }
1056 for (auto opc : operator_infer.operator_list()) {
1057 (void)operator_list_.insert(operator_list_.cend(), opc);
1058 }
1059 }
1060 return Status::SUCCESS;
1061 }
1062
RollbackToDynamicShape()1063 Status TensorRedistribution::RollbackToDynamicShape() {
1064 if (!this->IsAssembledStaticShape()) {
1065 return Status::FAILED;
1066 }
1067 for (auto &iter : this->from_dims_replace_memo_) {
1068 MS_LOG(DEBUG) << "from index=" << iter.first << ", value=" << iter.second << std::endl;
1069 }
1070 for (auto &iter : this->to_dims_replace_memo_) {
1071 MS_LOG(DEBUG) << "to index=" << iter.first << ", value=" << iter.second << std::endl;
1072 }
1073 MS_LOG(DEBUG) << "RollbackToDynamicShape: from_in_=" << this->from_origin_.ToString() << std::endl
1074 << "to_in_=" << this->to_origin_.ToString() << std::endl;
1075 return Status::SUCCESS;
1076 }
1077
ComputeCost()1078 Status TensorRedistribution::ComputeCost() {
1079 RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true);
1080 if (redistribution_oplist_ptr == nullptr) {
1081 MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed";
1082 return Status::FAILED;
1083 }
1084 // Compute redistribution communication cost and computation cost
1085 for (auto &op_cost : operator_list_) {
1086 OperatorR op = op_cost.first;
1087 Shape slice_shape = op_cost.second;
1088 double prod =
1089 std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
1090 std::string str = op.first;
1091 if (str == PERMUTE_BY_AXIS && ComputePermuteCost(prod, op.second) != Status::SUCCESS) {
1092 return Status::FAILED;
1093 } else if (str == CONCAT_BY_AXIS && ComputeConcatCost(prod, op.second) != Status::SUCCESS) {
1094 return Status::FAILED;
1095 } else {
1096 // There is only computation cost in SplitByAxis.
1097 // computation cost = before_slice_shape
1098 computation_cost_ += prod;
1099 // This addition may be erroneous
1100 memory_cost_ += prod;
1101 }
1102 }
1103 if (reshape_flag()) {
1104 Shape prev_shape;
1105 if (expand_able_) {
1106 prev_shape = from_.slice_shape().array();
1107 } else {
1108 prev_shape = from_.tensor_shape().array();
1109 }
1110 double prev_prod =
1111 std::accumulate(prev_shape.begin(), prev_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
1112 computation_cost_ += COST_FACTOR * prev_prod;
1113 memory_cost_ += COST_FACTOR * prev_prod;
1114 }
1115 return Status::SUCCESS;
1116 }
1117
ComputePermuteCost(double input_size,const Shape & attrs)1118 Status TensorRedistribution::ComputePermuteCost(double input_size, const Shape &attrs) {
1119 // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
1120 // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
1121 if (attrs.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
1122 MS_LOG(ERROR) << "attrs size should not be less than 5!";
1123 return Status::FAILED;
1124 }
1125 forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
1126 backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
1127 comm_cost_ += COST_FACTOR * input_size * ALLTOALL_SCALE_FACTOR;
1128 int64_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
1129 if (concat_dim == 0) {
1130 // memory cost = all_gather
1131 computation_cost_ += input_size;
1132 memory_cost_ += input_size;
1133 } else {
1134 // memory cost = all_gather + split + concat
1135 int64_t dev_num = attrs[TRANSFER_PERMUTE_DEV_NUM_INDEX];
1136 computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
1137 memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
1138 }
1139 return Status::SUCCESS;
1140 }
1141
ComputeConcatCost(double input_size,const Shape & attrs)1142 Status TensorRedistribution::ComputeConcatCost(double input_size, const Shape &attrs) {
1143 // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
1144 // computation cost = before_slice_shape
1145 if (attrs.size() < TRANSFER_CONCAT_ARGS_SIZE) {
1146 MS_LOG(ERROR) << "op.second size should not be less than 3!";
1147 return Status::FAILED;
1148 }
1149 double dev_num = attrs[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
1150 // here, communication cost = all_gather + reduce_scatter
1151 forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
1152 backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
1153 comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
1154 int64_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
1155 if (concat_dim == 0) {
1156 // computation cost = all_gather
1157 computation_cost_ += input_size;
1158 memory_cost_ += input_size * dev_num;
1159 } else {
1160 // computation cost = all_gather + split + concat
1161 computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
1162 memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
1163 }
1164 return Status::SUCCESS;
1165 }
1166 } // namespace parallel
1167 } // namespace mindspore
1168