• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/preprocess_dynamic_shape.h"
19 #include <algorithm>
20 #include <functional>
21 #include <map>
22 #include <set>
23 #include <string>
24 #include <vector>
25 #include "mindspore/core/ops/lite_ops.h"
26 #include "mindspore/core/ops/comparison_ops.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "tools/optimizer/common/format_utils.h"
30 #include "tools/optimizer/common/gllo_utils.h"
31 #include "tools/lite_exporter/fetch_content.h"
32 #include "ops/op_name.h"
33 #include "nnacl/op_base.h"
34 
35 namespace mindspore {
36 namespace opt {
37 namespace {
DoStack(const CNodePtr & cnode,const ShapeVector & out_shape,ShapeVector * out_data)38 int DoStack(const CNodePtr &cnode, const ShapeVector &out_shape, ShapeVector *out_data) {
39   MS_ASSERT(cnode != nullptr && out_data != nullptr);
40   if (!CheckPrimitiveType(cnode, prim::kPrimStack)) {
41     return lite::RET_NOT_SUPPORT;
42   }
43   if (out_shape.size() != 1 || out_shape.front() <= 0) {
44     return lite::RET_NOT_SUPPORT;
45   }
46   auto origin_inputs = cnode->inputs();
47   if (lite::RemoveIfDepend(cnode) != RET_OK) {
48     cnode->set_inputs(origin_inputs);
49     return lite::RET_NOT_SUPPORT;
50   }
51   RemoveIfMonad(cnode);
52   if (lite::RemoveIfMakeTuple(cnode) != RET_OK) {
53     cnode->set_inputs(origin_inputs);
54     return lite::RET_NOT_SUPPORT;
55   }
56   auto current_inputs = cnode->inputs();
57   for (size_t i = 1; i < current_inputs.size(); ++i) {
58     if (utils::isa<CNode>(current_inputs[i])) {
59       out_data->push_back(-1);
60       continue;
61     }
62     lite::DataInfo data_info;
63     if (lite::FetchConstData(cnode, i, converter::kFmkTypeMs, &data_info, false) != lite::RET_OK) {
64       cnode->set_inputs(origin_inputs);
65       MS_LOG(ERROR) << "etch stack's const data failed.";
66       return lite::RET_ERROR;
67     }
68     if (data_info.data_ptr_ == nullptr ||
69         (data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
70         std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<>()) != 1) {
71       cnode->set_inputs(origin_inputs);
72       return lite::RET_NOT_SUPPORT;
73     }
74     out_data->push_back(*static_cast<int *>(data_info.data_ptr_));
75   }
76   cnode->set_inputs(origin_inputs);
77   return lite::RET_OK;
78 }
79 
ArithmeticInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)80 int ArithmeticInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
81                          std::vector<ShapeVector> *out_shapes) {
82   MS_ASSERT(cnode != nullptr);
83   if (cnode->size() < kInputSizeThree || in_shapes.size() < kInputSizeTwo) {
84     MS_LOG(ERROR) << "Mul should have two inputs.";
85     return lite::RET_ERROR;
86   }
87   const auto &first_shape = in_shapes.front();
88   const auto &second_shape = in_shapes[1];
89   size_t out_shape_size = first_shape.size() >= second_shape.size() ? first_shape.size() : second_shape.size();
90   ShapeVector first_shape_expand;
91   for (size_t i = 0; i < (out_shape_size - first_shape.size()); ++i) {
92     first_shape_expand.push_back(1);
93   }
94   (void)first_shape_expand.insert(first_shape_expand.end(), first_shape.begin(), first_shape.end());
95   ShapeVector second_shape_expand;
96   for (size_t i = 0; i < (out_shape_size - second_shape.size()); ++i) {
97     second_shape_expand.push_back(1);
98   }
99   (void)second_shape_expand.insert(second_shape_expand.end(), second_shape.begin(), second_shape.end());
100   ShapeVector out_shape;
101   for (size_t i = 0; i < out_shape_size; ++i) {
102     if (first_shape_expand[i] == second_shape_expand[i]) {
103       out_shape.push_back(first_shape_expand[i]);
104       continue;
105     }
106     if (first_shape_expand[i] == 1) {
107       out_shape.push_back(second_shape_expand[i]);
108       continue;
109     }
110     if (second_shape_expand[i] == 1) {
111       out_shape.push_back(first_shape_expand[i]);
112       continue;
113     }
114     MS_LOG(INFO) << "Mul cannot determine out-shape.";
115     return lite::RET_NOT_SUPPORT;
116   }
117   out_shapes->clear();
118   out_shapes->push_back(out_shape);
119   return lite::RET_OK;
120 }
121 
CommonInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)122 int CommonInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
123                      std::vector<ShapeVector> *out_shapes) {
124   out_shapes->clear();
125   (void)out_shapes->insert(out_shapes->begin(), in_shapes.begin(), in_shapes.end());
126   return lite::RET_OK;
127 }
128 
ConcatInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)129 int ConcatInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
130                      std::vector<ShapeVector> *out_shapes) {
131   MS_ASSERT(cnode != nullptr);
132   if (cnode->size() < kInputSizeTwo || in_shapes.empty()) {
133     MS_LOG(ERROR) << "Concat should have at least one input.";
134     return lite::RET_ERROR;
135   }
136   auto prim = GetCNodePrimitive(cnode);
137   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "Concat's primitive is a nullptr.");
138   int axis = 0;
139   if (prim->GetAttr(ops::kAxis) != nullptr) {
140     axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
141   }
142   ShapeVector out_shape = in_shapes.front();
143   size_t rank = out_shape.size();
144   if (axis < 0) {
145     axis += rank;
146   }
147   MS_CHECK_TRUE_MSG(axis >= 0 && axis < static_cast<int>(rank), lite::RET_ERROR,
148                     "Concat's axis doesn't match with shape.");
149   int64_t axis_sum = 0;
150   for (const auto &in_shape : in_shapes) {
151     if (in_shape.size() != rank) {
152       return lite::RET_NOT_SUPPORT;
153     }
154     if (in_shape[axis] < 0) {
155       axis_sum = -1;
156       break;
157     }
158     axis_sum += in_shape[axis];
159   }
160   out_shape[axis] = axis_sum;
161   out_shapes->clear();
162   out_shapes->push_back(out_shape);
163   return lite::RET_OK;
164 }
165 
ExpandDimsInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)166 int ExpandDimsInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
167                          std::vector<ShapeVector> *out_shapes) {
168   MS_ASSERT(cnode != nullptr);
169   if (cnode->size() < kInputSizeThree || in_shapes.size() < kInputSizeTwo) {
170     MS_LOG(ERROR) << "Expanddims should have two inputs.";
171     return lite::RET_ERROR;
172   }
173   auto second_input = cnode->input(kInputIndexTwo);
174   MS_CHECK_TRUE_MSG(second_input != nullptr, lite::RET_ERROR, "Expanddims's second input is a nullptr.");
175   if (second_input->isa<CNode>()) {
176     return lite::RET_NOT_SUPPORT;
177   }
178   lite::DataInfo data_info;
179   auto ret = lite::FetchConstData(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, false);
180   MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Expanddims fetch second-input's data failed.");
181   MS_CHECK_TRUE_MSG(data_info.data_ptr_ != nullptr, lite::RET_ERROR,
182                     "Expanddims's second-input's data shouldn't a nullptr.");
183   MS_CHECK_TRUE_MSG(data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32, lite::RET_ERROR,
184                     "Expanddims's second-input's data-type should be int.");
185   auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
186   MS_CHECK_TRUE_MSG(element_num == 1, lite::RET_ERROR, "Expanddims's second-input should be a scalar.");
187   auto axis = *static_cast<int *>(data_info.data_ptr_);
188   auto first_shape = in_shapes.front();
189   auto first_shape_size = static_cast<int>(first_shape.size());
190   if (axis < 0) {
191     axis = first_shape_size + axis + 1;
192   }
193   MS_CHECK_TRUE_MSG(axis >= 0 && axis <= first_shape_size, lite::RET_ERROR, "Expanddims's second-input is invalid.");
194   out_shapes->clear();
195   (void)first_shape.insert(first_shape.begin() + axis, 1);
196   out_shapes->push_back(first_shape);
197   return lite::RET_OK;
198 }
199 
GatherInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)200 int GatherInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
201                      std::vector<ShapeVector> *out_shapes) {
202   MS_ASSERT(cnode != nullptr);
203   if (cnode->size() < kInputSizeFour || in_shapes.size() < kInputSizeThree) {
204     MS_LOG(ERROR) << "Gther should have three inputs.";
205     return lite::RET_ERROR;
206   }
207   auto third_input = cnode->input(kInputIndexThree);
208   MS_CHECK_TRUE_MSG(third_input != nullptr, lite::RET_ERROR, "Gather's third input is a nullptr.");
209   if (third_input->isa<CNode>()) {
210     return lite::RET_NOT_SUPPORT;
211   }
212   lite::DataInfo data_info;
213   auto ret = lite::FetchConstData(cnode, kInputIndexThree, converter::kFmkTypeMs, &data_info, false);
214   MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Gather fetch second-input's data failed.");
215   auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
216   MS_CHECK_TRUE_MSG(element_num <= 1, lite::RET_ERROR, "Gather's second-input should be a scalar.");
217   int axis{0};
218   if (element_num == 1) {
219     MS_CHECK_TRUE_MSG(data_info.data_ptr_ != nullptr, lite::RET_ERROR,
220                       "Gather's second-input's data shouldn't a nullptr.");
221     if (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32) {
222       axis = *static_cast<int *>(data_info.data_ptr_);
223     } else if (data_info.data_type_ == kNumberTypeInt64) {
224       axis = *static_cast<int64_t *>(data_info.data_ptr_);
225     } else {
226       MS_LOG(ERROR) << "Gather's axis is invalid, which should be int or int64.";
227       return lite::RET_ERROR;
228     }
229   }
230   const auto &first_shape = in_shapes.front();
231   auto first_shape_size = static_cast<int>(first_shape.size());
232   if (axis < 0) {
233     axis = first_shape_size + axis;
234   }
235   MS_CHECK_TRUE_MSG(axis >= 0 && axis < first_shape_size, lite::RET_ERROR, "Gather's axis out of range.");
236   const auto &second_shape = in_shapes[1];
237   ShapeVector out_shape;
238   for (int i = 0; i < axis; ++i) {
239     out_shape.push_back(first_shape[i]);
240   }
241   (void)out_shape.insert(out_shape.end(), second_shape.begin(), second_shape.end());
242   for (int i = axis + 1; i < first_shape_size; ++i) {
243     out_shape.push_back(first_shape[i]);
244   }
245   out_shapes->clear();
246   out_shapes->push_back(out_shape);
247   return lite::RET_OK;
248 }
249 
MatMulInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)250 int MatMulInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
251                      std::vector<ShapeVector> *out_shapes) {
252   MS_ASSERT(cnode != nullptr);
253   if (cnode->size() < kInputSizeThree || in_shapes.size() < kInputSizeTwo) {
254     MS_LOG(ERROR) << "MatMul should have at least two inputs.";
255     return lite::RET_ERROR;
256   }
257   auto prim = GetCNodePrimitive(cnode);
258   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "MatMul's primitive is a nullptr.");
259   bool a_trans = prim->GetAttr(ops::kTransposeA) && GetValue<bool>(prim->GetAttr(ops::kTransposeA));
260   bool b_trnas = prim->GetAttr(ops::kTransposeB) && GetValue<bool>(prim->GetAttr(ops::kTransposeB));
261   const auto &a_shape = in_shapes.front();
262   MS_CHECK_TRUE_RET(a_shape.size() >= kInputSizeTwo, lite::RET_NOT_SUPPORT);
263   const auto &b_shape = in_shapes[1];
264   MS_CHECK_TRUE_RET(b_shape.size() >= kInputSizeTwo, lite::RET_NOT_SUPPORT);
265   size_t a_rank = a_shape.size();
266   size_t b_rank = b_shape.size();
267   size_t out_rank = std::max(a_rank, b_rank);
268   ShapeVector a_pre_shape;
269   (void)a_pre_shape.insert(a_pre_shape.end(), out_rank - a_rank, 1);
270   (void)a_pre_shape.insert(a_pre_shape.end(), a_shape.begin(), a_shape.begin() + a_rank - C2NUM);
271   ShapeVector b_pre_shape;
272   (void)b_pre_shape.insert(b_pre_shape.end(), out_rank - b_rank, 1);
273   (void)b_pre_shape.insert(b_pre_shape.end(), b_shape.begin(), b_shape.begin() + b_rank - C2NUM);
274   ShapeVector out_shape;
275   MS_ASSERT(a_pre_shape.size() == b_pre_shape.size());
276   for (size_t i = 0; i < out_rank - C2NUM; ++i) {
277     if (a_pre_shape[i] == b_pre_shape[i]) {
278       out_shape.push_back(a_pre_shape[i]);
279       continue;
280     }
281     if (a_pre_shape[i] == 1) {
282       out_shape.push_back(b_pre_shape[i]);
283       continue;
284     }
285     if (b_pre_shape[i] == 1) {
286       out_shape.push_back(a_pre_shape[i]);
287       continue;
288     }
289     return lite::RET_NOT_SUPPORT;
290   }
291   out_shape.push_back(a_trans ? a_shape.back() : a_shape[a_rank - C2NUM]);
292   out_shape.push_back(b_trnas ? b_shape[b_rank - C2NUM] : b_shape.back());
293   out_shapes->clear();
294   out_shapes->push_back(out_shape);
295   return lite::RET_OK;
296 }
297 
ReduceInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)298 int ReduceInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
299                      std::vector<ShapeVector> *out_shapes) {
300   MS_ASSERT(cnode != nullptr);
301   MS_CHECK_FALSE_MSG(cnode->size() < kInputSizeThree || in_shapes.size() < kInputSizeTwo, lite::RET_ERROR,
302                      "Reduce should have two inputs");
303   auto prim = GetCNodePrimitive(cnode);
304   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "Reduce's primitive is a nullptr.");
305   bool keep_dim = prim->GetAttr(ops::kKeepDims) != nullptr && GetValue<bool>(prim->GetAttr(ops::kKeepDims));
306   bool reduce_to_end = prim->GetAttr(ops::kReduceToEnd) != nullptr && GetValue<bool>(prim->GetAttr(ops::kReduceToEnd));
307   if (reduce_to_end) {
308     return lite::RET_NOT_SUPPORT;
309   }
310   auto second_input = cnode->input(kInputIndexTwo);
311   MS_CHECK_TRUE_MSG(second_input != nullptr, lite::RET_ERROR, "Reduce's second input is a nullptr.");
312   if (second_input->isa<CNode>()) {
313     return lite::RET_NOT_SUPPORT;
314   }
315   lite::DataInfo data_info;
316   auto ret = lite::FetchConstData(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, false);
317   MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Reduce fetch second-input's data failed.");
318   MS_CHECK_TRUE_MSG(data_info.shape_.size() <= 1, lite::RET_ERROR, "Reduce second-input should be <= 1D.");
319   std::set<int> reduce_axes;
320   int rank = static_cast<int>(in_shapes.front().size());
321   if (data_info.data_ptr_ == nullptr) {
322     MS_LOG(INFO) << "reduce op rand is: " << rank << ", cnode name: " << cnode->fullname_with_scope();
323     for (int dim = 0; dim < rank; dim++) {
324       (void)reduce_axes.insert(dim);
325     }
326   } else {
327     int element_num = data_info.shape_.empty() ? 1 : data_info.shape_.front();
328     std::vector<int> temp;
329     int *axes{nullptr};
330     if (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32) {
331       axes = static_cast<int *>(data_info.data_ptr_);
332     } else if (data_info.data_type_ == kNumberTypeInt64) {
333       (void)temp.insert(temp.begin(), static_cast<int64_t *>(data_info.data_ptr_),
334                         static_cast<int64_t *>(data_info.data_ptr_) + element_num);
335       axes = temp.data();
336     } else {
337       return lite::RET_NOT_SUPPORT;
338     }
339     for (int i = 0; i < element_num; ++i) {
340       int axis = axes[i] >= 0 ? axes[i] : axes[i] + rank;
341       MS_CHECK_TRUE_MSG(axis >= 0 && axis < rank, lite::RET_ERROR, "Reduce's axis is out of range.");
342       (void)reduce_axes.insert(axis);
343     }
344   }
345   int start = 0;
346   ShapeVector out_shape;
347   for (auto iter = reduce_axes.begin(); iter != reduce_axes.end(); ++iter) {
348     int end = *iter;
349     for (; start < end; ++start) {
350       out_shape.push_back(in_shapes.front()[start]);
351     }
352     if (keep_dim) {
353       out_shape.push_back(1);
354     }
355     ++start;
356   }
357   for (; start < rank; ++start) {
358     out_shape.push_back(in_shapes.front()[start]);
359   }
360   out_shapes->clear();
361   out_shapes->push_back(out_shape);
362   return lite::RET_OK;
363 }
364 
ReshapeInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)365 int ReshapeInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
366                       std::vector<ShapeVector> *out_shapes) {
367   MS_ASSERT(cnode != nullptr);
368   if (cnode->size() < kInputSizeTwo) {
369     (void)out_shapes->emplace_back();
370     return lite::RET_OK;
371   }
372   if (in_shapes.size() < kInputSizeTwo) {
373     MS_LOG(ERROR) << "Reshape should have two inputs.";
374     return lite::RET_ERROR;
375   }
376   out_shapes->clear();
377   auto second_input = cnode->input(kInputIndexTwo);
378   MS_CHECK_TRUE_MSG(second_input != nullptr, lite::RET_ERROR, "Reshape's second input is a nullptr.");
379   if (second_input->isa<CNode>()) {
380     const auto &second_in_shape = in_shapes[1];
381     if (second_in_shape.size() != 1 || second_in_shape.front() <= 0) {
382       return lite::RET_NOT_SUPPORT;
383     }
384     ShapeVector out_shape;
385     auto ret = DoStack(second_input->cast<CNodePtr>(), second_in_shape, &out_shape);
386     if (ret == lite::RET_NOT_SUPPORT) {
387       out_shape = ShapeVector(second_in_shape.front(), -1);
388     } else if (ret != lite::RET_OK) {
389       MS_LOG(ERROR) << "Do stack failed.";
390       return ret;
391     }
392     out_shapes->push_back(out_shape);
393     return lite::RET_OK;
394   }
395   lite::DataInfo data_info;
396   auto ret = lite::FetchConstData(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, false);
397   MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Reshape fetch second-input's data failed.");
398   MS_CHECK_TRUE_MSG(data_info.shape_.size() <= 1, lite::RET_ERROR, "Reshape second-input should be <= 1D.");
399   if (data_info.data_ptr_ == nullptr || (data_info.shape_.size() == 1 && data_info.shape_.front() == 0)) {
400     (void)out_shapes->emplace_back();
401   }
402   auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
403   ShapeVector out_shape;
404   if (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32) {
405     for (int i = 0; i < element_num; ++i) {
406       out_shape.push_back(*(static_cast<int *>(data_info.data_ptr_) + i));
407     }
408   } else if (data_info.data_type_ == kNumberTypeInt64) {
409     for (int i = 0; i < element_num; ++i) {
410       out_shape.push_back(*(static_cast<int64_t *>(data_info.data_ptr_) + i));
411     }
412   } else {
413     return lite::RET_NOT_SUPPORT;
414   }
415   const auto &in_shape = in_shapes.front();
416   for (size_t i = 0; i < out_shape.size(); ++i) {
417     if (out_shape[i] == 0) {
418       MS_CHECK_TRUE_MSG(in_shape.size() > i, lite::RET_ERROR, "Reshape's in-rank is invalid.");
419       out_shape[i] = in_shape[i];
420     }
421   }
422   out_shapes->push_back(out_shape);
423   return lite::RET_OK;
424 }
425 
ShapeInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)426 int ShapeInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
427                     std::vector<ShapeVector> *out_shapes) {
428   MS_ASSERT(cnode != nullptr);
429   if (cnode->size() < kInputSizeTwo || in_shapes.empty()) {
430     MS_LOG(ERROR) << "Shape should have one inputs.";
431     return lite::RET_ERROR;
432   }
433   ShapeVector out_shape = {static_cast<int64_t>(in_shapes.front().size())};
434   out_shapes->clear();
435   out_shapes->push_back(out_shape);
436   return lite::RET_OK;
437 }
438 
SplitInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)439 int SplitInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
440                     std::vector<ShapeVector> *out_shapes) {
441   MS_ASSERT(cnode != nullptr);
442   if (cnode->size() < kInputSizeTwo || in_shapes.empty()) {
443     MS_LOG(ERROR) << "Split should have one inputs.";
444     return lite::RET_ERROR;
445   }
446   auto prim = GetCNodePrimitive(cnode);
447   auto out_num = prim->GetAttr(ops::kOutputNum) == nullptr ? 0 : GetValue<int64_t>(prim->GetAttr(ops::kOutputNum));
448   auto size_splits = prim->GetAttr(ops::kSizeSplits) == nullptr
449                        ? std::vector<int64_t>{}
450                        : GetValue<std::vector<int64_t>>(prim->GetAttr(ops::kSizeSplits));
451   out_num = (out_num == 0 ? static_cast<int64_t>(size_splits.size()) : out_num);
452   if (out_num <= 0) {
453     return lite::RET_NOT_SUPPORT;
454   }
455   auto axis = prim->GetAttr(ops::kAxis) == nullptr ? 0 : GetValue<int64_t>(prim->GetAttr(ops::kAxis));
456   auto &in_shape = in_shapes.front();
457   axis = axis < 0 ? static_cast<int64_t>(in_shape.size()) + axis : axis;
458   MS_CHECK_TRUE_MSG(axis >= 0 && axis < static_cast<int64_t>(in_shape.size()), lite::RET_ERROR,
459                     "Split's axis is out of range.");
460   out_shapes->clear();
461   ShapeVector out_shape = in_shape;
462   if (size_splits.empty()) {
463     MS_CHECK_TRUE_MSG(in_shape[axis] > 0 && in_shape[axis] % out_num == 0, lite::RET_ERROR,
464                       "Split's dim doesn't match split-axis.");
465     out_shape[axis] = in_shape[axis] / out_num;
466     (void)out_shapes->insert(out_shapes->end(), out_num, out_shape);
467   } else {
468     for (auto v : size_splits) {
469       out_shape[axis] = v;
470       out_shapes->push_back(out_shape);
471     }
472   }
473   return lite::RET_OK;
474 }
475 
SqueezeInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)476 int SqueezeInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
477                       std::vector<ShapeVector> *out_shapes) {
478   MS_ASSERT(cnode != nullptr);
479   if (in_shapes.empty()) {
480     MS_LOG(ERROR) << "Squeeze should have one input at least.";
481     return lite::RET_ERROR;
482   }
483   auto prim = GetCNodePrimitive(cnode);
484   if (prim == nullptr) {
485     MS_LOG(ERROR) << "Squeeze's primitive is a nullptr.";
486     return lite::RET_ERROR;
487   }
488   auto axes = prim->GetAttr(ops::kAxis) != nullptr ? GetValue<std::vector<int64_t>>(prim->GetAttr(ops::kAxis))
489                                                    : std::vector<int64_t>();
490   auto &in_shape = in_shapes.front();
491   ShapeVector out_shape;
492   if (axes.empty()) {
493     for (size_t i = 0; i < in_shape.size(); ++i) {
494       if (in_shape[i] < 0) {
495         return lite::RET_NOT_SUPPORT;
496       }
497       if (in_shape[i] != 1) {
498         out_shape.push_back(in_shape[i]);
499       }
500     }
501   } else {
502     auto dims = static_cast<int64_t>(in_shape.size());
503     std::vector<int> flags(dims, 0);
504     for (auto axis : axes) {
505       axis = axis < 0 ? axis + dims : axis;
506       if (axis < 0 || axis >= dims) {
507         MS_LOG(ERROR) << "Squeeze's axis is invalid. node name is " << cnode->fullname_with_scope();
508         return lite::RET_ERROR;
509       }
510       flags[axis] = 1;
511     }
512     for (int64_t i = 0; i < dims; ++i) {
513       if (flags[i] == 0) {
514         out_shape.push_back(in_shape[i]);
515       }
516     }
517   }
518   out_shapes->clear();
519   out_shapes->push_back(out_shape);
520   return lite::RET_OK;
521 }
522 
StackInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)523 int StackInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
524                     std::vector<ShapeVector> *out_shapes) {
525   MS_ASSERT(cnode != nullptr);
526   if (in_shapes.empty()) {
527     MS_LOG(ERROR) << "Stack should have one input at least.";
528     return lite::RET_ERROR;
529   }
530   auto dims = in_shapes.front().size();
531   if (std::any_of(in_shapes.begin(), in_shapes.end(),
532                   [dims](const ShapeVector &in_shape) { return in_shape.size() != dims; })) {
533     MS_LOG(ERROR) << "Stack all-inputs should hava same rank.";
534     return lite::RET_INPUT_TENSOR_ERROR;
535   }
536   if (std::any_of(in_shapes.begin(), in_shapes.end(), [](const ShapeVector &in_shape) {
537         return std::any_of(in_shape.begin(), in_shape.end(), [](int64_t val) { return val == 0; });
538       })) {
539     return lite::RET_NOT_SUPPORT;
540   }
541   auto prim = GetCNodePrimitive(cnode);
542   auto axis = prim->GetAttr(ops::kAxis) == nullptr ? 0 : GetValue<int64_t>(prim->GetAttr(ops::kAxis));
543   if (axis < 0) {
544     axis += static_cast<int64_t>(dims);
545   }
546   if (axis < 0 || axis > static_cast<int64_t>(dims)) {
547     MS_LOG(ERROR) << "stack's axis is invalid.";
548     return lite::RET_PARAM_INVALID;
549   }
550   ShapeVector out_shape;
551   auto FillShape = [&out_shape, &in_shapes](int64_t start, int64_t end) mutable {
552     for (; start < end; ++start) {
553       ShapeVector vertical;
554       for (const auto &in_shape : in_shapes) {
555         if (in_shape[start] >= 0) {
556           vertical.push_back(in_shape[start]);
557         } else if (in_shape[start] != -1) {
558           MS_LOG(ERROR) << "Stack's input-shape must not have a dim-value less than -1.";
559           return lite::RET_INPUT_TENSOR_ERROR;
560         }
561       }
562       out_shape.push_back(vertical.size() < in_shapes.size() ? -1 : vertical.front());
563       if (!vertical.empty()) {
564         int64_t dim = vertical.front();
565         if (std::any_of(vertical.begin(), vertical.end(), [dim](const int64_t value) { return value != dim; })) {
566           MS_LOG(ERROR) << "Stack's input-shape must be same each other.";
567           return lite::RET_INPUT_TENSOR_ERROR;
568         }
569       }
570     }
571     return lite::RET_OK;
572   };
573   if (FillShape(0, axis) != lite::RET_OK) {
574     MS_LOG(ERROR) << "Stack do fillShape failed.";
575     return lite::RET_ERROR;
576   }
577   out_shape.push_back(static_cast<int64_t>(in_shapes.size()));
578   if (FillShape(axis, dims) != lite::RET_OK) {
579     MS_LOG(ERROR) << "Stack do fillShape failed.";
580     return lite::RET_ERROR;
581   }
582   out_shapes->clear();
583   out_shapes->push_back(out_shape);
584   return lite::RET_OK;
585 }
586 
CheckStridedSlice(const CNodePtr & cnode,int64_t in_rank,lite::DataInfo * begins,lite::DataInfo * ends)587 int CheckStridedSlice(const CNodePtr &cnode, int64_t in_rank, lite::DataInfo *begins, lite::DataInfo *ends) {
588   MS_ASSERT(cnode != nullptr);
589   auto prim = GetCNodePrimitive(cnode);
590   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "StridedSlice's primitive is a nullptr.");
591   int64_t ellipsis_mask = prim->GetAttr(ops::kEllipsisMask) ? GetValue<int64_t>(prim->GetAttr(ops::kEllipsisMask)) : 0;
592   int64_t new_axis_mask = prim->GetAttr(ops::kNewAxisMask) ? GetValue<int64_t>(prim->GetAttr(ops::kNewAxisMask)) : 0;
593   if ((ellipsis_mask | new_axis_mask) != 0) {
594     return lite::RET_NOT_SUPPORT;
595   }
596   for (size_t i = C2NUM; i < kInputSizeFive; ++i) {
597     MS_CHECK_TRUE_MSG(cnode->input(i) != nullptr, lite::RET_ERROR, "StridedSlice's input is a nullptr.");
598     if (utils::isa<CNode>(cnode->input(i))) {
599       return lite::RET_NOT_SUPPORT;
600     }
601   }
602   auto BasicCond = [](const lite::DataInfo &data_info) {
603     return data_info.data_ptr_ != nullptr &&
604            (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32);
605   };
606   if (lite::FetchConstData(cnode, C2NUM, converter::kFmkTypeMs, begins, false) != lite::RET_OK) {
607     MS_LOG(ERROR) << "Fetch StridedSlice's begins failed.";
608     return lite::RET_ERROR;
609   }
610   MS_CHECK_TRUE_RET(begins->shape_.size() == C1NUM && begins->shape_.front() <= in_rank && BasicCond(*begins),
611                     lite::RET_NOT_SUPPORT);
612   if (lite::FetchConstData(cnode, C3NUM, converter::kFmkTypeMs, ends, false) != lite::RET_OK) {
613     MS_LOG(ERROR) << "Fetch StridedSlice's ends failed.";
614     return lite::RET_ERROR;
615   }
616   MS_CHECK_TRUE_RET(ends->shape_ == begins->shape_ && BasicCond(*ends), lite::RET_NOT_SUPPORT);
617   lite::DataInfo strides;
618   if (lite::FetchConstData(cnode, C4NUM, converter::kFmkTypeMs, &strides, false) != lite::RET_OK) {
619     MS_LOG(ERROR) << "Fetch StridedSlice's strides failed.";
620     return lite::RET_ERROR;
621   }
622   MS_CHECK_TRUE_RET(strides.shape_ == begins->shape_ && BasicCond(strides), lite::RET_NOT_SUPPORT);
623   for (int i = 0; i < strides.shape_.front(); ++i) {
624     if (static_cast<int *>(strides.data_ptr_)[i] != 1) {
625       return lite::RET_NOT_SUPPORT;
626     }
627   }
628   return lite::RET_OK;
629 }
630 
StridedSliceInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)631 int StridedSliceInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
632                            std::vector<ShapeVector> *out_shapes) {
633   MS_ASSERT(cnode != nullptr);
634   if (cnode->size() != kInputSizeFive || in_shapes.size() != kInputSizeFour) {
635     return lite::RET_NOT_SUPPORT;
636   }
637   lite::DataInfo begins;
638   lite::DataInfo ends;
639   auto ret = CheckStridedSlice(cnode, in_shapes.front().size(), &begins, &ends);
640   if (ret != lite::RET_OK) {
641     return ret;
642   }
643 
644   auto prim = GetCNodePrimitive(cnode);
645   int64_t begin_mask = prim->GetAttr(ops::kBeginMask) ? GetValue<int64_t>(prim->GetAttr(ops::kBeginMask)) : 0;
646   int64_t end_mask = prim->GetAttr(ops::kEndMask) ? GetValue<int64_t>(prim->GetAttr(ops::kEndMask)) : 0;
647   int64_t shrink_mask =
648     prim->GetAttr(ops::kShrinkAxisMask) ? GetValue<int64_t>(prim->GetAttr(ops::kShrinkAxisMask)) : 0;
649   const auto &in_shape = in_shapes.front();
650   ShapeVector out_shape;
651   int index = 0;
652   for (; index < begins.shape_.front(); ++index) {
653     if (shrink_mask & (1 << index)) {
654       continue;
655     }
656     int b_mask = begin_mask & (1 << index);
657     int e_mask = end_mask & (1 << index);
658     if (b_mask && e_mask) {
659       out_shape.push_back(in_shape[index]);
660       continue;
661     }
662     int64_t begin = static_cast<int *>(begins.data_ptr_)[index];
663     int64_t end = static_cast<int *>(ends.data_ptr_)[index];
664     if (b_mask) {
665       begin = 0;
666     }
667     if (e_mask) {
668       end = in_shape[index];
669     }
670     if (in_shape[index] > 0) {
671       begin += (begin >= 0 ? 0 : in_shape[index]);
672       end += (end >= 0 ? 0 : in_shape[index]);
673     }
674     if (begin < 0 || end < 0 || begin > end) {
675       return lite::RET_NOT_SUPPORT;
676     }
677     out_shape.push_back(end - begin);
678   }
679   (void)out_shape.insert(out_shape.end(), in_shape.begin() + index, in_shape.end());
680   out_shapes->clear();
681   out_shapes->push_back(out_shape);
682   return lite::RET_OK;
683 }
684 
TransposeInferShape(const CNodePtr & cnode,const std::vector<ShapeVector> & in_shapes,std::vector<ShapeVector> * out_shapes)685 int TransposeInferShape(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
686                         std::vector<ShapeVector> *out_shapes) {
687   MS_ASSERT(cnode != nullptr);
688   out_shapes->clear();
689   if (in_shapes.size() == 1) {
690     auto in_shape = in_shapes.front();
691     ShapeVector out_shape(in_shape.rbegin(), in_shape.rend());
692     out_shapes->push_back(out_shape);
693     return lite::RET_OK;
694   }
695   if (in_shapes.size() != C2NUM) {
696     MS_LOG(ERROR) << "Transpose's input should be 1 or 2, now is " << in_shapes.size();
697     return lite::RET_INPUT_TENSOR_ERROR;
698   }
699   if (utils::isa<CNode>(cnode->input(ops::kInputIndex2))) {
700     return lite::RET_NOT_SUPPORT;
701   }
702   lite::DataInfo data_info;
703   if (lite::FetchConstData(cnode, ops::kInputIndex2, converter::kFmkTypeMs, &data_info, false)) {
704     MS_LOG(ERROR) << "Fetch constant info failed, " << cnode->fullname_with_scope();
705     return lite::RET_ERROR;
706   }
707   if (data_info.data_ptr_ == nullptr ||
708       (data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32)) {
709     return lite::RET_NOT_SUPPORT;
710   }
711   auto num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<>());
712   auto in_shape = in_shapes.front();
713   if (num != static_cast<int>(in_shape.size())) {
714     MS_LOG(ERROR) << "Transpose's perm doesn't match with input.";
715     return lite::RET_INPUT_TENSOR_ERROR;
716   }
717   std::vector<int> visit_flags(num, 0);
718   ShapeVector out_shape;
719   for (int i = 0; i < num; ++i) {
720     auto dim_index = static_cast<int *>(data_info.data_ptr_)[i];
721     if (dim_index < 0 || dim_index >= num || visit_flags[dim_index]) {
722       MS_LOG(ERROR) << "Transpose's perm is invalid.";
723       return lite::RET_INPUT_TENSOR_ERROR;
724     }
725     visit_flags[dim_index] = 1;
726     out_shape.push_back(in_shape[dim_index]);
727   }
728   out_shapes->push_back(out_shape);
729   return lite::RET_OK;
730 }
731 }  // namespace
732 
Run(const FuncGraphPtr & func_graph)733 int DynamicShapePreprocessor::Run(const FuncGraphPtr &func_graph) {
734   MS_ASSERT(func_graph != nullptr);
735   op_shape_infos_.clear();
736   auto is_dynamic = CheckIsDynamicModel(func_graph);
737   if (!is_dynamic) {
738     return lite::RET_NOT_SUPPORT;
739   }
740   auto ret = ProcessOps(func_graph);
741   if (ret != lite::RET_OK) {
742     MS_LOG(ERROR) << "Preprocess for mul-reduce-fusion failed.";
743     return lite::RET_ERROR;
744   }
745   return lite::RET_OK;
746 }
747 
CheckIsDynamicModel(const FuncGraphPtr & func_graph)748 bool DynamicShapePreprocessor::CheckIsDynamicModel(const FuncGraphPtr &func_graph) {
749   MS_ASSERT(func_graph != nullptr);
750   MS_ASSERT(graph_input_shape != nullptr);
751   auto graph_inputs = func_graph->get_inputs();
752   lite::DataInfo data_info;
753   bool is_dynamic{false};
754   for (auto &input : graph_inputs) {
755     if (!utils::isa<Parameter>(input)) {
756       continue;
757     }
758     auto ret = lite::FetchFromDefaultParam(input->cast<ParameterPtr>(), converter::kFmkTypeMs, &data_info, false);
759     if (ret != lite::RET_OK) {
760       return false;
761     }
762     ShapeVector shape(data_info.shape_.begin(), data_info.shape_.end());
763     is_dynamic = is_dynamic || std::any_of(shape.begin(), shape.end(), [](int64_t v) { return v == -1; });
764     op_shape_infos_[input] = std::make_pair(std::vector<ShapeVector>{}, std::vector<ShapeVector>{shape});
765   }
766   return is_dynamic;
767 }
768 
ProcessOps(const FuncGraphPtr & func_graph)769 int DynamicShapePreprocessor::ProcessOps(const FuncGraphPtr &func_graph) {
770   MS_ASSERT(func_graph != nullptr);
771   MS_ASSERT(ops_can_infer != nullptr);
772   std::set<std::string> support_ops = {
773     prim::kPrimAddFusion->name(),    prim::kPrimActivation->name(), prim::kPrimCast->name(),
774     prim::kPrimConcat->name(),       prim::kPrimExpandDims->name(), prim::kPrimGather->name(),
775     prim::kPrimMatMulFusion->name(), prim::kPrimMulFusion->name(),  prim::kPrimNotEqual->name(),
776     prim::kPrimReduceFusion->name(), prim::kPrimReshape->name(),    prim::kPrimShape->name(),
777     prim::kPrimSplit->name(),        prim::kPrimSqueeze->name(),    prim::kPrimStack->name(),
778     prim::kPrimStridedSlice->name(), prim::kPrimTranspose->name()};
779   auto node_list = TopoSort(func_graph->get_return());
780   for (auto &node : node_list) {
781     if (!utils::isa<CNode>(node)) {
782       continue;
783     }
784     auto cnode = node->cast<CNodePtr>();
785     auto prim = GetCNodePrimitive(cnode);
786     if (prim == nullptr) {
787       continue;
788     }
789     auto op_type = prim->name();
790     if (support_ops.find(op_type) == support_ops.end()) {
791       continue;
792     }
793     auto origin_inputs = cnode->inputs();
794     if (lite::RemoveIfDepend(cnode) != RET_OK) {
795       cnode->set_inputs(origin_inputs);
796       continue;
797     }
798     RemoveIfMonad(cnode);
799     if (lite::RemoveIfMakeTuple(cnode) != RET_OK) {
800       cnode->set_inputs(origin_inputs);
801       continue;
802     }
803     auto current_inputs = cnode->inputs();
804     bool can_infer = std::any_of(current_inputs.begin(), current_inputs.end(), [this](AnfNodePtr &anf_node) {
805       return op_shape_infos_.find(anf_node) != op_shape_infos_.end() || !utils::isa<CNode>(anf_node);
806     });
807     if (!can_infer) {
808       cnode->set_inputs(origin_inputs);
809       continue;
810     }
811     auto ret = DoInfer(cnode, op_type);
812     cnode->set_inputs(origin_inputs);
813     if (ret != lite::RET_OK) {
814       MS_LOG(ERROR) << "error occurred when infer " << op_type;
815       return ret;
816     }
817   }
818   return lite::RET_OK;
819 }
820 
DoInfer(const CNodePtr & cnode,const std::string & op_type)821 int DynamicShapePreprocessor::DoInfer(const CNodePtr &cnode, const std::string &op_type) {
822   MS_ASSERT(cnode != nullptr);
823   std::map<std::string, std::function<int(const CNodePtr &cnode, const std::vector<ShapeVector> &in_shapes,
824                                           std::vector<ShapeVector> *out_shapes)>>
825     infer_func = {
826       {prim::kPrimAddFusion->name(), ArithmeticInferShape},  {prim::kPrimActivation->name(), CommonInferShape},
827       {prim::kPrimCast->name(), CommonInferShape},           {prim::kPrimConcat->name(), ConcatInferShape},
828       {prim::kPrimExpandDims->name(), ExpandDimsInferShape}, {prim::kPrimGather->name(), GatherInferShape},
829       {prim::kPrimMatMulFusion->name(), MatMulInferShape},   {prim::kPrimMulFusion->name(), ArithmeticInferShape},
830       {prim::kPrimNotEqual->name(), CommonInferShape},       {prim::kPrimReduceFusion->name(), ReduceInferShape},
831       {prim::kPrimReshape->name(), ReshapeInferShape},       {prim::kPrimShape->name(), ShapeInferShape},
832       {prim::kPrimSplit->name(), SplitInferShape},           {prim::kPrimSqueeze->name(), SqueezeInferShape},
833       {prim::kPrimStack->name(), StackInferShape},           {prim::kPrimStridedSlice->name(), StridedSliceInferShape},
834       {prim::kPrimTranspose->name(), TransposeInferShape}};
835   if (infer_func.find(op_type) == infer_func.end()) {
836     MS_LOG(ERROR) << "Current op: " << op_type << " doesn't support infer.";
837     return lite::RET_ERROR;
838   }
839   std::vector<ShapeVector> in_shapes;
840   lite::DataInfo data_info;
841   for (size_t i = 1; i < cnode->size(); ++i) {
842     auto input = cnode->input(i);
843     if (input == nullptr) {
844       continue;
845     }
846     if (utils::isa<CNode>(input)) {
847       auto real_input_info = GetRealCertainVarInput(cnode, i);
848       MS_CHECK_TRUE_MSG(real_input_info.first != nullptr, lite::RET_ERROR, "Current op is invalid.");
849       if (op_shape_infos_.find(real_input_info.first) == op_shape_infos_.end()) {
850         return lite::RET_OK;
851       }
852       auto &upper_node_out = op_shape_infos_[real_input_info.first].second;
853       auto index = real_input_info.second;
854       MS_CHECK_TRUE_MSG(index >= 0 && index < static_cast<int>(upper_node_out.size()), lite::RET_ERROR,
855                         "Current op is invalid.");
856       in_shapes.push_back(upper_node_out[index]);
857     } else {
858       auto ret = lite::FetchConstData(cnode, i, converter::kFmkTypeMs, &data_info, false);
859       if (ret != lite::RET_OK) {
860         MS_LOG(ERROR) << "Fetch constant info failed, " << cnode->fullname_with_scope();
861         return lite::RET_ERROR;
862       }
863       ShapeVector in_shape(data_info.shape_.begin(), data_info.shape_.end());
864       in_shapes.push_back(in_shape);
865     }
866   }
867   auto func = infer_func[op_type];
868   MS_ASSERT(func != nullptr);
869   std::vector<ShapeVector> out_shapes;
870   auto ret = func(cnode, in_shapes, &out_shapes);
871   if (ret == lite::RET_NOT_SUPPORT) {
872     return lite::RET_OK;
873   }
874   if (ret != lite::RET_OK) {
875     MS_LOG(ERROR) << "current op is invalid, " << op_type;
876     return lite::RET_ERROR;
877   }
878   op_shape_infos_[cnode] = std::make_pair(in_shapes, out_shapes);
879   return lite::RET_OK;
880 }
881 }  // namespace opt
882 }  // namespace mindspore
883