1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/optimizer/common/gllo_utils.h"
17 #include <algorithm>
18 #include <vector>
19 #include <utility>
20 #include <unordered_map>
21 #include <functional>
22 #include <string>
23 #include <set>
24 #include <fstream>
25 #include "mindspore/core/ops/structure_ops.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/conv_pool_ops.h"
28 #include "mindspore/core/ops/lite_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "base/float16.h"
31 #include "ops/fusion/conv2d_fusion.h"
32 #include "ops/auto_generate/gen_lite_ops.h"
33 #include "ops/ops_func_impl/gather.h"
34 #include "ops/tuple_get_item.h"
35 #include "tools/common/tensor_util.h"
36 #include "frontend/operator/ops.h"
37 #include "include/backend/optimizer/helper.h"
38 #include "tools/converter/quantizer/quant_param_holder.h"
39 #include "nnacl/op_base.h"
40 #include "src/common/log_util.h"
41 #include "tools/converter/parser/parser_utils.h"
42 #include "tools/optimizer/common/helper.h"
43 #include "ops/op_utils.h"
44 #include "ops/custom.h"
45 #include "ops/tensor_copy.h"
46 #include "include/common/utils/anfalgo.h"
47 #include "tools/optimizer/common/format_utils.h"
48
49 namespace mindspore {
50 namespace opt {
51 namespace {
52 constexpr auto kAnfPrimitiveIndex = 0;
53 constexpr auto kDeviceTypeNone = -1;
DeduceDimConvertion(schema::Format src_format,schema::Format dst_format,std::vector<int> * const perm)54 int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector<int> *const perm) {
55 MS_ASSERT(perm != nullptr);
56 auto src_format_str = std::string(schema::EnumNameFormat(src_format));
57 auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
58 if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
59 MS_LOG(ERROR) << "src_format or dst_format is error.";
60 return lite::RET_ERROR;
61 }
62 std::replace(src_format_str.begin(), src_format_str.end(), 'K', 'N');
63 std::replace(dst_format_str.begin(), dst_format_str.end(), 'K', 'N');
64 perm->clear();
65 std::unordered_map<char, int> dim_map;
66 for (size_t i = 0; i < src_format_str.size(); ++i) {
67 dim_map[src_format_str[i]] = i;
68 }
69 for (size_t i = 0; i < dst_format_str.size(); ++i) {
70 if (dim_map.find(dst_format_str[i]) == dim_map.end()) {
71 MS_LOG(ERROR) << "src_format and dst_format cannot match, please check.";
72 return RET_ERROR;
73 }
74 perm->push_back(dim_map[dst_format_str[i]]);
75 }
76 return lite::RET_OK;
77 }
78
79 template <class T>
TransposeDim4(const ShapeVector & input_shape,const ShapeVector & output_shape,const std::vector<int> & perm,const T * const in_data,T * out_data)80 void TransposeDim4(const ShapeVector &input_shape, const ShapeVector &output_shape, const std::vector<int> &perm,
81 const T *const in_data, T *out_data) {
82 auto num_axes = input_shape.size();
83 std::vector<int64_t> strides;
84 std::vector<int64_t> out_strides;
85 strides.resize(num_axes);
86 out_strides.resize(num_axes);
87 strides[num_axes - 1] = 1LL;
88 out_strides[num_axes - 1] = 1LL;
89 for (size_t i = num_axes - 1; i >= 1; i--) {
90 strides[i - 1] = input_shape[i] * strides[i];
91 out_strides[i - 1] = output_shape[i] * out_strides[i];
92 }
93 const auto stride0 = strides[perm[kIndex0]];
94 const auto stride1 = strides[perm[kIndex1]];
95 const auto stride2 = strides[perm[kIndex2]];
96 const auto stride3 = strides[perm[kIndex3]];
97 const auto out_stride0 = out_strides[kIndex0];
98 const auto out_stride1 = out_strides[kIndex1];
99 const auto out_stride2 = out_strides[kIndex2];
100 const auto output0 = output_shape[kIndex0];
101 const auto output1 = output_shape[kIndex1];
102 const auto output2 = output_shape[kIndex2];
103 const auto output3 = output_shape[kIndex3];
104
105 int64_t out_beg_i = 0;
106 int64_t beg_i = 0;
107 for (int64_t i = 0; i < output0; ++i) {
108 int64_t out_beg_ij = out_beg_i;
109 int64_t beg_ij = beg_i;
110 for (int64_t j = 0; j < output1; ++j) {
111 int64_t out_beg_ijk = out_beg_ij;
112 int64_t beg_ijk = beg_ij;
113 for (int64_t k = 0; k < output2; ++k) {
114 for (int64_t m = 0; m < output3; ++m) {
115 out_data[out_beg_ijk + m] = in_data[beg_ijk + m * stride3];
116 }
117 out_beg_ijk += out_stride2;
118 beg_ijk += stride2;
119 }
120 out_beg_ij += out_stride1;
121 beg_ij += stride1;
122 }
123 out_beg_i += out_stride0;
124 beg_i += stride0;
125 }
126 }
127
128 template <typename T>
DoTransposeData(const tensor::TensorPtr & tensor,schema::Format src_format,schema::Format dst_format)129 STATUS DoTransposeData(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
130 MS_ASSERT(tensor != nullptr);
131 auto origin_shape = tensor->shape_c();
132 if (origin_shape.size() != kInputSizeFour) {
133 MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << origin_shape.size();
134 return lite::RET_ERROR;
135 }
136 if (std::any_of(origin_shape.begin(), origin_shape.end(), [](int64_t val) { return val <= 0; })) {
137 MS_LOG(ERROR) << "the tensor's shape is invalid.";
138 return lite::RET_ERROR;
139 }
140 std::vector<int> perm;
141 if (DeduceDimConvertion(src_format, dst_format, &perm) != RET_OK) {
142 MS_LOG(ERROR) << "deduce perm failed.";
143 return lite::RET_ERROR;
144 }
145 ShapeVector new_shape;
146 for (auto &val : perm) {
147 if (val < 0 || static_cast<size_t>(val) >= origin_shape.size()) {
148 MS_LOG(ERROR) << "deduce perm is invalid.";
149 return lite::RET_ERROR;
150 }
151 new_shape.push_back(origin_shape[val]);
152 }
153 int64_t count = 1;
154 for (const auto &dat : origin_shape) {
155 if (INT_MUL_OVERFLOW(count, dat)) {
156 MS_LOG(ERROR) << "Int mul overflow";
157 return RET_ERROR;
158 }
159 count *= dat;
160 }
161 if (count <= 0 || count > static_cast<int64_t>(INT32_MAX)) {
162 MS_LOG(ERROR) << "tensor element num is too big, which should be smaller than int32_max.";
163 return RET_ERROR;
164 }
165 std::vector<T> buf(count);
166
167 void *originWeightData = tensor->data_c();
168 MS_CHECK_TRUE_RET(originWeightData != nullptr, RET_ERROR);
169 T *weightData = static_cast<T *>(originWeightData);
170 TransposeDim4<T>(origin_shape, new_shape, perm, weightData, buf.data());
171 if (memcpy_s(tensor->data_c(), tensor->Size(), buf.data(), count * sizeof(T)) != EOK) {
172 MS_LOG(ERROR) << "memcpy_s failed.";
173 return RET_ERROR;
174 }
175 tensor->set_shape(new_shape);
176 return RET_OK;
177 }
178
IsRealKernel(const AnfNodePtr & node)179 bool IsRealKernel(const AnfNodePtr &node) {
180 if (node == nullptr) {
181 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
182 return false;
183 }
184 // parameter and value node is not a real kernel too
185 if (!node->isa<CNode>()) {
186 return true;
187 }
188 auto cnode = node->cast<CNodePtr>();
189 if (cnode == nullptr) {
190 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
191 return false;
192 }
193 if (cnode->empty()) {
194 MS_LOG(ERROR) << "Illegal null input of cnode(%s)" << node->DebugString();
195 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR);
196 return false;
197 }
198 auto input = cnode->input(0);
199 #ifndef ENABLE_SECURITY
200 bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
201 IsPrimitive(input, prim::kPrimTensorSummary) ||
202 IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
203 IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
204 IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) ||
205 IsPrimitive(input, prim::kPrimPartial);
206 #else
207 bool is_virtual_node = IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) ||
208 IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) ||
209 IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
210 #endif
211 return !is_virtual_node;
212 }
213
CopyDataFromInt64(const int64_t * origin_data,int * tensor_data,size_t data_count)214 void CopyDataFromInt64(const int64_t *origin_data, int *tensor_data, size_t data_count) {
215 for (size_t i = 0; i < data_count; ++i) {
216 if (origin_data[i] == INT64_MAX) {
217 tensor_data[i] = INT32_MAX;
218 } else if (origin_data[i] == INT64_MIN) {
219 tensor_data[i] = INT32_MIN;
220 } else if (origin_data[i] > static_cast<int64_t>(INT32_MAX) || origin_data[i] < static_cast<int64_t>(INT32_MIN)) {
221 MS_LOG(WARNING) << "int64 data " << origin_data[i] << " cannot fit into int32";
222 tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN;
223 } else {
224 tensor_data[i] = static_cast<int>(origin_data[i]);
225 }
226 }
227 }
228
CopyTensorDataFromTensorInfo(const tensor::TensorPtr & tensor_info,const std::shared_ptr<tensor::Tensor> & tensor_info_dst,size_t data_count,bool keep_origin_dtype)229 int CopyTensorDataFromTensorInfo(const tensor::TensorPtr &tensor_info,
230 const std::shared_ptr<tensor::Tensor> &tensor_info_dst, size_t data_count,
231 bool keep_origin_dtype) {
232 if (tensor_info->data_type() == kNumberTypeInt64 && !keep_origin_dtype) {
233 auto *tensor_data = reinterpret_cast<int *>(tensor_info_dst->data_c());
234 if (tensor_data == nullptr) {
235 MS_LOG(ERROR) << "new data failed";
236 return RET_ERROR;
237 }
238 auto *origin_data = reinterpret_cast<int64_t *>(tensor_info->data_c());
239 MS_CHECK_TRUE_MSG(origin_data != nullptr, lite::RET_NULL_PTR, "origin_data is nullptr");
240 CopyDataFromInt64(origin_data, tensor_data, data_count);
241 } else if (tensor_info->data_type() == kNumberTypeFloat64) {
242 auto *tensor_data = reinterpret_cast<float *>(tensor_info_dst->data_c());
243 if (tensor_data == nullptr) {
244 MS_LOG(ERROR) << "new data failed";
245 return RET_ERROR;
246 }
247 auto *origin_data = reinterpret_cast<double_t *>(tensor_info->data_c());
248 for (size_t i = 0; i < data_count; ++i) {
249 if (origin_data[i] > static_cast<double_t>(FLT_MAX) || origin_data[i] < static_cast<double_t>(-FLT_MAX)) {
250 MS_LOG(WARNING) << "float64 data " << origin_data[i] << " cannot fit into float32";
251 tensor_data[i] = origin_data[i] > 0 ? FLT_MAX : -FLT_MAX;
252 } else {
253 tensor_data[i] = static_cast<float>(origin_data[i]);
254 }
255 }
256 } else {
257 tensor_info_dst->set_data_type(tensor_info->data_type());
258 auto *tensor_data = reinterpret_cast<int8_t *>(tensor_info_dst->data_c());
259 if (tensor_data == nullptr) {
260 MS_LOG(ERROR) << "new data failed";
261 return RET_ERROR;
262 }
263 if (memcpy_s(tensor_data, tensor_info_dst->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) {
264 MS_LOG(ERROR) << "memcpy data failed.";
265 return RET_ERROR;
266 }
267 }
268 return RET_OK;
269 }
270 } // namespace
271
CheckInputs(const CNodePtr & cnode)272 bool CheckInputs(const CNodePtr &cnode) {
273 if (cnode == nullptr) {
274 MS_LOG(ERROR) << "cnode is nullptr.";
275 return false;
276 }
277 if (std::any_of(cnode->inputs().begin(), cnode->inputs().end(),
278 [](const AnfNodePtr &anf_node) { return anf_node == nullptr; })) {
279 MS_LOG(ERROR) << "input is nullptr.";
280 return false;
281 }
282 return true;
283 }
284
CastToInt(const ValuePtr & value)285 std::vector<int> CastToInt(const ValuePtr &value) {
286 if (value == nullptr) {
287 MS_LOG(WARNING) << "valueptr is nullptr.";
288 return {};
289 }
290 std::vector<int> cur_value = {};
291 if (utils::isa<ValueSequencePtr>(value)) {
292 if (!value->cast<ValueSequencePtr>()->value().empty()) {
293 auto data_type = value->cast<ValueSequencePtr>()->value().front()->type()->number_type();
294 if (data_type == kNumberTypeInt64) {
295 auto origin_value = GetValue<std::vector<int64_t>>(value);
296 std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
297 [](int64_t index) { return static_cast<int>(index); });
298 } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
299 cur_value = GetValue<std::vector<int>>(value);
300 } else {
301 MS_LOG(ERROR) << "he function only process integer data.";
302 return {};
303 }
304 }
305 } else {
306 auto data_type = value->type()->number_type();
307 switch (data_type) {
308 case kNumberTypeInt64:
309 cur_value.push_back(static_cast<int>(GetValue<int64_t>(value)));
310 break;
311 case kNumberTypeInt:
312 case kNumberTypeInt32:
313 cur_value.push_back(GetValue<int>(value));
314 break;
315 case kNumberTypeBool:
316 cur_value.push_back(GetValue<bool>(value));
317 break;
318 default:
319 MS_LOG(ERROR) << "the function only process integer data.";
320 return {};
321 }
322 }
323 return cur_value;
324 }
325
CastToVec2DInt(const ValuePtr & value)326 std::vector<std::vector<int>> CastToVec2DInt(const ValuePtr &value) {
327 if (value == nullptr) {
328 MS_LOG(WARNING) << "valueptr is nullptr.";
329 return {};
330 }
331
332 std::vector<std::vector<int>> result_value;
333 if (utils::isa<ValueSequencePtr>(value)) {
334 auto data_type = value->cast<ValueSequencePtr>()
335 ->value()
336 .front()
337 ->cast<ValueSequencePtr>()
338 ->value()
339 .front()
340 ->type()
341 ->number_type();
342 if (data_type == kNumberTypeInt64) {
343 auto origin_value = GetValue<std::vector<std::vector<int64_t>>>(value);
344 for (auto &i : origin_value) {
345 std::vector<int> cur_value;
346 std::transform(i.begin(), i.end(), std::back_inserter(cur_value),
347 [](int64_t j) { return static_cast<int>(j); });
348 result_value.push_back(cur_value);
349 }
350 } else if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
351 result_value = GetValue<std::vector<std::vector<int>>>(value);
352 } else {
353 MS_LOG(ERROR) << "he function only process integer data.";
354 return result_value;
355 }
356 }
357 return result_value;
358 }
359
CastToFloat(const ValuePtr & value)360 std::vector<float> CastToFloat(const ValuePtr &value) {
361 if (value == nullptr) {
362 MS_LOG(WARNING) << "valueptr is nullptr.";
363 return {};
364 }
365 std::vector<float> cur_value = {};
366 if (utils::isa<ValueSequencePtr>(value)) {
367 if (!value->cast<ValueSequencePtr>()->value().empty()) {
368 auto data_type = value->cast<ValueSequencePtr>()->value().front()->type()->number_type();
369 if (data_type == kNumberTypeFloat || data_type == kNumberTypeFloat32) {
370 cur_value = GetValue<std::vector<float>>(value);
371 } else {
372 MS_LOG(ERROR) << "the function only process float data.";
373 return {};
374 }
375 }
376 } else {
377 auto data_type = value->type()->number_type();
378 if (data_type == kNumberTypeFloat || data_type == kNumberTypeFloat32) {
379 cur_value.push_back(GetValue<float>(value));
380 } else {
381 MS_LOG(ERROR) << "the function only process float data.";
382 return {};
383 }
384 }
385 return cur_value;
386 }
387
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)388 bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
389 if (node == nullptr || primitive_type == nullptr) {
390 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
391 return false;
392 }
393 if (node->isa<CNode>()) {
394 auto cnode = node->cast<CNodePtr>();
395 return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
396 } else if (node->isa<ValueNode>()) {
397 return IsPrimitive(node, primitive_type);
398 }
399 return false;
400 }
401
GetPrimitiveType(const AnfNodePtr & node,std::string * name)402 STATUS GetPrimitiveType(const AnfNodePtr &node, std::string *name) {
403 if (name == nullptr) {
404 MS_LOG(ERROR) << "name is nulltr.";
405 return RET_ERROR;
406 }
407 if (node->isa<CNode>()) {
408 auto cnode = node->cast<CNodePtr>();
409 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
410 if (primitive == nullptr) {
411 MS_LOG(ERROR) << "primitive is nullptr. " << cnode->fullname_with_scope();
412 return RET_ERROR;
413 }
414 if (CheckPrimitiveType(node, prim::kPrimCustom)) {
415 auto custom_prim = api::MakeShared<ops::Custom>(primitive);
416 MS_CHECK_TRUE_MSG(custom_prim != nullptr, RET_ERROR, "custom op is nullptr.");
417 *name = custom_prim->get_type();
418 return RET_OK;
419 } else {
420 *name = primitive->name();
421 return RET_OK;
422 }
423 } else if (node->isa<ValueNode>()) {
424 auto fn_value = GetValueNode<PrimitivePtr>(node);
425 CHECK_NULL_RETURN(fn_value);
426 *name = fn_value->name();
427 return RET_OK;
428 }
429 MS_LOG(ERROR) << "There is no name for this node";
430 return RET_ERROR;
431 }
432
IsOpType(const BaseRef & n,const PrimitivePtr & prim)433 bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
434 if (utils::isa<AnfNodePtr>(n)) {
435 auto anf_node = utils::cast<AnfNodePtr>(n);
436 return CheckPrimitiveType(anf_node, prim);
437 }
438 return false;
439 }
440
IsRealCNodeKernel(const AnfNodePtr & node)441 bool IsRealCNodeKernel(const AnfNodePtr &node) {
442 if (node == nullptr) {
443 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
444 MS_LOG(ERROR) << "node is nullptr";
445 return false;
446 }
447 // parameter and value node is not a real cnode kernel
448 if (!node->isa<CNode>()) {
449 return false;
450 }
451 // return considered as a real node
452 if (CheckPrimitiveType(node, prim::kPrimReturn)) {
453 return true;
454 }
455 return IsRealKernel(node);
456 }
IsGraphKernel(const AnfNodePtr & node)457 bool IsGraphKernel(const AnfNodePtr &node) {
458 if (node == nullptr) {
459 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
460 return false;
461 }
462 // graph kernel should be a real cnode kernel.
463 if (!IsRealCNodeKernel(node)) {
464 return false;
465 }
466
467 auto cnode = node->cast<CNodePtr>();
468 if (cnode == nullptr) {
469 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
470 MS_LOG(ERROR) << "node is nullptr";
471 return false;
472 }
473 auto input = cnode->input(kAnfPrimitiveIndex);
474 // graph kernel should has func_graph as first input.
475 if (!IsValueNode<FuncGraph>(input)) {
476 return false;
477 }
478
479 auto func_graph = GetValueNode<FuncGraphPtr>(input);
480 if (func_graph == nullptr) {
481 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
482 return false;
483 }
484 return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
485 }
486
AddNewBiasNode(const float * bias_data,const FuncGraphPtr & func_graph,int kernel_num,TypeId type_id)487 ParameterPtr AddNewBiasNode(const float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) {
488 if (bias_data == nullptr || func_graph == nullptr) {
489 MS_LOG(ERROR) << "input parameter is nullptr.";
490 return nullptr;
491 }
492 auto bias_parameter = func_graph->add_parameter();
493 MS_ASSERT(bias_parameter != nullptr);
494 std::vector<int64_t> shape_vector = {kernel_num};
495 auto tensor_info =
496 lite::CreateTensorInfo(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t), shape_vector, type_id);
497 if (tensor_info == nullptr) {
498 MS_LOG(ERROR) << "create tensor info failed.";
499 return nullptr;
500 }
501 auto status = lite::InitParameterFromTensorInfo(bias_parameter, tensor_info);
502 if (status != RET_OK) {
503 MS_LOG(ERROR) << "init parameter from tensor info failed";
504 return nullptr;
505 }
506
507 return bias_parameter;
508 }
509
GetTensorInfo(const AnfNodePtr & node)510 tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node) {
511 MS_CHECK_TRUE_RET(node != nullptr, nullptr);
512 if (!utils::isa<ParameterPtr>(node)) {
513 if (utils::isa<ValueNodePtr>(node)) {
514 auto valueNode = node->cast<ValueNodePtr>();
515 auto value_ptr = valueNode->value();
516 MS_CHECK_TRUE_RET(value_ptr != nullptr, nullptr);
517 auto value = value_ptr->cast<tensor::TensorPtr>();
518 if (value != nullptr) {
519 return value;
520 }
521 }
522 MS_LOG(DEBUG) << "get lite param value node neither parameternode or valuenode";
523 return nullptr;
524 }
525 auto param = node->cast<ParameterPtr>();
526 MS_ASSERT(param != nullptr);
527 if (!param->has_default() || param->default_param() == nullptr) {
528 return nullptr;
529 }
530 auto tensor_info = param->default_param()->cast<tensor::TensorPtr>();
531 return tensor_info;
532 }
533
GetCNodeInputAbstract(const CNodePtr & cnode,size_t index)534 AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index) {
535 if (cnode == nullptr) {
536 MS_LOG(ERROR) << "CNodePtr is nullptr";
537 return nullptr;
538 }
539 if (!(index > 0 && index < cnode->size())) {
540 return nullptr;
541 }
542 auto input = cnode->input(index);
543 if (input == nullptr) {
544 MS_LOG(ERROR) << "CNode input is nullptr";
545 return nullptr;
546 }
547
548 AbstractBasePtr abstract = nullptr;
549 if (utils::isa<ParameterPtr>(input)) {
550 auto parameter = input->cast<ParameterPtr>();
551 abstract = parameter->abstract();
552 } else if (utils::isa<ValueNodePtr>(input)) {
553 auto value_node = input->cast<ValueNodePtr>();
554 abstract = value_node->abstract();
555 } else if (utils::isa<CNodePtr>(input)) {
556 auto input_cnode = input->cast<CNodePtr>();
557 if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
558 MS_ASSERT(input_cnode->size() == kTupleGetItemInputSize);
559 auto get_item_input_cnode = input_cnode->input(1);
560 MS_ASSERT(get_item_input_cnode != nullptr);
561 auto idx = GetTupleGetItemOutIndex(input_cnode);
562 if (!utils::isa<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
563 MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
564 return nullptr;
565 }
566 auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
567 auto abstract_list = abstract_tuple->elements();
568 if (abstract_list.size() <= idx) {
569 MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
570 return nullptr;
571 }
572 abstract = abstract_list[idx];
573 } else {
574 abstract = input_cnode->abstract();
575 }
576 } else {
577 MS_LOG(ERROR) << "unsupported input node type";
578 return nullptr;
579 }
580 return abstract;
581 }
582
IsParamNode(const BaseRef & n)583 bool IsParamNode(const BaseRef &n) {
584 if (!utils::isa<ParameterPtr>(n)) {
585 return false;
586 }
587 auto parameter = utils::cast<ParameterPtr>(n);
588 if (!parameter->has_default() || parameter->default_param() == nullptr) {
589 return false;
590 }
591 auto tensor = parameter->default_param()->cast<tensor::TensorPtr>();
592 if (tensor == nullptr) {
593 return false;
594 }
595 return tensor->data_c() != nullptr;
596 }
597
GetTensorInfoFromAbstract(tensor::TensorPtr * const tensor_info,const CNodePtr & cnode,size_t index)598 STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *const tensor_info, const CNodePtr &cnode, size_t index) {
599 CHECK_NULL_RETURN(tensor_info);
600 CHECK_NULL_RETURN(cnode);
601 AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index);
602 if (abstract == nullptr) {
603 MS_LOG(WARNING) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr, infershape is delayed.";
604 return RET_ERROR;
605 }
606 if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
607 MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor";
608 return RET_ERROR;
609 }
610 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
611 if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape
612 MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
613 return RET_ERROR;
614 }
615 *tensor_info = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack());
616 if (*tensor_info == nullptr) {
617 MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr";
618 return RET_ERROR;
619 }
620 return RET_OK;
621 }
622
IsParamOrValueNodeWithData(const BaseRef & n)623 bool IsParamOrValueNodeWithData(const BaseRef &n) {
624 if (utils::isa<ValueNode>(n)) {
625 auto value_node = utils::cast<ValueNodePtr>(n);
626 auto value = value_node->value();
627 if (value == nullptr) {
628 return false;
629 }
630 if (value->isa<tensor::Tensor>()) {
631 auto tensor = value->cast<tensor::TensorPtr>();
632 return tensor != nullptr && tensor->data_c() != nullptr;
633 } else if (value->isa<ValueSequence>()) {
634 auto sequence_ptr = value->cast<ValueSequencePtr>();
635 return sequence_ptr != nullptr && !sequence_ptr->value().empty();
636 } else {
637 return false;
638 }
639 }
640 if (utils::isa<ParameterPtr>(n)) {
641 return IsParamNode(n);
642 }
643 return false;
644 }
645
IsParallelSplitConvNode(const BaseRef & n)646 bool IsParallelSplitConvNode(const BaseRef &n) {
647 if (utils::isa<AnfNodePtr>(n)) {
648 auto anf_node = utils::cast<AnfNodePtr>(n);
649 PrimitivePtr prim = nullptr;
650 if (utils::isa<CNodePtr>(anf_node)) {
651 prim = GetValueNode<PrimitivePtr>(anf_node->cast<CNodePtr>()->input(kAnfPrimitiveIndex));
652 }
653 if (utils::isa<ValueNodePtr>(anf_node)) {
654 prim = GetValueNode<PrimitivePtr>(anf_node);
655 }
656 if (prim == nullptr) {
657 return false;
658 }
659 int device_type =
660 prim->GetAttr(ops::kDeviceType) != nullptr ? GetValue<int32_t>(prim->GetAttr(ops::kDeviceType)) : kDeviceTypeNone;
661 if (device_type != kDeviceTypeNone) {
662 return false;
663 }
664 return CheckPrimitiveType(anf_node, prim::kPrimConv2DFusion) || CheckPrimitiveType(anf_node, prim::kPrimConv2D);
665 }
666 return false;
667 }
668
IsConvNode(const BaseRef & n)669 bool IsConvNode(const BaseRef &n) {
670 if (utils::isa<AnfNodePtr>(n)) {
671 auto anf_node = utils::cast<AnfNodePtr>(n);
672 PrimitivePtr prim = nullptr;
673 if (utils::isa<CNodePtr>(anf_node)) {
674 prim = GetValueNode<PrimitivePtr>(anf_node->cast<CNodePtr>()->input(kAnfPrimitiveIndex));
675 }
676 if (utils::isa<ValueNodePtr>(anf_node)) {
677 prim = GetValueNode<PrimitivePtr>(anf_node);
678 }
679 if (prim == nullptr) {
680 return false;
681 }
682
683 if (prim->GetAttr(ops::kActivationType) != nullptr &&
684 GetValue<int64_t>(prim->GetAttr(ops::kActivationType)) != NO_ACTIVATION) {
685 return false;
686 }
687
688 bool is_depth_wise =
689 prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
690 return CheckPrimitiveType(anf_node, prim::kPrimConv2DFusion) ||
691 (CheckPrimitiveType(anf_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise);
692 }
693 return false;
694 }
695
CheckIsAllInputsParam(const AnfNodePtr & node)696 bool CheckIsAllInputsParam(const AnfNodePtr &node) {
697 if (node == nullptr) {
698 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
699 MS_LOG(ERROR) << "node is nullptr";
700 return false;
701 }
702 if (utils::isa<CNode>(node)) {
703 auto cnode = node->cast<CNodePtr>();
704 for (size_t i = 1; i < cnode->size(); i++) {
705 if (!utils::isa<Parameter>(cnode->input(i)) && !utils::isa<ValueNodePtr>(cnode->input(i))) {
706 return false;
707 }
708 }
709 return true;
710 }
711 return false;
712 }
713
GetOutputTensorNum(const AnfNodePtr & node)714 size_t GetOutputTensorNum(const AnfNodePtr &node) {
715 if (node == nullptr) {
716 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
717 MS_LOG(ERROR) << "node is nullptr";
718 return 0;
719 }
720 auto type = node->Type();
721 if (type == nullptr) {
722 return 1;
723 }
724 if (type->isa<Tuple>()) {
725 auto tuple_type = type->cast<TuplePtr>();
726 if (tuple_type == nullptr) {
727 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
728 MS_LOG(ERROR) << "typle_type is nullptr";
729 return 0;
730 }
731 return tuple_type->size();
732 } else if (type->isa<TensorType>() || type->isa<Number>()) {
733 return 1;
734 } else if (type->isa<TypeNone>()) {
735 return 0;
736 } else {
737 return 1;
738 }
739 }
740
IsMultiOutputTensors(const FuncGraphPtr & graph,const AnfNodePtr & node)741 bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
742 if (graph == nullptr || node == nullptr) {
743 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
744 return false;
745 }
746 auto output_node_list = Helper::GetRealNodeUsedList(graph, node);
747 if (output_node_list == nullptr) {
748 MS_LOG(ERROR) << "output node list is nullptr";
749 return false;
750 }
751 if (output_node_list->size() != 1) {
752 MS_LOG(DEBUG) << "fusion node has multi output nodes";
753 return true;
754 }
755 return false;
756 }
757
GetTupleGetItemRealInput(const CNodePtr & tuple_get_item)758 AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
759 if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) {
760 MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
761 return nullptr;
762 }
763 return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
764 }
765
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)766 size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
767 if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) {
768 MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
769 return -1;
770 }
771 auto output_index_value_node = tuple_get_item->input(kInputIndexTwo);
772 if (output_index_value_node == nullptr) {
773 MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
774 return -1;
775 }
776 auto value_node = output_index_value_node->cast<ValueNodePtr>();
777 if (value_node == nullptr) {
778 MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
779 return -1;
780 }
781 auto indexes = CastToInt(value_node->value());
782 if (indexes.empty()) {
783 MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
784 return -1;
785 }
786 return indexes.front();
787 }
788
GetListGetItemOutIndex(const CNodePtr & list_get_item)789 size_t GetListGetItemOutIndex(const CNodePtr &list_get_item) {
790 if (list_get_item == nullptr || list_get_item->size() != kInputSizeThree) {
791 MS_LOG(ERROR) << "The node list_get_item is invalid.";
792 return SIZE_MAX;
793 }
794 auto output_index_value_node = list_get_item->input(kInputIndexTwo);
795 if (output_index_value_node == nullptr) {
796 MS_LOG(ERROR) << "The node list_get_item is invalid.";
797 return SIZE_MAX;
798 }
799 auto value_node = output_index_value_node->cast<ValueNodePtr>();
800 if (value_node == nullptr) {
801 MS_LOG(ERROR) << "The node list_get_item is invalid.";
802 return SIZE_MAX;
803 }
804 auto indexes = CastToInt(value_node->value());
805 if (indexes.empty()) {
806 MS_LOG(ERROR) << "The node list_get_item is invalid.";
807 return SIZE_MAX;
808 }
809 return indexes.front();
810 }
811
TransFilterFormat(const tensor::TensorPtr & tensor,schema::Format src_format,schema::Format dst_format)812 STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
813 MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
814 std::unordered_map<TypeId, std::function<STATUS(const tensor::TensorPtr &, schema::Format, schema::Format)>>
815 trans_func = {{kNumberTypeFloat32, DoTransposeData<float>},
816 {kNumberTypeUInt8, DoTransposeData<uint8_t>},
817 {kNumberTypeInt8, DoTransposeData<int8_t>},
818 {kNumberTypeFloat16, DoTransposeData<float16>}};
819 auto data_type = tensor->data_type();
820 auto iter = trans_func.find(data_type);
821 if (iter == trans_func.end()) {
822 MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
823 return RET_ERROR;
824 }
825 return iter->second(tensor, src_format, dst_format);
826 }
827
BuildParameterNode(const FuncGraphPtr & func_graph,const tensor::TensorPtr & tensor_info,const std::string & node_name,bool keep_origin_dtype)828 ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::TensorPtr &tensor_info,
829 const std::string &node_name, bool keep_origin_dtype) {
830 if (func_graph == nullptr || tensor_info == nullptr) {
831 MS_LOG(ERROR) << "input parameter is nullptr.";
832 return nullptr;
833 }
834 auto param_node = func_graph->add_parameter();
835 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
836 auto shape = tensor_info->shape();
837 std::vector<int64_t> shape_vector;
838 std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
839 [](const int &val) { return static_cast<int64_t>(val); });
840 auto data_type = tensor_info->data_type();
841 if (tensor_info->data_type() == kNumberTypeFloat64 && !keep_origin_dtype) {
842 data_type = kNumberTypeFloat32;
843 }
844 if (tensor_info->data_type() == kNumberTypeInt64) {
845 data_type = kNumberTypeInt32;
846 }
847 param_node->set_name(node_name);
848 param_node->debug_info()->set_name(node_name);
849 auto tensor_info_new = std::make_shared<tensor::Tensor>(data_type, shape_vector);
850 if (tensor_info_new == nullptr) {
851 MS_LOG(ERROR) << "new tensor::Tensor failed.";
852 return nullptr;
853 }
854 int data_count = 1;
855 for (const auto &dat : shape) {
856 if (INT_MUL_OVERFLOW(data_count, static_cast<int>(dat))) {
857 MS_LOG(ERROR) << "Int mul overflow.";
858 return nullptr;
859 }
860 data_count *= static_cast<int>(dat);
861 }
862 if (data_count < 0) {
863 MS_LOG(ERROR) << "Invalid shape.";
864 return nullptr;
865 }
866 if (tensor_info->Size() == 0) {
867 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
868 if (status != RET_OK) {
869 MS_LOG(ERROR) << "init parameter from tensor info failed";
870 return nullptr;
871 }
872 return param_node;
873 }
874
875 if (CopyTensorDataFromTensorInfo(tensor_info, tensor_info_new, static_cast<size_t>(data_count), keep_origin_dtype) !=
876 RET_OK) {
877 MS_LOG(ERROR) << "copy tensor data failed";
878 return nullptr;
879 }
880
881 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
882 if (status != RET_OK) {
883 MS_LOG(ERROR) << "init parameter from tensor info failed";
884 return nullptr;
885 }
886 param_node->set_default_param(tensor_info_new);
887 return param_node;
888 }
889
BuildIntValueParameterNode(const FuncGraphPtr & func_graph,const int32_t & data,const std::string & node_name,bool empty_shape)890 ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data,
891 const std::string &node_name, bool empty_shape) {
892 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
893 auto param_node = func_graph->add_parameter();
894 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
895 param_node->set_name(node_name);
896 ShapeVector shape = empty_shape ? std::vector<int64_t>{} : std::vector<int64_t>{1};
897 auto tensor_info = lite::CreateTensorInfo(&data, sizeof(int32_t), shape, kNumberTypeInt32);
898 if (tensor_info == nullptr) {
899 MS_LOG(ERROR) << "Create tensor info failed";
900 return nullptr;
901 }
902
903 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
904 if (status != RET_OK) {
905 MS_LOG(ERROR) << "init parameter from tensor info failed";
906 return nullptr;
907 }
908 return param_node;
909 }
910
BuildIntVecValueNode(const FuncGraphPtr & func_graph,const std::vector<int32_t> & data)911 ValueNodePtr BuildIntVecValueNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data) {
912 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
913 auto value = MakeValue(data);
914 MS_CHECK_TRUE_RET(value != nullptr, nullptr);
915 auto value_node = std::make_shared<ValueNode>(value);
916 value_node->set_abstract(value->ToAbstract());
917 MS_EXCEPTION_IF_NULL(value_node);
918 func_graph->AddValueNode(value_node);
919 return value_node;
920 }
921
BuildIntVecParameterNode(const FuncGraphPtr & func_graph,const std::vector<int32_t> & data,const std::string & node_name)922 ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
923 const std::string &node_name) {
924 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
925 auto param_node = func_graph->add_parameter();
926 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
927 param_node->set_name(node_name);
928
929 std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
930 auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(int32_t), shape_vector, kNumberTypeInt32);
931 if (tensor_info == nullptr) {
932 MS_LOG(ERROR) << "Create tensor info failed";
933 return nullptr;
934 }
935
936 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
937 if (status != RET_OK) {
938 MS_LOG(ERROR) << "init parameter from tensor info failed";
939 return nullptr;
940 }
941
942 return param_node;
943 }
944
BuildInt64VecParameterNode(const FuncGraphPtr & func_graph,const std::vector<int64_t> & data,const std::string & node_name)945 ParameterPtr BuildInt64VecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &data,
946 const std::string &node_name) {
947 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
948 auto param_node = func_graph->add_parameter();
949 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
950 param_node->set_name(node_name);
951
952 std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
953 auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(int64_t), shape_vector, kNumberTypeInt64);
954 if (tensor_info == nullptr) {
955 MS_LOG(ERROR) << "Create tensor info failed!";
956 return nullptr;
957 }
958
959 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
960 if (status != RET_OK) {
961 MS_LOG(ERROR) << "init parameter from tensor info failed!";
962 return nullptr;
963 }
964
965 return param_node;
966 }
967
BuildIntVec2DParameterNode(const FuncGraphPtr & func_graph,const std::vector<std::vector<int32_t>> & data,const std::string & node_name)968 ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<int32_t>> &data,
969 const std::string &node_name) {
970 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
971 auto param_node = func_graph->add_parameter();
972 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
973 param_node->set_name(node_name);
974
975 MS_CHECK_TRUE_RET(!data.empty(), nullptr);
976 std::vector<int64_t> shape_vector;
977 shape_vector.push_back(data.size());
978 shape_vector.push_back(data.at(0).size());
979
980 std::vector<int32_t> data_1d;
981 for (auto pair : data) {
982 data_1d.insert(data_1d.end(), pair.begin(), pair.end());
983 }
984
985 auto size = data_1d.size() * sizeof(int32_t);
986 auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeInt32);
987 if (tensor_info == nullptr) {
988 MS_LOG(ERROR) << "Create tensor info failed";
989 return nullptr;
990 }
991 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
992 if (status != RET_OK) {
993 MS_LOG(ERROR) << "init parameter from tensor info failed";
994 return nullptr;
995 }
996 return param_node;
997 }
998
BuildFloatValueParameterNode(const FuncGraphPtr & func_graph,const float & data,const std::string & node_name,bool empty_shape)999 ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
1000 const std::string &node_name, bool empty_shape) {
1001 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1002 auto param_node = func_graph->add_parameter();
1003 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1004 param_node->set_name(node_name);
1005
1006 ShapeVector shape = empty_shape ? std::vector<int64_t>{} : std::vector<int64_t>{1};
1007 auto tensor_info = lite::CreateTensorInfo(&data, sizeof(float), shape, kNumberTypeFloat32);
1008 if (tensor_info == nullptr) {
1009 MS_LOG(ERROR) << "Create tensor info failed";
1010 return nullptr;
1011 }
1012 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1013 if (status != RET_OK) {
1014 MS_LOG(ERROR) << "init parameter from tensor info failed";
1015 return nullptr;
1016 }
1017 return param_node;
1018 }
1019
BuildInt64ValueParameterNode(const FuncGraphPtr & func_graph,const int64_t & data,const std::string & node_name,bool empty_shape)1020 ParameterPtr BuildInt64ValueParameterNode(const FuncGraphPtr &func_graph, const int64_t &data,
1021 const std::string &node_name, bool empty_shape) {
1022 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1023 auto param_node = func_graph->add_parameter();
1024 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1025 param_node->set_name(node_name);
1026 ShapeVector shape = empty_shape ? std::vector<int64_t>{} : std::vector<int64_t>{1};
1027 auto tensor_info = lite::CreateTensorInfo(&data, sizeof(int64_t), shape, kNumberTypeInt64);
1028 if (tensor_info == nullptr) {
1029 MS_LOG(ERROR) << "Create tensor info failed!";
1030 return nullptr;
1031 }
1032 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1033 if (status != RET_OK) {
1034 MS_LOG(ERROR) << "init parameter from tensor info failed!";
1035 return nullptr;
1036 }
1037 return param_node;
1038 }
1039
BuildFloat16ValueParameterNode(const FuncGraphPtr & func_graph,const float & data,const std::string & node_name,bool empty_shape)1040 ParameterPtr BuildFloat16ValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
1041 const std::string &node_name, bool empty_shape) {
1042 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1043 auto param_node = func_graph->add_parameter();
1044 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1045 param_node->set_name(node_name);
1046
1047 ShapeVector shape = empty_shape ? std::vector<int64_t>{} : std::vector<int64_t>{1};
1048 auto tensor_info = lite::CreateTensorInfo(&data, sizeof(float16), shape, kNumberTypeFloat16);
1049 if (tensor_info == nullptr) {
1050 MS_LOG(ERROR) << "Create tensor info failed";
1051 return nullptr;
1052 }
1053 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1054 if (status != RET_OK) {
1055 MS_LOG(ERROR) << "init parameter from tensor info failed";
1056 return nullptr;
1057 }
1058 return param_node;
1059 }
1060
BuildFloat16VecParameterNode(const FuncGraphPtr & func_graph,const std::vector<float16> & data,const std::string & node_name)1061 ParameterPtr BuildFloat16VecParameterNode(const FuncGraphPtr &func_graph, const std::vector<float16> &data,
1062 const std::string &node_name) {
1063 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1064 auto param_node = func_graph->add_parameter();
1065 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1066 param_node->set_name(node_name);
1067
1068 std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
1069 auto tensor_info =
1070 lite::CreateTensorInfo(data.data(), data.size() * sizeof(float16), shape_vector, kNumberTypeFloat16);
1071 if (tensor_info == nullptr) {
1072 MS_LOG(ERROR) << "Create tensor info failed";
1073 return nullptr;
1074 }
1075
1076 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1077 if (status != RET_OK) {
1078 MS_LOG(ERROR) << "init parameter from tensor info failed";
1079 return nullptr;
1080 }
1081
1082 return param_node;
1083 }
1084
BuildFloatVecParameterNode(const FuncGraphPtr & func_graph,const std::vector<float> & data,const std::string & node_name)1085 ParameterPtr BuildFloatVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<float> &data,
1086 const std::string &node_name) {
1087 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1088 auto param_node = func_graph->add_parameter();
1089 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1090 param_node->set_name(node_name);
1091
1092 std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
1093 auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(float), shape_vector, kNumberTypeFloat);
1094 if (tensor_info == nullptr) {
1095 MS_LOG(ERROR) << "Create tensor info failed";
1096 return nullptr;
1097 }
1098
1099 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1100 if (status != RET_OK) {
1101 MS_LOG(ERROR) << "init parameter from tensor info failed";
1102 return nullptr;
1103 }
1104
1105 return param_node;
1106 }
1107
BuildFloatVec2DParameterNode(const FuncGraphPtr & func_graph,const std::vector<std::vector<float>> & data,const std::string & node_name)1108 ParameterPtr BuildFloatVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<float>> &data,
1109 const std::string &node_name) {
1110 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1111 auto param_node = func_graph->add_parameter();
1112 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
1113 param_node->set_name(node_name);
1114
1115 MS_CHECK_TRUE_RET(!data.empty(), nullptr);
1116 std::vector<int64_t> shape_vector;
1117 shape_vector.push_back(data.size());
1118 shape_vector.push_back(data.at(0).size());
1119
1120 std::vector<float> data_1d;
1121 for (auto pair : data) {
1122 data_1d.insert(data_1d.end(), pair.begin(), pair.end());
1123 }
1124
1125 auto size = data_1d.size() * sizeof(float);
1126 auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeFloat32);
1127 if (tensor_info == nullptr) {
1128 MS_LOG(ERROR) << "Create tensor info failed";
1129 return nullptr;
1130 }
1131 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
1132 if (status != RET_OK) {
1133 MS_LOG(ERROR) << "init parameter from tensor info failed";
1134 return nullptr;
1135 }
1136 return param_node;
1137 }
1138
GenTransposeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & perm,const std::string & cnode_name)1139 CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &perm,
1140 const std::string &cnode_name) {
1141 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1142 MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
1143 auto perm_node = BuildIntVecParameterNode(func_graph, perm, cnode_name + "_perm");
1144 MS_ASSERT(perm_node != nullptr);
1145 ops::Transpose transpose_node;
1146 auto trans_prim = transpose_node.GetPrim();
1147 MS_CHECK_TRUE_RET(trans_prim != nullptr, nullptr);
1148 auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node});
1149 MS_ASSERT(cnode != nullptr);
1150 auto manager = Manage(func_graph);
1151 MS_ASSERT(manager != nullptr);
1152 manager->SetEdge(cnode, 1, input_node);
1153 manager->SetEdge(cnode, kInputIndexTwo, perm_node);
1154 cnode->set_fullname_with_scope(cnode_name);
1155 auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeTwo, 1);
1156 MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1157 trans_prim->AddAttr("quant_params", quant_params_holder);
1158 auto input_abstract = input_node->abstract();
1159 if (input_abstract != nullptr) {
1160 auto abstract = input_abstract->Clone();
1161 MS_CHECK_TRUE_RET(abstract != nullptr, nullptr);
1162 FormatTransNodeType perm_type = perm == kNC2NH ? kNCHW2NHWC : (perm == kNH2NC ? kNHWC2NCHW : kNONE);
1163 if (ConvertAbstractFormatShape(abstract, perm_type) != RET_OK) {
1164 MS_LOG(WARNING) << "Convert abstract failed for node: " << cnode->fullname_with_scope();
1165 return cnode;
1166 }
1167 cnode->set_abstract(abstract);
1168 }
1169 return cnode;
1170 }
1171
GenCastNode(const FuncGraphPtr & graph,const AnfNodePtr & input_node,const std::string & cnode_name,const TypeId dst_type,const AbstractBasePtr & abstract)1172 CNodePtr GenCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input_node, const std::string &cnode_name,
1173 const TypeId dst_type, const AbstractBasePtr &abstract) {
1174 MS_CHECK_TRUE_RET(graph != nullptr, nullptr);
1175 MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
1176 ops::Cast cast_node;
1177 auto new_cast_c = cast_node.GetPrim();
1178 if (new_cast_c == nullptr) {
1179 MS_LOG(ERROR) << "new_cast_c is nullptr";
1180 return nullptr;
1181 }
1182 TypePtr dst_type_ptr = TypeIdToType(dst_type);
1183 if (dst_type_ptr == nullptr) {
1184 MS_LOG(ERROR) << "dst_type_ptr is nullptr";
1185 return nullptr;
1186 }
1187 new_cast_c->AddAttr(ops::kDstType, dst_type_ptr);
1188 ValueNodePtr value_node = NewValueNode(new_cast_c);
1189 if (value_node == nullptr) {
1190 MS_LOG(ERROR) << "NewValueNode Failed";
1191 return nullptr;
1192 }
1193
1194 auto dtype_value = MakeValue(dst_type_ptr);
1195 auto dtype_value_node = NewValueNode(dtype_value);
1196 dtype_value_node->set_abstract(dtype_value->ToAbstract());
1197 graph->AddValueNode(dtype_value_node);
1198
1199 auto cast_cnode = graph->NewCNode({value_node});
1200 if (cast_cnode == nullptr) {
1201 MS_LOG(ERROR) << "new_cnode is nullptr";
1202 return nullptr;
1203 }
1204 cast_cnode->set_fullname_with_scope(cnode_name);
1205 cast_cnode->set_abstract(abstract);
1206 auto manager = Manage(graph);
1207 (void)manager->Replace(input_node, cast_cnode);
1208 manager->AddEdge(cast_cnode, input_node);
1209 manager->AddEdge(cast_cnode, dtype_value_node);
1210 return cast_cnode;
1211 }
1212
GenReshapeNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & shape,const std::string & cnode_name)1213 CNodePtr GenReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &shape,
1214 const std::string &cnode_name) {
1215 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
1216 MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
1217 auto reshape_prim = std::make_shared<ops::Reshape>();
1218 if (reshape_prim == nullptr) {
1219 MS_LOG(ERROR) << "create reshape failed.";
1220 return nullptr;
1221 }
1222 auto prim_c = reshape_prim->GetPrim();
1223 prim_c->set_attr("shape", MakeValue(shape));
1224 ValueNodePtr value_node = NewValueNode(prim_c);
1225 MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "Create value_node return nullptr");
1226 auto new_shape_node = opt::BuildIntVecParameterNode(func_graph, shape, cnode_name + "_shape");
1227 MS_CHECK_TRUE_MSG(new_shape_node != nullptr, nullptr, "Create shape parameter return nullptr");
1228 std::vector<AnfNodePtr> op_inputs = {value_node, input_node, new_shape_node};
1229 auto reshape_cnode = func_graph->NewCNode(op_inputs);
1230 MS_CHECK_TRUE_MSG(reshape_cnode != nullptr, nullptr, "Create cnode return nullptr");
1231 reshape_cnode->set_fullname_with_scope(cnode_name);
1232 return reshape_cnode;
1233 }
1234
GenGatherNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const std::vector<int> & indices,const std::string & cnode_name,const std::vector<int> & axis)1235 CNodePtr GenGatherNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &indices,
1236 const std::string &cnode_name, const std::vector<int> &axis) {
1237 if (func_graph == nullptr || input_node == nullptr) {
1238 MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1239 return nullptr;
1240 }
1241 auto indices_node = BuildIntVecParameterNode(func_graph, indices, cnode_name + "_indices");
1242 if (indices_node == nullptr) {
1243 MS_LOG(ERROR) << "make indices node failed.";
1244 return nullptr;
1245 }
1246 auto axis_node = BuildIntVecParameterNode(func_graph, axis, cnode_name + "_axis");
1247 if (axis_node == nullptr) {
1248 MS_LOG(ERROR) << "make indices node failed.";
1249 return nullptr;
1250 }
1251 ops::Gather gather_node;
1252 auto gather_prim = gather_node.GetPrim();
1253 MS_CHECK_TRUE_RET(gather_prim != nullptr, nullptr);
1254 auto cnode = func_graph->NewCNode(gather_prim, {input_node, indices_node, axis_node});
1255 MS_ASSERT(cnode != nullptr);
1256 auto manager = Manage(func_graph);
1257 MS_ASSERT(manager != nullptr);
1258 manager->SetEdge(cnode, 1, input_node);
1259 manager->SetEdge(cnode, kInputIndexTwo, indices_node);
1260 manager->SetEdge(cnode, kInputIndexThree, axis_node);
1261 cnode->set_fullname_with_scope(cnode_name);
1262 auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeThree, 1);
1263 MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1264 gather_prim->AddAttr("quant_params", quant_params_holder);
1265 return cnode;
1266 }
1267
GenGatherNodeDynamicIndex(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,const AnfNodePtr & indices_node,const std::string & cnode_name,const std::vector<int> & axis)1268 CNodePtr GenGatherNodeDynamicIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
1269 const AnfNodePtr &indices_node, const std::string &cnode_name,
1270 const std::vector<int> &axis) {
1271 if (func_graph == nullptr || input_node == nullptr || indices_node == nullptr) {
1272 MS_LOG(ERROR) << "Input parameter is nullptr, which is nullptr!";
1273 return nullptr;
1274 }
1275 auto axis_node = BuildIntVecParameterNode(func_graph, axis, cnode_name + "_axis");
1276 if (axis_node == nullptr) {
1277 MS_LOG(ERROR) << "Build axis node failed!";
1278 return nullptr;
1279 }
1280 ops::Gather gather_node;
1281 auto gather_prim = gather_node.GetPrim();
1282 MS_CHECK_TRUE_RET(gather_prim != nullptr, nullptr);
1283 auto cnode = func_graph->NewCNode(gather_prim, {input_node, indices_node, axis_node});
1284 MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
1285 auto manager = Manage(func_graph);
1286 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
1287 manager->SetEdge(cnode, 1, input_node);
1288 manager->SetEdge(cnode, kInputIndexTwo, indices_node);
1289 manager->SetEdge(cnode, kInputIndexThree, axis_node);
1290 cnode->set_fullname_with_scope(cnode_name);
1291 if (input_node->abstract() != nullptr) {
1292 cnode->set_abstract(input_node->abstract()->Clone());
1293 }
1294 auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSizeThree, 1);
1295 MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1296 gather_prim->AddAttr("quant_params", quant_params_holder);
1297 return cnode;
1298 }
1299
GenConcatNode(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & input_node_vec,const std::string & cnode_name,int64_t axis)1300 CNodePtr GenConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &input_node_vec,
1301 const std::string &cnode_name, int64_t axis) {
1302 if (func_graph == nullptr) {
1303 MS_LOG(ERROR) << "func_graph is nullptr, which is invalid.";
1304 return nullptr;
1305 }
1306 ops::Concat concat_node;
1307 concat_node.set_axis(axis);
1308 auto concat_prim = concat_node.GetPrim();
1309 MS_CHECK_TRUE_RET(concat_prim != nullptr, nullptr);
1310 auto cnode = func_graph->NewCNode(concat_prim, input_node_vec);
1311 MS_ASSERT(cnode != nullptr);
1312 auto manager = Manage(func_graph);
1313 MS_ASSERT(manager != nullptr);
1314 cnode->set_fullname_with_scope(cnode_name);
1315 auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(input_node_vec.size(), 1);
1316 MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
1317 concat_prim->AddAttr("quant_params", quant_params_holder);
1318 return cnode;
1319 }
1320
GenTupleGetItemNode(const FuncGraphPtr & func_graph,const CNodePtr & input,size_t index)1321 CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index) {
1322 if (func_graph == nullptr || input == nullptr) {
1323 MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1324 return nullptr;
1325 }
1326 auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
1327 MS_CHECK_TRUE_RET(tuple_get_item_prim != nullptr, nullptr);
1328 auto second_input = NewValueNode(MakeValue<int64_t>(index));
1329 MS_CHECK_TRUE_RET(second_input != nullptr, nullptr);
1330 auto tuple_get_item_prim_c = tuple_get_item_prim->GetPrim();
1331 MS_CHECK_TRUE_RET(tuple_get_item_prim_c != nullptr, nullptr);
1332 auto tuple_cnode = func_graph->NewCNode(tuple_get_item_prim_c, {input, second_input});
1333 MS_ASSERT(tuple_cnode != nullptr);
1334 tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index));
1335 return tuple_cnode;
1336 }
1337
FetchShapeFromAbstract(const abstract::AbstractBasePtr & abstract,ShapeVector * shape)1338 STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape) {
1339 if (abstract == nullptr || shape == nullptr) {
1340 MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
1341 return lite::RET_ERROR;
1342 }
1343 if (!utils::isa<abstract::AbstractTensor>(abstract)) {
1344 MS_LOG(ERROR) << "abstract of cnode is invalid.";
1345 return lite::RET_ERROR;
1346 }
1347 auto abstract_tensor = abstract->cast<abstract::AbstractTensorPtr>();
1348 if (abstract_tensor->BuildShape() == nullptr || !utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
1349 MS_LOG(ERROR) << "shape of cnode's output is invalid.";
1350 return lite::RET_ERROR;
1351 }
1352 *shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
1353 return lite::RET_OK;
1354 }
1355
IsTrainOp(const CNodePtr & cnode)1356 bool IsTrainOp(const CNodePtr &cnode) {
1357 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1358 if (prim == nullptr) {
1359 return false;
1360 }
1361 auto cnode_type = prim->name();
1362 // optimizer op
1363 if (cnode_type == "Adam" || cnode_type == "SGD" || cnode_type == "ApplyMomentum") {
1364 return true;
1365 }
1366 // loss op
1367 if (cnode_type == "SoftmaxCrossEntropyWithLogits" || cnode_type == "SparseSoftmaxCrossEntropyWithLogits" ||
1368 cnode_type == "SmoothL1Loss" || cnode_type == "SmoothL1LossGrad" ||
1369 cnode_type == "SigmoidCrossEntropyWithLogits" || cnode_type == "SigmoidCrossEntropyWithLogitsGrad") {
1370 return true;
1371 }
1372 // grad op
1373 if (cnode_type.find("Grad") != std::string::npos ||
1374 cnode->fullname_with_scope().find("Gradients") != std::string::npos) {
1375 return true;
1376 }
1377 return false;
1378 }
1379
IsMarkedTrainOp(const CNodePtr & cnode)1380 bool IsMarkedTrainOp(const CNodePtr &cnode) {
1381 if (cnode == nullptr) {
1382 return false;
1383 }
1384 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1385 MS_CHECK_TRUE_RET(prim != nullptr, false);
1386 if (prim->GetAttr("trainOp") != nullptr && GetValue<bool>(prim->GetAttr("trainOp"))) {
1387 MS_LOG(DEBUG) << "train op not fusion.";
1388 return true;
1389 }
1390 return false;
1391 }
1392
GetOutputSize(const AnfNodePtr & anf_node)1393 size_t GetOutputSize(const AnfNodePtr &anf_node) {
1394 if (anf_node == nullptr) {
1395 MS_LOG(ERROR) << "anf_node is nullptr.";
1396 return RET_ERROR;
1397 }
1398 AbstractBasePtr abstract_base;
1399 if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
1400 abstract_base = anf_node->cast<CNodePtr>()->input(1)->abstract();
1401 } else {
1402 abstract_base = anf_node->abstract();
1403 }
1404 // used for multi output e.g. split.
1405 if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
1406 auto abstract_tuple = abstract_base->cast<abstract::AbstractTuplePtr>();
1407 return abstract_tuple->elements().size();
1408 }
1409 return 1;
1410 }
1411
GetAnfNodeOutputShape(const AnfNodePtr & node,size_t output_idx)1412 ShapeVector GetAnfNodeOutputShape(const AnfNodePtr &node, size_t output_idx) {
1413 if (node == nullptr) {
1414 MS_LOG(ERROR) << "anf_node is nullptr.";
1415 return {};
1416 }
1417 auto as_value_node = node->cast<ValueNodePtr>();
1418 if (as_value_node) {
1419 auto value = as_value_node->value();
1420 auto tensor = value->cast<tensor::TensorPtr>();
1421 if (tensor) {
1422 return tensor->shape_c();
1423 }
1424 return {};
1425 }
1426 auto base_shape = node->Shape();
1427 if (base_shape == nullptr) {
1428 MS_LOG(INFO) << "Failed to get shape from node " << node->fullname_with_scope();
1429 return {};
1430 }
1431 if (base_shape->isa<abstract::Shape>()) {
1432 if (output_idx != 0) {
1433 MS_LOG(EXCEPTION) << "The node " << node->fullname_with_scope() << "is a single output node but got index ["
1434 << output_idx;
1435 }
1436 auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
1437 MS_EXCEPTION_IF_NULL(shape_ptr);
1438 return shape_ptr->shape();
1439 } else if (base_shape->isa<abstract::NoShape>()) {
1440 return ShapeVector();
1441 } else if (base_shape->isa<abstract::TupleShape>()) {
1442 auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
1443 MS_EXCEPTION_IF_NULL(tuple_shape);
1444 if (output_idx >= tuple_shape->size()) {
1445 MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
1446 << node->fullname_with_scope();
1447 }
1448 auto b_shp = (*tuple_shape)[output_idx];
1449 if (b_shp->isa<abstract::Shape>()) {
1450 auto shape_ptr = b_shp->cast<abstract::ShapePtr>();
1451 MS_EXCEPTION_IF_NULL(shape_ptr);
1452 return shape_ptr->shape();
1453 } else if (b_shp->isa<abstract::NoShape>()) {
1454 return ShapeVector();
1455 } else if (b_shp->isa<abstract::TupleShape>()) {
1456 MS_LOG(INFO) << "The output shape of node:" << node->fullname_with_scope() << " index:" << output_idx
1457 << " is a TupleShape:" << base_shape->ToString();
1458 return ShapeVector();
1459 } else {
1460 MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
1461 << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
1462 << "node :" << node->fullname_with_scope() << ".";
1463 }
1464 }
1465 return ShapeVector();
1466 }
1467
GetDataTypeFromAnfNode(const AnfNodePtr & anf_node,TypeId * type_id)1468 int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
1469 if (anf_node == nullptr || type_id == nullptr) {
1470 MS_LOG(ERROR) << "anf_node or type_id is nullptr.";
1471 return RET_ERROR;
1472 }
1473 AbstractBasePtr abstract_base;
1474 if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
1475 abstract_base = anf_node->cast<CNodePtr>()->input(1)->abstract();
1476 } else {
1477 abstract_base = anf_node->abstract();
1478 }
1479 // used for multi output e.g. split.
1480 if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
1481 auto abstract_tuple = abstract_base->cast<abstract::AbstractTuplePtr>();
1482 if (abstract_tuple->elements().empty()) {
1483 MS_LOG(ERROR) << "abstract_tuple elements is empty.";
1484 return RET_ERROR;
1485 }
1486 abstract_base = abstract_tuple->elements().front();
1487 }
1488 if (abstract_base == nullptr) {
1489 MS_LOG(INFO) << "Abstract of parameter is nullptr, " << anf_node->fullname_with_scope();
1490 *type_id = kTypeUnknown;
1491 return lite::RET_NOT_SUPPORT;
1492 }
1493 if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
1494 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
1495 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
1496 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
1497 MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
1498 *type_id = type_ptr->type_id();
1499 } else if (utils::isa<abstract::AbstractScalarPtr>(abstract_base)) {
1500 auto abstract_scalar = utils::cast<abstract::AbstractScalarPtr>(abstract_base);
1501 auto type_ptr = abstract_scalar->GetTypeTrack();
1502 MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
1503 *type_id = type_ptr->type_id();
1504 } else {
1505 MS_LOG(ERROR) << anf_node->fullname_with_scope() << " is unsupported type:" << abstract_base->type_name();
1506 return RET_ERROR;
1507 }
1508 return RET_OK;
1509 }
1510
IsQuantParameterNode(const PrimitivePtr & prim)1511 bool IsQuantParameterNode(const PrimitivePtr &prim) {
1512 MS_CHECK_TRUE_RET(prim != nullptr, false);
1513 auto quant_attr = prim->GetAttr("quant_params");
1514 if (quant_attr != nullptr) {
1515 auto quant_param_holder = quant_attr->cast<lite::QuantParamHolderPtr>();
1516 MS_CHECK_TRUE_RET(quant_param_holder != nullptr, false);
1517 auto quant_params = quant_param_holder->get_input_quant_params();
1518 bool is_quant = std::any_of(quant_params.begin(), quant_params.end(), [](std::vector<schema::QuantParamT> ¶ms) {
1519 return !params.empty() && params.front().inited;
1520 });
1521 if (is_quant) {
1522 return true;
1523 }
1524 }
1525 return false;
1526 }
1527
UpdateManager(const FuncGraphPtr & func_graph)1528 void UpdateManager(const FuncGraphPtr &func_graph) {
1529 auto manager = func_graph->manager();
1530 if (manager == nullptr) {
1531 manager = Manage(func_graph, true);
1532 } else {
1533 manager->Clear();
1534 manager->AddFuncGraph(func_graph, true);
1535 }
1536 std::set<FuncGraphPtr> all_func_graphs;
1537 mindspore::lite::GetAllFuncGraph(func_graph, &all_func_graphs);
1538 for (auto &one_func_graph : all_func_graphs) {
1539 manager->AddFuncGraph(one_func_graph);
1540 }
1541 }
1542
GetRealCertainVarInput(const CNodePtr & cnode,size_t index)1543 std::pair<CNodePtr, int> GetRealCertainVarInput(const CNodePtr &cnode, size_t index) {
1544 MS_CHECK_TRUE_MSG(cnode != nullptr, {}, "function's parameter is nullptr.");
1545 MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, {}, "required input is nullptr");
1546 auto real_input_cnode = cnode->input(index)->cast<CNodePtr>();
1547 if (real_input_cnode == nullptr) {
1548 MS_LOG(DEBUG) << "input node is not a cnode.";
1549 return {};
1550 }
1551 int item_index = 0;
1552 if (opt::CheckPrimitiveType(real_input_cnode, prim::kPrimTupleGetItem)) {
1553 auto index_node = real_input_cnode->input(opt::kInputIndexTwo);
1554 MS_CHECK_TRUE_MSG(index_node != nullptr, {}, "tuple_get_item's second input is nullptr.");
1555 MS_CHECK_TRUE_MSG(index_node->isa<ValueNode>(), {}, "tuple_get_item's second input should be valuenode.");
1556 auto index_ptr = index_node->cast<ValueNodePtr>()->value();
1557 MS_CHECK_TRUE_MSG(index_ptr != nullptr, {}, "tuple_get_item's second input val is nullptr.");
1558 auto value = CastToInt(index_ptr);
1559 MS_CHECK_TRUE_MSG(value.size() == 1, {}, "tuple_get_item's second input is invalid.");
1560 item_index = value.front();
1561 MS_CHECK_TRUE_MSG(real_input_cnode->input(1) != nullptr, {}, "tuple_get_item's first input is nullptr");
1562 real_input_cnode = real_input_cnode->input(1)->cast<CNodePtr>();
1563 MS_CHECK_TRUE_MSG(real_input_cnode != nullptr, {}, "tuple_get_item first input is not cnode.");
1564 }
1565 return {real_input_cnode, item_index};
1566 }
1567
DetermineCertainVarInputHasInferred(const CNodePtr & cnode,size_t index,bool * infer_succ)1568 int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ) {
1569 MS_CHECK_TRUE_MSG(cnode != nullptr && infer_succ != nullptr, RET_ERROR, "function's parameter is nullptr.");
1570 auto var_input_info = GetRealCertainVarInput(cnode, index);
1571 if (var_input_info.first == nullptr) {
1572 MS_LOG(ERROR) << "cannot get the real var input.";
1573 return RET_ERROR;
1574 }
1575 auto real_input_cnode = var_input_info.first;
1576 auto item_index = var_input_info.second;
1577 auto input_node_prim = GetValueNode<PrimitivePtr>((real_input_cnode->input(0)));
1578 MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "get primitive failed.");
1579 *infer_succ = false;
1580 auto value_ptr = input_node_prim->GetAttr(kInferDone);
1581 if (value_ptr != nullptr) {
1582 MS_CHECK_TRUE_MSG(value_ptr->isa<BoolImm>(), RET_ERROR, "value is not a boolean.");
1583 *infer_succ = GetValue<bool>(value_ptr);
1584 }
1585 value_ptr = input_node_prim->GetAttr(kInferFlags);
1586 if (value_ptr == nullptr) {
1587 return RET_OK;
1588 }
1589 MS_CHECK_TRUE_MSG(value_ptr->isa<ValueSequeue>(), RET_ERROR, "infer flag should be a vector.");
1590 auto value_sequence = value_ptr->cast<ValueSequeuePtr>();
1591 auto elements = value_sequence->value();
1592 MS_CHECK_TRUE_MSG(!elements.empty(), RET_ERROR, "infer_info has no content.");
1593 auto first_element = elements.front();
1594 MS_CHECK_TRUE_MSG(first_element != nullptr, RET_ERROR, "element is a nullptr.");
1595 MS_CHECK_TRUE_MSG(first_element->isa<BoolImm>(), RET_ERROR, "each element is not a boolean.");
1596 auto infer_infos = GetValue<std::vector<bool>>(value_ptr);
1597 MS_CHECK_TRUE_MSG(item_index >= 0 && static_cast<size_t>(item_index) < infer_infos.size(), RET_ERROR,
1598 "item index is out of range.");
1599 *infer_succ = infer_infos[item_index];
1600 return RET_OK;
1601 }
CheckAndGetCnodeIndex(const CNodePtr & cnode,size_t * index,const PrimitivePtr & primitive_type)1602 bool CheckAndGetCnodeIndex(const CNodePtr &cnode, size_t *index, const PrimitivePtr &primitive_type) {
1603 MS_CHECK_TRUE_RET(cnode != nullptr, false);
1604 MS_CHECK_TRUE_RET(index != nullptr, false);
1605 if (cnode->size() != kInputSizeThree) {
1606 return false;
1607 }
1608 size_t dst_index = 0;
1609 for (size_t i = 1; i < cnode->size(); ++i) {
1610 if (CheckPrimitiveType(cnode->input(i), primitive_type)) {
1611 dst_index = i;
1612 break;
1613 }
1614 }
1615 if (dst_index == 0) {
1616 return false;
1617 }
1618 *index = dst_index;
1619 return true;
1620 }
1621
PrintFuncGraph(const FuncGraphPtr & func_graph,const std::string & output_file)1622 void PrintFuncGraph(const FuncGraphPtr &func_graph, const std::string &output_file) {
1623 if (func_graph == nullptr) {
1624 MS_LOG(WARNING) << "input func_graph is nullptr";
1625 return;
1626 }
1627 static int index = 0;
1628 auto real_file = std::to_string(index++) + "_" + output_file + ".txt";
1629 std::ofstream fp(real_file);
1630 if (!fp.is_open()) {
1631 MS_LOG(ERROR) << "Failed to create file " << real_file;
1632 return;
1633 }
1634 auto nodes = func_graph->TopoSort(func_graph->get_return());
1635 auto type_name = [](const AnfNodePtr &anf_node) -> std::string {
1636 if (anf_node->cast<CNodePtr>()) {
1637 return GetCNodeFuncName(anf_node->cast<CNodePtr>());
1638 } else if (anf_node->cast<ParameterPtr>()) {
1639 if (anf_node->cast<ParameterPtr>()->has_default()) {
1640 return "Parameter_Constant";
1641 } else {
1642 return "Parameter_Variable";
1643 }
1644 } else if (anf_node->cast<ValueNodePtr>()) {
1645 return "ValueNode";
1646 }
1647 return anf_node->ToString();
1648 };
1649 for (auto &node : nodes) {
1650 if (IsValueNode<Primitive>(node)) {
1651 continue;
1652 }
1653 auto cnode = node->cast<CNodePtr>();
1654 if (cnode == nullptr) {
1655 fp << node->fullname_with_scope() << ", type: " << type_name(node)
1656 << ", shape: " << GetAnfNodeOutputShape(node, 0) << std::endl;
1657 fp << std::endl;
1658 continue;
1659 }
1660 TypeId type_id = kTypeUnknown;
1661 GetDataTypeFromAnfNode(node, &type_id);
1662 fp << node->fullname_with_scope() << ", type: " << type_name(node) << ", shape: " << GetAnfNodeOutputShape(node, 0)
1663 << ", data type: " << static_cast<int>(type_id) << std::endl;
1664 auto &inputs = cnode->inputs();
1665 for (auto &input : inputs) {
1666 if (IsValueNode<Primitive>(input)) {
1667 continue;
1668 }
1669 type_id = kTypeUnknown;
1670 GetDataTypeFromAnfNode(node, &type_id);
1671 fp << "---input " << input->fullname_with_scope() << ", type: " << type_name(input)
1672 << ", shape: " << GetAnfNodeOutputShape(input, 0) << ", data type: " << static_cast<int>(type_id) << std::endl;
1673 }
1674 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1675 if (prim != nullptr) {
1676 for (auto &attr : prim->attrs()) {
1677 if (attr.second) {
1678 fp << "---attr " << attr.first << ": " << attr.second->ToString() << std::endl;
1679 } else {
1680 fp << "---attr " << attr.first << ": value nullptr" << std::endl;
1681 }
1682 }
1683 }
1684 fp << std::endl;
1685 }
1686 }
1687
1688 #if !defined(_WIN32) && !defined(_WIN64)
GetNodeInputs(const AnfNodePtr & anf_node)1689 std::vector<KernelWithIndex> GetNodeInputs(const AnfNodePtr &anf_node) {
1690 if (!anf_node) {
1691 return {};
1692 }
1693 if (!anf_node->isa<CNode>()) {
1694 return {{anf_node, 0}};
1695 }
1696 auto cnode = anf_node->cast<CNodePtr>();
1697 std::vector<common::KernelWithIndex> inputs;
1698 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
1699 for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
1700 const auto &pre_node_output = common::AnfAlgo::GetPrevNodeOutput(cnode, input_idx);
1701 auto pre_node = pre_node_output.first;
1702 if (opt::CheckPrimitiveType(pre_node, prim::kPrimMakeTuple) ||
1703 opt::CheckPrimitiveType(pre_node, prim::kPrimMakeTupleV2)) {
1704 auto tuple_inputs = GetNodeInputs(pre_node);
1705 std::copy(tuple_inputs.begin(), tuple_inputs.end(), std::back_inserter(inputs));
1706 } else {
1707 inputs.push_back(pre_node_output);
1708 }
1709 }
1710 return inputs;
1711 }
1712 #else
GetNodeInputs(const AnfNodePtr & anf_node)1713 std::vector<KernelWithIndex> GetNodeInputs(const AnfNodePtr &anf_node) { return {}; }
1714 #endif
1715
IsReduceModeMeetOutEqualIn(const PrimitivePtr & prim)1716 bool IsReduceModeMeetOutEqualIn(const PrimitivePtr &prim) {
1717 if (prim == nullptr) {
1718 return false;
1719 }
1720 if (prim->GetAttr(ops::kMode) == nullptr) {
1721 return false;
1722 }
1723 auto mode = GetValue<int64_t>(prim->GetAttr(ops::kMode));
1724 std::set<int64_t> meet_mode = {Reduce_Mean, Reduce_Max, Reduce_Min, Reduce_Prod, Reduce_Sum};
1725 return meet_mode.find(mode) != meet_mode.end();
1726 }
1727
AdjustInputToCnode(const CNodePtr & cnode,size_t input_index)1728 STATUS AdjustInputToCnode(const CNodePtr &cnode, size_t input_index) {
1729 auto func_graph = cnode->func_graph();
1730 if (func_graph == nullptr) {
1731 MS_LOG(ERROR) << "func graph is nullptr.";
1732 return RET_ERROR;
1733 }
1734 ops::TensorMove tensor_move;
1735 auto tensor_move_prim = tensor_move.GetPrim();
1736 if (tensor_move_prim == nullptr) {
1737 MS_LOG(ERROR) << "tensor move prim is nullptr.";
1738 return RET_ERROR;
1739 }
1740 auto tensor_move_cnode = func_graph->NewCNode(tensor_move_prim, {cnode->input(input_index)});
1741 if (tensor_move_cnode == nullptr) {
1742 MS_LOG(ERROR) << "new cnode failed.";
1743 return RET_ERROR;
1744 }
1745 auto manager = Manage(func_graph);
1746 if (manager == nullptr) {
1747 MS_LOG(ERROR) << "manager is nullptr.";
1748 return RET_ERROR;
1749 }
1750 tensor_move_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "tensor_move" +
1751 std::to_string(input_index));
1752 auto temp_abstract = cnode->input(input_index)->abstract()->Clone();
1753 if (temp_abstract == nullptr) {
1754 MS_LOG(ERROR) << "abstract clone failed.";
1755 return RET_ERROR;
1756 }
1757 tensor_move_cnode->set_abstract(temp_abstract);
1758 manager->SetEdge(cnode, input_index, tensor_move_cnode);
1759 return RET_OK;
1760 }
1761
GetTensorFromParameterNode(const EquivPtr & equiv,const VarPtr & input)1762 tensor::TensorPtr GetTensorFromParameterNode(const EquivPtr &equiv, const VarPtr &input) {
1763 MS_CHECK_TRUE_RET(equiv != nullptr && input != nullptr, nullptr);
1764 auto node = utils::cast<AnfNodePtr>((*equiv)[input]);
1765 if (node == nullptr || !utils::isa<ParameterPtr>(node)) {
1766 MS_LOG(ERROR) << "node is nullptr or node is not a parameter node.";
1767 return nullptr;
1768 }
1769 auto parameter_node = node->cast<ParameterPtr>();
1770 if (!parameter_node->has_default() || parameter_node->default_param() == nullptr) {
1771 MS_LOG(ERROR) << "parameter_node has no default or its default_param() is nullptr.";
1772 return nullptr;
1773 }
1774 auto param_value_lite = parameter_node->default_param()->cast<tensor::TensorPtr>();
1775 return param_value_lite;
1776 }
1777
GetIntParameterValue(const EquivPtr & equiv,const VarPtr & input)1778 const int GetIntParameterValue(const EquivPtr &equiv, const VarPtr &input) {
1779 MS_CHECK_TRUE_RET(equiv != nullptr && input != nullptr, INT_MIN);
1780 auto param_value_lite = GetTensorFromParameterNode(equiv, input);
1781 const int value = INT_MIN;
1782 if (param_value_lite == nullptr) {
1783 return value;
1784 }
1785 if (param_value_lite->data_type() != kNumberTypeInt32 && param_value_lite->data_type() != kNumberTypeInt) {
1786 return value;
1787 }
1788 if (param_value_lite->Size() != sizeof(int)) {
1789 return value;
1790 }
1791 return *static_cast<int *>(param_value_lite->data_c());
1792 }
1793
GetFloatParameterValue(const EquivPtr & equiv,const VarPtr & input)1794 const float GetFloatParameterValue(const EquivPtr &equiv, const VarPtr &input) {
1795 const float value = -1;
1796 MS_CHECK_TRUE_RET(equiv != nullptr && input != nullptr, value);
1797 auto param_value_lite = GetTensorFromParameterNode(equiv, input);
1798 if (param_value_lite == nullptr) {
1799 return value;
1800 }
1801 if (param_value_lite->data_type() != kNumberTypeFloat32 && param_value_lite->data_type() != kNumberTypeFloat) {
1802 return value;
1803 }
1804 if (param_value_lite->Size() != sizeof(float)) {
1805 return value;
1806 }
1807 return *static_cast<float *>(param_value_lite->data_c());
1808 }
1809
1810 }; // namespace opt
1811 } // namespace mindspore
1812