1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <map>
17 #include <functional>
18 #include <algorithm>
19 #include <set>
20 #include "frontend/parallel/tensor_layout/layout_utils.h"
21
22 namespace mindspore::parallel {
GetTensorSize(const Shape & shape)23 int64_t GetTensorSize(const Shape &shape) {
24 int64_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
25 return std::abs(size);
26 }
27
RecordDimsChange(size_t key,int64_t value,std::map<size_t,int64_t> * memo,bool update)28 bool RecordDimsChange(size_t key, int64_t value, std::map<size_t, int64_t> *memo, bool update) {
29 auto iter = memo->find(key);
30 if (!update && iter != memo->end()) {
31 return false;
32 }
33 if (update && memo->find(key) != memo->end()) {
34 (*memo)[key] = value;
35 return true;
36 }
37 memo->insert({key, value});
38 return true;
39 }
40
GetFactors(const TensorLayout & layout,Array * array)41 Status GetFactors(const TensorLayout &layout, Array *array) {
42 std::vector<int64_t> factors(layout.tensor_shape().array().size());
43 for (uint64_t i = 0; i < layout.tensor_map().GetDimSize(); i++) {
44 if (layout.tensor_map().GetDimByIdx(i) != -1) {
45 int64_t divisor = layout.GetSliceNumByTensorDimensionIndex(i);
46 if (divisor == 0) {
47 MS_LOG(ERROR) << "GetSliceNumByTensorDimensionIndex is 0";
48 return Status::FAILED;
49 }
50 factors[i] = divisor;
51 } else {
52 factors[i] = 1;
53 }
54 }
55 array->Init(factors);
56 return Status::SUCCESS;
57 }
58
UseStrictMode(const Shape & from_shape,const Shape & to_shape)59 bool UseStrictMode(const Shape &from_shape, const Shape &to_shape) {
60 if (from_shape.size() == to_shape.size()) {
61 for (size_t i = 0; i < from_shape.size(); ++i) {
62 if (from_shape[i] != to_shape[i]) {
63 return false;
64 }
65 }
66 return true;
67 }
68 return false;
69 }
70
GetLeastFactorWithoutConstDims(const Shape & to_shape,const Array & to_factors)71 int64_t GetLeastFactorWithoutConstDims(const Shape &to_shape, const Array &to_factors) {
72 Shape new_to_factors;
73 for (size_t i = 0; i < to_shape.size(); i++) {
74 if (to_shape.at(i) == -1 && to_factors.GetDimByIdx(i) != -1) {
75 new_to_factors.emplace_back(to_factors.GetDimByIdx(i));
76 }
77 }
78 if (new_to_factors.empty()) {
79 return 1;
80 }
81 int64_t factor = std::accumulate(new_to_factors.begin(), new_to_factors.end(), 1, std::multiplies<int64_t>());
82 return factor;
83 }
84
InitShapeVec(const Shape & src_shape,Shape * tgt_shape)85 void InitShapeVec(const Shape &src_shape, Shape *tgt_shape) {
86 size_t src_size = src_shape.size();
87 size_t tgt_size = tgt_shape->size();
88 size_t copy_size = std::min(src_size, tgt_size);
89 std::copy(src_shape.begin(), src_shape.begin() + copy_size, tgt_shape->begin());
90 if (tgt_size >= src_size) {
91 return;
92 }
93 for (size_t i = tgt_size; i < src_size; ++i) {
94 (*tgt_shape)[tgt_size - 1] *= src_shape[i];
95 }
96 if (GetTensorSize(src_shape) != GetTensorSize(*tgt_shape)) {
97 MS_LOG(ERROR) << "Failed to copy init tensor.";
98 }
99 }
100
CheckDynamicShape(const TensorLayout & from_in,const TensorLayout & to_in)101 bool CheckDynamicShape(const TensorLayout &from_in, const TensorLayout &to_in) {
102 Shape from_shape = from_in.tensor_shape().array();
103 Shape to_shape = to_in.tensor_shape().array();
104 auto func = [](const Shape &shape) -> bool { return std::find(shape.begin(), shape.end(), -1) != shape.end(); };
105 return func(from_shape) && func(to_shape);
106 }
107
UnifyFromAndToShape(Shape * new_from_shape,Shape * new_to_shape,const TensorLayout & from_in,const TensorLayout & to_in,ReplacementMemo * from_dims_replace_memo)108 void UnifyFromAndToShape(Shape *new_from_shape, Shape *new_to_shape, const TensorLayout &from_in,
109 const TensorLayout &to_in, ReplacementMemo *from_dims_replace_memo) {
110 Shape original_from_shape = from_in.tensor_shape().array();
111 Shape original_to_shape = to_in.tensor_shape().array();
112 for (size_t i = 0; i < new_from_shape->size(); ++i) {
113 if (original_from_shape[i] == -1) {
114 if (i < new_to_shape->size() && new_from_shape->at(i) < new_to_shape->at(i) &&
115 new_to_shape->at(i) % new_from_shape->at(i) == 0) {
116 int64_t scalar = new_to_shape->at(i) / new_from_shape->at(i);
117 for (size_t j = i + 1; j < new_from_shape->size(); ++j) {
118 if (original_from_shape[j] != -1) {
119 continue;
120 }
121 if (new_from_shape->at(j) > scalar && new_from_shape->at(j) % scalar == 0) {
122 (*new_from_shape)[j] = new_from_shape->at(j) / scalar;
123 (*new_from_shape)[i] = new_from_shape->at(i) * scalar;
124 RecordDimsChange(i, new_from_shape->at(i), from_dims_replace_memo, true);
125 RecordDimsChange(j, new_from_shape->at(j), from_dims_replace_memo, true);
126 break;
127 }
128 }
129 }
130 }
131 }
132 }
133
IntroduceConstraints(const Shape & expected_tgt_shape,Shape * tgt_shape)134 void IntroduceConstraints(const Shape &expected_tgt_shape, Shape *tgt_shape) {
135 // ([80,7,768,16], [-1,-1,3072,-1]) -> [80,7,3072,4]
136 // ([20480,768,1,1], [-1, 1024, 12, 64]) -> [20, 1024, 12, 64]
137 // Record fix dim index.
138 std::set<size_t> index;
139 std::vector<size_t> dynamic_dim_index;
140 for (size_t i = 0; i < expected_tgt_shape.size(); ++i) {
141 if (expected_tgt_shape[i] == -1) {
142 dynamic_dim_index.emplace_back(i);
143 }
144 }
145 for (size_t i = 0; i < expected_tgt_shape.size(); ++i) {
146 if (expected_tgt_shape[i] == -1) {
147 continue;
148 }
149 if (tgt_shape->at(i) == expected_tgt_shape[i]) {
150 index.insert(i);
151 continue;
152 }
153 if (tgt_shape->at(i) > expected_tgt_shape[i]) {
154 if (tgt_shape->at(i) % expected_tgt_shape[i] == 0) {
155 int64_t f = tgt_shape->at(i) / expected_tgt_shape[i];
156 for (int32_t j = static_cast<int32_t>(tgt_shape->size()) - 1; j >= 0; --j) {
157 if (j == static_cast<int32_t>(i) || index.find(j) != index.end()) {
158 continue;
159 }
160 (*tgt_shape)[j] *= f;
161 break;
162 }
163 (*tgt_shape)[i] = expected_tgt_shape[i];
164 } else {
165 MS_LOG(ERROR) << "Can't be divided.";
166 }
167 } else {
168 if (expected_tgt_shape[i] % tgt_shape->at(i) == 0) {
169 int64_t f = expected_tgt_shape[i] / tgt_shape->at(i);
170 for (int32_t j = static_cast<int32_t>(tgt_shape->size()) - 1; j >= 0; --j) {
171 if (j == static_cast<int32_t>(i) || index.find(j) != index.end()) {
172 continue;
173 }
174 int64_t divider = std::gcd(f, tgt_shape->at(j));
175 (*tgt_shape)[j] /= divider;
176 f /= divider;
177 if (f == 1) {
178 break;
179 }
180 }
181 if (f != 1) {
182 MS_LOG(ERROR) << "Can't merge shape.";
183 }
184 (*tgt_shape)[i] = expected_tgt_shape[i];
185 } else {
186 int64_t target_dim = expected_tgt_shape[i]; // 1024
187 for (int32_t j = static_cast<int32_t>(tgt_shape->size()) - 1; j >= 0; --j) {
188 if (index.find(j) != index.end()) {
189 continue;
190 }
191 int64_t divider = std::gcd(target_dim, tgt_shape->at(j));
192 (*tgt_shape)[j] /= divider;
193 target_dim /= divider;
194 if (target_dim == 1) {
195 break;
196 }
197 }
198 if (target_dim != 1) {
199 MS_LOG(ERROR) << "Can't be divided.";
200 } else {
201 // find last dyn dim on right and put tgt_shape->at(i) to it
202 (*tgt_shape)[dynamic_dim_index.back()] = tgt_shape->at(dynamic_dim_index.back()) * tgt_shape->at(i);
203 (*tgt_shape)[i] = expected_tgt_shape[i];
204 }
205 }
206 }
207 index.insert(i);
208 }
209 }
210
ForwardMatching(const Shape & src_shape,const Shape & expected_tgt_shape,Shape * tgt_shape,const Array & tgt_factors)211 bool ForwardMatching(const Shape &src_shape, const Shape &expected_tgt_shape, Shape *tgt_shape,
212 const Array &tgt_factors) {
213 // Borrow the size from right dim, then borrow the size from left dim.
214 // tgt_shape must be inited with value 1 and has fixed size.
215 InitShapeVec(src_shape, tgt_shape);
216 IntroduceConstraints(expected_tgt_shape, tgt_shape);
217 int64_t tensor_size = GetTensorSize(*tgt_shape);
218 size_t src_size = tgt_shape->size();
219 std::set<size_t> fix_index;
220 for (size_t i = 0; i < expected_tgt_shape.size(); ++i) {
221 if (expected_tgt_shape[i] != -1) {
222 fix_index.insert(i);
223 }
224 }
225 for (size_t i = 0; i < tgt_shape->size(); ++i) {
226 if (tgt_shape->at(i) % tgt_factors.GetDimByIdx(i) == 0) {
227 tensor_size /= tgt_shape->at(i);
228 continue;
229 }
230 // Borrow the size from right dim.
231 int64_t factor = tgt_factors.GetDimByIdx(i);
232 int64_t val = tgt_shape->at(i) * factor;
233 if (val > tensor_size) {
234 MS_LOG(DEBUG) << "Out of size when calculate index " << i;
235 return false;
236 }
237 size_t ptr = i + 1;
238 while (ptr < src_size) {
239 if (fix_index.find(ptr) != fix_index.end()) {
240 ++ptr;
241 continue;
242 }
243 if (tgt_shape->at(ptr) >= factor && tgt_shape->at(ptr) % factor == 0) {
244 (*tgt_shape)[ptr] /= factor;
245 factor = 1;
246 break;
247 }
248 int64_t divisor = std::gcd(tgt_shape->at(ptr), factor);
249 factor /= divisor;
250 (*tgt_shape)[ptr] /= divisor;
251 ++ptr;
252 }
253 if (factor != 1) {
254 MS_LOG(DEBUG) << "Out of size when calculate index " << i << ". Can't borrow dim from right.";
255 return false;
256 }
257 (*tgt_shape)[i] = val;
258 tensor_size /= val;
259 }
260 if (tensor_size != 1) {
261 MS_LOG(ERROR) << "Failed to forward matching.";
262 return false;
263 }
264 return true;
265 }
266
BackwardMatching(const Shape & expected_tgt_shape,Shape * tgt_shape,const Array & tgt_factors)267 bool BackwardMatching(const Shape &expected_tgt_shape, Shape *tgt_shape, const Array &tgt_factors) {
268 // Borrow the size from right dim.
269 // Then borrow the size from left dim.
270 int64_t ori_tensor_size = GetTensorSize(*tgt_shape);
271 int64_t dst_size = SizeToLong(tgt_shape->size());
272 std::set<size_t> fix_index;
273 for (size_t i = 0; i < expected_tgt_shape.size(); ++i) {
274 if (expected_tgt_shape[i] != -1) {
275 fix_index.insert(i);
276 }
277 }
278 for (int32_t i = dst_size - 1; i >= 0; --i) {
279 // Borrow the size from left dim.
280 int64_t factor = tgt_factors.GetDimByIdx(i);
281 if (tgt_shape->at(i) % factor == 0) {
282 continue;
283 }
284 int64_t to_be_filled_dim = tgt_shape->at(i) * factor;
285 int32_t ptr = i - 1;
286 while (ptr >= 0) {
287 if (fix_index.find(ptr) != fix_index.end()) {
288 --ptr;
289 continue;
290 }
291 if (tgt_shape->at(ptr) % factor == 0 && tgt_shape->at(ptr) / factor % tgt_factors.GetDimByIdx(ptr) == 0) {
292 (*tgt_shape)[ptr] /= factor;
293 factor = 1;
294 break;
295 }
296 int64_t divisor = std::gcd(tgt_shape->at(ptr), factor);
297 factor /= divisor;
298 (*tgt_shape)[ptr] /= divisor;
299 --ptr;
300 }
301 if (factor != 1) {
302 MS_LOG(ERROR) << "Can't borrow factor from left.";
303 return false;
304 }
305 (*tgt_shape)[i] = to_be_filled_dim;
306 }
307 if (ori_tensor_size != GetTensorSize(*tgt_shape)) {
308 MS_LOG(ERROR) << "After backward matching, tensor size is not equal.";
309 return false;
310 }
311 return true;
312 }
313
SolveCombination(const Shape & src_shape_arr,size_t src_index,const std::vector<std::vector<int64_t>> & enum_numbers,size_t offset,int64_t target,std::vector<int64_t> * candidates_values)314 bool SolveCombination(const Shape &src_shape_arr, size_t src_index,
315 const std::vector<std::vector<int64_t>> &enum_numbers, size_t offset, int64_t target,
316 std::vector<int64_t> *candidates_values) {
317 bool is_last = (enum_numbers.size() - offset - 1) == 0;
318 if (src_index < src_shape_arr.size()) {
319 constexpr size_t MAX_DIM = 8;
320 for (size_t factor = 1; factor < MAX_DIM; ++factor) {
321 int64_t preferred_choose = SizeToLong(factor) * src_shape_arr[src_index];
322 if (std::find(enum_numbers[offset].begin(), enum_numbers[offset].end(), preferred_choose) !=
323 enum_numbers[offset].end() &&
324 preferred_choose <= target && target % preferred_choose == 0) {
325 (*candidates_values)[offset] = preferred_choose;
326 if (!is_last && SolveCombination(src_shape_arr, src_index + 1, enum_numbers, offset + 1,
327 target / candidates_values->at(offset), candidates_values)) {
328 return true;
329 }
330 }
331 }
332 }
333 for (size_t i = 0; i < enum_numbers[offset].size(); ++i) {
334 if (enum_numbers[offset][i] > target) {
335 break;
336 }
337 if (target % enum_numbers[offset][i] != 0) {
338 continue;
339 }
340 (*candidates_values)[offset] = enum_numbers[offset][i];
341 if (is_last && target / enum_numbers[offset][i] == 1) {
342 return true;
343 }
344 if (!is_last && SolveCombination(src_shape_arr, src_index, enum_numbers, offset + 1,
345 target / enum_numbers[offset][i], candidates_values)) {
346 return true;
347 }
348 }
349 return false;
350 }
351 } // namespace mindspore::parallel
352