1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/optimizer/common/gllo_utils.h"
17 #include <algorithm>
18 #include <vector>
19 #include <utility>
20 #include <unordered_map>
21 #include <functional>
22 #include <string>
23 #include "base/float16.h"
24 #include "ops/fusion/conv2d_fusion.h"
25 #include "ops/transpose.h"
26 #include "ops/gather.h"
27 #include "tools/converter/ops/ops_def.h"
28 #include "tools/common/tensor_util.h"
29 #include "frontend/operator/ops.h"
30 #include "backend/optimizer/common/helper.h"
31 #include "tools/converter/quant_param_holder.h"
32 #include "nnacl/op_base.h"
33 #include "src/common/log_util.h"
34
35 namespace mindspore {
36 namespace opt {
37 namespace {
38 constexpr auto kAnfPrimitiveIndex = 0;
39 constexpr auto kDeviceTypeNone = -1;
DeduceDimConvertion(schema::Format src_format,schema::Format dst_format,std::vector<int> * perm)40 int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
41 MS_ASSERT(perm != nullptr);
42 auto src_format_str = std::string(schema::EnumNameFormat(src_format));
43 auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
44 if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
45 MS_LOG(ERROR) << "src_format or dst_format is error.";
46 return lite::RET_ERROR;
47 }
48 std::replace(src_format_str.begin(), src_format_str.end(), 'K', 'N');
49 std::replace(dst_format_str.begin(), dst_format_str.end(), 'K', 'N');
50 perm->clear();
51 std::unordered_map<char, int> dim_map;
52 for (size_t i = 0; i < src_format_str.size(); ++i) {
53 dim_map[src_format_str[i]] = i;
54 }
55 for (size_t i = 0; i < dst_format_str.size(); ++i) {
56 if (dim_map.find(dst_format_str[i]) == dim_map.end()) {
57 MS_LOG(ERROR) << "src_format and dst_format cannot match, please check.";
58 return RET_ERROR;
59 }
60 perm->push_back(dim_map[dst_format_str[i]]);
61 }
62 return lite::RET_OK;
63 }
64
65 template <typename T>
TransposeData(const ShapeVector & origin_shape,const ShapeVector & cur_shape,const std::vector<int> & perm,T * weight_data,std::vector<T> * buf)66 void TransposeData(const ShapeVector &origin_shape, const ShapeVector &cur_shape, const std::vector<int> &perm,
67 T *weight_data, std::vector<T> *buf) {
68 MS_ASSERT(weight_data != nullptr && buf != nullptr);
69 MS_ASSERT(origin_shape.size() == cur_shape.size() && cur_shape.size() == perm.size());
70 int count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int>());
71 ShapeVector post_multiply(cur_shape.size());
72 std::unordered_map<int, int> dim_map;
73 for (int i = cur_shape.size() - 1; i >= 0; --i) {
74 if (i == static_cast<int>(cur_shape.size() - 1)) {
75 post_multiply[i] = 1;
76 } else {
77 post_multiply[i] = cur_shape[i + 1] * post_multiply[i + 1];
78 }
79 dim_map[perm[i]] = i;
80 }
81 std::unordered_map<int, int> position_map;
82 for (int i = 0; i < count; ++i) {
83 int temp = i;
84 for (int j = static_cast<int>(origin_shape.size()) - 1; j >= 0; --j) {
85 MS_ASSERT(origin_shape[j] > 0);
86 position_map[j] = temp % origin_shape[j];
87 temp /= origin_shape[j];
88 }
89 int64_t new_pos = std::accumulate(position_map.begin(), position_map.end(), 0,
90 [&post_multiply, &dim_map](int64_t res, const std::pair<int, int> &pair_y) {
91 return res + post_multiply[dim_map[pair_y.first]] * pair_y.second;
92 });
93 buf->at(new_pos) = weight_data[i];
94 }
95 }
96
97 template <typename T>
DoTransposeData(const tensor::TensorPtr & tensor,schema::Format src_format,schema::Format dst_format)98 STATUS DoTransposeData(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
99 MS_ASSERT(tensor != nullptr);
100 auto origin_shape = tensor->shape_c();
101 if (origin_shape.size() != kInputSizeFour) {
102 MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << origin_shape.size();
103 return lite::RET_ERROR;
104 }
105 if (std::any_of(origin_shape.begin(), origin_shape.end(), [](int64_t val) { return val <= 0; })) {
106 MS_LOG(ERROR) << "the tensor's shape is invalid.";
107 return lite::RET_ERROR;
108 }
109 std::vector<int> perm;
110 if (DeduceDimConvertion(src_format, dst_format, &perm) != RET_OK) {
111 MS_LOG(ERROR) << "deduce perm failed.";
112 return lite::RET_ERROR;
113 }
114 ShapeVector new_shape;
115 for (auto &val : perm) {
116 if (val < 0 || static_cast<size_t>(val) >= origin_shape.size()) {
117 MS_LOG(ERROR) << "deduce perm is invalid.";
118 return lite::RET_ERROR;
119 }
120 new_shape.push_back(origin_shape[val]);
121 }
122 auto count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1LL, std::multiplies<int64_t>());
123 if (count <= 0 || count > static_cast<int64_t>(INT32_MAX)) {
124 MS_LOG(ERROR) << "tensor element num is too big, which should be smaller than int32_max.";
125 return RET_ERROR;
126 }
127 std::vector<T> buf(count);
128
129 void *originWeightData = tensor->data_c();
130 MS_CHECK_TRUE_RET(originWeightData != nullptr, RET_ERROR);
131 T *weightData = static_cast<T *>(originWeightData);
132 TransposeData<T>(origin_shape, new_shape, perm, weightData, &buf);
133 if (memcpy_s(tensor->data_c(), tensor->Size(), buf.data(), count * sizeof(T)) != EOK) {
134 MS_LOG(ERROR) << "memcpy_s failed.";
135 return RET_ERROR;
136 }
137 tensor->set_shape(new_shape);
138 return RET_OK;
139 }
140
IsRealKernel(const AnfNodePtr & node)141 bool IsRealKernel(const AnfNodePtr &node) {
142 if (node == nullptr) {
143 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
144 return false;
145 }
146 // parameter and value node is not a real kernel too
147 if (!node->isa<CNode>()) {
148 return true;
149 }
150 auto cnode = node->cast<CNodePtr>();
151 if (cnode == nullptr) {
152 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
153 return false;
154 }
155 if (cnode->inputs().empty()) {
156 MS_LOG(ERROR) << "Illegal null input of cnode(%s)" << node->DebugString();
157 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR);
158 return false;
159 }
160 auto input = cnode->inputs()[0];
161 #ifndef ENABLE_SECURITY
162 bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
163 IsPrimitive(input, prim::kPrimTensorSummary) ||
164 IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
165 IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
166 IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) ||
167 IsPrimitive(input, prim::kPrimPartial);
168 #else
169 bool is_virtual_node = IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) ||
170 IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) ||
171 IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
172 #endif
173 return !is_virtual_node;
174 }
175
CreateValueNodeWithSexp(const BaseRef & sexp)176 ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
177 if (utils::isa<int>(sexp)) {
178 return NewValueNode(utils::cast<int>(sexp));
179 }
180 if (utils::isa<float>(sexp)) {
181 return NewValueNode(utils::cast<float>(sexp));
182 }
183 if (utils::isa<bool>(sexp)) {
184 return NewValueNode(utils::cast<bool>(sexp));
185 }
186 if (utils::isa<ValuePtr>(sexp)) {
187 return NewValueNode(utils::cast<ValuePtr>(sexp));
188 }
189 return nullptr;
190 }
191
CreateCNodeWithGraph(const std::vector<AnfNodePtr> & input_nodes,const BaseRef & graph)192 CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
193 if (utils::isa<FuncGraphPtr>(graph)) {
194 return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
195 }
196 if (utils::isa<VarPtr>(graph)) {
197 return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
198 }
199 return nullptr;
200 }
201
CreateVarNodeWithSexp(const BaseRef & sexp,const BaseRef & graph)202 VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
203 if (utils::isa<VarPtr>(graph)) {
204 MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
205 return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
206 }
207 if (utils::isa<FuncGraphPtr>(graph)) {
208 MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
209 return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
210 }
211 MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
212 return nullptr;
213 }
214
HandleSexpVector(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)215 AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
216 bool multigraph) {
217 if (primitive_vars == nullptr) {
218 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
219 return nullptr;
220 }
221 MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
222 std::vector<AnfNodePtr> input_nodes;
223 const auto &tuple = utils::cast<VectorRef>(sexp);
224 if (multigraph && utils::isa<VarPtr>(graph)) {
225 for (auto &x : tuple) {
226 auto is_var = std::make_shared<Var>("G");
227 MS_CHECK_TRUE_RET(is_var != nullptr, nullptr);
228 AnfNodePtr node = SexpToNode(x, is_var, primitive_vars, true);
229 input_nodes.push_back(node);
230 }
231 auto var_ptr = utils::cast<VarPtr>(graph);
232 return std::make_shared<CNode>(input_nodes, var_ptr);
233 }
234
235 for (auto &x : tuple) {
236 AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
237 input_nodes.push_back(node);
238 }
239 return CreateCNodeWithGraph(input_nodes, graph);
240 }
241
AnfEqualPrimitive(const AnfNodePtr & a_node,const AnfNodePtr & b_node)242 bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
243 auto a_value_node = a_node->cast<ValueNodePtr>();
244 auto b_value_node = b_node->cast<ValueNodePtr>();
245 if (a_value_node == nullptr || b_value_node == nullptr) {
246 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
247 return false;
248 }
249
250 auto a_value = a_value_node->value();
251 auto b_value = b_value_node->value();
252 if (a_value == nullptr || b_value == nullptr) {
253 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
254 return false;
255 }
256
257 auto a_prim = a_value->cast<PrimitivePtr>();
258 auto b_prim = b_value->cast<PrimitivePtr>();
259 if (a_prim == nullptr || b_prim == nullptr) {
260 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
261 return false;
262 }
263 return a_prim->name() == b_prim->name();
264 }
265
AnfEqualValueNode(const AnfNodePtr & a_node,const AnfNodePtr & b_node)266 bool AnfEqualValueNode(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
267 auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
268 auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
269 if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) {
270 MS_LOG(ERROR) << "cast value node ptr fail";
271 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
272 return false;
273 }
274 auto a_value_ptr = a_value_node_ptr->value();
275 auto b_value_ptr = b_value_node_ptr->value();
276 if (a_value_ptr == nullptr || b_value_ptr == nullptr) {
277 MS_LOG(ERROR) << "value ptr is nullptr";
278 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
279 return false;
280 }
281
282 if (utils::isa<ops::PrimitiveC>(a_value_ptr) && utils::isa<ops::PrimitiveC>(b_value_ptr)) {
283 auto a_obj = (ops::PrimitiveC *)(a_value_ptr.get());
284 auto b_obj = (ops::PrimitiveC *)(b_value_ptr.get());
285 return (*a_obj) == (*b_obj);
286 } else {
287 return (*a_value_ptr) == (*b_value_ptr);
288 }
289 }
290 } // namespace
291
CheckInputs(const CNodePtr & cnode)292 bool CheckInputs(const CNodePtr &cnode) {
293 if (cnode == nullptr) {
294 MS_LOG(ERROR) << "cnode is nullptr.";
295 return false;
296 }
297 if (std::any_of(cnode->inputs().begin(), cnode->inputs().end(),
298 [](const AnfNodePtr &anf_node) { return anf_node == nullptr; })) {
299 MS_LOG(ERROR) << "input is nullptr.";
300 return false;
301 }
302 return true;
303 }
304
CastToInt(const ValuePtr & value)305 std::vector<int> CastToInt(const ValuePtr &value) {
306 if (value == nullptr) {
307 MS_LOG(WARNING) << "valueptr is nullptr.";
308 return {};
309 }
310 std::vector<int> cur_value = {};
311 if (utils::isa<ValueSequeuePtr>(value)) {
312 if (!value->cast<ValueSequeuePtr>()->value().empty()) {
313 auto data_type = value->cast<ValueSequeuePtr>()->value().front()->type()->number_type();
314 if (data_type == kNumberTypeInt64) {
315 auto origin_value = GetValue<std::vector<int64_t>>(value);
316 std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
317 [](int64_t index) { return static_cast<int>(index); });
318 } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
319 cur_value = GetValue<std::vector<int>>(value);
320 } else {
321 MS_LOG(ERROR) << "he function only process integer data.";
322 return {};
323 }
324 }
325 } else {
326 auto data_type = value->type()->number_type();
327 if (data_type == kNumberTypeInt64) {
328 cur_value.push_back(static_cast<int>(GetValue<int64_t>(value)));
329 } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
330 cur_value.push_back(GetValue<int>(value));
331 } else {
332 MS_LOG(ERROR) << "the function only process integer data.";
333 return {};
334 }
335 }
336 return cur_value;
337 }
338
CastToVec2DInt(const ValuePtr & value)339 std::vector<std::vector<int>> CastToVec2DInt(const ValuePtr &value) {
340 if (value == nullptr) {
341 MS_LOG(WARNING) << "valueptr is nullptr.";
342 return {};
343 }
344
345 std::vector<std::vector<int>> result_value;
346 if (utils::isa<ValueSequeuePtr>(value)) {
347 auto data_type =
348 value->cast<ValueSequeuePtr>()->value().front()->cast<ValueSequeuePtr>()->value().front()->type()->number_type();
349 if (data_type == kNumberTypeInt64) {
350 auto origin_value = GetValue<std::vector<std::vector<int64_t>>>(value);
351 for (auto &i : origin_value) {
352 std::vector<int> cur_value;
353 std::transform(i.begin(), i.end(), std::back_inserter(cur_value),
354 [](int64_t j) { return static_cast<int>(j); });
355 result_value.push_back(cur_value);
356 }
357 } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
358 result_value = GetValue<std::vector<std::vector<int>>>(value);
359 } else {
360 MS_LOG(ERROR) << "he function only process integer data.";
361 return result_value;
362 }
363 }
364 return result_value;
365 }
366
CastToFloat(const ValuePtr & value)367 std::vector<float> CastToFloat(const ValuePtr &value) {
368 if (value == nullptr) {
369 MS_LOG(WARNING) << "valueptr is nullptr.";
370 return {};
371 }
372 std::vector<float> cur_value = {};
373 if (utils::isa<ValueSequeuePtr>(value)) {
374 if (!value->cast<ValueSequeuePtr>()->value().empty()) {
375 auto data_type = value->cast<ValueSequeuePtr>()->value().front()->type()->number_type();
376 if (data_type == kNumberTypeFloat || data_type == kNumberTypeFloat32) {
377 cur_value = GetValue<std::vector<float>>(value);
378 } else {
379 MS_LOG(ERROR) << "the function only process float data.";
380 return {};
381 }
382 }
383 } else {
384 auto data_type = value->type()->number_type();
385 if (data_type == kNumberTypeFloat || data_type == kNumberTypeFloat32) {
386 cur_value.push_back(GetValue<float>(value));
387 } else {
388 MS_LOG(ERROR) << "the function only process float data.";
389 return {};
390 }
391 }
392 return cur_value;
393 }
394
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)395 bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
396 if (node == nullptr || primitive_type == nullptr) {
397 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
398 return false;
399 }
400 if (node->isa<CNode>()) {
401 auto cnode = node->cast<CNodePtr>();
402 return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
403 } else if (node->isa<ValueNode>()) {
404 return IsPrimitive(node, primitive_type);
405 }
406 return false;
407 }
408
AnfEqual(const BaseRef & a,const BaseRef & b)409 bool AnfEqual(const BaseRef &a, const BaseRef &b) {
410 if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
411 auto a_node = utils::cast<AnfNodePtr>(a);
412 auto b_node = utils::cast<AnfNodePtr>(b);
413 if (a_node == nullptr || b_node == nullptr) {
414 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
415 return false;
416 }
417 if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
418 return AnfEqualPrimitive(a_node, b_node);
419 }
420 if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
421 return AnfEqualValueNode(a_node, b_node);
422 }
423 }
424 if (a.m_ptr->isa<mindspore::ops::PrimitiveC>() && b.m_ptr->isa<mindspore::ops::PrimitiveC>()) {
425 auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
426 auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
427 return a_value_node_ptr->name() == b_value_node_ptr->name();
428 }
429
430 return a == b;
431 }
432
CNodeTypeEqual(const BaseRef & a,const BaseRef & b)433 bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
434 // To matchCNode and Kernel's type
435 if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
436 return true;
437 }
438 return a.type() == b.type();
439 }
440
SexpToNode(const BaseRef & sexp,const BaseRef & graph,PrimitiveVarMap * primitive_vars,bool multigraph)441 AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
442 MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
443 if (primitive_vars == nullptr) {
444 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
445 return nullptr;
446 }
447 if (utils::isa<VectorRef>(sexp)) {
448 return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
449 }
450 if (utils::isa<VarPtr>(sexp)) {
451 auto var_ptr = utils::cast<VarPtr>(sexp);
452 if (var_ptr == nullptr) {
453 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
454 return nullptr;
455 }
456 if (var_ptr->primitive()) {
457 (*primitive_vars)[var_ptr->primitive()] = var_ptr;
458 return NewValueNode(var_ptr->primitive());
459 }
460 return CreateVarNodeWithSexp(sexp, graph);
461 }
462 if (utils::isa<AnfNodePtr>(sexp)) {
463 return utils::cast<AnfNodePtr>(sexp);
464 }
465 auto value_node = CreateValueNodeWithSexp(sexp);
466 if (value_node == nullptr) {
467 MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString();
468 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
469 return nullptr;
470 }
471 return value_node;
472 }
473
IsOpType(const BaseRef & n,const PrimitivePtr & prim)474 bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
475 if (utils::isa<AnfNodePtr>(n)) {
476 auto anf_node = utils::cast<AnfNodePtr>(n);
477 return CheckPrimitiveType(anf_node, prim);
478 }
479 return false;
480 }
481
IsRealCNodeKernel(const AnfNodePtr & node)482 bool IsRealCNodeKernel(const AnfNodePtr &node) {
483 if (node == nullptr) {
484 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
485 return false;
486 }
487 // parameter and value node is not a real cnode kernel
488 if (!node->isa<CNode>()) {
489 return false;
490 }
491 // return considered as a real node
492 if (CheckPrimitiveType(node, prim::kPrimReturn)) {
493 return true;
494 }
495 return IsRealKernel(node);
496 }
IsGraphKernel(const AnfNodePtr & node)497 bool IsGraphKernel(const AnfNodePtr &node) {
498 if (node == nullptr) {
499 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
500 return false;
501 }
502 // graph kernel should be a real cnode kernel.
503 if (!IsRealCNodeKernel(node)) {
504 return false;
505 }
506
507 auto cnode = node->cast<CNodePtr>();
508 if (cnode == nullptr) {
509 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
510 return false;
511 }
512 auto input = cnode->input(kAnfPrimitiveIndex);
513 // graph kernel should has func_graph as first input.
514 if (!IsValueNode<FuncGraph>(input)) {
515 return false;
516 }
517
518 auto func_graph = GetValueNode<FuncGraphPtr>(input);
519 if (func_graph == nullptr) {
520 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
521 return false;
522 }
523 return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
524 }
525
AddNewBiasNode(float * bias_data,const FuncGraphPtr & func_graph,int kernel_num,TypeId type_id)526 ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) {
527 if (bias_data == nullptr || func_graph == nullptr) {
528 MS_LOG(ERROR) << "input parameter is nullptr.";
529 return nullptr;
530 }
531 auto bias_parameter = func_graph->add_parameter();
532 MS_ASSERT(bias_parameter != nullptr);
533 std::vector<int64_t> shape_vector = {kernel_num};
534 auto tensor_info =
535 lite::CreateTensorInfo(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t), shape_vector, type_id);
536 if (tensor_info == nullptr) {
537 MS_LOG(ERROR) << "create tensor info failed.";
538 return nullptr;
539 }
540 auto status = lite::InitParameterFromTensorInfo(bias_parameter, tensor_info);
541 if (status != RET_OK) {
542 MS_LOG(ERROR) << "init parameter from tensor info failed";
543 return nullptr;
544 }
545
546 return bias_parameter;
547 }
548
GetTensorInfo(const AnfNodePtr & node)549 tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node) {
550 MS_CHECK_TRUE_RET(node != nullptr, nullptr);
551 if (!utils::isa<ParameterPtr>(node)) {
552 if (utils::isa<ValueNodePtr>(node)) {
553 auto valueNode = node->cast<ValueNodePtr>();
554 auto value_ptr = valueNode->value();
555 MS_CHECK_TRUE_RET(value_ptr != nullptr, nullptr);
556 auto value = value_ptr->cast<tensor::TensorPtr>();
557 if (value != nullptr) {
558 return value;
559 }
560 }
561 MS_LOG(DEBUG) << "get lite param value node neither parameternode or valuenode";
562 return nullptr;
563 }
564 auto param = node->cast<ParameterPtr>();
565 MS_ASSERT(param != nullptr);
566 if (!param->has_default() || param->default_param() == nullptr) {
567 return nullptr;
568 }
569 auto tensor_info = param->default_param()->cast<tensor::TensorPtr>();
570 return tensor_info;
571 }
572
GetCNodeInputAbstract(const CNodePtr & cnode,size_t index)573 AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index) {
574 if (cnode == nullptr) {
575 MS_LOG(ERROR) << "CNodePtr is nullptr";
576 return nullptr;
577 }
578 auto inputs = cnode->inputs();
579 if (!(index > 0 && index < inputs.size())) {
580 return nullptr;
581 }
582 auto input = inputs[index];
583 if (input == nullptr) {
584 MS_LOG(ERROR) << "CNode input is nullptr";
585 return nullptr;
586 }
587
588 AbstractBasePtr abstract = nullptr;
589 if (utils::isa<ParameterPtr>(input)) {
590 auto parameter = input->cast<ParameterPtr>();
591 abstract = parameter->abstract();
592 } else if (utils::isa<CNodePtr>(input)) {
593 auto input_cnode = input->cast<CNodePtr>();
594 if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
595 auto tuple_inputs = input_cnode->inputs();
596 MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
597 auto get_item_input_cnode = tuple_inputs.at(1);
598 MS_ASSERT(get_item_input_cnode != nullptr);
599 auto idx = GetTupleGetItemOutIndex(input_cnode);
600 if (!utils::isa<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
601 MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
602 return nullptr;
603 }
604 auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
605 auto abstract_list = abstract_tuple->elements();
606 if (abstract_list.size() <= idx) {
607 MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
608 return nullptr;
609 }
610 abstract = abstract_list[idx];
611 } else {
612 abstract = input_cnode->abstract();
613 }
614 } else {
615 MS_LOG(ERROR) << "unsupported input node type";
616 return nullptr;
617 }
618 return abstract;
619 }
620
IsParamNode(const BaseRef & n)621 bool IsParamNode(const BaseRef &n) {
622 if (!utils::isa<ParameterPtr>(n)) {
623 return false;
624 }
625 auto parameter = utils::cast<ParameterPtr>(n);
626 if (!parameter->has_default() || parameter->default_param() == nullptr) {
627 return false;
628 }
629 auto tensor = parameter->default_param()->cast<tensor::TensorPtr>();
630 if (tensor == nullptr) {
631 return false;
632 }
633 return tensor->data_c() != nullptr;
634 }
635
GetTensorInfoFromAbstract(tensor::TensorPtr * tensor_info,const CNodePtr & cnode,size_t index)636 STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index) {
637 CHECK_NULL_RETURN(tensor_info);
638 CHECK_NULL_RETURN(cnode);
639 AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index);
640 if (abstract == nullptr) {
641 MS_LOG(WARNING) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr, infershape is delayed.";
642 return RET_ERROR;
643 }
644 if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
645 MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor";
646 return RET_ERROR;
647 }
648 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
649 if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape
650 MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
651 return RET_ERROR;
652 }
653 *tensor_info = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack());
654 if (*tensor_info == nullptr) {
655 MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr";
656 return RET_ERROR;
657 }
658 return RET_OK;
659 }
660
IsParamOrValueNodeWithData(const BaseRef & n)661 bool IsParamOrValueNodeWithData(const BaseRef &n) {
662 if (utils::isa<ValueNode>(n)) {
663 auto value_node = utils::cast<ValueNodePtr>(n);
664 auto value = value_node->value();
665 if (value != nullptr && value->isa<tensor::Tensor>()) {
666 auto tensor = value->cast<tensor::TensorPtr>();
667 if (tensor == nullptr || tensor->data_c() == nullptr) {
668 return false;
669 }
670 return true;
671 } else {
672 return false;
673 }
674 }
675 if (utils::isa<ParameterPtr>(n)) {
676 return IsParamNode(n);
677 }
678 return false;
679 }
680
IsParallelSplitConvNode(const BaseRef & n)681 bool IsParallelSplitConvNode(const BaseRef &n) {
682 if (utils::isa<AnfNodePtr>(n)) {
683 auto anf_node = utils::cast<AnfNodePtr>(n);
684 PrimitivePtr prim;
685 if (utils::isa<CNodePtr>(anf_node)) {
686 prim = GetValueNode<PrimitivePtr>(anf_node->cast<CNodePtr>()->input(kAnfPrimitiveIndex));
687 }
688 if (utils::isa<ValueNodePtr>(anf_node)) {
689 prim = GetValueNode<PrimitivePtr>(anf_node);
690 }
691 if (prim == nullptr) {
692 return false;
693 }
694 int device_type =
695 prim->GetAttr(ops::kDeviceType) != nullptr ? GetValue<int32_t>(prim->GetAttr(ops::kDeviceType)) : kDeviceTypeNone;
696 if (device_type != kDeviceTypeNone) {
697 return false;
698 }
699 return CheckPrimitiveType(anf_node, prim::kPrimConv2DFusion) || CheckPrimitiveType(anf_node, prim::kPrimConv2D);
700 }
701 return false;
702 }
703
IsConvNode(const BaseRef & n)704 bool IsConvNode(const BaseRef &n) {
705 if (utils::isa<AnfNodePtr>(n)) {
706 auto anf_node = utils::cast<AnfNodePtr>(n);
707 PrimitivePtr prim;
708 if (utils::isa<CNodePtr>(anf_node)) {
709 prim = GetValueNode<PrimitivePtr>(anf_node->cast<CNodePtr>()->input(kAnfPrimitiveIndex));
710 }
711 if (utils::isa<ValueNodePtr>(anf_node)) {
712 prim = GetValueNode<PrimitivePtr>(anf_node);
713 }
714 if (prim == nullptr) {
715 return false;
716 }
717
718 if (prim->GetAttr(ops::kActivationType) != nullptr &&
719 GetValue<int64_t>(prim->GetAttr(ops::kActivationType)) != NO_ACTIVATION) {
720 return false;
721 }
722
723 bool is_depth_wise =
724 prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
725 return CheckPrimitiveType(anf_node, prim::kPrimConv2DFusion) ||
726 (CheckPrimitiveType(anf_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise);
727 }
728 return false;
729 }
730
CheckIsAllInputsParam(const AnfNodePtr & node)731 bool CheckIsAllInputsParam(const AnfNodePtr &node) {
732 if (node == nullptr) {
733 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
734 return false;
735 }
736 if (utils::isa<CNode>(node)) {
737 auto cnode = node->cast<CNodePtr>();
738 for (size_t i = 1; i < cnode->inputs().size(); i++) {
739 if (!utils::isa<Parameter>(cnode->input(i)) && !utils::isa<ValueNodePtr>(cnode->input(i))) {
740 return false;
741 }
742 }
743 return true;
744 }
745 return false;
746 }
747
GetOutputTensorNum(const AnfNodePtr & node)748 size_t GetOutputTensorNum(const AnfNodePtr &node) {
749 if (node == nullptr) {
750 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
751 return 0;
752 }
753 auto type = node->Type();
754 if (type == nullptr) {
755 return 1;
756 }
757 if (type->isa<Tuple>()) {
758 auto tuple_type = type->cast<TuplePtr>();
759 if (tuple_type == nullptr) {
760 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
761 return 0;
762 }
763 return tuple_type->size();
764 } else if (type->isa<TensorType>() || type->isa<Number>()) {
765 return 1;
766 } else if (type->isa<TypeNone>()) {
767 return 0;
768 } else {
769 return 1;
770 }
771 }
772
IsMultiOutputTensors(const FuncGraphPtr & graph,const AnfNodePtr & node)773 bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
774 if (graph == nullptr || node == nullptr) {
775 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
776 return false;
777 }
778 auto output_node_list = GetRealNodeUsedList(graph, node);
779 if (output_node_list == nullptr) {
780 MS_LOG(ERROR) << "output node list is nullptr";
781 return false;
782 }
783 if (output_node_list->size() != 1) {
784 MS_LOG(DEBUG) << "fusion node has multi output nodes";
785 return true;
786 }
787 return false;
788 }
789
GetRealNodeUsedList(const FuncGraphPtr & graph,const AnfNodePtr & node)790 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
791 const AnfNodePtr &node) {
792 if (graph == nullptr || node == nullptr) {
793 MS_LOG(ERROR) << "input parameter is nullptr.";
794 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
795 return nullptr;
796 }
797 auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
798 MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
799 auto manager = graph->manager();
800 if (manager == nullptr) {
801 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
802 return nullptr;
803 }
804 auto iter = manager->node_users().find(node);
805 if (iter == manager->node_users().end()) {
806 MS_LOG(ERROR) << "node has no output in manager";
807 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
808 return nullptr;
809 }
810 auto output_info_list = iter->second;
811 std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
812 return output_node_list;
813 }
814
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)815 size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
816 if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) {
817 MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
818 return -1;
819 }
820 auto output_index_value_node = tuple_get_item->input(kInputIndexTwo);
821 if (output_index_value_node == nullptr) {
822 MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
823 return -1;
824 }
825 auto value_node = output_index_value_node->cast<ValueNodePtr>();
826 if (value_node == nullptr) {
827 MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
828 return -1;
829 }
830 auto indexes = CastToInt(value_node->value());
831 if (indexes.empty()) {
832 MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
833 return -1;
834 }
835 return indexes.front();
836 }
837
GetRealNodeUsedListByOutputIdx(const FuncGraphPtr & graph,const AnfNodePtr & node,size_t output_index)838 std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
839 const AnfNodePtr &node,
840 size_t output_index) {
841 if (graph == nullptr || node == nullptr) {
842 MS_LOG(ERROR) << "input parameter is nullptr.";
843 return nullptr;
844 }
845 auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
846 MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
847 auto manager = graph->manager();
848 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
849 auto iter = manager->node_users().find(node);
850 if (iter == manager->node_users().end()) {
851 MS_LOG(ERROR) << "node has no output in manager";
852 return output_node_list;
853 }
854 auto output_info_list = iter->second;
855 for (const auto &output_info : output_info_list) {
856 size_t used_output_index;
857 if (CheckPrimitiveType(output_info.first, prim::kPrimTupleGetItem)) {
858 used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
859 } else if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
860 used_output_index = output_index;
861 } else {
862 if (output_index != 0) {
863 MS_LOG(ERROR) << "node has no output in manager";
864 return output_node_list;
865 }
866 return output_node_list;
867 }
868 if (used_output_index == output_index) {
869 output_node_list->push_back(output_info);
870 }
871 }
872 return output_node_list;
873 }
874
TransFilterFormat(const tensor::TensorPtr & tensor,schema::Format src_format,schema::Format dst_format)875 STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
876 MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
877 std::unordered_map<TypeId, std::function<STATUS(const tensor::TensorPtr &, schema::Format, schema::Format)>>
878 trans_func = {{kNumberTypeFloat32, DoTransposeData<float>},
879 {kNumberTypeUInt8, DoTransposeData<uint8_t>},
880 {kNumberTypeInt8, DoTransposeData<int8_t>},
881 {kNumberTypeFloat16, DoTransposeData<float16>}};
882 auto data_type = tensor->data_type();
883 auto iter = trans_func.find(data_type);
884 if (iter == trans_func.end()) {
885 MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
886 return RET_ERROR;
887 }
888 return iter->second(tensor, src_format, dst_format);
889 }
890
BuildParameterNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const tensor::TensorPtr & tensor_info)891 ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
892 const tensor::TensorPtr &tensor_info) {
893 if (func_graph == nullptr || node == nullptr || tensor_info == nullptr) {
894 MS_LOG(ERROR) << "input parameter is nullptr.";
895 return nullptr;
896 }
897 auto param_node = func_graph->add_parameter();
898 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
899 auto shape = tensor_info->shape();
900 std::vector<int64_t> shape_vector;
901 std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
902 [](const int &val) { return static_cast<int64_t>(val); });
903 auto data_type = tensor_info->data_type() == kNumberTypeInt64 ? kNumberTypeInt32 : tensor_info->data_type();
904 param_node->set_name(node->fullname_with_scope());
905 auto tensor_info_new = std::make_shared<tensor::Tensor>(data_type, shape_vector);
906 if (tensor_info_new == nullptr) {
907 MS_LOG(ERROR) << "new tensor::Tensor failed.";
908 return nullptr;
909 }
910 size_t data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
911 if (tensor_info->Size() == 0) {
912 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
913 if (status != RET_OK) {
914 MS_LOG(ERROR) << "init parameter from tensor info failed";
915 return nullptr;
916 }
917 return param_node;
918 }
919 if (tensor_info->data_type() == kNumberTypeInt64) {
920 auto *tensor_data = reinterpret_cast<int *>(tensor_info_new->data_c());
921 if (tensor_data == nullptr) {
922 MS_LOG(ERROR) << "new data failed";
923 return nullptr;
924 }
925 auto *origin_data = reinterpret_cast<int64_t *>(tensor_info->data_c());
926 for (size_t i = 0; i < data_count; ++i) {
927 if (origin_data[i] > static_cast<int64_t>(INT32_MAX) || origin_data[i] < static_cast<int64_t>(INT32_MIN)) {
928 MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32";
929 tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN;
930 } else {
931 tensor_data[i] = static_cast<int>(origin_data[i]);
932 }
933 }
934 } else {
935 tensor_info_new->set_data_type(tensor_info->data_type());
936 auto *tensor_data = reinterpret_cast<int8_t *>(tensor_info_new->data_c());
937 if (tensor_data == nullptr) {
938 MS_LOG(ERROR) << "new data failed";
939 return nullptr;
940 }
941 if (memcpy_s(tensor_data, tensor_info_new->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) {
942 MS_LOG(ERROR) << "memcpy data failed.";
943 return nullptr;
944 }
945 }
946 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
947 if (status != RET_OK) {
948 MS_LOG(ERROR) << "init parameter from tensor info failed";
949 return nullptr;
950 }
951 param_node->set_default_param(tensor_info_new);
952 return param_node;
953 }
954
BuildIntValueParameterNode(const FuncGraphPtr & func_graph,const int32_t & data,const std::string & node_name)955 ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data,
956 const std::string &node_name) {
957 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
958 auto param_node = func_graph->add_parameter();
959 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
960 param_node->set_name(node_name);
961
962 auto tensor_info = lite::CreateTensorInfo(&data, sizeof(int32_t), {1}, kNumberTypeInt32);
963 if (tensor_info == nullptr) {
964 MS_LOG(ERROR) << "Create tensor info failed";
965 return nullptr;
966 }
967
968 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
969 if (status != RET_OK) {
970 MS_LOG(ERROR) << "init parameter from tensor info failed";
971 return nullptr;
972 }
973 return param_node;
974 }
975
BuildIntVecParameterNode(const FuncGraphPtr & func_graph,const std::vector<int32_t> & data,const std::string & node_name)976 ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
977 const std::string &node_name) {
978 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
979 auto param_node = func_graph->add_parameter();
980 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
981 param_node->set_name(node_name);
982
983 std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
984 auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(int32_t), shape_vector, kNumberTypeInt32);
985 if (tensor_info == nullptr) {
986 MS_LOG(ERROR) << "Create tensor info failed";
987 return nullptr;
988 }
989
990 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
991 if (status != RET_OK) {
992 MS_LOG(ERROR) << "init parameter from tensor info failed";
993 return nullptr;
994 }
995
996 return param_node;
997 }
998
BuildIntVec2DParameterNode(const FuncGraphPtr & func_graph,const std::vector<std::vector<int32_t>> & data,const std::string & node_name)999 ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<int32_t>> &data,
1000 const std::string &node_name) {
1001 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1002 auto param_node = func_graph->add_parameter();
1003 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1004 param_node->set_name(node_name);
1005
1006 std::vector<int64_t> shape_vector;
1007 shape_vector.push_back(data.size());
1008 shape_vector.push_back(2);
1009
1010 std::vector<int32_t> data_1d;
1011 for (auto pair : data) {
1012 data_1d.insert(data_1d.end(), pair.begin(), pair.end());
1013 }
1014
1015 auto size = data_1d.size() * sizeof(int32_t);
1016 auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeInt32);
1017 if (tensor_info == nullptr) {
1018 MS_LOG(ERROR) << "Create tensor info failed";
1019 return nullptr;
1020 }
1021 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1022 if (status != RET_OK) {
1023 MS_LOG(ERROR) << "init parameter from tensor info failed";
1024 return nullptr;
1025 }
1026 return param_node;
1027 }
1028
BuildFloatValueParameterNode(const FuncGraphPtr & func_graph,const float & data,const std::string & node_name)1029 ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
1030 const std::string &node_name) {
1031 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1032 auto param_node = func_graph->add_parameter();
1033 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1034 param_node->set_name(node_name);
1035
1036 auto tensor_info = lite::CreateTensorInfo(&data, sizeof(float), {1}, kNumberTypeFloat32);
1037 if (tensor_info == nullptr) {
1038 MS_LOG(ERROR) << "Create tensor info failed";
1039 return nullptr;
1040 }
1041 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1042 if (status != RET_OK) {
1043 MS_LOG(ERROR) << "init parameter from tensor info failed";
1044 return nullptr;
1045 }
1046 return param_node;
1047 }
1048
BuildFloatVecParameterNode(const FuncGraphPtr & func_graph,const std::vector<float> & data,const std::string & node_name)1049 ParameterPtr BuildFloatVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<float> &data,
1050 const std::string &node_name) {
1051 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1052 auto param_node = func_graph->add_parameter();
1053 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1054 param_node->set_name(node_name);
1055
1056 std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
1057 auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(float), shape_vector, kNumberTypeFloat);
1058 if (tensor_info == nullptr) {
1059 MS_LOG(ERROR) << "Create tensor info failed";
1060 return nullptr;
1061 }
1062
1063 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1064 if (status != RET_OK) {
1065 MS_LOG(ERROR) << "init parameter from tensor info failed";
1066 return nullptr;
1067 }
1068
1069 return param_node;
1070 }
1071
GenTransposeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & perm,const std::string & cnode_name)1072 CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &perm,
1073 const std::string &cnode_name) {
1074 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1075 MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
1076 auto perm_node = BuildIntVecParameterNode(func_graph, perm, cnode_name + "_perm");
1077 MS_ASSERT(perm_node != nullptr);
1078 auto trans_prim = std::make_shared<ops::Transpose>();
1079 MS_CHECK_TRUE_RET(trans_prim != nullptr, nullptr);
1080 auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node});
1081 MS_ASSERT(cnode != nullptr);
1082 auto manager = Manage(func_graph);
1083 MS_ASSERT(manager != nullptr);
1084 manager->SetEdge(cnode, 1, input_node);
1085 manager->SetEdge(cnode, kInputIndexTwo, perm_node);
1086 cnode->set_fullname_with_scope(cnode_name);
1087 auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeTwo, 1);
1088 MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1089 trans_prim->AddAttr("quant_params", quant_params_holder);
1090 return cnode;
1091 }
1092
GenGatherNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & indices,const std::string & cnode_name)1093 CNodePtr GenGatherNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &indices,
1094 const std::string &cnode_name) {
1095 if (func_graph == nullptr || input_node == nullptr) {
1096 MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1097 return nullptr;
1098 }
1099 auto indices_node = BuildIntVecParameterNode(func_graph, indices, cnode_name + "_indices");
1100 if (indices_node == nullptr) {
1101 MS_LOG(ERROR) << "make indices node failed.";
1102 return nullptr;
1103 }
1104 auto axis_node = BuildIntVecParameterNode(func_graph, {0}, cnode_name + "_indices");
1105 if (axis_node == nullptr) {
1106 MS_LOG(ERROR) << "make indices node failed.";
1107 return nullptr;
1108 }
1109 auto gather_prim = std::make_shared<ops::Gather>();
1110 MS_CHECK_TRUE_RET(gather_prim != nullptr, nullptr);
1111 auto cnode = func_graph->NewCNode(gather_prim, {input_node, indices_node, axis_node});
1112 MS_ASSERT(cnode != nullptr);
1113 auto manager = Manage(func_graph);
1114 MS_ASSERT(manager != nullptr);
1115 manager->SetEdge(cnode, 1, input_node);
1116 manager->SetEdge(cnode, kInputIndexTwo, indices_node);
1117 manager->SetEdge(cnode, kInputIndexThree, axis_node);
1118 cnode->set_fullname_with_scope(cnode_name);
1119 auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeThree, 1);
1120 MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1121 gather_prim->AddAttr("quant_params", quant_params_holder);
1122 return cnode;
1123 }
1124
GenTupleGetItemNode(const FuncGraphPtr & func_graph,const CNodePtr & input,size_t index)1125 CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index) {
1126 if (func_graph == nullptr || input == nullptr) {
1127 MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1128 return nullptr;
1129 }
1130 auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>();
1131 MS_CHECK_TRUE_RET(tuple_get_item_prim != nullptr, nullptr);
1132 auto second_input = NewValueNode(MakeValue<int>(index));
1133 MS_CHECK_TRUE_RET(second_input != nullptr, nullptr);
1134 auto tuple_cnode = func_graph->NewCNode(tuple_get_item_prim, {input, second_input});
1135 MS_ASSERT(tuple_cnode != nullptr);
1136 tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index));
1137 return tuple_cnode;
1138 }
1139
FetchShapeFromAbstract(const abstract::AbstractBasePtr & abstract,ShapeVector * shape)1140 STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape) {
1141 if (abstract == nullptr || shape == nullptr) {
1142 MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1143 return lite::RET_ERROR;
1144 }
1145 if (!utils::isa<abstract::AbstractTensor>(abstract)) {
1146 MS_LOG(ERROR) << "abstract of cnode is invalid.";
1147 return lite::RET_ERROR;
1148 }
1149 auto abstract_tensor = abstract->cast<abstract::AbstractTensorPtr>();
1150 if (abstract_tensor->BuildShape() == nullptr || !utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
1151 MS_LOG(ERROR) << "shape of cnode's output is invalid.";
1152 return lite::RET_ERROR;
1153 }
1154 *shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
1155 return lite::RET_OK;
1156 }
1157
IsTrainOp(const CNodePtr & cnode)1158 bool IsTrainOp(const CNodePtr &cnode) {
1159 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1160 auto cnode_type = prim->type_name();
1161 // optimizer op
1162 if (cnode_type == "Adam" || cnode_type == "SGD" || cnode_type == "ApplyMomentum") {
1163 return true;
1164 }
1165 // loss op
1166 if (cnode_type == "SoftmaxCrossEntropyWithLogits" || cnode_type == "SpareSoftmaxCrossEntropyWithLogits" ||
1167 cnode_type == "SmoothL1Loss" || cnode_type == "SmoothL1LossGrad" ||
1168 cnode_type == "SigmoidCrossEntropyWithLogits" || cnode_type == "SigmoidCrossEntropyWithLogpitsGrad") {
1169 return true;
1170 }
1171 // grad op
1172 if (cnode_type.find("Grad") != std::string::npos ||
1173 cnode->fullname_with_scope().find("Gradients") != std::string::npos) {
1174 return true;
1175 }
1176 return false;
1177 }
1178
IsMarkedTrainOp(const CNodePtr & cnode)1179 bool IsMarkedTrainOp(const CNodePtr &cnode) {
1180 if (cnode == nullptr) {
1181 return false;
1182 }
1183 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1184 MS_CHECK_TRUE_RET(prim != nullptr, false);
1185 if (prim->GetAttr("trainOp") != nullptr && GetValue<bool>(prim->GetAttr("trainOp"))) {
1186 MS_LOG(DEBUG) << "train op not fusion.";
1187 return true;
1188 }
1189 return false;
1190 }
1191
GetDataTypeFromAnfNode(const AnfNodePtr & anf_node,TypeId * type_id)1192 int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
1193 if (anf_node == nullptr || type_id == nullptr) {
1194 MS_LOG(ERROR) << "anf_node or type_id is nullptr.";
1195 return RET_ERROR;
1196 }
1197 auto abstract_base = anf_node->abstract();
1198 // used for multi output e.g. split.
1199 if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
1200 auto abstract_tuple = abstract_base->cast<abstract::AbstractTuplePtr>();
1201 if (abstract_tuple->elements().empty()) {
1202 MS_LOG(ERROR) << "abstract_tuple elements is empty.";
1203 return RET_ERROR;
1204 }
1205 abstract_base = abstract_tuple->elements().front();
1206 }
1207 if (abstract_base == nullptr) {
1208 MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << anf_node->fullname_with_scope();
1209 return RET_ERROR;
1210 }
1211 if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
1212 MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << anf_node->fullname_with_scope();
1213 return RET_ERROR;
1214 }
1215 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
1216 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
1217 MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
1218 *type_id = type_ptr->type_id();
1219 return RET_OK;
1220 }
1221 } // namespace opt
1222 } // namespace mindspore
1223