• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <map>
17 #include <functional>
18 #include <algorithm>
19 #include <set>
20 #include "frontend/parallel/tensor_layout/layout_utils.h"
21 
22 namespace mindspore::parallel {
GetTensorSize(const Shape & shape)23 int64_t GetTensorSize(const Shape &shape) {
24   int64_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
25   return std::abs(size);
26 }
27 
RecordDimsChange(size_t key,int64_t value,std::map<size_t,int64_t> * memo,bool update)28 bool RecordDimsChange(size_t key, int64_t value, std::map<size_t, int64_t> *memo, bool update) {
29   auto iter = memo->find(key);
30   if (!update && iter != memo->end()) {
31     return false;
32   }
33   if (update && memo->find(key) != memo->end()) {
34     (*memo)[key] = value;
35     return true;
36   }
37   memo->insert({key, value});
38   return true;
39 }
40 
GetFactors(const TensorLayout & layout,Array * array)41 Status GetFactors(const TensorLayout &layout, Array *array) {
42   std::vector<int64_t> factors(layout.tensor_shape().array().size());
43   for (uint64_t i = 0; i < layout.tensor_map().GetDimSize(); i++) {
44     if (layout.tensor_map().GetDimByIdx(i) != -1) {
45       int64_t divisor = layout.GetSliceNumByTensorDimensionIndex(i);
46       if (divisor == 0) {
47         MS_LOG(ERROR) << "GetSliceNumByTensorDimensionIndex is 0";
48         return Status::FAILED;
49       }
50       factors[i] = divisor;
51     } else {
52       factors[i] = 1;
53     }
54   }
55   array->Init(factors);
56   return Status::SUCCESS;
57 }
58 
UseStrictMode(const Shape & from_shape,const Shape & to_shape)59 bool UseStrictMode(const Shape &from_shape, const Shape &to_shape) {
60   if (from_shape.size() == to_shape.size()) {
61     for (size_t i = 0; i < from_shape.size(); ++i) {
62       if (from_shape[i] != to_shape[i]) {
63         return false;
64       }
65     }
66     return true;
67   }
68   return false;
69 }
70 
GetLeastFactorWithoutConstDims(const Shape & to_shape,const Array & to_factors)71 int64_t GetLeastFactorWithoutConstDims(const Shape &to_shape, const Array &to_factors) {
72   Shape new_to_factors;
73   for (size_t i = 0; i < to_shape.size(); i++) {
74     if (to_shape.at(i) == -1 && to_factors.GetDimByIdx(i) != -1) {
75       new_to_factors.emplace_back(to_factors.GetDimByIdx(i));
76     }
77   }
78   if (new_to_factors.empty()) {
79     return 1;
80   }
81   int64_t factor = std::accumulate(new_to_factors.begin(), new_to_factors.end(), 1, std::multiplies<int64_t>());
82   return factor;
83 }
84 
InitShapeVec(const Shape & src_shape,Shape * tgt_shape)85 void InitShapeVec(const Shape &src_shape, Shape *tgt_shape) {
86   size_t src_size = src_shape.size();
87   size_t tgt_size = tgt_shape->size();
88   size_t copy_size = std::min(src_size, tgt_size);
89   std::copy(src_shape.begin(), src_shape.begin() + copy_size, tgt_shape->begin());
90   if (tgt_size >= src_size) {
91     return;
92   }
93   for (size_t i = tgt_size; i < src_size; ++i) {
94     (*tgt_shape)[tgt_size - 1] *= src_shape[i];
95   }
96   if (GetTensorSize(src_shape) != GetTensorSize(*tgt_shape)) {
97     MS_LOG(ERROR) << "Failed to copy init tensor.";
98   }
99 }
100 
CheckDynamicShape(const TensorLayout & from_in,const TensorLayout & to_in)101 bool CheckDynamicShape(const TensorLayout &from_in, const TensorLayout &to_in) {
102   Shape from_shape = from_in.tensor_shape().array();
103   Shape to_shape = to_in.tensor_shape().array();
104   auto func = [](const Shape &shape) -> bool { return std::find(shape.begin(), shape.end(), -1) != shape.end(); };
105   return func(from_shape) && func(to_shape);
106 }
107 
UnifyFromAndToShape(Shape * new_from_shape,Shape * new_to_shape,const TensorLayout & from_in,const TensorLayout & to_in,ReplacementMemo * from_dims_replace_memo)108 void UnifyFromAndToShape(Shape *new_from_shape, Shape *new_to_shape, const TensorLayout &from_in,
109                          const TensorLayout &to_in, ReplacementMemo *from_dims_replace_memo) {
110   Shape original_from_shape = from_in.tensor_shape().array();
111   Shape original_to_shape = to_in.tensor_shape().array();
112   for (size_t i = 0; i < new_from_shape->size(); ++i) {
113     if (original_from_shape[i] == -1) {
114       if (i < new_to_shape->size() && new_from_shape->at(i) < new_to_shape->at(i) &&
115           new_to_shape->at(i) % new_from_shape->at(i) == 0) {
116         int64_t scalar = new_to_shape->at(i) / new_from_shape->at(i);
117         for (size_t j = i + 1; j < new_from_shape->size(); ++j) {
118           if (original_from_shape[j] != -1) {
119             continue;
120           }
121           if (new_from_shape->at(j) > scalar && new_from_shape->at(j) % scalar == 0) {
122             (*new_from_shape)[j] = new_from_shape->at(j) / scalar;
123             (*new_from_shape)[i] = new_from_shape->at(i) * scalar;
124             RecordDimsChange(i, new_from_shape->at(i), from_dims_replace_memo, true);
125             RecordDimsChange(j, new_from_shape->at(j), from_dims_replace_memo, true);
126             break;
127           }
128         }
129       }
130     }
131   }
132 }
133 
IntroduceConstraints(const Shape & expected_tgt_shape,Shape * tgt_shape)134 void IntroduceConstraints(const Shape &expected_tgt_shape, Shape *tgt_shape) {
135   // ([80,7,768,16], [-1,-1,3072,-1]) -> [80,7,3072,4]
136   // ([20480,768,1,1], [-1, 1024, 12, 64]) -> [20, 1024, 12, 64]
137   // Record fix dim index.
138   std::set<size_t> index;
139   std::vector<size_t> dynamic_dim_index;
140   for (size_t i = 0; i < expected_tgt_shape.size(); ++i) {
141     if (expected_tgt_shape[i] == -1) {
142       dynamic_dim_index.emplace_back(i);
143     }
144   }
145   for (size_t i = 0; i < expected_tgt_shape.size(); ++i) {
146     if (expected_tgt_shape[i] == -1) {
147       continue;
148     }
149     if (tgt_shape->at(i) == expected_tgt_shape[i]) {
150       index.insert(i);
151       continue;
152     }
153     if (tgt_shape->at(i) > expected_tgt_shape[i]) {
154       if (tgt_shape->at(i) % expected_tgt_shape[i] == 0) {
155         int64_t f = tgt_shape->at(i) / expected_tgt_shape[i];
156         for (int32_t j = static_cast<int32_t>(tgt_shape->size()) - 1; j >= 0; --j) {
157           if (j == static_cast<int32_t>(i) || index.find(j) != index.end()) {
158             continue;
159           }
160           (*tgt_shape)[j] *= f;
161           break;
162         }
163         (*tgt_shape)[i] = expected_tgt_shape[i];
164       } else {
165         MS_LOG(ERROR) << "Can't be divided.";
166       }
167     } else {
168       if (expected_tgt_shape[i] % tgt_shape->at(i) == 0) {
169         int64_t f = expected_tgt_shape[i] / tgt_shape->at(i);
170         for (int32_t j = static_cast<int32_t>(tgt_shape->size()) - 1; j >= 0; --j) {
171           if (j == static_cast<int32_t>(i) || index.find(j) != index.end()) {
172             continue;
173           }
174           int64_t divider = std::gcd(f, tgt_shape->at(j));
175           (*tgt_shape)[j] /= divider;
176           f /= divider;
177           if (f == 1) {
178             break;
179           }
180         }
181         if (f != 1) {
182           MS_LOG(ERROR) << "Can't merge shape.";
183         }
184         (*tgt_shape)[i] = expected_tgt_shape[i];
185       } else {
186         int64_t target_dim = expected_tgt_shape[i];  // 1024
187         for (int32_t j = static_cast<int32_t>(tgt_shape->size()) - 1; j >= 0; --j) {
188           if (index.find(j) != index.end()) {
189             continue;
190           }
191           int64_t divider = std::gcd(target_dim, tgt_shape->at(j));
192           (*tgt_shape)[j] /= divider;
193           target_dim /= divider;
194           if (target_dim == 1) {
195             break;
196           }
197         }
198         if (target_dim != 1) {
199           MS_LOG(ERROR) << "Can't be divided.";
200         } else {
201           // find last dyn dim on right and put tgt_shape->at(i) to it
202           (*tgt_shape)[dynamic_dim_index.back()] = tgt_shape->at(dynamic_dim_index.back()) * tgt_shape->at(i);
203           (*tgt_shape)[i] = expected_tgt_shape[i];
204         }
205       }
206     }
207     index.insert(i);
208   }
209 }
210 
ForwardMatching(const Shape & src_shape,const Shape & expected_tgt_shape,Shape * tgt_shape,const Array & tgt_factors)211 bool ForwardMatching(const Shape &src_shape, const Shape &expected_tgt_shape, Shape *tgt_shape,
212                      const Array &tgt_factors) {
213   // Borrow the size from right dim, then borrow the size from left dim.
214   // tgt_shape must be inited with value 1 and has fixed size.
215   InitShapeVec(src_shape, tgt_shape);
216   IntroduceConstraints(expected_tgt_shape, tgt_shape);
217   int64_t tensor_size = GetTensorSize(*tgt_shape);
218   size_t src_size = tgt_shape->size();
219   std::set<size_t> fix_index;
220   for (size_t i = 0; i < expected_tgt_shape.size(); ++i) {
221     if (expected_tgt_shape[i] != -1) {
222       fix_index.insert(i);
223     }
224   }
225   for (size_t i = 0; i < tgt_shape->size(); ++i) {
226     if (tgt_shape->at(i) % tgt_factors.GetDimByIdx(i) == 0) {
227       tensor_size /= tgt_shape->at(i);
228       continue;
229     }
230     // Borrow the size from right dim.
231     int64_t factor = tgt_factors.GetDimByIdx(i);
232     int64_t val = tgt_shape->at(i) * factor;
233     if (val > tensor_size) {
234       MS_LOG(DEBUG) << "Out of size when calculate index " << i;
235       return false;
236     }
237     size_t ptr = i + 1;
238     while (ptr < src_size) {
239       if (fix_index.find(ptr) != fix_index.end()) {
240         ++ptr;
241         continue;
242       }
243       if (tgt_shape->at(ptr) >= factor && tgt_shape->at(ptr) % factor == 0) {
244         (*tgt_shape)[ptr] /= factor;
245         factor = 1;
246         break;
247       }
248       int64_t divisor = std::gcd(tgt_shape->at(ptr), factor);
249       factor /= divisor;
250       (*tgt_shape)[ptr] /= divisor;
251       ++ptr;
252     }
253     if (factor != 1) {
254       MS_LOG(DEBUG) << "Out of size when calculate index " << i << ". Can't borrow dim from right.";
255       return false;
256     }
257     (*tgt_shape)[i] = val;
258     tensor_size /= val;
259   }
260   if (tensor_size != 1) {
261     MS_LOG(ERROR) << "Failed to forward matching.";
262     return false;
263   }
264   return true;
265 }
266 
BackwardMatching(const Shape & expected_tgt_shape,Shape * tgt_shape,const Array & tgt_factors)267 bool BackwardMatching(const Shape &expected_tgt_shape, Shape *tgt_shape, const Array &tgt_factors) {
268   // Borrow the size from right dim.
269   // Then borrow the size from left dim.
270   int64_t ori_tensor_size = GetTensorSize(*tgt_shape);
271   int64_t dst_size = SizeToLong(tgt_shape->size());
272   std::set<size_t> fix_index;
273   for (size_t i = 0; i < expected_tgt_shape.size(); ++i) {
274     if (expected_tgt_shape[i] != -1) {
275       fix_index.insert(i);
276     }
277   }
278   for (int32_t i = dst_size - 1; i >= 0; --i) {
279     // Borrow the size from left dim.
280     int64_t factor = tgt_factors.GetDimByIdx(i);
281     if (tgt_shape->at(i) % factor == 0) {
282       continue;
283     }
284     int64_t to_be_filled_dim = tgt_shape->at(i) * factor;
285     int32_t ptr = i - 1;
286     while (ptr >= 0) {
287       if (fix_index.find(ptr) != fix_index.end()) {
288         --ptr;
289         continue;
290       }
291       if (tgt_shape->at(ptr) % factor == 0 && tgt_shape->at(ptr) / factor % tgt_factors.GetDimByIdx(ptr) == 0) {
292         (*tgt_shape)[ptr] /= factor;
293         factor = 1;
294         break;
295       }
296       int64_t divisor = std::gcd(tgt_shape->at(ptr), factor);
297       factor /= divisor;
298       (*tgt_shape)[ptr] /= divisor;
299       --ptr;
300     }
301     if (factor != 1) {
302       MS_LOG(ERROR) << "Can't borrow factor from left.";
303       return false;
304     }
305     (*tgt_shape)[i] = to_be_filled_dim;
306   }
307   if (ori_tensor_size != GetTensorSize(*tgt_shape)) {
308     MS_LOG(ERROR) << "After backward matching, tensor size is not equal.";
309     return false;
310   }
311   return true;
312 }
313 
SolveCombination(const Shape & src_shape_arr,size_t src_index,const std::vector<std::vector<int64_t>> & enum_numbers,size_t offset,int64_t target,std::vector<int64_t> * candidates_values)314 bool SolveCombination(const Shape &src_shape_arr, size_t src_index,
315                       const std::vector<std::vector<int64_t>> &enum_numbers, size_t offset, int64_t target,
316                       std::vector<int64_t> *candidates_values) {
317   bool is_last = (enum_numbers.size() - offset - 1) == 0;
318   if (src_index < src_shape_arr.size()) {
319     constexpr size_t MAX_DIM = 8;
320     for (size_t factor = 1; factor < MAX_DIM; ++factor) {
321       int64_t preferred_choose = SizeToLong(factor) * src_shape_arr[src_index];
322       if (std::find(enum_numbers[offset].begin(), enum_numbers[offset].end(), preferred_choose) !=
323             enum_numbers[offset].end() &&
324           preferred_choose <= target && target % preferred_choose == 0) {
325         (*candidates_values)[offset] = preferred_choose;
326         if (!is_last && SolveCombination(src_shape_arr, src_index + 1, enum_numbers, offset + 1,
327                                          target / candidates_values->at(offset), candidates_values)) {
328           return true;
329         }
330       }
331     }
332   }
333   for (size_t i = 0; i < enum_numbers[offset].size(); ++i) {
334     if (enum_numbers[offset][i] > target) {
335       break;
336     }
337     if (target % enum_numbers[offset][i] != 0) {
338       continue;
339     }
340     (*candidates_values)[offset] = enum_numbers[offset][i];
341     if (is_last && target / enum_numbers[offset][i] == 1) {
342       return true;
343     }
344     if (!is_last && SolveCombination(src_shape_arr, src_index, enum_numbers, offset + 1,
345                                      target / enum_numbers[offset][i], candidates_values)) {
346       return true;
347     }
348   }
349   return false;
350 }
351 }  // namespace mindspore::parallel
352