1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/preprocess_dynamic_shape.h"
19 #include <algorithm>
20 #include <functional>
21 #include <map>
22 #include <set>
23 #include <string>
24 #include <vector>
25 #include "mindspore/core/ops/lite_ops.h"
26 #include "mindspore/core/ops/comparison_ops.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "tools/optimizer/common/format_utils.h"
30 #include "tools/optimizer/common/gllo_utils.h"
31 #include "tools/lite_exporter/fetch_content.h"
32 #include "ops/op_name.h"
33 #include "nnacl/op_base.h"
34
35 namespace mindspore {
36 namespace opt {
37 namespace {
DoStack(const CNodePtr & cnode,const ShapeVector & out_shape,ShapeVector * out_data)38 int DoStack(const CNodePtr &cnode, const ShapeVector &out_shape, ShapeVector *out_data) {
39 MS_ASSERT(cnode != nullptr && out_data != nullptr);
40 if (!CheckPrimitiveType(cnode, prim::kPrimStack)) {
41 return lite::RET_NOT_SUPPORT;
42 }
43 if (out_shape.size() != 1 || out_shape.front() <= 0) {
44 return lite::RET_NOT_SUPPORT;
45 }
46 auto origin_inputs = cnode->inputs();
47 if (lite::RemoveIfDepend(cnode) != RET_OK) {
48 cnode->set_inputs(origin_inputs);
49 return lite::RET_NOT_SUPPORT;
50 }
51 RemoveIfMonad(cnode);
52 if (lite::RemoveIfMakeTuple(cnode) != RET_OK) {
53 cnode->set_inputs(origin_inputs);
54 return lite::RET_NOT_SUPPORT;
55 }
56 auto current_inputs = cnode->inputs();
57 for (size_t i = 1; i < current_inputs.size(); ++i) {
58 if (utils::isa<CNode>(current_inputs[i])) {
59 out_data->push_back(-1);
60 continue;
61 }
62 lite::DataInfo data_info;
63 if (lite::FetchConstData(cnode, i, converter::kFmkTypeMs, &data_info, false) != lite::RET_OK) {
64 cnode->set_inputs(origin_inputs);
65 MS_LOG(ERROR) << "etch stack's const data failed.";
66 return lite::RET_ERROR;
67 }
68 if (data_info.data_ptr_ == nullptr ||
69 (data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
70 std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<>()) != 1) {
71 cnode->set_inputs(origin_inputs);
72 return lite::RET_NOT_SUPPORT;
73 }
74 out_data->push_back(*static_cast<int *>(data_info.data_ptr_));
75 }
76 cnode->set_inputs(origin_inputs);
77 return lite::RET_OK;
78 }
79
ArithmeticInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)80 int ArithmeticInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
81 std::vector<ShapeVector> *out_shapes) {
82 MS_ASSERT(cnode != nullptr);
83 if (cnode->size() < kInputSizeThree || in_shapes.size() < kInputSizeTwo) {
84 MS_LOG(ERROR) << "Mul should have two inputs.";
85 return lite::RET_ERROR;
86 }
87 const auto &first_shape = in_shapes.front();
88 const auto &second_shape = in_shapes[1];
89 size_t out_shape_size = first_shape.size() >= second_shape.size() ? first_shape.size() : second_shape.size();
90 ShapeVector first_shape_expand;
91 for (size_t i = 0; i < (out_shape_size - first_shape.size()); ++i) {
92 first_shape_expand.push_back(1);
93 }
94 (void)first_shape_expand.insert(first_shape_expand.end(), first_shape.begin(), first_shape.end());
95 ShapeVector second_shape_expand;
96 for (size_t i = 0; i < (out_shape_size - second_shape.size()); ++i) {
97 second_shape_expand.push_back(1);
98 }
99 (void)second_shape_expand.insert(second_shape_expand.end(), second_shape.begin(), second_shape.end());
100 ShapeVector out_shape;
101 for (size_t i = 0; i < out_shape_size; ++i) {
102 if (first_shape_expand[i] == second_shape_expand[i]) {
103 out_shape.push_back(first_shape_expand[i]);
104 continue;
105 }
106 if (first_shape_expand[i] == 1) {
107 out_shape.push_back(second_shape_expand[i]);
108 continue;
109 }
110 if (second_shape_expand[i] == 1) {
111 out_shape.push_back(first_shape_expand[i]);
112 continue;
113 }
114 MS_LOG(INFO) << "Mul cannot determine out-shape.";
115 return lite::RET_NOT_SUPPORT;
116 }
117 out_shapes->clear();
118 out_shapes->push_back(out_shape);
119 return lite::RET_OK;
120 }
121
CommonInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)122 int CommonInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
123 std::vector<ShapeVector> *out_shapes) {
124 out_shapes->clear();
125 (void)out_shapes->insert(out_shapes->begin(), in_shapes.begin(), in_shapes.end());
126 return lite::RET_OK;
127 }
128
ConcatInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)129 int ConcatInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
130 std::vector<ShapeVector> *out_shapes) {
131 MS_ASSERT(cnode != nullptr);
132 if (cnode->size() < kInputSizeTwo || in_shapes.empty()) {
133 MS_LOG(ERROR) << "Concat should have at least one input.";
134 return lite::RET_ERROR;
135 }
136 auto prim = GetCNodePrimitive(cnode);
137 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "Concat's primitive is a nullptr.");
138 int axis = 0;
139 if (prim->GetAttr(ops::kAxis) != nullptr) {
140 axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
141 }
142 ShapeVector out_shape = in_shapes.front();
143 size_t rank = out_shape.size();
144 if (axis < 0) {
145 axis += rank;
146 }
147 MS_CHECK_TRUE_MSG(axis >= 0 && axis < static_cast<int>(rank), lite::RET_ERROR,
148 "Concat's axis doesn't match with shape.");
149 int64_t axis_sum = 0;
150 for (const auto &in_shape : in_shapes) {
151 if (in_shape.size() != rank) {
152 return lite::RET_NOT_SUPPORT;
153 }
154 if (in_shape[axis] < 0) {
155 axis_sum = -1;
156 break;
157 }
158 axis_sum += in_shape[axis];
159 }
160 out_shape[axis] = axis_sum;
161 out_shapes->clear();
162 out_shapes->push_back(out_shape);
163 return lite::RET_OK;
164 }
165
ExpandDimsInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)166 int ExpandDimsInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
167 std::vector<ShapeVector> *out_shapes) {
168 MS_ASSERT(cnode != nullptr);
169 if (cnode->size() < kInputSizeThree || in_shapes.size() < kInputSizeTwo) {
170 MS_LOG(ERROR) << "Expanddims should have two inputs.";
171 return lite::RET_ERROR;
172 }
173 auto second_input = cnode->input(kInputIndexTwo);
174 MS_CHECK_TRUE_MSG(second_input != nullptr, lite::RET_ERROR, "Expanddims's second input is a nullptr.");
175 if (second_input->isa<CNode>()) {
176 return lite::RET_NOT_SUPPORT;
177 }
178 lite::DataInfo data_info;
179 auto ret = lite::FetchConstData(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, false);
180 MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Expanddims fetch second-input's data failed.");
181 MS_CHECK_TRUE_MSG(data_info.data_ptr_ != nullptr, lite::RET_ERROR,
182 "Expanddims's second-input's data shouldn't a nullptr.");
183 MS_CHECK_TRUE_MSG(data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32, lite::RET_ERROR,
184 "Expanddims's second-input's data-type should be int.");
185 auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
186 MS_CHECK_TRUE_MSG(element_num == 1, lite::RET_ERROR, "Expanddims's second-input should be a scalar.");
187 auto axis = *static_cast<int *>(data_info.data_ptr_);
188 auto first_shape = in_shapes.front();
189 auto first_shape_size = static_cast<int>(first_shape.size());
190 if (axis < 0) {
191 axis = first_shape_size + axis + 1;
192 }
193 MS_CHECK_TRUE_MSG(axis >= 0 && axis <= first_shape_size, lite::RET_ERROR, "Expanddims's second-input is invalid.");
194 out_shapes->clear();
195 (void)first_shape.insert(first_shape.begin() + axis, 1);
196 out_shapes->push_back(first_shape);
197 return lite::RET_OK;
198 }
199
GatherInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)200 int GatherInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
201 std::vector<ShapeVector> *out_shapes) {
202 MS_ASSERT(cnode != nullptr);
203 if (cnode->size() < kInputSizeFour || in_shapes.size() < kInputSizeThree) {
204 MS_LOG(ERROR) << "Gther should have three inputs.";
205 return lite::RET_ERROR;
206 }
207 auto third_input = cnode->input(kInputIndexThree);
208 MS_CHECK_TRUE_MSG(third_input != nullptr, lite::RET_ERROR, "Gather's third input is a nullptr.");
209 if (third_input->isa<CNode>()) {
210 return lite::RET_NOT_SUPPORT;
211 }
212 lite::DataInfo data_info;
213 auto ret = lite::FetchConstData(cnode, kInputIndexThree, converter::kFmkTypeMs, &data_info, false);
214 MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Gather fetch second-input's data failed.");
215 auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
216 MS_CHECK_TRUE_MSG(element_num <= 1, lite::RET_ERROR, "Gather's second-input should be a scalar.");
217 int axis{0};
218 if (element_num == 1) {
219 MS_CHECK_TRUE_MSG(data_info.data_ptr_ != nullptr, lite::RET_ERROR,
220 "Gather's second-input's data shouldn't a nullptr.");
221 if (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32) {
222 axis = *static_cast<int *>(data_info.data_ptr_);
223 } else if (data_info.data_type_ == kNumberTypeInt64) {
224 axis = *static_cast<int64_t *>(data_info.data_ptr_);
225 } else {
226 MS_LOG(ERROR) << "Gather's axis is invalid, which should be int or int64.";
227 return lite::RET_ERROR;
228 }
229 }
230 const auto &first_shape = in_shapes.front();
231 auto first_shape_size = static_cast<int>(first_shape.size());
232 if (axis < 0) {
233 axis = first_shape_size + axis;
234 }
235 MS_CHECK_TRUE_MSG(axis >= 0 && axis < first_shape_size, lite::RET_ERROR, "Gather's axis out of range.");
236 const auto &second_shape = in_shapes[1];
237 ShapeVector out_shape;
238 for (int i = 0; i < axis; ++i) {
239 out_shape.push_back(first_shape[i]);
240 }
241 (void)out_shape.insert(out_shape.end(), second_shape.begin(), second_shape.end());
242 for (int i = axis + 1; i < first_shape_size; ++i) {
243 out_shape.push_back(first_shape[i]);
244 }
245 out_shapes->clear();
246 out_shapes->push_back(out_shape);
247 return lite::RET_OK;
248 }
249
MatMulInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)250 int MatMulInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
251 std::vector<ShapeVector> *out_shapes) {
252 MS_ASSERT(cnode != nullptr);
253 if (cnode->size() < kInputSizeThree || in_shapes.size() < kInputSizeTwo) {
254 MS_LOG(ERROR) << "MatMul should have at least two inputs.";
255 return lite::RET_ERROR;
256 }
257 auto prim = GetCNodePrimitive(cnode);
258 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "MatMul's primitive is a nullptr.");
259 bool a_trans = prim->GetAttr(ops::kTransposeA) && GetValue<bool>(prim->GetAttr(ops::kTransposeA));
260 bool b_trnas = prim->GetAttr(ops::kTransposeB) && GetValue<bool>(prim->GetAttr(ops::kTransposeB));
261 const auto &a_shape = in_shapes.front();
262 MS_CHECK_TRUE_RET(a_shape.size() >= kInputSizeTwo, lite::RET_NOT_SUPPORT);
263 const auto &b_shape = in_shapes[1];
264 MS_CHECK_TRUE_RET(b_shape.size() >= kInputSizeTwo, lite::RET_NOT_SUPPORT);
265 size_t a_rank = a_shape.size();
266 size_t b_rank = b_shape.size();
267 size_t out_rank = std::max(a_rank, b_rank);
268 ShapeVector a_pre_shape;
269 (void)a_pre_shape.insert(a_pre_shape.end(), out_rank - a_rank, 1);
270 (void)a_pre_shape.insert(a_pre_shape.end(), a_shape.begin(), a_shape.begin() + a_rank - C2NUM);
271 ShapeVector b_pre_shape;
272 (void)b_pre_shape.insert(b_pre_shape.end(), out_rank - b_rank, 1);
273 (void)b_pre_shape.insert(b_pre_shape.end(), b_shape.begin(), b_shape.begin() + b_rank - C2NUM);
274 ShapeVector out_shape;
275 MS_ASSERT(a_pre_shape.size() == b_pre_shape.size());
276 for (size_t i = 0; i < out_rank - C2NUM; ++i) {
277 if (a_pre_shape[i] == b_pre_shape[i]) {
278 out_shape.push_back(a_pre_shape[i]);
279 continue;
280 }
281 if (a_pre_shape[i] == 1) {
282 out_shape.push_back(b_pre_shape[i]);
283 continue;
284 }
285 if (b_pre_shape[i] == 1) {
286 out_shape.push_back(a_pre_shape[i]);
287 continue;
288 }
289 return lite::RET_NOT_SUPPORT;
290 }
291 out_shape.push_back(a_trans ? a_shape.back() : a_shape[a_rank - C2NUM]);
292 out_shape.push_back(b_trnas ? b_shape[b_rank - C2NUM] : b_shape.back());
293 out_shapes->clear();
294 out_shapes->push_back(out_shape);
295 return lite::RET_OK;
296 }
297
ReduceInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)298 int ReduceInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
299 std::vector<ShapeVector> *out_shapes) {
300 MS_ASSERT(cnode != nullptr);
301 MS_CHECK_FALSE_MSG(cnode->size() < kInputSizeThree || in_shapes.size() < kInputSizeTwo, lite::RET_ERROR,
302 "Reduce should have two inputs");
303 auto prim = GetCNodePrimitive(cnode);
304 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "Reduce's primitive is a nullptr.");
305 bool keep_dim = prim->GetAttr(ops::kKeepDims) != nullptr && GetValue<bool>(prim->GetAttr(ops::kKeepDims));
306 bool reduce_to_end = prim->GetAttr(ops::kReduceToEnd) != nullptr && GetValue<bool>(prim->GetAttr(ops::kReduceToEnd));
307 if (reduce_to_end) {
308 return lite::RET_NOT_SUPPORT;
309 }
310 auto second_input = cnode->input(kInputIndexTwo);
311 MS_CHECK_TRUE_MSG(second_input != nullptr, lite::RET_ERROR, "Reduce's second input is a nullptr.");
312 if (second_input->isa<CNode>()) {
313 return lite::RET_NOT_SUPPORT;
314 }
315 lite::DataInfo data_info;
316 auto ret = lite::FetchConstData(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, false);
317 MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Reduce fetch second-input's data failed.");
318 MS_CHECK_TRUE_MSG(data_info.shape_.size() <= 1, lite::RET_ERROR, "Reduce second-input should be <= 1D.");
319 std::set<int> reduce_axes;
320 int rank = static_cast<int>(in_shapes.front().size());
321 if (data_info.data_ptr_ == nullptr) {
322 MS_LOG(INFO) << "reduce op rand is: " << rank << ", cnode name: " << cnode->fullname_with_scope();
323 for (int dim = 0; dim < rank; dim++) {
324 (void)reduce_axes.insert(dim);
325 }
326 } else {
327 int element_num = data_info.shape_.empty() ? 1 : data_info.shape_.front();
328 std::vector<int> temp;
329 int *axes{nullptr};
330 if (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32) {
331 axes = static_cast<int *>(data_info.data_ptr_);
332 } else if (data_info.data_type_ == kNumberTypeInt64) {
333 (void)temp.insert(temp.begin(), static_cast<int64_t *>(data_info.data_ptr_),
334 static_cast<int64_t *>(data_info.data_ptr_) + element_num);
335 axes = temp.data();
336 } else {
337 return lite::RET_NOT_SUPPORT;
338 }
339 for (int i = 0; i < element_num; ++i) {
340 int axis = axes[i] >= 0 ? axes[i] : axes[i] + rank;
341 MS_CHECK_TRUE_MSG(axis >= 0 && axis < rank, lite::RET_ERROR, "Reduce's axis is out of range.");
342 (void)reduce_axes.insert(axis);
343 }
344 }
345 int start = 0;
346 ShapeVector out_shape;
347 for (auto iter = reduce_axes.begin(); iter != reduce_axes.end(); ++iter) {
348 int end = *iter;
349 for (; start < end; ++start) {
350 out_shape.push_back(in_shapes.front()[start]);
351 }
352 if (keep_dim) {
353 out_shape.push_back(1);
354 }
355 ++start;
356 }
357 for (; start < rank; ++start) {
358 out_shape.push_back(in_shapes.front()[start]);
359 }
360 out_shapes->clear();
361 out_shapes->push_back(out_shape);
362 return lite::RET_OK;
363 }
364
ReshapeInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)365 int ReshapeInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
366 std::vector<ShapeVector> *out_shapes) {
367 MS_ASSERT(cnode != nullptr);
368 if (cnode->size() < kInputSizeTwo) {
369 (void)out_shapes->emplace_back();
370 return lite::RET_OK;
371 }
372 if (in_shapes.size() < kInputSizeTwo) {
373 MS_LOG(ERROR) << "Reshape should have two inputs.";
374 return lite::RET_ERROR;
375 }
376 out_shapes->clear();
377 auto second_input = cnode->input(kInputIndexTwo);
378 MS_CHECK_TRUE_MSG(second_input != nullptr, lite::RET_ERROR, "Reshape's second input is a nullptr.");
379 if (second_input->isa<CNode>()) {
380 const auto &second_in_shape = in_shapes[1];
381 if (second_in_shape.size() != 1 || second_in_shape.front() <= 0) {
382 return lite::RET_NOT_SUPPORT;
383 }
384 ShapeVector out_shape;
385 auto ret = DoStack(second_input->cast<CNodePtr>(), second_in_shape, &out_shape);
386 if (ret == lite::RET_NOT_SUPPORT) {
387 out_shape = ShapeVector(second_in_shape.front(), -1);
388 } else if (ret != lite::RET_OK) {
389 MS_LOG(ERROR) << "Do stack failed.";
390 return ret;
391 }
392 out_shapes->push_back(out_shape);
393 return lite::RET_OK;
394 }
395 lite::DataInfo data_info;
396 auto ret = lite::FetchConstData(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, false);
397 MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Reshape fetch second-input's data failed.");
398 MS_CHECK_TRUE_MSG(data_info.shape_.size() <= 1, lite::RET_ERROR, "Reshape second-input should be <= 1D.");
399 if (data_info.data_ptr_ == nullptr || (data_info.shape_.size() == 1 && data_info.shape_.front() == 0)) {
400 (void)out_shapes->emplace_back();
401 }
402 auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
403 ShapeVector out_shape;
404 if (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32) {
405 for (int i = 0; i < element_num; ++i) {
406 out_shape.push_back(*(static_cast<int *>(data_info.data_ptr_) + i));
407 }
408 } else if (data_info.data_type_ == kNumberTypeInt64) {
409 for (int i = 0; i < element_num; ++i) {
410 out_shape.push_back(*(static_cast<int64_t *>(data_info.data_ptr_) + i));
411 }
412 } else {
413 return lite::RET_NOT_SUPPORT;
414 }
415 const auto &in_shape = in_shapes.front();
416 for (size_t i = 0; i < out_shape.size(); ++i) {
417 if (out_shape[i] == 0) {
418 MS_CHECK_TRUE_MSG(in_shape.size() > i, lite::RET_ERROR, "Reshape's in-rank is invalid.");
419 out_shape[i] = in_shape[i];
420 }
421 }
422 out_shapes->push_back(out_shape);
423 return lite::RET_OK;
424 }
425
ShapeInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)426 int ShapeInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
427 std::vector<ShapeVector> *out_shapes) {
428 MS_ASSERT(cnode != nullptr);
429 if (cnode->size() < kInputSizeTwo || in_shapes.empty()) {
430 MS_LOG(ERROR) << "Shape should have one inputs.";
431 return lite::RET_ERROR;
432 }
433 ShapeVector out_shape = {static_cast<int64_t>(in_shapes.front().size())};
434 out_shapes->clear();
435 out_shapes->push_back(out_shape);
436 return lite::RET_OK;
437 }
438
SplitInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)439 int SplitInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
440 std::vector<ShapeVector> *out_shapes) {
441 MS_ASSERT(cnode != nullptr);
442 if (cnode->size() < kInputSizeTwo || in_shapes.empty()) {
443 MS_LOG(ERROR) << "Split should have one inputs.";
444 return lite::RET_ERROR;
445 }
446 auto prim = GetCNodePrimitive(cnode);
447 auto out_num = prim->GetAttr(ops::kOutputNum) == nullptr ? 0 : GetValue<int64_t>(prim->GetAttr(ops::kOutputNum));
448 auto size_splits = prim->GetAttr(ops::kSizeSplits) == nullptr
449 ? std::vector<int64_t>{}
450 : GetValue<std::vector<int64_t>>(prim->GetAttr(ops::kSizeSplits));
451 out_num = (out_num == 0 ? static_cast<int64_t>(size_splits.size()) : out_num);
452 if (out_num <= 0) {
453 return lite::RET_NOT_SUPPORT;
454 }
455 auto axis = prim->GetAttr(ops::kAxis) == nullptr ? 0 : GetValue<int64_t>(prim->GetAttr(ops::kAxis));
456 auto &in_shape = in_shapes.front();
457 axis = axis < 0 ? static_cast<int64_t>(in_shape.size()) + axis : axis;
458 MS_CHECK_TRUE_MSG(axis >= 0 && axis < static_cast<int64_t>(in_shape.size()), lite::RET_ERROR,
459 "Split's axis is out of range.");
460 out_shapes->clear();
461 ShapeVector out_shape = in_shape;
462 if (size_splits.empty()) {
463 MS_CHECK_TRUE_MSG(in_shape[axis] > 0 && in_shape[axis] % out_num == 0, lite::RET_ERROR,
464 "Split's dim doesn't match split-axis.");
465 out_shape[axis] = in_shape[axis] / out_num;
466 (void)out_shapes->insert(out_shapes->end(), out_num, out_shape);
467 } else {
468 for (auto v : size_splits) {
469 out_shape[axis] = v;
470 out_shapes->push_back(out_shape);
471 }
472 }
473 return lite::RET_OK;
474 }
475
SqueezeInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)476 int SqueezeInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
477 std::vector<ShapeVector> *out_shapes) {
478 MS_ASSERT(cnode != nullptr);
479 if (in_shapes.empty()) {
480 MS_LOG(ERROR) << "Squeeze should have one input at least.";
481 return lite::RET_ERROR;
482 }
483 auto prim = GetCNodePrimitive(cnode);
484 if (prim == nullptr) {
485 MS_LOG(ERROR) << "Squeeze's primitive is a nullptr.";
486 return lite::RET_ERROR;
487 }
488 auto axes = prim->GetAttr(ops::kAxis) != nullptr ? GetValue<std::vector<int64_t>>(prim->GetAttr(ops::kAxis))
489 : std::vector<int64_t>();
490 auto &in_shape = in_shapes.front();
491 ShapeVector out_shape;
492 if (axes.empty()) {
493 for (size_t i = 0; i < in_shape.size(); ++i) {
494 if (in_shape[i] < 0) {
495 return lite::RET_NOT_SUPPORT;
496 }
497 if (in_shape[i] != 1) {
498 out_shape.push_back(in_shape[i]);
499 }
500 }
501 } else {
502 auto dims = static_cast<int64_t>(in_shape.size());
503 std::vector<int> flags(dims, 0);
504 for (auto axis : axes) {
505 axis = axis < 0 ? axis + dims : axis;
506 if (axis < 0 || axis >= dims) {
507 MS_LOG(ERROR) << "Squeeze's axis is invalid. node name is " << cnode->fullname_with_scope();
508 return lite::RET_ERROR;
509 }
510 flags[axis] = 1;
511 }
512 for (int64_t i = 0; i < dims; ++i) {
513 if (flags[i] == 0) {
514 out_shape.push_back(in_shape[i]);
515 }
516 }
517 }
518 out_shapes->clear();
519 out_shapes->push_back(out_shape);
520 return lite::RET_OK;
521 }
522
StackInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)523 int StackInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
524 std::vector<ShapeVector> *out_shapes) {
525 MS_ASSERT(cnode != nullptr);
526 if (in_shapes.empty()) {
527 MS_LOG(ERROR) << "Stack should have one input at least.";
528 return lite::RET_ERROR;
529 }
530 auto dims = in_shapes.front().size();
531 if (std::any_of(in_shapes.begin(), in_shapes.end(),
532 [dims](const ShapeVector &in_shape) { return in_shape.size() != dims; })) {
533 MS_LOG(ERROR) << "Stack all-inputs should hava same rank.";
534 return lite::RET_INPUT_TENSOR_ERROR;
535 }
536 if (std::any_of(in_shapes.begin(), in_shapes.end(), [](const ShapeVector &in_shape) {
537 return std::any_of(in_shape.begin(), in_shape.end(), [](int64_t val) { return val == 0; });
538 })) {
539 return lite::RET_NOT_SUPPORT;
540 }
541 auto prim = GetCNodePrimitive(cnode);
542 auto axis = prim->GetAttr(ops::kAxis) == nullptr ? 0 : GetValue<int64_t>(prim->GetAttr(ops::kAxis));
543 if (axis < 0) {
544 axis += static_cast<int64_t>(dims);
545 }
546 if (axis < 0 || axis > static_cast<int64_t>(dims)) {
547 MS_LOG(ERROR) << "stack's axis is invalid.";
548 return lite::RET_PARAM_INVALID;
549 }
550 ShapeVector out_shape;
551 auto FillShape = [&out_shape, &in_shapes](int64_t start, int64_t end) mutable {
552 for (; start < end; ++start) {
553 ShapeVector vertical;
554 for (const auto &in_shape : in_shapes) {
555 if (in_shape[start] >= 0) {
556 vertical.push_back(in_shape[start]);
557 } else if (in_shape[start] != -1) {
558 MS_LOG(ERROR) << "Stack's input-shape must not have a dim-value less than -1.";
559 return lite::RET_INPUT_TENSOR_ERROR;
560 }
561 }
562 out_shape.push_back(vertical.size() < in_shapes.size() ? -1 : vertical.front());
563 if (!vertical.empty()) {
564 int64_t dim = vertical.front();
565 if (std::any_of(vertical.begin(), vertical.end(), [dim](const int64_t value) { return value != dim; })) {
566 MS_LOG(ERROR) << "Stack's input-shape must be same each other.";
567 return lite::RET_INPUT_TENSOR_ERROR;
568 }
569 }
570 }
571 return lite::RET_OK;
572 };
573 if (FillShape(0, axis) != lite::RET_OK) {
574 MS_LOG(ERROR) << "Stack do fillShape failed.";
575 return lite::RET_ERROR;
576 }
577 out_shape.push_back(static_cast<int64_t>(in_shapes.size()));
578 if (FillShape(axis, dims) != lite::RET_OK) {
579 MS_LOG(ERROR) << "Stack do fillShape failed.";
580 return lite::RET_ERROR;
581 }
582 out_shapes->clear();
583 out_shapes->push_back(out_shape);
584 return lite::RET_OK;
585 }
586
CheckStridedSlice(const CNodePtr & cnode,int64_t in_rank,lite::DataInfo * begins,lite::DataInfo * ends)587 int CheckStridedSlice(const CNodePtr &cnode, int64_t in_rank, lite::DataInfo *begins, lite::DataInfo *ends) {
588 MS_ASSERT(cnode != nullptr);
589 auto prim = GetCNodePrimitive(cnode);
590 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "StridedSlice's primitive is a nullptr.");
591 int64_t ellipsis_mask = prim->GetAttr(ops::kEllipsisMask) ? GetValue<int64_t>(prim->GetAttr(ops::kEllipsisMask)) : 0;
592 int64_t new_axis_mask = prim->GetAttr(ops::kNewAxisMask) ? GetValue<int64_t>(prim->GetAttr(ops::kNewAxisMask)) : 0;
593 if ((ellipsis_mask | new_axis_mask) != 0) {
594 return lite::RET_NOT_SUPPORT;
595 }
596 for (size_t i = C2NUM; i < kInputSizeFive; ++i) {
597 MS_CHECK_TRUE_MSG(cnode->input(i) != nullptr, lite::RET_ERROR, "StridedSlice's input is a nullptr.");
598 if (utils::isa<CNode>(cnode->input(i))) {
599 return lite::RET_NOT_SUPPORT;
600 }
601 }
602 auto BasicCond = [](const lite::DataInfo &data_info) {
603 return data_info.data_ptr_ != nullptr &&
604 (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32);
605 };
606 if (lite::FetchConstData(cnode, C2NUM, converter::kFmkTypeMs, begins, false) != lite::RET_OK) {
607 MS_LOG(ERROR) << "Fetch StridedSlice's begins failed.";
608 return lite::RET_ERROR;
609 }
610 MS_CHECK_TRUE_RET(begins->shape_.size() == C1NUM && begins->shape_.front() <= in_rank && BasicCond(*begins),
611 lite::RET_NOT_SUPPORT);
612 if (lite::FetchConstData(cnode, C3NUM, converter::kFmkTypeMs, ends, false) != lite::RET_OK) {
613 MS_LOG(ERROR) << "Fetch StridedSlice's ends failed.";
614 return lite::RET_ERROR;
615 }
616 MS_CHECK_TRUE_RET(ends->shape_ == begins->shape_ && BasicCond(*ends), lite::RET_NOT_SUPPORT);
617 lite::DataInfo strides;
618 if (lite::FetchConstData(cnode, C4NUM, converter::kFmkTypeMs, &strides, false) != lite::RET_OK) {
619 MS_LOG(ERROR) << "Fetch StridedSlice's strides failed.";
620 return lite::RET_ERROR;
621 }
622 MS_CHECK_TRUE_RET(strides.shape_ == begins->shape_ && BasicCond(strides), lite::RET_NOT_SUPPORT);
623 for (int i = 0; i < strides.shape_.front(); ++i) {
624 if (static_cast<int *>(strides.data_ptr_)[i] != 1) {
625 return lite::RET_NOT_SUPPORT;
626 }
627 }
628 return lite::RET_OK;
629 }
630
StridedSliceInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)631 int StridedSliceInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
632 std::vector<ShapeVector> *out_shapes) {
633 MS_ASSERT(cnode != nullptr);
634 if (cnode->size() != kInputSizeFive || in_shapes.size() != kInputSizeFour) {
635 return lite::RET_NOT_SUPPORT;
636 }
637 lite::DataInfo begins;
638 lite::DataInfo ends;
639 auto ret = CheckStridedSlice(cnode, in_shapes.front().size(), &begins, &ends);
640 if (ret != lite::RET_OK) {
641 return ret;
642 }
643
644 auto prim = GetCNodePrimitive(cnode);
645 int64_t begin_mask = prim->GetAttr(ops::kBeginMask) ? GetValue<int64_t>(prim->GetAttr(ops::kBeginMask)) : 0;
646 int64_t end_mask = prim->GetAttr(ops::kEndMask) ? GetValue<int64_t>(prim->GetAttr(ops::kEndMask)) : 0;
647 int64_t shrink_mask =
648 prim->GetAttr(ops::kShrinkAxisMask) ? GetValue<int64_t>(prim->GetAttr(ops::kShrinkAxisMask)) : 0;
649 const auto &in_shape = in_shapes.front();
650 ShapeVector out_shape;
651 int index = 0;
652 for (; index < begins.shape_.front(); ++index) {
653 if (shrink_mask & (1 << index)) {
654 continue;
655 }
656 int b_mask = begin_mask & (1 << index);
657 int e_mask = end_mask & (1 << index);
658 if (b_mask && e_mask) {
659 out_shape.push_back(in_shape[index]);
660 continue;
661 }
662 int64_t begin = static_cast<int *>(begins.data_ptr_)[index];
663 int64_t end = static_cast<int *>(ends.data_ptr_)[index];
664 if (b_mask) {
665 begin = 0;
666 }
667 if (e_mask) {
668 end = in_shape[index];
669 }
670 if (in_shape[index] > 0) {
671 begin += (begin >= 0 ? 0 : in_shape[index]);
672 end += (end >= 0 ? 0 : in_shape[index]);
673 }
674 if (begin < 0 || end < 0 || begin > end) {
675 return lite::RET_NOT_SUPPORT;
676 }
677 out_shape.push_back(end - begin);
678 }
679 (void)out_shape.insert(out_shape.end(), in_shape.begin() + index, in_shape.end());
680 out_shapes->clear();
681 out_shapes->push_back(out_shape);
682 return lite::RET_OK;
683 }
684
TransposeInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)685 int TransposeInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
686 std::vector<ShapeVector> *out_shapes) {
687 MS_ASSERT(cnode != nullptr);
688 out_shapes->clear();
689 if (in_shapes.size() == 1) {
690 auto in_shape = in_shapes.front();
691 ShapeVector out_shape(in_shape.rbegin(), in_shape.rend());
692 out_shapes->push_back(out_shape);
693 return lite::RET_OK;
694 }
695 if (in_shapes.size() != C2NUM) {
696 MS_LOG(ERROR) << "Transpose's input should be 1 or 2, now is " << in_shapes.size();
697 return lite::RET_INPUT_TENSOR_ERROR;
698 }
699 if (utils::isa<CNode>(cnode->input(ops::kInputIndex2))) {
700 return lite::RET_NOT_SUPPORT;
701 }
702 lite::DataInfo data_info;
703 if (lite::FetchConstData(cnode, ops::kInputIndex2, converter::kFmkTypeMs, &data_info, false)) {
704 MS_LOG(ERROR) << "Fetch constant info failed, " << cnode->fullname_with_scope();
705 return lite::RET_ERROR;
706 }
707 if (data_info.data_ptr_ == nullptr ||
708 (data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32)) {
709 return lite::RET_NOT_SUPPORT;
710 }
711 auto num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<>());
712 auto in_shape = in_shapes.front();
713 if (num != static_cast<int>(in_shape.size())) {
714 MS_LOG(ERROR) << "Transpose's perm doesn't match with input.";
715 return lite::RET_INPUT_TENSOR_ERROR;
716 }
717 std::vector<int> visit_flags(num, 0);
718 ShapeVector out_shape;
719 for (int i = 0; i < num; ++i) {
720 auto dim_index = static_cast<int *>(data_info.data_ptr_)[i];
721 if (dim_index < 0 || dim_index >= num || visit_flags[dim_index]) {
722 MS_LOG(ERROR) << "Transpose's perm is invalid.";
723 return lite::RET_INPUT_TENSOR_ERROR;
724 }
725 visit_flags[dim_index] = 1;
726 out_shape.push_back(in_shape[dim_index]);
727 }
728 out_shapes->push_back(out_shape);
729 return lite::RET_OK;
730 }
731 } // namespace
732
Run(const FuncGraphPtr & func_graph)733 int DynamicShapePreprocessor::Run(const FuncGraphPtr &func_graph) {
734 MS_ASSERT(func_graph != nullptr);
735 op_shape_infos_.clear();
736 auto is_dynamic = CheckIsDynamicModel(func_graph);
737 if (!is_dynamic) {
738 return lite::RET_NOT_SUPPORT;
739 }
740 auto ret = ProcessOps(func_graph);
741 if (ret != lite::RET_OK) {
742 MS_LOG(ERROR) << "Preprocess for mul-reduce-fusion failed.";
743 return lite::RET_ERROR;
744 }
745 return lite::RET_OK;
746 }
747
CheckIsDynamicModel(const FuncGraphPtr & func_graph)748 bool DynamicShapePreprocessor::CheckIsDynamicModel(const FuncGraphPtr &func_graph) {
749 MS_ASSERT(func_graph != nullptr);
750 MS_ASSERT(graph_input_shape != nullptr);
751 auto graph_inputs = func_graph->get_inputs();
752 lite::DataInfo data_info;
753 bool is_dynamic{false};
754 for (auto &input : graph_inputs) {
755 if (!utils::isa<Parameter>(input)) {
756 continue;
757 }
758 auto ret = lite::FetchFromDefaultParam(input->cast<ParameterPtr>(), converter::kFmkTypeMs, &data_info, false);
759 if (ret != lite::RET_OK) {
760 return false;
761 }
762 ShapeVector shape(data_info.shape_.begin(), data_info.shape_.end());
763 is_dynamic = is_dynamic || std::any_of(shape.begin(), shape.end(), [](int64_t v) { return v == -1; });
764 op_shape_infos_[input] = std::make_pair(std::vector<ShapeVector>{}, std::vector<ShapeVector>{shape});
765 }
766 return is_dynamic;
767 }
768
ProcessOps(const FuncGraphPtr & func_graph)769 int DynamicShapePreprocessor::ProcessOps(const FuncGraphPtr &func_graph) {
770 MS_ASSERT(func_graph != nullptr);
771 MS_ASSERT(ops_can_infer != nullptr);
772 std::set<std::string> support_ops = {
773 prim::kPrimAddFusion->name(), prim::kPrimActivation->name(), prim::kPrimCast->name(),
774 prim::kPrimConcat->name(), prim::kPrimExpandDims->name(), prim::kPrimGather->name(),
775 prim::kPrimMatMulFusion->name(), prim::kPrimMulFusion->name(), prim::kPrimNotEqual->name(),
776 prim::kPrimReduceFusion->name(), prim::kPrimReshape->name(), prim::kPrimShape->name(),
777 prim::kPrimSplit->name(), prim::kPrimSqueeze->name(), prim::kPrimStack->name(),
778 prim::kPrimStridedSlice->name(), prim::kPrimTranspose->name()};
779 auto node_list = TopoSort(func_graph->get_return());
780 for (auto &node : node_list) {
781 if (!utils::isa<CNode>(node)) {
782 continue;
783 }
784 auto cnode = node->cast<CNodePtr>();
785 auto prim = GetCNodePrimitive(cnode);
786 if (prim == nullptr) {
787 continue;
788 }
789 auto op_type = prim->name();
790 if (support_ops.find(op_type) == support_ops.end()) {
791 continue;
792 }
793 auto origin_inputs = cnode->inputs();
794 if (lite::RemoveIfDepend(cnode) != RET_OK) {
795 cnode->set_inputs(origin_inputs);
796 continue;
797 }
798 RemoveIfMonad(cnode);
799 if (lite::RemoveIfMakeTuple(cnode) != RET_OK) {
800 cnode->set_inputs(origin_inputs);
801 continue;
802 }
803 auto current_inputs = cnode->inputs();
804 bool can_infer = std::any_of(current_inputs.begin(), current_inputs.end(), [this](AnfNodePtr &anf_node) {
805 return op_shape_infos_.find(anf_node) != op_shape_infos_.end() || !utils::isa<CNode>(anf_node);
806 });
807 if (!can_infer) {
808 cnode->set_inputs(origin_inputs);
809 continue;
810 }
811 auto ret = DoInfer(cnode, op_type);
812 cnode->set_inputs(origin_inputs);
813 if (ret != lite::RET_OK) {
814 MS_LOG(ERROR) << "error occurred when infer " << op_type;
815 return ret;
816 }
817 }
818 return lite::RET_OK;
819 }
820
DoInfer(const CNodePtr & cnode,const std::string & op_type)821 int DynamicShapePreprocessor::DoInfer(const CNodePtr &cnode, const std::string &op_type) {
822 MS_ASSERT(cnode != nullptr);
823 std::map<std::string, std::function<int(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
824 std::vector<ShapeVector> *out_shapes)>>
825 infer_func = {
826 {prim::kPrimAddFusion->name(), ArithmeticInferShape}, {prim::kPrimActivation->name(), CommonInferShape},
827 {prim::kPrimCast->name(), CommonInferShape}, {prim::kPrimConcat->name(), ConcatInferShape},
828 {prim::kPrimExpandDims->name(), ExpandDimsInferShape}, {prim::kPrimGather->name(), GatherInferShape},
829 {prim::kPrimMatMulFusion->name(), MatMulInferShape}, {prim::kPrimMulFusion->name(), ArithmeticInferShape},
830 {prim::kPrimNotEqual->name(), CommonInferShape}, {prim::kPrimReduceFusion->name(), ReduceInferShape},
831 {prim::kPrimReshape->name(), ReshapeInferShape}, {prim::kPrimShape->name(), ShapeInferShape},
832 {prim::kPrimSplit->name(), SplitInferShape}, {prim::kPrimSqueeze->name(), SqueezeInferShape},
833 {prim::kPrimStack->name(), StackInferShape}, {prim::kPrimStridedSlice->name(), StridedSliceInferShape},
834 {prim::kPrimTranspose->name(), TransposeInferShape}};
835 if (infer_func.find(op_type) == infer_func.end()) {
836 MS_LOG(ERROR) << "Current op: " << op_type << " doesn't support infer.";
837 return lite::RET_ERROR;
838 }
839 std::vector<ShapeVector> in_shapes;
840 lite::DataInfo data_info;
841 for (size_t i = 1; i < cnode->size(); ++i) {
842 auto input = cnode->input(i);
843 if (input == nullptr) {
844 continue;
845 }
846 if (utils::isa<CNode>(input)) {
847 auto real_input_info = GetRealCertainVarInput(cnode, i);
848 MS_CHECK_TRUE_MSG(real_input_info.first != nullptr, lite::RET_ERROR, "Current op is invalid.");
849 if (op_shape_infos_.find(real_input_info.first) == op_shape_infos_.end()) {
850 return lite::RET_OK;
851 }
852 auto &upper_node_out = op_shape_infos_[real_input_info.first].second;
853 auto index = real_input_info.second;
854 MS_CHECK_TRUE_MSG(index >= 0 && index < static_cast<int>(upper_node_out.size()), lite::RET_ERROR,
855 "Current op is invalid.");
856 in_shapes.push_back(upper_node_out[index]);
857 } else {
858 auto ret = lite::FetchConstData(cnode, i, converter::kFmkTypeMs, &data_info, false);
859 if (ret != lite::RET_OK) {
860 MS_LOG(ERROR) << "Fetch constant info failed, " << cnode->fullname_with_scope();
861 return lite::RET_ERROR;
862 }
863 ShapeVector in_shape(data_info.shape_.begin(), data_info.shape_.end());
864 in_shapes.push_back(in_shape);
865 }
866 }
867 auto func = infer_func[op_type];
868 MS_ASSERT(func != nullptr);
869 std::vector<ShapeVector> out_shapes;
870 auto ret = func(cnode, in_shapes, &out_shapes);
871 if (ret == lite::RET_NOT_SUPPORT) {
872 return lite::RET_OK;
873 }
874 if (ret != lite::RET_OK) {
875 MS_LOG(ERROR) << "current op is invalid, " << op_type;
876 return lite::RET_ERROR;
877 }
878 op_shape_infos_[cnode] = std::make_pair(in_shapes, out_shapes);
879 return lite::RET_OK;
880 }
881 } // namespace opt
882 } // namespace mindspore
883