• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/optimizer/common/gllo_utils.h"
17 #include <algorithm>
18 #include <vector>
19 #include <utility>
20 #include <unordered_map>
21 #include <functional>
22 #include <string>
23 #include <set>
24 #include <fstream>
25 #include "mindspore/core/ops/structure_ops.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/conv_pool_ops.h"
28 #include "mindspore/core/ops/lite_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "base/float16.h"
31 #include "ops/fusion/conv2d_fusion.h"
32 #include "ops/auto_generate/gen_lite_ops.h"
33 #include "ops/ops_func_impl/gather.h"
34 #include "ops/tuple_get_item.h"
35 #include "tools/common/tensor_util.h"
36 #include "frontend/operator/ops.h"
37 #include "include/backend/optimizer/helper.h"
38 #include "tools/converter/quantizer/quant_param_holder.h"
39 #include "nnacl/op_base.h"
40 #include "src/common/log_util.h"
41 #include "tools/converter/parser/parser_utils.h"
42 #include "tools/optimizer/common/helper.h"
43 #include "ops/op_utils.h"
44 #include "ops/custom.h"
45 #include "ops/tensor_copy.h"
46 #include "include/common/utils/anfalgo.h"
47 #include "tools/optimizer/common/format_utils.h"
48 
49 namespace mindspore {
50 namespace opt {
51 namespace {
52 constexpr auto kAnfPrimitiveIndex = 0;
53 constexpr auto kDeviceTypeNone = -1;
DeduceDimConvertion(schema::Format src_format,schema::Format dst_format,std::vector<int> * const perm)54 int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector<int> *const perm) {
55   MS_ASSERT(perm != nullptr);
56   auto src_format_str = std::string(schema::EnumNameFormat(src_format));
57   auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
58   if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
59     MS_LOG(ERROR) << "src_format or dst_format is error.";
60     return lite::RET_ERROR;
61   }
62   std::replace(src_format_str.begin(), src_format_str.end(), 'K', 'N');
63   std::replace(dst_format_str.begin(), dst_format_str.end(), 'K', 'N');
64   perm->clear();
65   std::unordered_map<char, int> dim_map;
66   for (size_t i = 0; i < src_format_str.size(); ++i) {
67     dim_map[src_format_str[i]] = i;
68   }
69   for (size_t i = 0; i < dst_format_str.size(); ++i) {
70     if (dim_map.find(dst_format_str[i]) == dim_map.end()) {
71       MS_LOG(ERROR) << "src_format and dst_format cannot match, please check.";
72       return RET_ERROR;
73     }
74     perm->push_back(dim_map[dst_format_str[i]]);
75   }
76   return lite::RET_OK;
77 }
78 
79 template <class T>
TransposeDim4(const ShapeVector & input_shape,const ShapeVector & output_shape,const std::vector<int> & perm,const T * const in_data,T * out_data)80 void TransposeDim4(const ShapeVector &input_shape, const ShapeVector &output_shape, const std::vector<int> &perm,
81                    const T *const in_data, T *out_data) {
82   auto num_axes = input_shape.size();
83   std::vector<int64_t> strides;
84   std::vector<int64_t> out_strides;
85   strides.resize(num_axes);
86   out_strides.resize(num_axes);
87   strides[num_axes - 1] = 1LL;
88   out_strides[num_axes - 1] = 1LL;
89   for (size_t i = num_axes - 1; i >= 1; i--) {
90     strides[i - 1] = input_shape[i] * strides[i];
91     out_strides[i - 1] = output_shape[i] * out_strides[i];
92   }
93   const auto stride0 = strides[perm[kIndex0]];
94   const auto stride1 = strides[perm[kIndex1]];
95   const auto stride2 = strides[perm[kIndex2]];
96   const auto stride3 = strides[perm[kIndex3]];
97   const auto out_stride0 = out_strides[kIndex0];
98   const auto out_stride1 = out_strides[kIndex1];
99   const auto out_stride2 = out_strides[kIndex2];
100   const auto output0 = output_shape[kIndex0];
101   const auto output1 = output_shape[kIndex1];
102   const auto output2 = output_shape[kIndex2];
103   const auto output3 = output_shape[kIndex3];
104 
105   int64_t out_beg_i = 0;
106   int64_t beg_i = 0;
107   for (int64_t i = 0; i < output0; ++i) {
108     int64_t out_beg_ij = out_beg_i;
109     int64_t beg_ij = beg_i;
110     for (int64_t j = 0; j < output1; ++j) {
111       int64_t out_beg_ijk = out_beg_ij;
112       int64_t beg_ijk = beg_ij;
113       for (int64_t k = 0; k < output2; ++k) {
114         for (int64_t m = 0; m < output3; ++m) {
115           out_data[out_beg_ijk + m] = in_data[beg_ijk + m * stride3];
116         }
117         out_beg_ijk += out_stride2;
118         beg_ijk += stride2;
119       }
120       out_beg_ij += out_stride1;
121       beg_ij += stride1;
122     }
123     out_beg_i += out_stride0;
124     beg_i += stride0;
125   }
126 }
127 
128 template <typename T>
DoTransposeData(const tensor::TensorPtr & tensor,schema::Format src_format,schema::Format dst_format)129 STATUS DoTransposeData(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
130   MS_ASSERT(tensor != nullptr);
131   auto origin_shape = tensor->shape_c();
132   if (origin_shape.size() != kInputSizeFour) {
133     MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << origin_shape.size();
134     return lite::RET_ERROR;
135   }
136   if (std::any_of(origin_shape.begin(), origin_shape.end(), [](int64_t val) { return val <= 0; })) {
137     MS_LOG(ERROR) << "the tensor's shape is invalid.";
138     return lite::RET_ERROR;
139   }
140   std::vector<int> perm;
141   if (DeduceDimConvertion(src_format, dst_format, &perm) != RET_OK) {
142     MS_LOG(ERROR) << "deduce perm failed.";
143     return lite::RET_ERROR;
144   }
145   ShapeVector new_shape;
146   for (auto &val : perm) {
147     if (val < 0 || static_cast<size_t>(val) >= origin_shape.size()) {
148       MS_LOG(ERROR) << "deduce perm is invalid.";
149       return lite::RET_ERROR;
150     }
151     new_shape.push_back(origin_shape[val]);
152   }
153   int64_t count = 1;
154   for (const auto &dat : origin_shape) {
155     if (INT_MUL_OVERFLOW(count, dat)) {
156       MS_LOG(ERROR) << "Int mul overflow";
157       return RET_ERROR;
158     }
159     count *= dat;
160   }
161   if (count <= 0 || count > static_cast<int64_t>(INT32_MAX)) {
162     MS_LOG(ERROR) << "tensor element num is too big, which should be smaller than int32_max.";
163     return RET_ERROR;
164   }
165   std::vector<T> buf(count);
166 
167   void *originWeightData = tensor->data_c();
168   MS_CHECK_TRUE_RET(originWeightData != nullptr, RET_ERROR);
169   T *weightData = static_cast<T *>(originWeightData);
170   TransposeDim4<T>(origin_shape, new_shape, perm, weightData, buf.data());
171   if (memcpy_s(tensor->data_c(), tensor->Size(), buf.data(), count * sizeof(T)) != EOK) {
172     MS_LOG(ERROR) << "memcpy_s failed.";
173     return RET_ERROR;
174   }
175   tensor->set_shape(new_shape);
176   return RET_OK;
177 }
178 
IsRealKernel(const AnfNodePtr & node)179 bool IsRealKernel(const AnfNodePtr &node) {
180   if (node == nullptr) {
181     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
182     return false;
183   }
184   // parameter and value node is not a real kernel too
185   if (!node->isa<CNode>()) {
186     return true;
187   }
188   auto cnode = node->cast<CNodePtr>();
189   if (cnode == nullptr) {
190     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
191     return false;
192   }
193   if (cnode->empty()) {
194     MS_LOG(ERROR) << "Illegal null input of cnode(%s)" << node->DebugString();
195     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR);
196     return false;
197   }
198   auto input = cnode->input(0);
199 #ifndef ENABLE_SECURITY
200   bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
201                          IsPrimitive(input, prim::kPrimTensorSummary) ||
202                          IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
203                          IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
204                          IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) ||
205                          IsPrimitive(input, prim::kPrimPartial);
206 #else
207   bool is_virtual_node = IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) ||
208                          IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) ||
209                          IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
210 #endif
211   return !is_virtual_node;
212 }
213 
CopyDataFromInt64(const int64_t * origin_data,int * tensor_data,size_t data_count)214 void CopyDataFromInt64(const int64_t *origin_data, int *tensor_data, size_t data_count) {
215   for (size_t i = 0; i < data_count; ++i) {
216     if (origin_data[i] == INT64_MAX) {
217       tensor_data[i] = INT32_MAX;
218     } else if (origin_data[i] == INT64_MIN) {
219       tensor_data[i] = INT32_MIN;
220     } else if (origin_data[i] > static_cast<int64_t>(INT32_MAX) || origin_data[i] < static_cast<int64_t>(INT32_MIN)) {
221       MS_LOG(WARNING) << "int64 data " << origin_data[i] << " cannot fit into int32";
222       tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN;
223     } else {
224       tensor_data[i] = static_cast<int>(origin_data[i]);
225     }
226   }
227 }
228 
CopyTensorDataFromTensorInfo(const tensor::TensorPtr & tensor_info,const std::shared_ptr<tensor::Tensor> & tensor_info_dst,size_t data_count,bool keep_origin_dtype)229 int CopyTensorDataFromTensorInfo(const tensor::TensorPtr &tensor_info,
230                                  const std::shared_ptr<tensor::Tensor> &tensor_info_dst, size_t data_count,
231                                  bool keep_origin_dtype) {
232   if (tensor_info->data_type() == kNumberTypeInt64 && !keep_origin_dtype) {
233     auto *tensor_data = reinterpret_cast<int *>(tensor_info_dst->data_c());
234     if (tensor_data == nullptr) {
235       MS_LOG(ERROR) << "new data failed";
236       return RET_ERROR;
237     }
238     auto *origin_data = reinterpret_cast<int64_t *>(tensor_info->data_c());
239     MS_CHECK_TRUE_MSG(origin_data != nullptr, lite::RET_NULL_PTR, "origin_data is nullptr");
240     CopyDataFromInt64(origin_data, tensor_data, data_count);
241   } else if (tensor_info->data_type() == kNumberTypeFloat64) {
242     auto *tensor_data = reinterpret_cast<float *>(tensor_info_dst->data_c());
243     if (tensor_data == nullptr) {
244       MS_LOG(ERROR) << "new data failed";
245       return RET_ERROR;
246     }
247     auto *origin_data = reinterpret_cast<double_t *>(tensor_info->data_c());
248     for (size_t i = 0; i < data_count; ++i) {
249       if (origin_data[i] > static_cast<double_t>(FLT_MAX) || origin_data[i] < static_cast<double_t>(-FLT_MAX)) {
250         MS_LOG(WARNING) << "float64 data " << origin_data[i] << " cannot fit into float32";
251         tensor_data[i] = origin_data[i] > 0 ? FLT_MAX : -FLT_MAX;
252       } else {
253         tensor_data[i] = static_cast<float>(origin_data[i]);
254       }
255     }
256   } else {
257     tensor_info_dst->set_data_type(tensor_info->data_type());
258     auto *tensor_data = reinterpret_cast<int8_t *>(tensor_info_dst->data_c());
259     if (tensor_data == nullptr) {
260       MS_LOG(ERROR) << "new data failed";
261       return RET_ERROR;
262     }
263     if (memcpy_s(tensor_data, tensor_info_dst->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) {
264       MS_LOG(ERROR) << "memcpy data failed.";
265       return RET_ERROR;
266     }
267   }
268   return RET_OK;
269 }
270 }  // namespace
271 
CheckInputs(const CNodePtr & cnode)272 bool CheckInputs(const CNodePtr &cnode) {
273   if (cnode == nullptr) {
274     MS_LOG(ERROR) << "cnode is nullptr.";
275     return false;
276   }
277   if (std::any_of(cnode->inputs().begin(), cnode->inputs().end(),
278                   [](const AnfNodePtr &anf_node) { return anf_node == nullptr; })) {
279     MS_LOG(ERROR) << "input is nullptr.";
280     return false;
281   }
282   return true;
283 }
284 
CastToInt(const ValuePtr & value)285 std::vector<int> CastToInt(const ValuePtr &value) {
286   if (value == nullptr) {
287     MS_LOG(WARNING) << "valueptr is nullptr.";
288     return {};
289   }
290   std::vector<int> cur_value = {};
291   if (utils::isa<ValueSequencePtr>(value)) {
292     if (!value->cast<ValueSequencePtr>()->value().empty()) {
293       auto data_type = value->cast<ValueSequencePtr>()->value().front()->type()->number_type();
294       if (data_type == kNumberTypeInt64) {
295         auto origin_value = GetValue<std::vector<int64_t>>(value);
296         std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
297                        [](int64_t index) { return static_cast<int>(index); });
298       } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
299         cur_value = GetValue<std::vector<int>>(value);
300       } else {
301         MS_LOG(ERROR) << "he function only process integer data.";
302         return {};
303       }
304     }
305   } else {
306     auto data_type = value->type()->number_type();
307     switch (data_type) {
308       case kNumberTypeInt64:
309         cur_value.push_back(static_cast<int>(GetValue<int64_t>(value)));
310         break;
311       case kNumberTypeInt:
312       case kNumberTypeInt32:
313         cur_value.push_back(GetValue<int>(value));
314         break;
315       case kNumberTypeBool:
316         cur_value.push_back(GetValue<bool>(value));
317         break;
318       default:
319         MS_LOG(ERROR) << "the function only process integer data.";
320         return {};
321     }
322   }
323   return cur_value;
324 }
325 
CastToVec2DInt(const ValuePtr & value)326 std::vector<std::vector<int>> CastToVec2DInt(const ValuePtr &value) {
327   if (value == nullptr) {
328     MS_LOG(WARNING) << "valueptr is nullptr.";
329     return {};
330   }
331 
332   std::vector<std::vector<int>> result_value;
333   if (utils::isa<ValueSequencePtr>(value)) {
334     auto data_type = value->cast<ValueSequencePtr>()
335                        ->value()
336                        .front()
337                        ->cast<ValueSequencePtr>()
338                        ->value()
339                        .front()
340                        ->type()
341                        ->number_type();
342     if (data_type == kNumberTypeInt64) {
343       auto origin_value = GetValue<std::vector<std::vector<int64_t>>>(value);
344       for (auto &i : origin_value) {
345         std::vector<int> cur_value;
346         std::transform(i.begin(), i.end(), std::back_inserter(cur_value),
347                        [](int64_t j) { return static_cast<int>(j); });
348         result_value.push_back(cur_value);
349       }
350     } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
351       result_value = GetValue<std::vector<std::vector<int>>>(value);
352     } else {
353       MS_LOG(ERROR) << "he function only process integer data.";
354       return result_value;
355     }
356   }
357   return result_value;
358 }
359 
CastToFloat(const ValuePtr & value)360 std::vector<float> CastToFloat(const ValuePtr &value) {
361   if (value == nullptr) {
362     MS_LOG(WARNING) << "valueptr is nullptr.";
363     return {};
364   }
365   std::vector<float> cur_value = {};
366   if (utils::isa<ValueSequencePtr>(value)) {
367     if (!value->cast<ValueSequencePtr>()->value().empty()) {
368       auto data_type = value->cast<ValueSequencePtr>()->value().front()->type()->number_type();
369       if (data_type == kNumberTypeFloat || data_type == kNumberTypeFloat32) {
370         cur_value = GetValue<std::vector<float>>(value);
371       } else {
372         MS_LOG(ERROR) << "the function only process float data.";
373         return {};
374       }
375     }
376   } else {
377     auto data_type = value->type()->number_type();
378     if (data_type == kNumberTypeFloat || data_type == kNumberTypeFloat32) {
379       cur_value.push_back(GetValue<float>(value));
380     } else {
381       MS_LOG(ERROR) << "the function only process float data.";
382       return {};
383     }
384   }
385   return cur_value;
386 }
387 
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)388 bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
389   if (node == nullptr || primitive_type == nullptr) {
390     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
391     return false;
392   }
393   if (node->isa<CNode>()) {
394     auto cnode = node->cast<CNodePtr>();
395     return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
396   } else if (node->isa<ValueNode>()) {
397     return IsPrimitive(node, primitive_type);
398   }
399   return false;
400 }
401 
GetPrimitiveType(const AnfNodePtr & node,std::string * name)402 STATUS GetPrimitiveType(const AnfNodePtr &node, std::string *name) {
403   if (name == nullptr) {
404     MS_LOG(ERROR) << "name is nulltr.";
405     return RET_ERROR;
406   }
407   if (node->isa<CNode>()) {
408     auto cnode = node->cast<CNodePtr>();
409     auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
410     if (primitive == nullptr) {
411       MS_LOG(ERROR) << "primitive is nullptr. " << cnode->fullname_with_scope();
412       return RET_ERROR;
413     }
414     if (CheckPrimitiveType(node, prim::kPrimCustom)) {
415       auto custom_prim = api::MakeShared<ops::Custom>(primitive);
416       MS_CHECK_TRUE_MSG(custom_prim != nullptr, RET_ERROR, "custom op is nullptr.");
417       *name = custom_prim->get_type();
418       return RET_OK;
419     } else {
420       *name = primitive->name();
421       return RET_OK;
422     }
423   } else if (node->isa<ValueNode>()) {
424     auto fn_value = GetValueNode<PrimitivePtr>(node);
425     CHECK_NULL_RETURN(fn_value);
426     *name = fn_value->name();
427     return RET_OK;
428   }
429   MS_LOG(ERROR) << "There is no name for this node";
430   return RET_ERROR;
431 }
432 
IsOpType(const BaseRef & n,const PrimitivePtr & prim)433 bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
434   if (utils::isa<AnfNodePtr>(n)) {
435     auto anf_node = utils::cast<AnfNodePtr>(n);
436     return CheckPrimitiveType(anf_node, prim);
437   }
438   return false;
439 }
440 
IsRealCNodeKernel(const AnfNodePtr & node)441 bool IsRealCNodeKernel(const AnfNodePtr &node) {
442   if (node == nullptr) {
443     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
444     MS_LOG(ERROR) << "node is nullptr";
445     return false;
446   }
447   // parameter and value node is not a real cnode kernel
448   if (!node->isa<CNode>()) {
449     return false;
450   }
451   // return considered as a real node
452   if (CheckPrimitiveType(node, prim::kPrimReturn)) {
453     return true;
454   }
455   return IsRealKernel(node);
456 }
IsGraphKernel(const AnfNodePtr & node)457 bool IsGraphKernel(const AnfNodePtr &node) {
458   if (node == nullptr) {
459     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
460     return false;
461   }
462   // graph kernel should be a real cnode kernel.
463   if (!IsRealCNodeKernel(node)) {
464     return false;
465   }
466 
467   auto cnode = node->cast<CNodePtr>();
468   if (cnode == nullptr) {
469     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
470     MS_LOG(ERROR) << "node is nullptr";
471     return false;
472   }
473   auto input = cnode->input(kAnfPrimitiveIndex);
474   // graph kernel should has func_graph as first input.
475   if (!IsValueNode<FuncGraph>(input)) {
476     return false;
477   }
478 
479   auto func_graph = GetValueNode<FuncGraphPtr>(input);
480   if (func_graph == nullptr) {
481     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
482     return false;
483   }
484   return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
485 }
486 
AddNewBiasNode(const float * bias_data,const FuncGraphPtr & func_graph,int kernel_num,TypeId type_id)487 ParameterPtr AddNewBiasNode(const float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) {
488   if (bias_data == nullptr || func_graph == nullptr) {
489     MS_LOG(ERROR) << "input parameter is nullptr.";
490     return nullptr;
491   }
492   auto bias_parameter = func_graph->add_parameter();
493   MS_ASSERT(bias_parameter != nullptr);
494   std::vector<int64_t> shape_vector = {kernel_num};
495   auto tensor_info =
496     lite::CreateTensorInfo(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t), shape_vector, type_id);
497   if (tensor_info == nullptr) {
498     MS_LOG(ERROR) << "create tensor info failed.";
499     return nullptr;
500   }
501   auto status = lite::InitParameterFromTensorInfo(bias_parameter, tensor_info);
502   if (status != RET_OK) {
503     MS_LOG(ERROR) << "init parameter from tensor info failed";
504     return nullptr;
505   }
506 
507   return bias_parameter;
508 }
509 
GetTensorInfo(const AnfNodePtr & node)510 tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node) {
511   MS_CHECK_TRUE_RET(node != nullptr, nullptr);
512   if (!utils::isa<ParameterPtr>(node)) {
513     if (utils::isa<ValueNodePtr>(node)) {
514       auto valueNode = node->cast<ValueNodePtr>();
515       auto value_ptr = valueNode->value();
516       MS_CHECK_TRUE_RET(value_ptr != nullptr, nullptr);
517       auto value = value_ptr->cast<tensor::TensorPtr>();
518       if (value != nullptr) {
519         return value;
520       }
521     }
522     MS_LOG(DEBUG) << "get lite param value node neither parameternode or valuenode";
523     return nullptr;
524   }
525   auto param = node->cast<ParameterPtr>();
526   MS_ASSERT(param != nullptr);
527   if (!param->has_default() || param->default_param() == nullptr) {
528     return nullptr;
529   }
530   auto tensor_info = param->default_param()->cast<tensor::TensorPtr>();
531   return tensor_info;
532 }
533 
GetCNodeInputAbstract(const CNodePtr & cnode,size_t index)534 AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index) {
535   if (cnode == nullptr) {
536     MS_LOG(ERROR) << "CNodePtr is nullptr";
537     return nullptr;
538   }
539   if (!(index > 0 && index < cnode->size())) {
540     return nullptr;
541   }
542   auto input = cnode->input(index);
543   if (input == nullptr) {
544     MS_LOG(ERROR) << "CNode input is nullptr";
545     return nullptr;
546   }
547 
548   AbstractBasePtr abstract = nullptr;
549   if (utils::isa<ParameterPtr>(input)) {
550     auto parameter = input->cast<ParameterPtr>();
551     abstract = parameter->abstract();
552   } else if (utils::isa<ValueNodePtr>(input)) {
553     auto value_node = input->cast<ValueNodePtr>();
554     abstract = value_node->abstract();
555   } else if (utils::isa<CNodePtr>(input)) {
556     auto input_cnode = input->cast<CNodePtr>();
557     if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
558       MS_ASSERT(input_cnode->size() == kTupleGetItemInputSize);
559       auto get_item_input_cnode = input_cnode->input(1);
560       MS_ASSERT(get_item_input_cnode != nullptr);
561       auto idx = GetTupleGetItemOutIndex(input_cnode);
562       if (!utils::isa<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
563         MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
564         return nullptr;
565       }
566       auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
567       auto abstract_list = abstract_tuple->elements();
568       if (abstract_list.size() <= idx) {
569         MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
570         return nullptr;
571       }
572       abstract = abstract_list[idx];
573     } else {
574       abstract = input_cnode->abstract();
575     }
576   } else {
577     MS_LOG(ERROR) << "unsupported input node type";
578     return nullptr;
579   }
580   return abstract;
581 }
582 
IsParamNode(const BaseRef & n)583 bool IsParamNode(const BaseRef &n) {
584   if (!utils::isa<ParameterPtr>(n)) {
585     return false;
586   }
587   auto parameter = utils::cast<ParameterPtr>(n);
588   if (!parameter->has_default() || parameter->default_param() == nullptr) {
589     return false;
590   }
591   auto tensor = parameter->default_param()->cast<tensor::TensorPtr>();
592   if (tensor == nullptr) {
593     return false;
594   }
595   return tensor->data_c() != nullptr;
596 }
597 
GetTensorInfoFromAbstract(tensor::TensorPtr * const tensor_info,const CNodePtr & cnode,size_t index)598 STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *const tensor_info, const CNodePtr &cnode, size_t index) {
599   CHECK_NULL_RETURN(tensor_info);
600   CHECK_NULL_RETURN(cnode);
601   AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index);
602   if (abstract == nullptr) {
603     MS_LOG(WARNING) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr, infershape is delayed.";
604     return RET_ERROR;
605   }
606   if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
607     MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor";
608     return RET_ERROR;
609   }
610   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
611   if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) {  // input node not complete infershape
612     MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
613     return RET_ERROR;
614   }
615   *tensor_info = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack());
616   if (*tensor_info == nullptr) {
617     MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr";
618     return RET_ERROR;
619   }
620   return RET_OK;
621 }
622 
IsParamOrValueNodeWithData(const BaseRef & n)623 bool IsParamOrValueNodeWithData(const BaseRef &n) {
624   if (utils::isa<ValueNode>(n)) {
625     auto value_node = utils::cast<ValueNodePtr>(n);
626     auto value = value_node->value();
627     if (value == nullptr) {
628       return false;
629     }
630     if (value->isa<tensor::Tensor>()) {
631       auto tensor = value->cast<tensor::TensorPtr>();
632       return tensor != nullptr && tensor->data_c() != nullptr;
633     } else if (value->isa<ValueSequence>()) {
634       auto sequence_ptr = value->cast<ValueSequencePtr>();
635       return sequence_ptr != nullptr && !sequence_ptr->value().empty();
636     } else {
637       return false;
638     }
639   }
640   if (utils::isa<ParameterPtr>(n)) {
641     return IsParamNode(n);
642   }
643   return false;
644 }
645 
IsParallelSplitConvNode(const BaseRef & n)646 bool IsParallelSplitConvNode(const BaseRef &n) {
647   if (utils::isa<AnfNodePtr>(n)) {
648     auto anf_node = utils::cast<AnfNodePtr>(n);
649     PrimitivePtr prim = nullptr;
650     if (utils::isa<CNodePtr>(anf_node)) {
651       prim = GetValueNode<PrimitivePtr>(anf_node->cast<CNodePtr>()->input(kAnfPrimitiveIndex));
652     }
653     if (utils::isa<ValueNodePtr>(anf_node)) {
654       prim = GetValueNode<PrimitivePtr>(anf_node);
655     }
656     if (prim == nullptr) {
657       return false;
658     }
659     int device_type =
660       prim->GetAttr(ops::kDeviceType) != nullptr ? GetValue<int32_t>(prim->GetAttr(ops::kDeviceType)) : kDeviceTypeNone;
661     if (device_type != kDeviceTypeNone) {
662       return false;
663     }
664     return CheckPrimitiveType(anf_node, prim::kPrimConv2DFusion) || CheckPrimitiveType(anf_node, prim::kPrimConv2D);
665   }
666   return false;
667 }
668 
IsConvNode(const BaseRef & n)669 bool IsConvNode(const BaseRef &n) {
670   if (utils::isa<AnfNodePtr>(n)) {
671     auto anf_node = utils::cast<AnfNodePtr>(n);
672     PrimitivePtr prim = nullptr;
673     if (utils::isa<CNodePtr>(anf_node)) {
674       prim = GetValueNode<PrimitivePtr>(anf_node->cast<CNodePtr>()->input(kAnfPrimitiveIndex));
675     }
676     if (utils::isa<ValueNodePtr>(anf_node)) {
677       prim = GetValueNode<PrimitivePtr>(anf_node);
678     }
679     if (prim == nullptr) {
680       return false;
681     }
682 
683     if (prim->GetAttr(ops::kActivationType) != nullptr &&
684         GetValue<int64_t>(prim->GetAttr(ops::kActivationType)) != NO_ACTIVATION) {
685       return false;
686     }
687 
688     bool is_depth_wise =
689       prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
690     return CheckPrimitiveType(anf_node, prim::kPrimConv2DFusion) ||
691            (CheckPrimitiveType(anf_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise);
692   }
693   return false;
694 }
695 
CheckIsAllInputsParam(const AnfNodePtr & node)696 bool CheckIsAllInputsParam(const AnfNodePtr &node) {
697   if (node == nullptr) {
698     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
699     MS_LOG(ERROR) << "node is nullptr";
700     return false;
701   }
702   if (utils::isa<CNode>(node)) {
703     auto cnode = node->cast<CNodePtr>();
704     for (size_t i = 1; i < cnode->size(); i++) {
705       if (!utils::isa<Parameter>(cnode->input(i)) && !utils::isa<ValueNodePtr>(cnode->input(i))) {
706         return false;
707       }
708     }
709     return true;
710   }
711   return false;
712 }
713 
GetOutputTensorNum(const AnfNodePtr & node)714 size_t GetOutputTensorNum(const AnfNodePtr &node) {
715   if (node == nullptr) {
716     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
717     MS_LOG(ERROR) << "node is nullptr";
718     return 0;
719   }
720   auto type = node->Type();
721   if (type == nullptr) {
722     return 1;
723   }
724   if (type->isa<Tuple>()) {
725     auto tuple_type = type->cast<TuplePtr>();
726     if (tuple_type == nullptr) {
727       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
728       MS_LOG(ERROR) << "typle_type is nullptr";
729       return 0;
730     }
731     return tuple_type->size();
732   } else if (type->isa<TensorType>() || type->isa<Number>()) {
733     return 1;
734   } else if (type->isa<TypeNone>()) {
735     return 0;
736   } else {
737     return 1;
738   }
739 }
740 
IsMultiOutputTensors(const FuncGraphPtr & graph,const AnfNodePtr & node)741 bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
742   if (graph == nullptr || node == nullptr) {
743     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
744     return false;
745   }
746   auto output_node_list = Helper::GetRealNodeUsedList(graph, node);
747   if (output_node_list == nullptr) {
748     MS_LOG(ERROR) << "output node list is nullptr";
749     return false;
750   }
751   if (output_node_list->size() != 1) {
752     MS_LOG(DEBUG) << "fusion node has multi output nodes";
753     return true;
754   }
755   return false;
756 }
757 
GetTupleGetItemRealInput(const CNodePtr & tuple_get_item)758 AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
759   if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) {
760     MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
761     return nullptr;
762   }
763   return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
764 }
765 
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)766 size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
767   if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) {
768     MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
769     return -1;
770   }
771   auto output_index_value_node = tuple_get_item->input(kInputIndexTwo);
772   if (output_index_value_node == nullptr) {
773     MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
774     return -1;
775   }
776   auto value_node = output_index_value_node->cast<ValueNodePtr>();
777   if (value_node == nullptr) {
778     MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
779     return -1;
780   }
781   auto indexes = CastToInt(value_node->value());
782   if (indexes.empty()) {
783     MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
784     return -1;
785   }
786   return indexes.front();
787 }
788 
GetListGetItemOutIndex(const CNodePtr & list_get_item)789 size_t GetListGetItemOutIndex(const CNodePtr &list_get_item) {
790   if (list_get_item == nullptr || list_get_item->size() != kInputSizeThree) {
791     MS_LOG(ERROR) << "The node list_get_item is invalid.";
792     return SIZE_MAX;
793   }
794   auto output_index_value_node = list_get_item->input(kInputIndexTwo);
795   if (output_index_value_node == nullptr) {
796     MS_LOG(ERROR) << "The node list_get_item is invalid.";
797     return SIZE_MAX;
798   }
799   auto value_node = output_index_value_node->cast<ValueNodePtr>();
800   if (value_node == nullptr) {
801     MS_LOG(ERROR) << "The node list_get_item is invalid.";
802     return SIZE_MAX;
803   }
804   auto indexes = CastToInt(value_node->value());
805   if (indexes.empty()) {
806     MS_LOG(ERROR) << "The node list_get_item is invalid.";
807     return SIZE_MAX;
808   }
809   return indexes.front();
810 }
811 
TransFilterFormat(const tensor::TensorPtr & tensor,schema::Format src_format,schema::Format dst_format)812 STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
813   MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
814   std::unordered_map<TypeId, std::function<STATUS(const tensor::TensorPtr &, schema::Format, schema::Format)>>
815     trans_func = {{kNumberTypeFloat32, DoTransposeData<float>},
816                   {kNumberTypeUInt8, DoTransposeData<uint8_t>},
817                   {kNumberTypeInt8, DoTransposeData<int8_t>},
818                   {kNumberTypeFloat16, DoTransposeData<float16>}};
819   auto data_type = tensor->data_type();
820   auto iter = trans_func.find(data_type);
821   if (iter == trans_func.end()) {
822     MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
823     return RET_ERROR;
824   }
825   return iter->second(tensor, src_format, dst_format);
826 }
827 
BuildParameterNode(const FuncGraphPtr & func_graph,const tensor::TensorPtr & tensor_info,const std::string & node_name,bool keep_origin_dtype)828 ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::TensorPtr &tensor_info,
829                                 const std::string &node_name, bool keep_origin_dtype) {
830   if (func_graph == nullptr || tensor_info == nullptr) {
831     MS_LOG(ERROR) << "input parameter is nullptr.";
832     return nullptr;
833   }
834   auto param_node = func_graph->add_parameter();
835   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
836   auto shape = tensor_info->shape();
837   std::vector<int64_t> shape_vector;
838   std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
839                  [](const int &val) { return static_cast<int64_t>(val); });
840   auto data_type = tensor_info->data_type();
841   if (tensor_info->data_type() == kNumberTypeFloat64 && !keep_origin_dtype) {
842     data_type = kNumberTypeFloat32;
843   }
844   if (tensor_info->data_type() == kNumberTypeInt64) {
845     data_type = kNumberTypeInt32;
846   }
847   param_node->set_name(node_name);
848   param_node->debug_info()->set_name(node_name);
849   auto tensor_info_new = std::make_shared<tensor::Tensor>(data_type, shape_vector);
850   if (tensor_info_new == nullptr) {
851     MS_LOG(ERROR) << "new tensor::Tensor failed.";
852     return nullptr;
853   }
854   int data_count = 1;
855   for (const auto &dat : shape) {
856     if (INT_MUL_OVERFLOW(data_count, static_cast<int>(dat))) {
857       MS_LOG(ERROR) << "Int mul overflow.";
858       return nullptr;
859     }
860     data_count *= static_cast<int>(dat);
861   }
862   if (data_count < 0) {
863     MS_LOG(ERROR) << "Invalid shape.";
864     return nullptr;
865   }
866   if (tensor_info->Size() == 0) {
867     auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
868     if (status != RET_OK) {
869       MS_LOG(ERROR) << "init parameter from tensor info failed";
870       return nullptr;
871     }
872     return param_node;
873   }
874 
875   if (CopyTensorDataFromTensorInfo(tensor_info, tensor_info_new, static_cast<size_t>(data_count), keep_origin_dtype) !=
876       RET_OK) {
877     MS_LOG(ERROR) << "copy tensor data failed";
878     return nullptr;
879   }
880 
881   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
882   if (status != RET_OK) {
883     MS_LOG(ERROR) << "init parameter from tensor info failed";
884     return nullptr;
885   }
886   param_node->set_default_param(tensor_info_new);
887   return param_node;
888 }
889 
BuildIntValueParameterNode(const FuncGraphPtr & func_graph,const int32_t & data,const std::string & node_name,bool empty_shape)890 ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data,
891                                         const std::string &node_name, bool empty_shape) {
892   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
893   auto param_node = func_graph->add_parameter();
894   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
895   param_node->set_name(node_name);
896   ShapeVector shape = empty_shape ? std::vector<int64_t>{} : std::vector<int64_t>{1};
897   auto tensor_info = lite::CreateTensorInfo(&data, sizeof(int32_t), shape, kNumberTypeInt32);
898   if (tensor_info == nullptr) {
899     MS_LOG(ERROR) << "Create tensor info failed";
900     return nullptr;
901   }
902 
903   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
904   if (status != RET_OK) {
905     MS_LOG(ERROR) << "init parameter from tensor info failed";
906     return nullptr;
907   }
908   return param_node;
909 }
910 
BuildIntVecValueNode(const FuncGraphPtr & func_graph,const std::vector<int32_t> & data)911 ValueNodePtr BuildIntVecValueNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data) {
912   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
913   auto value = MakeValue(data);
914   MS_CHECK_TRUE_RET(value != nullptr, nullptr);
915   auto value_node = std::make_shared<ValueNode>(value);
916   value_node->set_abstract(value->ToAbstract());
917   MS_EXCEPTION_IF_NULL(value_node);
918   func_graph->AddValueNode(value_node);
919   return value_node;
920 }
921 
BuildIntVecParameterNode(const FuncGraphPtr & func_graph,const std::vector<int32_t> & data,const std::string & node_name)922 ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
923                                       const std::string &node_name) {
924   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
925   auto param_node = func_graph->add_parameter();
926   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
927   param_node->set_name(node_name);
928 
929   std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
930   auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(int32_t), shape_vector, kNumberTypeInt32);
931   if (tensor_info == nullptr) {
932     MS_LOG(ERROR) << "Create tensor info failed";
933     return nullptr;
934   }
935 
936   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
937   if (status != RET_OK) {
938     MS_LOG(ERROR) << "init parameter from tensor info failed";
939     return nullptr;
940   }
941 
942   return param_node;
943 }
944 
BuildInt64VecParameterNode(const FuncGraphPtr & func_graph,const std::vector<int64_t> & data,const std::string & node_name)945 ParameterPtr BuildInt64VecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &data,
946                                         const std::string &node_name) {
947   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
948   auto param_node = func_graph->add_parameter();
949   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
950   param_node->set_name(node_name);
951 
952   std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
953   auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(int64_t), shape_vector, kNumberTypeInt64);
954   if (tensor_info == nullptr) {
955     MS_LOG(ERROR) << "Create tensor info failed!";
956     return nullptr;
957   }
958 
959   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
960   if (status != RET_OK) {
961     MS_LOG(ERROR) << "init parameter from tensor info failed!";
962     return nullptr;
963   }
964 
965   return param_node;
966 }
967 
BuildIntVec2DParameterNode(const FuncGraphPtr & func_graph,const std::vector<std::vector<int32_t>> & data,const std::string & node_name)968 ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<int32_t>> &data,
969                                         const std::string &node_name) {
970   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
971   auto param_node = func_graph->add_parameter();
972   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
973   param_node->set_name(node_name);
974 
975   MS_CHECK_TRUE_RET(!data.empty(), nullptr);
976   std::vector<int64_t> shape_vector;
977   shape_vector.push_back(data.size());
978   shape_vector.push_back(data.at(0).size());
979 
980   std::vector<int32_t> data_1d;
981   for (auto pair : data) {
982     data_1d.insert(data_1d.end(), pair.begin(), pair.end());
983   }
984 
985   auto size = data_1d.size() * sizeof(int32_t);
986   auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeInt32);
987   if (tensor_info == nullptr) {
988     MS_LOG(ERROR) << "Create tensor info failed";
989     return nullptr;
990   }
991   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
992   if (status != RET_OK) {
993     MS_LOG(ERROR) << "init parameter from tensor info failed";
994     return nullptr;
995   }
996   return param_node;
997 }
998 
BuildFloatValueParameterNode(const FuncGraphPtr & func_graph,const float & data,const std::string & node_name,bool empty_shape)999 ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
1000                                           const std::string &node_name, bool empty_shape) {
1001   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1002   auto param_node = func_graph->add_parameter();
1003   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1004   param_node->set_name(node_name);
1005 
1006   ShapeVector shape = empty_shape ? std::vector<int64_t>{} : std::vector<int64_t>{1};
1007   auto tensor_info = lite::CreateTensorInfo(&data, sizeof(float), shape, kNumberTypeFloat32);
1008   if (tensor_info == nullptr) {
1009     MS_LOG(ERROR) << "Create tensor info failed";
1010     return nullptr;
1011   }
1012   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1013   if (status != RET_OK) {
1014     MS_LOG(ERROR) << "init parameter from tensor info failed";
1015     return nullptr;
1016   }
1017   return param_node;
1018 }
1019 
BuildInt64ValueParameterNode(const FuncGraphPtr & func_graph,const int64_t & data,const std::string & node_name,bool empty_shape)1020 ParameterPtr BuildInt64ValueParameterNode(const FuncGraphPtr &func_graph, const int64_t &data,
1021                                           const std::string &node_name, bool empty_shape) {
1022   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1023   auto param_node = func_graph->add_parameter();
1024   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1025   param_node->set_name(node_name);
1026   ShapeVector shape = empty_shape ? std::vector<int64_t>{} : std::vector<int64_t>{1};
1027   auto tensor_info = lite::CreateTensorInfo(&data, sizeof(int64_t), shape, kNumberTypeInt64);
1028   if (tensor_info == nullptr) {
1029     MS_LOG(ERROR) << "Create tensor info failed!";
1030     return nullptr;
1031   }
1032   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1033   if (status != RET_OK) {
1034     MS_LOG(ERROR) << "init parameter from tensor info failed!";
1035     return nullptr;
1036   }
1037   return param_node;
1038 }
1039 
BuildFloat16ValueParameterNode(const FuncGraphPtr & func_graph,const float & data,const std::string & node_name,bool empty_shape)1040 ParameterPtr BuildFloat16ValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
1041                                             const std::string &node_name, bool empty_shape) {
1042   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1043   auto param_node = func_graph->add_parameter();
1044   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1045   param_node->set_name(node_name);
1046 
1047   ShapeVector shape = empty_shape ? std::vector<int64_t>{} : std::vector<int64_t>{1};
1048   auto tensor_info = lite::CreateTensorInfo(&data, sizeof(float16), shape, kNumberTypeFloat16);
1049   if (tensor_info == nullptr) {
1050     MS_LOG(ERROR) << "Create tensor info failed";
1051     return nullptr;
1052   }
1053   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1054   if (status != RET_OK) {
1055     MS_LOG(ERROR) << "init parameter from tensor info failed";
1056     return nullptr;
1057   }
1058   return param_node;
1059 }
1060 
BuildFloat16VecParameterNode(const FuncGraphPtr & func_graph,const std::vector<float16> & data,const std::string & node_name)1061 ParameterPtr BuildFloat16VecParameterNode(const FuncGraphPtr &func_graph, const std::vector<float16> &data,
1062                                           const std::string &node_name) {
1063   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1064   auto param_node = func_graph->add_parameter();
1065   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1066   param_node->set_name(node_name);
1067 
1068   std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
1069   auto tensor_info =
1070     lite::CreateTensorInfo(data.data(), data.size() * sizeof(float16), shape_vector, kNumberTypeFloat16);
1071   if (tensor_info == nullptr) {
1072     MS_LOG(ERROR) << "Create tensor info failed";
1073     return nullptr;
1074   }
1075 
1076   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1077   if (status != RET_OK) {
1078     MS_LOG(ERROR) << "init parameter from tensor info failed";
1079     return nullptr;
1080   }
1081 
1082   return param_node;
1083 }
1084 
BuildFloatVecParameterNode(const FuncGraphPtr & func_graph,const std::vector<float> & data,const std::string & node_name)1085 ParameterPtr BuildFloatVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<float> &data,
1086                                         const std::string &node_name) {
1087   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1088   auto param_node = func_graph->add_parameter();
1089   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1090   param_node->set_name(node_name);
1091 
1092   std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
1093   auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(float), shape_vector, kNumberTypeFloat);
1094   if (tensor_info == nullptr) {
1095     MS_LOG(ERROR) << "Create tensor info failed";
1096     return nullptr;
1097   }
1098 
1099   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1100   if (status != RET_OK) {
1101     MS_LOG(ERROR) << "init parameter from tensor info failed";
1102     return nullptr;
1103   }
1104 
1105   return param_node;
1106 }
1107 
BuildFloatVec2DParameterNode(const FuncGraphPtr & func_graph,const std::vector<std::vector<float>> & data,const std::string & node_name)1108 ParameterPtr BuildFloatVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<float>> &data,
1109                                           const std::string &node_name) {
1110   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1111   auto param_node = func_graph->add_parameter();
1112   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1113   param_node->set_name(node_name);
1114 
1115   MS_CHECK_TRUE_RET(!data.empty(), nullptr);
1116   std::vector<int64_t> shape_vector;
1117   shape_vector.push_back(data.size());
1118   shape_vector.push_back(data.at(0).size());
1119 
1120   std::vector<float> data_1d;
1121   for (auto pair : data) {
1122     data_1d.insert(data_1d.end(), pair.begin(), pair.end());
1123   }
1124 
1125   auto size = data_1d.size() * sizeof(float);
1126   auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeFloat32);
1127   if (tensor_info == nullptr) {
1128     MS_LOG(ERROR) << "Create tensor info failed";
1129     return nullptr;
1130   }
1131   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1132   if (status != RET_OK) {
1133     MS_LOG(ERROR) << "init parameter from tensor info failed";
1134     return nullptr;
1135   }
1136   return param_node;
1137 }
1138 
GenTransposeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & perm,const std::string & cnode_name)1139 CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &perm,
1140                           const std::string &cnode_name) {
1141   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1142   MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
1143   auto perm_node = BuildIntVecParameterNode(func_graph, perm, cnode_name + "_perm");
1144   MS_ASSERT(perm_node != nullptr);
1145   ops::Transpose transpose_node;
1146   auto trans_prim = transpose_node.GetPrim();
1147   MS_CHECK_TRUE_RET(trans_prim != nullptr, nullptr);
1148   auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node});
1149   MS_ASSERT(cnode != nullptr);
1150   auto manager = Manage(func_graph);
1151   MS_ASSERT(manager != nullptr);
1152   manager->SetEdge(cnode, 1, input_node);
1153   manager->SetEdge(cnode, kInputIndexTwo, perm_node);
1154   cnode->set_fullname_with_scope(cnode_name);
1155   auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeTwo, 1);
1156   MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1157   trans_prim->AddAttr("quant_params", quant_params_holder);
1158   auto input_abstract = input_node->abstract();
1159   if (input_abstract != nullptr) {
1160     auto abstract = input_abstract->Clone();
1161     MS_CHECK_TRUE_RET(abstract != nullptr, nullptr);
1162     FormatTransNodeType perm_type = perm == kNC2NH ? kNCHW2NHWC : (perm == kNH2NC ? kNHWC2NCHW : kNONE);
1163     if (ConvertAbstractFormatShape(abstract, perm_type) != RET_OK) {
1164       MS_LOG(WARNING) << "Convert abstract failed for node: " << cnode->fullname_with_scope();
1165       return cnode;
1166     }
1167     cnode->set_abstract(abstract);
1168   }
1169   return cnode;
1170 }
1171 
GenCastNode(const FuncGraphPtr & graph,const AnfNodePtr & input_node,const std::string & cnode_name,const TypeId dst_type,const AbstractBasePtr & abstract)1172 CNodePtr GenCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input_node, const std::string &cnode_name,
1173                      const TypeId dst_type, const AbstractBasePtr &abstract) {
1174   MS_CHECK_TRUE_RET(graph != nullptr, nullptr);
1175   MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
1176   ops::Cast cast_node;
1177   auto new_cast_c = cast_node.GetPrim();
1178   if (new_cast_c == nullptr) {
1179     MS_LOG(ERROR) << "new_cast_c is nullptr";
1180     return nullptr;
1181   }
1182   TypePtr dst_type_ptr = TypeIdToType(dst_type);
1183   if (dst_type_ptr == nullptr) {
1184     MS_LOG(ERROR) << "dst_type_ptr is nullptr";
1185     return nullptr;
1186   }
1187   new_cast_c->AddAttr(ops::kDstType, dst_type_ptr);
1188   ValueNodePtr value_node = NewValueNode(new_cast_c);
1189   if (value_node == nullptr) {
1190     MS_LOG(ERROR) << "NewValueNode Failed";
1191     return nullptr;
1192   }
1193 
1194   auto dtype_value = MakeValue(dst_type_ptr);
1195   auto dtype_value_node = NewValueNode(dtype_value);
1196   dtype_value_node->set_abstract(dtype_value->ToAbstract());
1197   graph->AddValueNode(dtype_value_node);
1198 
1199   auto cast_cnode = graph->NewCNode({value_node});
1200   if (cast_cnode == nullptr) {
1201     MS_LOG(ERROR) << "new_cnode is nullptr";
1202     return nullptr;
1203   }
1204   cast_cnode->set_fullname_with_scope(cnode_name);
1205   cast_cnode->set_abstract(abstract);
1206   auto manager = Manage(graph);
1207   (void)manager->Replace(input_node, cast_cnode);
1208   manager->AddEdge(cast_cnode, input_node);
1209   manager->AddEdge(cast_cnode, dtype_value_node);
1210   return cast_cnode;
1211 }
1212 
GenReshapeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & shape,const std::string & cnode_name)1213 CNodePtr GenReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &shape,
1214                         const std::string &cnode_name) {
1215   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1216   MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
1217   auto reshape_prim = std::make_shared<ops::Reshape>();
1218   if (reshape_prim == nullptr) {
1219     MS_LOG(ERROR) << "create reshape failed.";
1220     return nullptr;
1221   }
1222   auto prim_c = reshape_prim->GetPrim();
1223   prim_c->set_attr("shape", MakeValue(shape));
1224   ValueNodePtr value_node = NewValueNode(prim_c);
1225   MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "Create value_node return nullptr");
1226   auto new_shape_node = opt::BuildIntVecParameterNode(func_graph, shape, cnode_name + "_shape");
1227   MS_CHECK_TRUE_MSG(new_shape_node != nullptr, nullptr, "Create shape parameter return nullptr");
1228   std::vector<AnfNodePtr> op_inputs = {value_node, input_node, new_shape_node};
1229   auto reshape_cnode = func_graph->NewCNode(op_inputs);
1230   MS_CHECK_TRUE_MSG(reshape_cnode != nullptr, nullptr, "Create cnode return nullptr");
1231   reshape_cnode->set_fullname_with_scope(cnode_name);
1232   return reshape_cnode;
1233 }
1234 
GenGatherNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & indices,const std::string & cnode_name,const std::vector<int> & axis)1235 CNodePtr GenGatherNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &indices,
1236                        const std::string &cnode_name, const std::vector<int> &axis) {
1237   if (func_graph == nullptr || input_node == nullptr) {
1238     MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1239     return nullptr;
1240   }
1241   auto indices_node = BuildIntVecParameterNode(func_graph, indices, cnode_name + "_indices");
1242   if (indices_node == nullptr) {
1243     MS_LOG(ERROR) << "make indices node failed.";
1244     return nullptr;
1245   }
1246   auto axis_node = BuildIntVecParameterNode(func_graph, axis, cnode_name + "_axis");
1247   if (axis_node == nullptr) {
1248     MS_LOG(ERROR) << "make indices node failed.";
1249     return nullptr;
1250   }
1251   ops::Gather gather_node;
1252   auto gather_prim = gather_node.GetPrim();
1253   MS_CHECK_TRUE_RET(gather_prim != nullptr, nullptr);
1254   auto cnode = func_graph->NewCNode(gather_prim, {input_node, indices_node, axis_node});
1255   MS_ASSERT(cnode != nullptr);
1256   auto manager = Manage(func_graph);
1257   MS_ASSERT(manager != nullptr);
1258   manager->SetEdge(cnode, 1, input_node);
1259   manager->SetEdge(cnode, kInputIndexTwo, indices_node);
1260   manager->SetEdge(cnode, kInputIndexThree, axis_node);
1261   cnode->set_fullname_with_scope(cnode_name);
1262   auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeThree, 1);
1263   MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1264   gather_prim->AddAttr("quant_params", quant_params_holder);
1265   return cnode;
1266 }
1267 
GenGatherNodeDynamicIndex(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const AnfNodePtr & indices_node,const std::string & cnode_name,const std::vector<int> & axis)1268 CNodePtr GenGatherNodeDynamicIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
1269                                    const AnfNodePtr &indices_node, const std::string &cnode_name,
1270                                    const std::vector<int> &axis) {
1271   if (func_graph == nullptr || input_node == nullptr || indices_node == nullptr) {
1272     MS_LOG(ERROR) << "Input parameter is nullptr, which is nullptr!";
1273     return nullptr;
1274   }
1275   auto axis_node = BuildIntVecParameterNode(func_graph, axis, cnode_name + "_axis");
1276   if (axis_node == nullptr) {
1277     MS_LOG(ERROR) << "Build axis node failed!";
1278     return nullptr;
1279   }
1280   ops::Gather gather_node;
1281   auto gather_prim = gather_node.GetPrim();
1282   MS_CHECK_TRUE_RET(gather_prim != nullptr, nullptr);
1283   auto cnode = func_graph->NewCNode(gather_prim, {input_node, indices_node, axis_node});
1284   MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
1285   auto manager = Manage(func_graph);
1286   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
1287   manager->SetEdge(cnode, 1, input_node);
1288   manager->SetEdge(cnode, kInputIndexTwo, indices_node);
1289   manager->SetEdge(cnode, kInputIndexThree, axis_node);
1290   cnode->set_fullname_with_scope(cnode_name);
1291   if (input_node->abstract() != nullptr) {
1292     cnode->set_abstract(input_node->abstract()->Clone());
1293   }
1294   auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeThree, 1);
1295   MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1296   gather_prim->AddAttr("quant_params", quant_params_holder);
1297   return cnode;
1298 }
1299 
GenConcatNode(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & input_node_vec,const std::string & cnode_name,int64_t axis)1300 CNodePtr GenConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &input_node_vec,
1301                        const std::string &cnode_name, int64_t axis) {
1302   if (func_graph == nullptr) {
1303     MS_LOG(ERROR) << "func_graph is nullptr, which is invalid.";
1304     return nullptr;
1305   }
1306   ops::Concat concat_node;
1307   concat_node.set_axis(axis);
1308   auto concat_prim = concat_node.GetPrim();
1309   MS_CHECK_TRUE_RET(concat_prim != nullptr, nullptr);
1310   auto cnode = func_graph->NewCNode(concat_prim, input_node_vec);
1311   MS_ASSERT(cnode != nullptr);
1312   auto manager = Manage(func_graph);
1313   MS_ASSERT(manager != nullptr);
1314   cnode->set_fullname_with_scope(cnode_name);
1315   auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(input_node_vec.size(), 1);
1316   MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1317   concat_prim->AddAttr("quant_params", quant_params_holder);
1318   return cnode;
1319 }
1320 
GenTupleGetItemNode(const FuncGraphPtr & func_graph,const CNodePtr & input,size_t index)1321 CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index) {
1322   if (func_graph == nullptr || input == nullptr) {
1323     MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1324     return nullptr;
1325   }
1326   auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
1327   MS_CHECK_TRUE_RET(tuple_get_item_prim != nullptr, nullptr);
1328   auto second_input = NewValueNode(MakeValue<int64_t>(index));
1329   MS_CHECK_TRUE_RET(second_input != nullptr, nullptr);
1330   auto tuple_get_item_prim_c = tuple_get_item_prim->GetPrim();
1331   MS_CHECK_TRUE_RET(tuple_get_item_prim_c != nullptr, nullptr);
1332   auto tuple_cnode = func_graph->NewCNode(tuple_get_item_prim_c, {input, second_input});
1333   MS_ASSERT(tuple_cnode != nullptr);
1334   tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index));
1335   return tuple_cnode;
1336 }
1337 
FetchShapeFromAbstract(const abstract::AbstractBasePtr & abstract,ShapeVector * shape)1338 STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape) {
1339   if (abstract == nullptr || shape == nullptr) {
1340     MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1341     return lite::RET_ERROR;
1342   }
1343   if (!utils::isa<abstract::AbstractTensor>(abstract)) {
1344     MS_LOG(ERROR) << "abstract of cnode is invalid.";
1345     return lite::RET_ERROR;
1346   }
1347   auto abstract_tensor = abstract->cast<abstract::AbstractTensorPtr>();
1348   if (abstract_tensor->BuildShape() == nullptr || !utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
1349     MS_LOG(ERROR) << "shape of cnode's output is invalid.";
1350     return lite::RET_ERROR;
1351   }
1352   *shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
1353   return lite::RET_OK;
1354 }
1355 
IsTrainOp(const CNodePtr & cnode)1356 bool IsTrainOp(const CNodePtr &cnode) {
1357   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1358   if (prim == nullptr) {
1359     return false;
1360   }
1361   auto cnode_type = prim->name();
1362   // optimizer op
1363   if (cnode_type == "Adam" || cnode_type == "SGD" || cnode_type == "ApplyMomentum") {
1364     return true;
1365   }
1366   // loss op
1367   if (cnode_type == "SoftmaxCrossEntropyWithLogits" || cnode_type == "SparseSoftmaxCrossEntropyWithLogits" ||
1368       cnode_type == "SmoothL1Loss" || cnode_type == "SmoothL1LossGrad" ||
1369       cnode_type == "SigmoidCrossEntropyWithLogits" || cnode_type == "SigmoidCrossEntropyWithLogitsGrad") {
1370     return true;
1371   }
1372   // grad op
1373   if (cnode_type.find("Grad") != std::string::npos ||
1374       cnode->fullname_with_scope().find("Gradients") != std::string::npos) {
1375     return true;
1376   }
1377   return false;
1378 }
1379 
IsMarkedTrainOp(const CNodePtr & cnode)1380 bool IsMarkedTrainOp(const CNodePtr &cnode) {
1381   if (cnode == nullptr) {
1382     return false;
1383   }
1384   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1385   MS_CHECK_TRUE_RET(prim != nullptr, false);
1386   if (prim->GetAttr("trainOp") != nullptr && GetValue<bool>(prim->GetAttr("trainOp"))) {
1387     MS_LOG(DEBUG) << "train op not fusion.";
1388     return true;
1389   }
1390   return false;
1391 }
1392 
GetOutputSize(const AnfNodePtr & anf_node)1393 size_t GetOutputSize(const AnfNodePtr &anf_node) {
1394   if (anf_node == nullptr) {
1395     MS_LOG(ERROR) << "anf_node is nullptr.";
1396     return RET_ERROR;
1397   }
1398   AbstractBasePtr abstract_base;
1399   if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
1400     abstract_base = anf_node->cast<CNodePtr>()->input(1)->abstract();
1401   } else {
1402     abstract_base = anf_node->abstract();
1403   }
1404   // used for multi output e.g. split.
1405   if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
1406     auto abstract_tuple = abstract_base->cast<abstract::AbstractTuplePtr>();
1407     return abstract_tuple->elements().size();
1408   }
1409   return 1;
1410 }
1411 
GetAnfNodeOutputShape(const AnfNodePtr & node,size_t output_idx)1412 ShapeVector GetAnfNodeOutputShape(const AnfNodePtr &node, size_t output_idx) {
1413   if (node == nullptr) {
1414     MS_LOG(ERROR) << "anf_node is nullptr.";
1415     return {};
1416   }
1417   auto as_value_node = node->cast<ValueNodePtr>();
1418   if (as_value_node) {
1419     auto value = as_value_node->value();
1420     auto tensor = value->cast<tensor::TensorPtr>();
1421     if (tensor) {
1422       return tensor->shape_c();
1423     }
1424     return {};
1425   }
1426   auto base_shape = node->Shape();
1427   if (base_shape == nullptr) {
1428     MS_LOG(INFO) << "Failed to get shape from node " << node->fullname_with_scope();
1429     return {};
1430   }
1431   if (base_shape->isa<abstract::Shape>()) {
1432     if (output_idx != 0) {
1433       MS_LOG(EXCEPTION) << "The node " << node->fullname_with_scope() << "is a single output node but got index ["
1434                         << output_idx;
1435     }
1436     auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
1437     MS_EXCEPTION_IF_NULL(shape_ptr);
1438     return shape_ptr->shape();
1439   } else if (base_shape->isa<abstract::NoShape>()) {
1440     return ShapeVector();
1441   } else if (base_shape->isa<abstract::TupleShape>()) {
1442     auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
1443     MS_EXCEPTION_IF_NULL(tuple_shape);
1444     if (output_idx >= tuple_shape->size()) {
1445       MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
1446                         << node->fullname_with_scope();
1447     }
1448     auto b_shp = (*tuple_shape)[output_idx];
1449     if (b_shp->isa<abstract::Shape>()) {
1450       auto shape_ptr = b_shp->cast<abstract::ShapePtr>();
1451       MS_EXCEPTION_IF_NULL(shape_ptr);
1452       return shape_ptr->shape();
1453     } else if (b_shp->isa<abstract::NoShape>()) {
1454       return ShapeVector();
1455     } else if (b_shp->isa<abstract::TupleShape>()) {
1456       MS_LOG(INFO) << "The output shape of node:" << node->fullname_with_scope() << " index:" << output_idx
1457                    << " is a TupleShape:" << base_shape->ToString();
1458       return ShapeVector();
1459     } else {
1460       MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
1461                         << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
1462                         << "node :" << node->fullname_with_scope() << ".";
1463     }
1464   }
1465   return ShapeVector();
1466 }
1467 
GetDataTypeFromAnfNode(const AnfNodePtr & anf_node,TypeId * type_id)1468 int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
1469   if (anf_node == nullptr || type_id == nullptr) {
1470     MS_LOG(ERROR) << "anf_node or type_id is nullptr.";
1471     return RET_ERROR;
1472   }
1473   AbstractBasePtr abstract_base;
1474   if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
1475     abstract_base = anf_node->cast<CNodePtr>()->input(1)->abstract();
1476   } else {
1477     abstract_base = anf_node->abstract();
1478   }
1479   // used for multi output e.g. split.
1480   if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
1481     auto abstract_tuple = abstract_base->cast<abstract::AbstractTuplePtr>();
1482     if (abstract_tuple->elements().empty()) {
1483       MS_LOG(ERROR) << "abstract_tuple elements is empty.";
1484       return RET_ERROR;
1485     }
1486     abstract_base = abstract_tuple->elements().front();
1487   }
1488   if (abstract_base == nullptr) {
1489     MS_LOG(INFO) << "Abstract of parameter is nullptr, " << anf_node->fullname_with_scope();
1490     *type_id = kTypeUnknown;
1491     return lite::RET_NOT_SUPPORT;
1492   }
1493   if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
1494     auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
1495     MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
1496     auto type_ptr = abstract_tensor->element()->GetTypeTrack();
1497     MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
1498     *type_id = type_ptr->type_id();
1499   } else if (utils::isa<abstract::AbstractScalarPtr>(abstract_base)) {
1500     auto abstract_scalar = utils::cast<abstract::AbstractScalarPtr>(abstract_base);
1501     auto type_ptr = abstract_scalar->GetTypeTrack();
1502     MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
1503     *type_id = type_ptr->type_id();
1504   } else {
1505     MS_LOG(ERROR) << anf_node->fullname_with_scope() << " is unsupported type:" << abstract_base->type_name();
1506     return RET_ERROR;
1507   }
1508   return RET_OK;
1509 }
1510 
IsQuantParameterNode(const PrimitivePtr & prim)1511 bool IsQuantParameterNode(const PrimitivePtr &prim) {
1512   MS_CHECK_TRUE_RET(prim != nullptr, false);
1513   auto quant_attr = prim->GetAttr("quant_params");
1514   if (quant_attr != nullptr) {
1515     auto quant_param_holder = quant_attr->cast<lite::QuantParamHolderPtr>();
1516     MS_CHECK_TRUE_RET(quant_param_holder != nullptr, false);
1517     auto quant_params = quant_param_holder->get_input_quant_params();
1518     bool is_quant = std::any_of(quant_params.begin(), quant_params.end(), [](std::vector<schema::QuantParamT> &params) {
1519       return !params.empty() && params.front().inited;
1520     });
1521     if (is_quant) {
1522       return true;
1523     }
1524   }
1525   return false;
1526 }
1527 
UpdateManager(const FuncGraphPtr & func_graph)1528 void UpdateManager(const FuncGraphPtr &func_graph) {
1529   auto manager = func_graph->manager();
1530   if (manager == nullptr) {
1531     manager = Manage(func_graph, true);
1532   } else {
1533     manager->Clear();
1534     manager->AddFuncGraph(func_graph, true);
1535   }
1536   std::set<FuncGraphPtr> all_func_graphs;
1537   mindspore::lite::GetAllFuncGraph(func_graph, &all_func_graphs);
1538   for (auto &one_func_graph : all_func_graphs) {
1539     manager->AddFuncGraph(one_func_graph);
1540   }
1541 }
1542 
GetRealCertainVarInput(const CNodePtr & cnode,size_t index)1543 std::pair<CNodePtr, int> GetRealCertainVarInput(const CNodePtr &cnode, size_t index) {
1544   MS_CHECK_TRUE_MSG(cnode != nullptr, {}, "function's parameter is nullptr.");
1545   MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, {}, "required input is nullptr");
1546   auto real_input_cnode = cnode->input(index)->cast<CNodePtr>();
1547   if (real_input_cnode == nullptr) {
1548     MS_LOG(DEBUG) << "input node is not a cnode.";
1549     return {};
1550   }
1551   int item_index = 0;
1552   if (opt::CheckPrimitiveType(real_input_cnode, prim::kPrimTupleGetItem)) {
1553     auto index_node = real_input_cnode->input(opt::kInputIndexTwo);
1554     MS_CHECK_TRUE_MSG(index_node != nullptr, {}, "tuple_get_item's second input is nullptr.");
1555     MS_CHECK_TRUE_MSG(index_node->isa<ValueNode>(), {}, "tuple_get_item's second input should be valuenode.");
1556     auto index_ptr = index_node->cast<ValueNodePtr>()->value();
1557     MS_CHECK_TRUE_MSG(index_ptr != nullptr, {}, "tuple_get_item's second input val is nullptr.");
1558     auto value = CastToInt(index_ptr);
1559     MS_CHECK_TRUE_MSG(value.size() == 1, {}, "tuple_get_item's second input is invalid.");
1560     item_index = value.front();
1561     MS_CHECK_TRUE_MSG(real_input_cnode->input(1) != nullptr, {}, "tuple_get_item's first input is nullptr");
1562     real_input_cnode = real_input_cnode->input(1)->cast<CNodePtr>();
1563     MS_CHECK_TRUE_MSG(real_input_cnode != nullptr, {}, "tuple_get_item first input is not cnode.");
1564   }
1565   return {real_input_cnode, item_index};
1566 }
1567 
DetermineCertainVarInputHasInferred(const CNodePtr & cnode,size_t index,bool * infer_succ)1568 int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ) {
1569   MS_CHECK_TRUE_MSG(cnode != nullptr && infer_succ != nullptr, RET_ERROR, "function's parameter is nullptr.");
1570   auto var_input_info = GetRealCertainVarInput(cnode, index);
1571   if (var_input_info.first == nullptr) {
1572     MS_LOG(ERROR) << "cannot get the real var input.";
1573     return RET_ERROR;
1574   }
1575   auto real_input_cnode = var_input_info.first;
1576   auto item_index = var_input_info.second;
1577   auto input_node_prim = GetValueNode<PrimitivePtr>((real_input_cnode->input(0)));
1578   MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "get primitive failed.");
1579   *infer_succ = false;
1580   auto value_ptr = input_node_prim->GetAttr(kInferDone);
1581   if (value_ptr != nullptr) {
1582     MS_CHECK_TRUE_MSG(value_ptr->isa<BoolImm>(), RET_ERROR, "value is not a boolean.");
1583     *infer_succ = GetValue<bool>(value_ptr);
1584   }
1585   value_ptr = input_node_prim->GetAttr(kInferFlags);
1586   if (value_ptr == nullptr) {
1587     return RET_OK;
1588   }
1589   MS_CHECK_TRUE_MSG(value_ptr->isa<ValueSequeue>(), RET_ERROR, "infer flag should be a vector.");
1590   auto value_sequence = value_ptr->cast<ValueSequeuePtr>();
1591   auto elements = value_sequence->value();
1592   MS_CHECK_TRUE_MSG(!elements.empty(), RET_ERROR, "infer_info has no content.");
1593   auto first_element = elements.front();
1594   MS_CHECK_TRUE_MSG(first_element != nullptr, RET_ERROR, "element is a nullptr.");
1595   MS_CHECK_TRUE_MSG(first_element->isa<BoolImm>(), RET_ERROR, "each element is not a boolean.");
1596   auto infer_infos = GetValue<std::vector<bool>>(value_ptr);
1597   MS_CHECK_TRUE_MSG(item_index >= 0 && static_cast<size_t>(item_index) < infer_infos.size(), RET_ERROR,
1598                     "item index is out of range.");
1599   *infer_succ = infer_infos[item_index];
1600   return RET_OK;
1601 }
CheckAndGetCnodeIndex(const CNodePtr & cnode,size_t * index,const PrimitivePtr & primitive_type)1602 bool CheckAndGetCnodeIndex(const CNodePtr &cnode, size_t *index, const PrimitivePtr &primitive_type) {
1603   MS_CHECK_TRUE_RET(cnode != nullptr, false);
1604   MS_CHECK_TRUE_RET(index != nullptr, false);
1605   if (cnode->size() != kInputSizeThree) {
1606     return false;
1607   }
1608   size_t dst_index = 0;
1609   for (size_t i = 1; i < cnode->size(); ++i) {
1610     if (CheckPrimitiveType(cnode->input(i), primitive_type)) {
1611       dst_index = i;
1612       break;
1613     }
1614   }
1615   if (dst_index == 0) {
1616     return false;
1617   }
1618   *index = dst_index;
1619   return true;
1620 }
1621 
PrintFuncGraph(const FuncGraphPtr & func_graph,const std::string & output_file)1622 void PrintFuncGraph(const FuncGraphPtr &func_graph, const std::string &output_file) {
1623   if (func_graph == nullptr) {
1624     MS_LOG(WARNING) << "input func_graph is nullptr";
1625     return;
1626   }
1627   static int index = 0;
1628   auto real_file = std::to_string(index++) + "_" + output_file + ".txt";
1629   std::ofstream fp(real_file);
1630   if (!fp.is_open()) {
1631     MS_LOG(ERROR) << "Failed to create file " << real_file;
1632     return;
1633   }
1634   auto nodes = func_graph->TopoSort(func_graph->get_return());
1635   auto type_name = [](const AnfNodePtr &anf_node) -> std::string {
1636     if (anf_node->cast<CNodePtr>()) {
1637       return GetCNodeFuncName(anf_node->cast<CNodePtr>());
1638     } else if (anf_node->cast<ParameterPtr>()) {
1639       if (anf_node->cast<ParameterPtr>()->has_default()) {
1640         return "Parameter_Constant";
1641       } else {
1642         return "Parameter_Variable";
1643       }
1644     } else if (anf_node->cast<ValueNodePtr>()) {
1645       return "ValueNode";
1646     }
1647     return anf_node->ToString();
1648   };
1649   for (auto &node : nodes) {
1650     if (IsValueNode<Primitive>(node)) {
1651       continue;
1652     }
1653     auto cnode = node->cast<CNodePtr>();
1654     if (cnode == nullptr) {
1655       fp << node->fullname_with_scope() << ", type: " << type_name(node)
1656          << ", shape: " << GetAnfNodeOutputShape(node, 0) << std::endl;
1657       fp << std::endl;
1658       continue;
1659     }
1660     TypeId type_id = kTypeUnknown;
1661     GetDataTypeFromAnfNode(node, &type_id);
1662     fp << node->fullname_with_scope() << ", type: " << type_name(node) << ", shape: " << GetAnfNodeOutputShape(node, 0)
1663        << ", data type: " << static_cast<int>(type_id) << std::endl;
1664     auto &inputs = cnode->inputs();
1665     for (auto &input : inputs) {
1666       if (IsValueNode<Primitive>(input)) {
1667         continue;
1668       }
1669       type_id = kTypeUnknown;
1670       GetDataTypeFromAnfNode(node, &type_id);
1671       fp << "---input " << input->fullname_with_scope() << ", type: " << type_name(input)
1672          << ", shape: " << GetAnfNodeOutputShape(input, 0) << ", data type: " << static_cast<int>(type_id) << std::endl;
1673     }
1674     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1675     if (prim != nullptr) {
1676       for (auto &attr : prim->attrs()) {
1677         if (attr.second) {
1678           fp << "---attr " << attr.first << ": " << attr.second->ToString() << std::endl;
1679         } else {
1680           fp << "---attr " << attr.first << ": value nullptr" << std::endl;
1681         }
1682       }
1683     }
1684     fp << std::endl;
1685   }
1686 }
1687 
1688 #if !defined(_WIN32) && !defined(_WIN64)
GetNodeInputs(const AnfNodePtr & anf_node)1689 std::vector<KernelWithIndex> GetNodeInputs(const AnfNodePtr &anf_node) {
1690   if (!anf_node) {
1691     return {};
1692   }
1693   if (!anf_node->isa<CNode>()) {
1694     return {{anf_node, 0}};
1695   }
1696   auto cnode = anf_node->cast<CNodePtr>();
1697   std::vector<common::KernelWithIndex> inputs;
1698   size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
1699   for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
1700     const auto &pre_node_output = common::AnfAlgo::GetPrevNodeOutput(cnode, input_idx);
1701     auto pre_node = pre_node_output.first;
1702     if (opt::CheckPrimitiveType(pre_node, prim::kPrimMakeTuple) ||
1703         opt::CheckPrimitiveType(pre_node, prim::kPrimMakeTupleV2)) {
1704       auto tuple_inputs = GetNodeInputs(pre_node);
1705       std::copy(tuple_inputs.begin(), tuple_inputs.end(), std::back_inserter(inputs));
1706     } else {
1707       inputs.push_back(pre_node_output);
1708     }
1709   }
1710   return inputs;
1711 }
1712 #else
GetNodeInputs(const AnfNodePtr & anf_node)1713 std::vector<KernelWithIndex> GetNodeInputs(const AnfNodePtr &anf_node) { return {}; }
1714 #endif
1715 
IsReduceModeMeetOutEqualIn(const PrimitivePtr & prim)1716 bool IsReduceModeMeetOutEqualIn(const PrimitivePtr &prim) {
1717   if (prim == nullptr) {
1718     return false;
1719   }
1720   if (prim->GetAttr(ops::kMode) == nullptr) {
1721     return false;
1722   }
1723   auto mode = GetValue<int64_t>(prim->GetAttr(ops::kMode));
1724   std::set<int64_t> meet_mode = {Reduce_Mean, Reduce_Max, Reduce_Min, Reduce_Prod, Reduce_Sum};
1725   return meet_mode.find(mode) != meet_mode.end();
1726 }
1727 
AdjustInputToCnode(const CNodePtr & cnode,size_t input_index)1728 STATUS AdjustInputToCnode(const CNodePtr &cnode, size_t input_index) {
1729   auto func_graph = cnode->func_graph();
1730   if (func_graph == nullptr) {
1731     MS_LOG(ERROR) << "func graph is nullptr.";
1732     return RET_ERROR;
1733   }
1734   ops::TensorMove tensor_move;
1735   auto tensor_move_prim = tensor_move.GetPrim();
1736   if (tensor_move_prim == nullptr) {
1737     MS_LOG(ERROR) << "tensor move prim is nullptr.";
1738     return RET_ERROR;
1739   }
1740   auto tensor_move_cnode = func_graph->NewCNode(tensor_move_prim, {cnode->input(input_index)});
1741   if (tensor_move_cnode == nullptr) {
1742     MS_LOG(ERROR) << "new cnode failed.";
1743     return RET_ERROR;
1744   }
1745   auto manager = Manage(func_graph);
1746   if (manager == nullptr) {
1747     MS_LOG(ERROR) << "manager is nullptr.";
1748     return RET_ERROR;
1749   }
1750   tensor_move_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "tensor_move" +
1751                                              std::to_string(input_index));
1752   auto temp_abstract = cnode->input(input_index)->abstract()->Clone();
1753   if (temp_abstract == nullptr) {
1754     MS_LOG(ERROR) << "abstract clone failed.";
1755     return RET_ERROR;
1756   }
1757   tensor_move_cnode->set_abstract(temp_abstract);
1758   manager->SetEdge(cnode, input_index, tensor_move_cnode);
1759   return RET_OK;
1760 }
1761 
GetTensorFromParameterNode(const EquivPtr & equiv,const VarPtr & input)1762 tensor::TensorPtr GetTensorFromParameterNode(const EquivPtr &equiv, const VarPtr &input) {
1763   MS_CHECK_TRUE_RET(equiv != nullptr && input != nullptr, nullptr);
1764   auto node = utils::cast<AnfNodePtr>((*equiv)[input]);
1765   if (node == nullptr || !utils::isa<ParameterPtr>(node)) {
1766     MS_LOG(ERROR) << "node is nullptr or node is not a parameter node.";
1767     return nullptr;
1768   }
1769   auto parameter_node = node->cast<ParameterPtr>();
1770   if (!parameter_node->has_default() || parameter_node->default_param() == nullptr) {
1771     MS_LOG(ERROR) << "parameter_node has no default or its default_param() is nullptr.";
1772     return nullptr;
1773   }
1774   auto param_value_lite = parameter_node->default_param()->cast<tensor::TensorPtr>();
1775   return param_value_lite;
1776 }
1777 
GetIntParameterValue(const EquivPtr & equiv,const VarPtr & input)1778 const int GetIntParameterValue(const EquivPtr &equiv, const VarPtr &input) {
1779   MS_CHECK_TRUE_RET(equiv != nullptr && input != nullptr, INT_MIN);
1780   auto param_value_lite = GetTensorFromParameterNode(equiv, input);
1781   const int value = INT_MIN;
1782   if (param_value_lite == nullptr) {
1783     return value;
1784   }
1785   if (param_value_lite->data_type() != kNumberTypeInt32 && param_value_lite->data_type() != kNumberTypeInt) {
1786     return value;
1787   }
1788   if (param_value_lite->Size() != sizeof(int)) {
1789     return value;
1790   }
1791   return *static_cast<int *>(param_value_lite->data_c());
1792 }
1793 
GetFloatParameterValue(const EquivPtr & equiv,const VarPtr & input)1794 const float GetFloatParameterValue(const EquivPtr &equiv, const VarPtr &input) {
1795   const float value = -1;
1796   MS_CHECK_TRUE_RET(equiv != nullptr && input != nullptr, value);
1797   auto param_value_lite = GetTensorFromParameterNode(equiv, input);
1798   if (param_value_lite == nullptr) {
1799     return value;
1800   }
1801   if (param_value_lite->data_type() != kNumberTypeFloat32 && param_value_lite->data_type() != kNumberTypeFloat) {
1802     return value;
1803   }
1804   if (param_value_lite->Size() != sizeof(float)) {
1805     return value;
1806   }
1807   return *static_cast<float *>(param_value_lite->data_c());
1808 }
1809 
1810 };  // namespace opt
1811 }  // namespace mindspore
1812