• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/optimizer/common/gllo_utils.h"
17 #include <algorithm>
18 #include <vector>
19 #include <utility>
20 #include <unordered_map>
21 #include <functional>
22 #include <string>
23 #include "base/float16.h"
24 #include "ops/fusion/conv2d_fusion.h"
25 #include "ops/transpose.h"
26 #include "ops/gather.h"
27 #include "tools/converter/ops/ops_def.h"
28 #include "tools/common/tensor_util.h"
29 #include "frontend/operator/ops.h"
30 #include "backend/optimizer/common/helper.h"
31 #include "tools/converter/quant_param_holder.h"
32 #include "nnacl/op_base.h"
33 #include "src/common/log_util.h"
34 
35 namespace mindspore {
36 namespace opt {
37 namespace {
38 constexpr auto kAnfPrimitiveIndex = 0;
39 constexpr auto kDeviceTypeNone = -1;
DeduceDimConvertion(schema::Format src_format,schema::Format dst_format,std::vector<int> * perm)40 int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
41   MS_ASSERT(perm != nullptr);
42   auto src_format_str = std::string(schema::EnumNameFormat(src_format));
43   auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
44   if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
45     MS_LOG(ERROR) << "src_format or dst_format is error.";
46     return lite::RET_ERROR;
47   }
48   std::replace(src_format_str.begin(), src_format_str.end(), 'K', 'N');
49   std::replace(dst_format_str.begin(), dst_format_str.end(), 'K', 'N');
50   perm->clear();
51   std::unordered_map<char, int> dim_map;
52   for (size_t i = 0; i < src_format_str.size(); ++i) {
53     dim_map[src_format_str[i]] = i;
54   }
55   for (size_t i = 0; i < dst_format_str.size(); ++i) {
56     if (dim_map.find(dst_format_str[i]) == dim_map.end()) {
57       MS_LOG(ERROR) << "src_format and dst_format cannot match, please check.";
58       return RET_ERROR;
59     }
60     perm->push_back(dim_map[dst_format_str[i]]);
61   }
62   return lite::RET_OK;
63 }
64 
65 template <typename T>
TransposeData(const ShapeVector & origin_shape,const ShapeVector & cur_shape,const std::vector<int> & perm,T * weight_data,std::vector<T> * buf)66 void TransposeData(const ShapeVector &origin_shape, const ShapeVector &cur_shape, const std::vector<int> &perm,
67                    T *weight_data, std::vector<T> *buf) {
68   MS_ASSERT(weight_data != nullptr && buf != nullptr);
69   MS_ASSERT(origin_shape.size() == cur_shape.size() && cur_shape.size() == perm.size());
70   int count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int>());
71   ShapeVector post_multiply(cur_shape.size());
72   std::unordered_map<int, int> dim_map;
73   for (int i = cur_shape.size() - 1; i >= 0; --i) {
74     if (i == static_cast<int>(cur_shape.size() - 1)) {
75       post_multiply[i] = 1;
76     } else {
77       post_multiply[i] = cur_shape[i + 1] * post_multiply[i + 1];
78     }
79     dim_map[perm[i]] = i;
80   }
81   std::unordered_map<int, int> position_map;
82   for (int i = 0; i < count; ++i) {
83     int temp = i;
84     for (int j = static_cast<int>(origin_shape.size()) - 1; j >= 0; --j) {
85       MS_ASSERT(origin_shape[j] > 0);
86       position_map[j] = temp % origin_shape[j];
87       temp /= origin_shape[j];
88     }
89     int64_t new_pos = std::accumulate(position_map.begin(), position_map.end(), 0,
90                                       [&post_multiply, &dim_map](int64_t res, const std::pair<int, int> &pair_y) {
91                                         return res + post_multiply[dim_map[pair_y.first]] * pair_y.second;
92                                       });
93     buf->at(new_pos) = weight_data[i];
94   }
95 }
96 
97 template <typename T>
DoTransposeData(const tensor::TensorPtr & tensor,schema::Format src_format,schema::Format dst_format)98 STATUS DoTransposeData(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
99   MS_ASSERT(tensor != nullptr);
100   auto origin_shape = tensor->shape_c();
101   if (origin_shape.size() != kInputSizeFour) {
102     MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << origin_shape.size();
103     return lite::RET_ERROR;
104   }
105   if (std::any_of(origin_shape.begin(), origin_shape.end(), [](int64_t val) { return val <= 0; })) {
106     MS_LOG(ERROR) << "the tensor's shape is invalid.";
107     return lite::RET_ERROR;
108   }
109   std::vector<int> perm;
110   if (DeduceDimConvertion(src_format, dst_format, &perm) != RET_OK) {
111     MS_LOG(ERROR) << "deduce perm failed.";
112     return lite::RET_ERROR;
113   }
114   ShapeVector new_shape;
115   for (auto &val : perm) {
116     if (val < 0 || static_cast<size_t>(val) >= origin_shape.size()) {
117       MS_LOG(ERROR) << "deduce perm is invalid.";
118       return lite::RET_ERROR;
119     }
120     new_shape.push_back(origin_shape[val]);
121   }
122   auto count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1LL, std::multiplies<int64_t>());
123   if (count <= 0 || count > static_cast<int64_t>(INT32_MAX)) {
124     MS_LOG(ERROR) << "tensor element num is too big, which should be smaller than int32_max.";
125     return RET_ERROR;
126   }
127   std::vector<T> buf(count);
128 
129   void *originWeightData = tensor->data_c();
130   MS_CHECK_TRUE_RET(originWeightData != nullptr, RET_ERROR);
131   T *weightData = static_cast<T *>(originWeightData);
132   TransposeData<T>(origin_shape, new_shape, perm, weightData, &buf);
133   if (memcpy_s(tensor->data_c(), tensor->Size(), buf.data(), count * sizeof(T)) != EOK) {
134     MS_LOG(ERROR) << "memcpy_s failed.";
135     return RET_ERROR;
136   }
137   tensor->set_shape(new_shape);
138   return RET_OK;
139 }
140 
IsRealKernel(const AnfNodePtr & node)141 bool IsRealKernel(const AnfNodePtr &node) {
142   if (node == nullptr) {
143     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
144     return false;
145   }
146   // parameter and value node is not a real kernel too
147   if (!node->isa<CNode>()) {
148     return true;
149   }
150   auto cnode = node->cast<CNodePtr>();
151   if (cnode == nullptr) {
152     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
153     return false;
154   }
155   if (cnode->inputs().empty()) {
156     MS_LOG(ERROR) << "Illegal null input of cnode(%s)" << node->DebugString();
157     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR);
158     return false;
159   }
160   auto input = cnode->inputs()[0];
161 #ifndef ENABLE_SECURITY
162   bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
163                          IsPrimitive(input, prim::kPrimTensorSummary) ||
164                          IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
165                          IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
166                          IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) ||
167                          IsPrimitive(input, prim::kPrimPartial);
168 #else
169   bool is_virtual_node = IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) ||
170                          IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) ||
171                          IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
172 #endif
173   return !is_virtual_node;
174 }
175 
CreateValueNodeWithSexp(const BaseRef & sexp)176 ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
177   if (utils::isa<int>(sexp)) {
178     return NewValueNode(utils::cast<int>(sexp));
179   }
180   if (utils::isa<float>(sexp)) {
181     return NewValueNode(utils::cast<float>(sexp));
182   }
183   if (utils::isa<bool>(sexp)) {
184     return NewValueNode(utils::cast<bool>(sexp));
185   }
186   if (utils::isa<ValuePtr>(sexp)) {
187     return NewValueNode(utils::cast<ValuePtr>(sexp));
188   }
189   return nullptr;
190 }
191 
CreateCNodeWithGraph(const std::vector<AnfNodePtr> & input_nodes,const BaseRef & graph)192 CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
193   if (utils::isa<FuncGraphPtr>(graph)) {
194     return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
195   }
196   if (utils::isa<VarPtr>(graph)) {
197     return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
198   }
199   return nullptr;
200 }
201 
CreateVarNodeWithSexp(const BaseRef & sexp,const BaseRef & graph)202 VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
203   if (utils::isa<VarPtr>(graph)) {
204     MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
205     return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
206   }
207   if (utils::isa<FuncGraphPtr>(graph)) {
208     MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
209     return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
210   }
211   MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
212   return nullptr;
213 }
214 
HandleSexpVector(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)215 AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
216                             bool multigraph) {
217   if (primitive_vars == nullptr) {
218     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
219     return nullptr;
220   }
221   MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
222   std::vector<AnfNodePtr> input_nodes;
223   const auto &tuple = utils::cast<VectorRef>(sexp);
224   if (multigraph && utils::isa<VarPtr>(graph)) {
225     for (auto &x : tuple) {
226       auto is_var = std::make_shared<Var>("G");
227       MS_CHECK_TRUE_RET(is_var != nullptr, nullptr);
228       AnfNodePtr node = SexpToNode(x, is_var, primitive_vars, true);
229       input_nodes.push_back(node);
230     }
231     auto var_ptr = utils::cast<VarPtr>(graph);
232     return std::make_shared<CNode>(input_nodes, var_ptr);
233   }
234 
235   for (auto &x : tuple) {
236     AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
237     input_nodes.push_back(node);
238   }
239   return CreateCNodeWithGraph(input_nodes, graph);
240 }
241 
AnfEqualPrimitive(const AnfNodePtr & a_node,const AnfNodePtr & b_node)242 bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
243   auto a_value_node = a_node->cast<ValueNodePtr>();
244   auto b_value_node = b_node->cast<ValueNodePtr>();
245   if (a_value_node == nullptr || b_value_node == nullptr) {
246     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
247     return false;
248   }
249 
250   auto a_value = a_value_node->value();
251   auto b_value = b_value_node->value();
252   if (a_value == nullptr || b_value == nullptr) {
253     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
254     return false;
255   }
256 
257   auto a_prim = a_value->cast<PrimitivePtr>();
258   auto b_prim = b_value->cast<PrimitivePtr>();
259   if (a_prim == nullptr || b_prim == nullptr) {
260     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
261     return false;
262   }
263   return a_prim->name() == b_prim->name();
264 }
265 
AnfEqualValueNode(const AnfNodePtr & a_node,const AnfNodePtr & b_node)266 bool AnfEqualValueNode(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
267   auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
268   auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
269   if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) {
270     MS_LOG(ERROR) << "cast value node ptr fail";
271     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
272     return false;
273   }
274   auto a_value_ptr = a_value_node_ptr->value();
275   auto b_value_ptr = b_value_node_ptr->value();
276   if (a_value_ptr == nullptr || b_value_ptr == nullptr) {
277     MS_LOG(ERROR) << "value ptr is nullptr";
278     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
279     return false;
280   }
281 
282   if (utils::isa<ops::PrimitiveC>(a_value_ptr) && utils::isa<ops::PrimitiveC>(b_value_ptr)) {
283     auto a_obj = (ops::PrimitiveC *)(a_value_ptr.get());
284     auto b_obj = (ops::PrimitiveC *)(b_value_ptr.get());
285     return (*a_obj) == (*b_obj);
286   } else {
287     return (*a_value_ptr) == (*b_value_ptr);
288   }
289 }
290 }  // namespace
291 
CheckInputs(const CNodePtr & cnode)292 bool CheckInputs(const CNodePtr &cnode) {
293   if (cnode == nullptr) {
294     MS_LOG(ERROR) << "cnode is nullptr.";
295     return false;
296   }
297   if (std::any_of(cnode->inputs().begin(), cnode->inputs().end(),
298                   [](const AnfNodePtr &anf_node) { return anf_node == nullptr; })) {
299     MS_LOG(ERROR) << "input is nullptr.";
300     return false;
301   }
302   return true;
303 }
304 
CastToInt(const ValuePtr & value)305 std::vector<int> CastToInt(const ValuePtr &value) {
306   if (value == nullptr) {
307     MS_LOG(WARNING) << "valueptr is nullptr.";
308     return {};
309   }
310   std::vector<int> cur_value = {};
311   if (utils::isa<ValueSequeuePtr>(value)) {
312     if (!value->cast<ValueSequeuePtr>()->value().empty()) {
313       auto data_type = value->cast<ValueSequeuePtr>()->value().front()->type()->number_type();
314       if (data_type == kNumberTypeInt64) {
315         auto origin_value = GetValue<std::vector<int64_t>>(value);
316         std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
317                        [](int64_t index) { return static_cast<int>(index); });
318       } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
319         cur_value = GetValue<std::vector<int>>(value);
320       } else {
321         MS_LOG(ERROR) << "he function only process integer data.";
322         return {};
323       }
324     }
325   } else {
326     auto data_type = value->type()->number_type();
327     if (data_type == kNumberTypeInt64) {
328       cur_value.push_back(static_cast<int>(GetValue<int64_t>(value)));
329     } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
330       cur_value.push_back(GetValue<int>(value));
331     } else {
332       MS_LOG(ERROR) << "the function only process integer data.";
333       return {};
334     }
335   }
336   return cur_value;
337 }
338 
CastToVec2DInt(const ValuePtr & value)339 std::vector<std::vector<int>> CastToVec2DInt(const ValuePtr &value) {
340   if (value == nullptr) {
341     MS_LOG(WARNING) << "valueptr is nullptr.";
342     return {};
343   }
344 
345   std::vector<std::vector<int>> result_value;
346   if (utils::isa<ValueSequeuePtr>(value)) {
347     auto data_type =
348       value->cast<ValueSequeuePtr>()->value().front()->cast<ValueSequeuePtr>()->value().front()->type()->number_type();
349     if (data_type == kNumberTypeInt64) {
350       auto origin_value = GetValue<std::vector<std::vector<int64_t>>>(value);
351       for (auto &i : origin_value) {
352         std::vector<int> cur_value;
353         std::transform(i.begin(), i.end(), std::back_inserter(cur_value),
354                        [](int64_t j) { return static_cast<int>(j); });
355         result_value.push_back(cur_value);
356       }
357     } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
358       result_value = GetValue<std::vector<std::vector<int>>>(value);
359     } else {
360       MS_LOG(ERROR) << "he function only process integer data.";
361       return result_value;
362     }
363   }
364   return result_value;
365 }
366 
CastToFloat(const ValuePtr & value)367 std::vector<float> CastToFloat(const ValuePtr &value) {
368   if (value == nullptr) {
369     MS_LOG(WARNING) << "valueptr is nullptr.";
370     return {};
371   }
372   std::vector<float> cur_value = {};
373   if (utils::isa<ValueSequeuePtr>(value)) {
374     if (!value->cast<ValueSequeuePtr>()->value().empty()) {
375       auto data_type = value->cast<ValueSequeuePtr>()->value().front()->type()->number_type();
376       if (data_type == kNumberTypeFloat || data_type == kNumberTypeFloat32) {
377         cur_value = GetValue<std::vector<float>>(value);
378       } else {
379         MS_LOG(ERROR) << "the function only process float data.";
380         return {};
381       }
382     }
383   } else {
384     auto data_type = value->type()->number_type();
385     if (data_type == kNumberTypeFloat || data_type == kNumberTypeFloat32) {
386       cur_value.push_back(GetValue<float>(value));
387     } else {
388       MS_LOG(ERROR) << "the function only process float data.";
389       return {};
390     }
391   }
392   return cur_value;
393 }
394 
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)395 bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
396   if (node == nullptr || primitive_type == nullptr) {
397     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
398     return false;
399   }
400   if (node->isa<CNode>()) {
401     auto cnode = node->cast<CNodePtr>();
402     return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
403   } else if (node->isa<ValueNode>()) {
404     return IsPrimitive(node, primitive_type);
405   }
406   return false;
407 }
408 
AnfEqual(const BaseRef & a,const BaseRef & b)409 bool AnfEqual(const BaseRef &a, const BaseRef &b) {
410   if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
411     auto a_node = utils::cast<AnfNodePtr>(a);
412     auto b_node = utils::cast<AnfNodePtr>(b);
413     if (a_node == nullptr || b_node == nullptr) {
414       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
415       return false;
416     }
417     if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
418       return AnfEqualPrimitive(a_node, b_node);
419     }
420     if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
421       return AnfEqualValueNode(a_node, b_node);
422     }
423   }
424   if (a.m_ptr->isa<mindspore::ops::PrimitiveC>() && b.m_ptr->isa<mindspore::ops::PrimitiveC>()) {
425     auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
426     auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
427     return a_value_node_ptr->name() == b_value_node_ptr->name();
428   }
429 
430   return a == b;
431 }
432 
CNodeTypeEqual(const BaseRef & a,const BaseRef & b)433 bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
434   // To matchCNode and Kernel's type
435   if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
436     return true;
437   }
438   return a.type() == b.type();
439 }
440 
SexpToNode(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)441 AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
442   MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
443   if (primitive_vars == nullptr) {
444     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
445     return nullptr;
446   }
447   if (utils::isa<VectorRef>(sexp)) {
448     return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
449   }
450   if (utils::isa<VarPtr>(sexp)) {
451     auto var_ptr = utils::cast<VarPtr>(sexp);
452     if (var_ptr == nullptr) {
453       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
454       return nullptr;
455     }
456     if (var_ptr->primitive()) {
457       (*primitive_vars)[var_ptr->primitive()] = var_ptr;
458       return NewValueNode(var_ptr->primitive());
459     }
460     return CreateVarNodeWithSexp(sexp, graph);
461   }
462   if (utils::isa<AnfNodePtr>(sexp)) {
463     return utils::cast<AnfNodePtr>(sexp);
464   }
465   auto value_node = CreateValueNodeWithSexp(sexp);
466   if (value_node == nullptr) {
467     MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString();
468     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
469     return nullptr;
470   }
471   return value_node;
472 }
473 
IsOpType(const BaseRef & n,const PrimitivePtr & prim)474 bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
475   if (utils::isa<AnfNodePtr>(n)) {
476     auto anf_node = utils::cast<AnfNodePtr>(n);
477     return CheckPrimitiveType(anf_node, prim);
478   }
479   return false;
480 }
481 
IsRealCNodeKernel(const AnfNodePtr & node)482 bool IsRealCNodeKernel(const AnfNodePtr &node) {
483   if (node == nullptr) {
484     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
485     return false;
486   }
487   // parameter and value node is not a real cnode kernel
488   if (!node->isa<CNode>()) {
489     return false;
490   }
491   // return considered as a real node
492   if (CheckPrimitiveType(node, prim::kPrimReturn)) {
493     return true;
494   }
495   return IsRealKernel(node);
496 }
IsGraphKernel(const AnfNodePtr & node)497 bool IsGraphKernel(const AnfNodePtr &node) {
498   if (node == nullptr) {
499     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
500     return false;
501   }
502   // graph kernel should be a real cnode kernel.
503   if (!IsRealCNodeKernel(node)) {
504     return false;
505   }
506 
507   auto cnode = node->cast<CNodePtr>();
508   if (cnode == nullptr) {
509     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
510     return false;
511   }
512   auto input = cnode->input(kAnfPrimitiveIndex);
513   // graph kernel should has func_graph as first input.
514   if (!IsValueNode<FuncGraph>(input)) {
515     return false;
516   }
517 
518   auto func_graph = GetValueNode<FuncGraphPtr>(input);
519   if (func_graph == nullptr) {
520     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
521     return false;
522   }
523   return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
524 }
525 
AddNewBiasNode(float * bias_data,const FuncGraphPtr & func_graph,int kernel_num,TypeId type_id)526 ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) {
527   if (bias_data == nullptr || func_graph == nullptr) {
528     MS_LOG(ERROR) << "input parameter is nullptr.";
529     return nullptr;
530   }
531   auto bias_parameter = func_graph->add_parameter();
532   MS_ASSERT(bias_parameter != nullptr);
533   std::vector<int64_t> shape_vector = {kernel_num};
534   auto tensor_info =
535     lite::CreateTensorInfo(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t), shape_vector, type_id);
536   if (tensor_info == nullptr) {
537     MS_LOG(ERROR) << "create tensor info failed.";
538     return nullptr;
539   }
540   auto status = lite::InitParameterFromTensorInfo(bias_parameter, tensor_info);
541   if (status != RET_OK) {
542     MS_LOG(ERROR) << "init parameter from tensor info failed";
543     return nullptr;
544   }
545 
546   return bias_parameter;
547 }
548 
GetTensorInfo(const AnfNodePtr & node)549 tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node) {
550   MS_CHECK_TRUE_RET(node != nullptr, nullptr);
551   if (!utils::isa<ParameterPtr>(node)) {
552     if (utils::isa<ValueNodePtr>(node)) {
553       auto valueNode = node->cast<ValueNodePtr>();
554       auto value_ptr = valueNode->value();
555       MS_CHECK_TRUE_RET(value_ptr != nullptr, nullptr);
556       auto value = value_ptr->cast<tensor::TensorPtr>();
557       if (value != nullptr) {
558         return value;
559       }
560     }
561     MS_LOG(DEBUG) << "get lite param value node neither parameternode or valuenode";
562     return nullptr;
563   }
564   auto param = node->cast<ParameterPtr>();
565   MS_ASSERT(param != nullptr);
566   if (!param->has_default() || param->default_param() == nullptr) {
567     return nullptr;
568   }
569   auto tensor_info = param->default_param()->cast<tensor::TensorPtr>();
570   return tensor_info;
571 }
572 
GetCNodeInputAbstract(const CNodePtr & cnode,size_t index)573 AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index) {
574   if (cnode == nullptr) {
575     MS_LOG(ERROR) << "CNodePtr is nullptr";
576     return nullptr;
577   }
578   auto inputs = cnode->inputs();
579   if (!(index > 0 && index < inputs.size())) {
580     return nullptr;
581   }
582   auto input = inputs[index];
583   if (input == nullptr) {
584     MS_LOG(ERROR) << "CNode input is nullptr";
585     return nullptr;
586   }
587 
588   AbstractBasePtr abstract = nullptr;
589   if (utils::isa<ParameterPtr>(input)) {
590     auto parameter = input->cast<ParameterPtr>();
591     abstract = parameter->abstract();
592   } else if (utils::isa<CNodePtr>(input)) {
593     auto input_cnode = input->cast<CNodePtr>();
594     if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
595       auto tuple_inputs = input_cnode->inputs();
596       MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
597       auto get_item_input_cnode = tuple_inputs.at(1);
598       MS_ASSERT(get_item_input_cnode != nullptr);
599       auto idx = GetTupleGetItemOutIndex(input_cnode);
600       if (!utils::isa<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
601         MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
602         return nullptr;
603       }
604       auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
605       auto abstract_list = abstract_tuple->elements();
606       if (abstract_list.size() <= idx) {
607         MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
608         return nullptr;
609       }
610       abstract = abstract_list[idx];
611     } else {
612       abstract = input_cnode->abstract();
613     }
614   } else {
615     MS_LOG(ERROR) << "unsupported input node type";
616     return nullptr;
617   }
618   return abstract;
619 }
620 
IsParamNode(const BaseRef & n)621 bool IsParamNode(const BaseRef &n) {
622   if (!utils::isa<ParameterPtr>(n)) {
623     return false;
624   }
625   auto parameter = utils::cast<ParameterPtr>(n);
626   if (!parameter->has_default() || parameter->default_param() == nullptr) {
627     return false;
628   }
629   auto tensor = parameter->default_param()->cast<tensor::TensorPtr>();
630   if (tensor == nullptr) {
631     return false;
632   }
633   return tensor->data_c() != nullptr;
634 }
635 
GetTensorInfoFromAbstract(tensor::TensorPtr * tensor_info,const CNodePtr & cnode,size_t index)636 STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index) {
637   CHECK_NULL_RETURN(tensor_info);
638   CHECK_NULL_RETURN(cnode);
639   AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index);
640   if (abstract == nullptr) {
641     MS_LOG(WARNING) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr, infershape is delayed.";
642     return RET_ERROR;
643   }
644   if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
645     MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor";
646     return RET_ERROR;
647   }
648   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
649   if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) {  // input node not complete infershape
650     MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
651     return RET_ERROR;
652   }
653   *tensor_info = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack());
654   if (*tensor_info == nullptr) {
655     MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr";
656     return RET_ERROR;
657   }
658   return RET_OK;
659 }
660 
IsParamOrValueNodeWithData(const BaseRef & n)661 bool IsParamOrValueNodeWithData(const BaseRef &n) {
662   if (utils::isa<ValueNode>(n)) {
663     auto value_node = utils::cast<ValueNodePtr>(n);
664     auto value = value_node->value();
665     if (value != nullptr && value->isa<tensor::Tensor>()) {
666       auto tensor = value->cast<tensor::TensorPtr>();
667       if (tensor == nullptr || tensor->data_c() == nullptr) {
668         return false;
669       }
670       return true;
671     } else {
672       return false;
673     }
674   }
675   if (utils::isa<ParameterPtr>(n)) {
676     return IsParamNode(n);
677   }
678   return false;
679 }
680 
IsParallelSplitConvNode(const BaseRef & n)681 bool IsParallelSplitConvNode(const BaseRef &n) {
682   if (utils::isa<AnfNodePtr>(n)) {
683     auto anf_node = utils::cast<AnfNodePtr>(n);
684     PrimitivePtr prim;
685     if (utils::isa<CNodePtr>(anf_node)) {
686       prim = GetValueNode<PrimitivePtr>(anf_node->cast<CNodePtr>()->input(kAnfPrimitiveIndex));
687     }
688     if (utils::isa<ValueNodePtr>(anf_node)) {
689       prim = GetValueNode<PrimitivePtr>(anf_node);
690     }
691     if (prim == nullptr) {
692       return false;
693     }
694     int device_type =
695       prim->GetAttr(ops::kDeviceType) != nullptr ? GetValue<int32_t>(prim->GetAttr(ops::kDeviceType)) : kDeviceTypeNone;
696     if (device_type != kDeviceTypeNone) {
697       return false;
698     }
699     return CheckPrimitiveType(anf_node, prim::kPrimConv2DFusion) || CheckPrimitiveType(anf_node, prim::kPrimConv2D);
700   }
701   return false;
702 }
703 
IsConvNode(const BaseRef & n)704 bool IsConvNode(const BaseRef &n) {
705   if (utils::isa<AnfNodePtr>(n)) {
706     auto anf_node = utils::cast<AnfNodePtr>(n);
707     PrimitivePtr prim;
708     if (utils::isa<CNodePtr>(anf_node)) {
709       prim = GetValueNode<PrimitivePtr>(anf_node->cast<CNodePtr>()->input(kAnfPrimitiveIndex));
710     }
711     if (utils::isa<ValueNodePtr>(anf_node)) {
712       prim = GetValueNode<PrimitivePtr>(anf_node);
713     }
714     if (prim == nullptr) {
715       return false;
716     }
717 
718     if (prim->GetAttr(ops::kActivationType) != nullptr &&
719         GetValue<int64_t>(prim->GetAttr(ops::kActivationType)) != NO_ACTIVATION) {
720       return false;
721     }
722 
723     bool is_depth_wise =
724       prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
725     return CheckPrimitiveType(anf_node, prim::kPrimConv2DFusion) ||
726            (CheckPrimitiveType(anf_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise);
727   }
728   return false;
729 }
730 
CheckIsAllInputsParam(const AnfNodePtr & node)731 bool CheckIsAllInputsParam(const AnfNodePtr &node) {
732   if (node == nullptr) {
733     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
734     return false;
735   }
736   if (utils::isa<CNode>(node)) {
737     auto cnode = node->cast<CNodePtr>();
738     for (size_t i = 1; i < cnode->inputs().size(); i++) {
739       if (!utils::isa<Parameter>(cnode->input(i)) && !utils::isa<ValueNodePtr>(cnode->input(i))) {
740         return false;
741       }
742     }
743     return true;
744   }
745   return false;
746 }
747 
GetOutputTensorNum(const AnfNodePtr & node)748 size_t GetOutputTensorNum(const AnfNodePtr &node) {
749   if (node == nullptr) {
750     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
751     return 0;
752   }
753   auto type = node->Type();
754   if (type == nullptr) {
755     return 1;
756   }
757   if (type->isa<Tuple>()) {
758     auto tuple_type = type->cast<TuplePtr>();
759     if (tuple_type == nullptr) {
760       lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
761       return 0;
762     }
763     return tuple_type->size();
764   } else if (type->isa<TensorType>() || type->isa<Number>()) {
765     return 1;
766   } else if (type->isa<TypeNone>()) {
767     return 0;
768   } else {
769     return 1;
770   }
771 }
772 
IsMultiOutputTensors(const FuncGraphPtr & graph,const AnfNodePtr & node)773 bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
774   if (graph == nullptr || node == nullptr) {
775     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
776     return false;
777   }
778   auto output_node_list = GetRealNodeUsedList(graph, node);
779   if (output_node_list == nullptr) {
780     MS_LOG(ERROR) << "output node list is nullptr";
781     return false;
782   }
783   if (output_node_list->size() != 1) {
784     MS_LOG(DEBUG) << "fusion node has multi output nodes";
785     return true;
786   }
787   return false;
788 }
789 
GetRealNodeUsedList(const FuncGraphPtr & graph,const AnfNodePtr & node)790 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
791                                                                              const AnfNodePtr &node) {
792   if (graph == nullptr || node == nullptr) {
793     MS_LOG(ERROR) << "input parameter is nullptr.";
794     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
795     return nullptr;
796   }
797   auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
798   MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
799   auto manager = graph->manager();
800   if (manager == nullptr) {
801     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
802     return nullptr;
803   }
804   auto iter = manager->node_users().find(node);
805   if (iter == manager->node_users().end()) {
806     MS_LOG(ERROR) << "node has no output in manager";
807     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
808     return nullptr;
809   }
810   auto output_info_list = iter->second;
811   std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
812   return output_node_list;
813 }
814 
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)815 size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
816   if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) {
817     MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
818     return -1;
819   }
820   auto output_index_value_node = tuple_get_item->input(kInputIndexTwo);
821   if (output_index_value_node == nullptr) {
822     MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
823     return -1;
824   }
825   auto value_node = output_index_value_node->cast<ValueNodePtr>();
826   if (value_node == nullptr) {
827     MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
828     return -1;
829   }
830   auto indexes = CastToInt(value_node->value());
831   if (indexes.empty()) {
832     MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
833     return -1;
834   }
835   return indexes.front();
836 }
837 
GetRealNodeUsedListByOutputIdx(const FuncGraphPtr & graph,const AnfNodePtr & node,size_t output_index)838 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
839                                                                                         const AnfNodePtr &node,
840                                                                                         size_t output_index) {
841   if (graph == nullptr || node == nullptr) {
842     MS_LOG(ERROR) << "input parameter is nullptr.";
843     return nullptr;
844   }
845   auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
846   MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
847   auto manager = graph->manager();
848   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
849   auto iter = manager->node_users().find(node);
850   if (iter == manager->node_users().end()) {
851     MS_LOG(ERROR) << "node has no output in manager";
852     return output_node_list;
853   }
854   auto output_info_list = iter->second;
855   for (const auto &output_info : output_info_list) {
856     size_t used_output_index;
857     if (CheckPrimitiveType(output_info.first, prim::kPrimTupleGetItem)) {
858       used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
859     } else if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
860       used_output_index = output_index;
861     } else {
862       if (output_index != 0) {
863         MS_LOG(ERROR) << "node has no output in manager";
864         return output_node_list;
865       }
866       return output_node_list;
867     }
868     if (used_output_index == output_index) {
869       output_node_list->push_back(output_info);
870     }
871   }
872   return output_node_list;
873 }
874 
TransFilterFormat(const tensor::TensorPtr & tensor,schema::Format src_format,schema::Format dst_format)875 STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
876   MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
877   std::unordered_map<TypeId, std::function<STATUS(const tensor::TensorPtr &, schema::Format, schema::Format)>>
878     trans_func = {{kNumberTypeFloat32, DoTransposeData<float>},
879                   {kNumberTypeUInt8, DoTransposeData<uint8_t>},
880                   {kNumberTypeInt8, DoTransposeData<int8_t>},
881                   {kNumberTypeFloat16, DoTransposeData<float16>}};
882   auto data_type = tensor->data_type();
883   auto iter = trans_func.find(data_type);
884   if (iter == trans_func.end()) {
885     MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
886     return RET_ERROR;
887   }
888   return iter->second(tensor, src_format, dst_format);
889 }
890 
BuildParameterNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const tensor::TensorPtr & tensor_info)891 ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
892                                 const tensor::TensorPtr &tensor_info) {
893   if (func_graph == nullptr || node == nullptr || tensor_info == nullptr) {
894     MS_LOG(ERROR) << "input parameter is nullptr.";
895     return nullptr;
896   }
897   auto param_node = func_graph->add_parameter();
898   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
899   auto shape = tensor_info->shape();
900   std::vector<int64_t> shape_vector;
901   std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
902                  [](const int &val) { return static_cast<int64_t>(val); });
903   auto data_type = tensor_info->data_type() == kNumberTypeInt64 ? kNumberTypeInt32 : tensor_info->data_type();
904   param_node->set_name(node->fullname_with_scope());
905   auto tensor_info_new = std::make_shared<tensor::Tensor>(data_type, shape_vector);
906   if (tensor_info_new == nullptr) {
907     MS_LOG(ERROR) << "new tensor::Tensor failed.";
908     return nullptr;
909   }
910   size_t data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
911   if (tensor_info->Size() == 0) {
912     auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
913     if (status != RET_OK) {
914       MS_LOG(ERROR) << "init parameter from tensor info failed";
915       return nullptr;
916     }
917     return param_node;
918   }
919   if (tensor_info->data_type() == kNumberTypeInt64) {
920     auto *tensor_data = reinterpret_cast<int *>(tensor_info_new->data_c());
921     if (tensor_data == nullptr) {
922       MS_LOG(ERROR) << "new data failed";
923       return nullptr;
924     }
925     auto *origin_data = reinterpret_cast<int64_t *>(tensor_info->data_c());
926     for (size_t i = 0; i < data_count; ++i) {
927       if (origin_data[i] > static_cast<int64_t>(INT32_MAX) || origin_data[i] < static_cast<int64_t>(INT32_MIN)) {
928         MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32";
929         tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN;
930       } else {
931         tensor_data[i] = static_cast<int>(origin_data[i]);
932       }
933     }
934   } else {
935     tensor_info_new->set_data_type(tensor_info->data_type());
936     auto *tensor_data = reinterpret_cast<int8_t *>(tensor_info_new->data_c());
937     if (tensor_data == nullptr) {
938       MS_LOG(ERROR) << "new data failed";
939       return nullptr;
940     }
941     if (memcpy_s(tensor_data, tensor_info_new->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) {
942       MS_LOG(ERROR) << "memcpy data failed.";
943       return nullptr;
944     }
945   }
946   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
947   if (status != RET_OK) {
948     MS_LOG(ERROR) << "init parameter from tensor info failed";
949     return nullptr;
950   }
951   param_node->set_default_param(tensor_info_new);
952   return param_node;
953 }
954 
BuildIntValueParameterNode(const FuncGraphPtr & func_graph,const int32_t & data,const std::string & node_name)955 ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data,
956                                         const std::string &node_name) {
957   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
958   auto param_node = func_graph->add_parameter();
959   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
960   param_node->set_name(node_name);
961 
962   auto tensor_info = lite::CreateTensorInfo(&data, sizeof(int32_t), {1}, kNumberTypeInt32);
963   if (tensor_info == nullptr) {
964     MS_LOG(ERROR) << "Create tensor info failed";
965     return nullptr;
966   }
967 
968   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
969   if (status != RET_OK) {
970     MS_LOG(ERROR) << "init parameter from tensor info failed";
971     return nullptr;
972   }
973   return param_node;
974 }
975 
BuildIntVecParameterNode(const FuncGraphPtr & func_graph,const std::vector<int32_t> & data,const std::string & node_name)976 ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
977                                       const std::string &node_name) {
978   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
979   auto param_node = func_graph->add_parameter();
980   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
981   param_node->set_name(node_name);
982 
983   std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
984   auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(int32_t), shape_vector, kNumberTypeInt32);
985   if (tensor_info == nullptr) {
986     MS_LOG(ERROR) << "Create tensor info failed";
987     return nullptr;
988   }
989 
990   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
991   if (status != RET_OK) {
992     MS_LOG(ERROR) << "init parameter from tensor info failed";
993     return nullptr;
994   }
995 
996   return param_node;
997 }
998 
BuildIntVec2DParameterNode(const FuncGraphPtr & func_graph,const std::vector<std::vector<int32_t>> & data,const std::string & node_name)999 ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<int32_t>> &data,
1000                                         const std::string &node_name) {
1001   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1002   auto param_node = func_graph->add_parameter();
1003   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1004   param_node->set_name(node_name);
1005 
1006   std::vector<int64_t> shape_vector;
1007   shape_vector.push_back(data.size());
1008   shape_vector.push_back(2);
1009 
1010   std::vector<int32_t> data_1d;
1011   for (auto pair : data) {
1012     data_1d.insert(data_1d.end(), pair.begin(), pair.end());
1013   }
1014 
1015   auto size = data_1d.size() * sizeof(int32_t);
1016   auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeInt32);
1017   if (tensor_info == nullptr) {
1018     MS_LOG(ERROR) << "Create tensor info failed";
1019     return nullptr;
1020   }
1021   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1022   if (status != RET_OK) {
1023     MS_LOG(ERROR) << "init parameter from tensor info failed";
1024     return nullptr;
1025   }
1026   return param_node;
1027 }
1028 
BuildFloatValueParameterNode(const FuncGraphPtr & func_graph,const float & data,const std::string & node_name)1029 ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
1030                                           const std::string &node_name) {
1031   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1032   auto param_node = func_graph->add_parameter();
1033   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1034   param_node->set_name(node_name);
1035 
1036   auto tensor_info = lite::CreateTensorInfo(&data, sizeof(float), {1}, kNumberTypeFloat32);
1037   if (tensor_info == nullptr) {
1038     MS_LOG(ERROR) << "Create tensor info failed";
1039     return nullptr;
1040   }
1041   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1042   if (status != RET_OK) {
1043     MS_LOG(ERROR) << "init parameter from tensor info failed";
1044     return nullptr;
1045   }
1046   return param_node;
1047 }
1048 
BuildFloatVecParameterNode(const FuncGraphPtr & func_graph,const std::vector<float> & data,const std::string & node_name)1049 ParameterPtr BuildFloatVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<float> &data,
1050                                         const std::string &node_name) {
1051   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1052   auto param_node = func_graph->add_parameter();
1053   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1054   param_node->set_name(node_name);
1055 
1056   std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
1057   auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(float), shape_vector, kNumberTypeFloat);
1058   if (tensor_info == nullptr) {
1059     MS_LOG(ERROR) << "Create tensor info failed";
1060     return nullptr;
1061   }
1062 
1063   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1064   if (status != RET_OK) {
1065     MS_LOG(ERROR) << "init parameter from tensor info failed";
1066     return nullptr;
1067   }
1068 
1069   return param_node;
1070 }
1071 
GenTransposeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & perm,const std::string & cnode_name)1072 CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &perm,
1073                           const std::string &cnode_name) {
1074   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1075   MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
1076   auto perm_node = BuildIntVecParameterNode(func_graph, perm, cnode_name + "_perm");
1077   MS_ASSERT(perm_node != nullptr);
1078   auto trans_prim = std::make_shared<ops::Transpose>();
1079   MS_CHECK_TRUE_RET(trans_prim != nullptr, nullptr);
1080   auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node});
1081   MS_ASSERT(cnode != nullptr);
1082   auto manager = Manage(func_graph);
1083   MS_ASSERT(manager != nullptr);
1084   manager->SetEdge(cnode, 1, input_node);
1085   manager->SetEdge(cnode, kInputIndexTwo, perm_node);
1086   cnode->set_fullname_with_scope(cnode_name);
1087   auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeTwo, 1);
1088   MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1089   trans_prim->AddAttr("quant_params", quant_params_holder);
1090   return cnode;
1091 }
1092 
GenGatherNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & indices,const std::string & cnode_name)1093 CNodePtr GenGatherNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &indices,
1094                        const std::string &cnode_name) {
1095   if (func_graph == nullptr || input_node == nullptr) {
1096     MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1097     return nullptr;
1098   }
1099   auto indices_node = BuildIntVecParameterNode(func_graph, indices, cnode_name + "_indices");
1100   if (indices_node == nullptr) {
1101     MS_LOG(ERROR) << "make indices node failed.";
1102     return nullptr;
1103   }
1104   auto axis_node = BuildIntVecParameterNode(func_graph, {0}, cnode_name + "_indices");
1105   if (axis_node == nullptr) {
1106     MS_LOG(ERROR) << "make indices node failed.";
1107     return nullptr;
1108   }
1109   auto gather_prim = std::make_shared<ops::Gather>();
1110   MS_CHECK_TRUE_RET(gather_prim != nullptr, nullptr);
1111   auto cnode = func_graph->NewCNode(gather_prim, {input_node, indices_node, axis_node});
1112   MS_ASSERT(cnode != nullptr);
1113   auto manager = Manage(func_graph);
1114   MS_ASSERT(manager != nullptr);
1115   manager->SetEdge(cnode, 1, input_node);
1116   manager->SetEdge(cnode, kInputIndexTwo, indices_node);
1117   manager->SetEdge(cnode, kInputIndexThree, axis_node);
1118   cnode->set_fullname_with_scope(cnode_name);
1119   auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeThree, 1);
1120   MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1121   gather_prim->AddAttr("quant_params", quant_params_holder);
1122   return cnode;
1123 }
1124 
GenTupleGetItemNode(const FuncGraphPtr & func_graph,const CNodePtr & input,size_t index)1125 CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index) {
1126   if (func_graph == nullptr || input == nullptr) {
1127     MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1128     return nullptr;
1129   }
1130   auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>();
1131   MS_CHECK_TRUE_RET(tuple_get_item_prim != nullptr, nullptr);
1132   auto second_input = NewValueNode(MakeValue<int>(index));
1133   MS_CHECK_TRUE_RET(second_input != nullptr, nullptr);
1134   auto tuple_cnode = func_graph->NewCNode(tuple_get_item_prim, {input, second_input});
1135   MS_ASSERT(tuple_cnode != nullptr);
1136   tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index));
1137   return tuple_cnode;
1138 }
1139 
FetchShapeFromAbstract(const abstract::AbstractBasePtr & abstract,ShapeVector * shape)1140 STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape) {
1141   if (abstract == nullptr || shape == nullptr) {
1142     MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1143     return lite::RET_ERROR;
1144   }
1145   if (!utils::isa<abstract::AbstractTensor>(abstract)) {
1146     MS_LOG(ERROR) << "abstract of cnode is invalid.";
1147     return lite::RET_ERROR;
1148   }
1149   auto abstract_tensor = abstract->cast<abstract::AbstractTensorPtr>();
1150   if (abstract_tensor->BuildShape() == nullptr || !utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
1151     MS_LOG(ERROR) << "shape of cnode's output is invalid.";
1152     return lite::RET_ERROR;
1153   }
1154   *shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
1155   return lite::RET_OK;
1156 }
1157 
IsTrainOp(const CNodePtr & cnode)1158 bool IsTrainOp(const CNodePtr &cnode) {
1159   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1160   auto cnode_type = prim->type_name();
1161   // optimizer op
1162   if (cnode_type == "Adam" || cnode_type == "SGD" || cnode_type == "ApplyMomentum") {
1163     return true;
1164   }
1165   // loss op
1166   if (cnode_type == "SoftmaxCrossEntropyWithLogits" || cnode_type == "SpareSoftmaxCrossEntropyWithLogits" ||
1167       cnode_type == "SmoothL1Loss" || cnode_type == "SmoothL1LossGrad" ||
1168       cnode_type == "SigmoidCrossEntropyWithLogits" || cnode_type == "SigmoidCrossEntropyWithLogpitsGrad") {
1169     return true;
1170   }
1171   // grad op
1172   if (cnode_type.find("Grad") != std::string::npos ||
1173       cnode->fullname_with_scope().find("Gradients") != std::string::npos) {
1174     return true;
1175   }
1176   return false;
1177 }
1178 
IsMarkedTrainOp(const CNodePtr & cnode)1179 bool IsMarkedTrainOp(const CNodePtr &cnode) {
1180   if (cnode == nullptr) {
1181     return false;
1182   }
1183   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1184   MS_CHECK_TRUE_RET(prim != nullptr, false);
1185   if (prim->GetAttr("trainOp") != nullptr && GetValue<bool>(prim->GetAttr("trainOp"))) {
1186     MS_LOG(DEBUG) << "train op not fusion.";
1187     return true;
1188   }
1189   return false;
1190 }
1191 
GetDataTypeFromAnfNode(const AnfNodePtr & anf_node,TypeId * type_id)1192 int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
1193   if (anf_node == nullptr || type_id == nullptr) {
1194     MS_LOG(ERROR) << "anf_node or type_id is nullptr.";
1195     return RET_ERROR;
1196   }
1197   auto abstract_base = anf_node->abstract();
1198   // used for multi output e.g. split.
1199   if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
1200     auto abstract_tuple = abstract_base->cast<abstract::AbstractTuplePtr>();
1201     if (abstract_tuple->elements().empty()) {
1202       MS_LOG(ERROR) << "abstract_tuple elements is empty.";
1203       return RET_ERROR;
1204     }
1205     abstract_base = abstract_tuple->elements().front();
1206   }
1207   if (abstract_base == nullptr) {
1208     MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << anf_node->fullname_with_scope();
1209     return RET_ERROR;
1210   }
1211   if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
1212     MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << anf_node->fullname_with_scope();
1213     return RET_ERROR;
1214   }
1215   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
1216   auto type_ptr = abstract_tensor->element()->GetTypeTrack();
1217   MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
1218   *type_id = type_ptr->type_id();
1219   return RET_OK;
1220 }
1221 }  // namespace opt
1222 }  // namespace mindspore
1223