• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/anf_util.h"
18 #include <memory>
19 #include <vector>
20 #include <unordered_map>
21 #include <string>
22 #include <map>
23 #include <algorithm>
24 #include "third_party/securec/include/securec.h"
25 #include "common/op_enum.h"
26 #include "common/op_attr.h"
27 #include "common/string_util.h"
28 #include "ops/custom.h"
29 #include "ops/tuple_get_item.h"
30 #include "ops/auto_generate/gen_lite_ops.h"
31 #include "common/check_base.h"
32 namespace mindspore {
33 namespace ops {
34 class PrimitiveC;
35 }
36 }  // namespace mindspore
37 namespace mindspore {
38 namespace dpico {
39 namespace {
40 const std::map<TypeId, size_t> kTypeMap = {
41   {kNumberTypeBool, 1},       {kNumberTypeInt, 4},     {kNumberTypeInt8, 1},    {kNumberTypeInt16, 2},
42   {kNumberTypeInt32, 4},      {kNumberTypeInt64, 8},   {kNumberTypeUInt, 4},    {kNumberTypeUInt8, 1},
43   {kNumberTypeUInt16, 2},     {kNumberTypeUInt32, 4},  {kNumberTypeUInt64, 8},  {kNumberTypeFloat, 4},
44   {kNumberTypeFloat16, 2},    {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}, {kNumberTypeComplex64, 8},
45   {kNumberTypeComplex128, 16}};
46 constexpr size_t kTupleGetItemInputSize = 3;
47 constexpr size_t kInputNodeOutputIndexInTupleGetItem = 2;
48 using PrimitiveCPtr = std::shared_ptr<ops::PrimitiveC>;
TypeIdSize(const TypeId data_type)49 size_t TypeIdSize(const TypeId data_type) {
50   const size_t unsupported_type_error = 0;
51   auto iter = kTypeMap.find(data_type);
52   if (iter != kTypeMap.end()) {
53     return iter->second;
54   }
55   return unsupported_type_error;
56 }
57 }  // namespace
CheckPrimitiveType(const api::AnfNodePtr & node,const api::PrimitivePtr & primitive_type)58 bool CheckPrimitiveType(const api::AnfNodePtr &node, const api::PrimitivePtr &primitive_type) {
59   if (node == nullptr) {
60     return false;
61   }
62   if (node->isa<api::CNode>()) {
63     auto cnode = node->cast<api::CNodePtr>();
64     return IsPrimitive(cnode->input(0), primitive_type);
65   } else if (node->isa<api::ValueNode>()) {
66     return IsPrimitive(node, primitive_type);
67   }
68   return false;
69 }
70 
GetPrimitiveType(const api::AnfNodePtr & node,std::string * name)71 STATUS GetPrimitiveType(const api::AnfNodePtr &node, std::string *name) {
72   if (name == nullptr) {
73     MS_LOG(ERROR) << "name is nulltr.";
74     return RET_ERROR;
75   }
76   if (node == nullptr) {
77     MS_LOG(ERROR) << "node is nullptr.";
78     return RET_ERROR;
79   }
80   if (node->isa<api::CNode>()) {
81     auto cnode = node->cast<api::CNodePtr>();
82     auto primitive = api::GetValueNode<api::PrimitivePtr>(cnode->input(0));
83     if (primitive == nullptr) {
84       MS_LOG(ERROR) << "primitive is nullptr. " << cnode->fullname_with_scope();
85       return RET_ERROR;
86     }
87     if (CheckPrimitiveType(node, api::MakeShared<ops::Custom>())) {
88       auto custom_prim = api::utils::cast<api::SharedPtr<ops::Custom>>(primitive);
89       MS_CHECK_TRUE_MSG(custom_prim != nullptr, RET_ERROR, "custom op is nullptr.");
90       *name = custom_prim->get_type();
91       return RET_OK;
92     } else {
93       *name = primitive->name();
94       return RET_OK;
95     }
96   } else if (node->isa<api::ValueNode>()) {
97     auto fn_value = api::GetValueNode<api::PrimitivePtr>(node);
98     if (fn_value == nullptr) {
99       MS_LOG(ERROR) << "fn_value is nullptr.";
100       return RET_ERROR;
101     }
102     *name = fn_value->name();
103     return RET_OK;
104   }
105   MS_LOG(ERROR) << "There is no name for this node";
106   return RET_ERROR;
107 }
GetShapeVectorFromParameter(const api::AnfNodePtr & anode,ShapeVector * shape_vector)108 STATUS GetShapeVectorFromParameter(const api::AnfNodePtr &anode, ShapeVector *shape_vector) {
109   if (shape_vector == nullptr) {
110     MS_LOG(ERROR) << "shape vector is nullptr.";
111     return RET_ERROR;
112   }
113   if (!api::utils::isa<api::Parameter>(anode)) {
114     MS_LOG(ERROR) << "anode should be parameter node. ";
115     return RET_ERROR;
116   }
117   auto param_node = anode->cast<api::ParameterPtr>();
118   auto abstract_base = param_node->abstract();
119   if (abstract_base == nullptr) {
120     MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
121     return lite::RET_PARAM_INVALID;
122   }
123   if (!api::utils::isa<api::AbstractTensorPtr>(abstract_base)) {
124     MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name();
125     return lite::RET_INPUT_TENSOR_ERROR;
126   }
127   auto abstract_tensor = api::utils::cast<api::AbstractTensorPtr>(abstract_base);
128   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
129   if (!api::utils::isa<api::ShapePtr>(abstract_tensor->shape())) {
130     MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
131     return lite::RET_PARAM_INVALID;
132   }
133   *shape_vector = api::utils::cast<api::ShapePtr>(abstract_tensor->shape())->shape();
134   return RET_OK;
135 }
CastToInt(const api::ValuePtr & value)136 std::vector<int> CastToInt(const api::ValuePtr &value) {
137   if (value == nullptr) {
138     MS_LOG(WARNING) << "valueptr is nullptr.";
139     return {};
140   }
141   std::vector<int> cur_value = {};
142   if (api::utils::isa<api::ValueSequencePtr>(value)) {
143     if (!value->cast<api::ValueSequencePtr>()->value().empty()) {
144       auto origin_value = api::GetValue<std::vector<int64_t>>(value);
145       (void)std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
146                            [](int64_t index) { return static_cast<int32_t>(index); });
147     }
148   } else {
149     cur_value.push_back(static_cast<int>(api::GetValue<int64_t>(value)));
150   }
151   return cur_value;
152 }
GetTupleGetItemOutIndex(const api::CNodePtr & tuple_get_item)153 size_t GetTupleGetItemOutIndex(const api::CNodePtr &tuple_get_item) {
154   MS_ASSERT(tuple_get_item != nullptr);
155   if (tuple_get_item->size() != kTupleGetItemInputSize) {
156     MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
157     return SIZE_MAX;
158   }
159   auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
160   MS_ASSERT(output_index_value_node != nullptr);
161   auto value_node = output_index_value_node->cast<api::ValueNodePtr>();
162   MS_ASSERT(value_node != nullptr);
163   auto value_vec = CastToInt(value_node->value());
164   if (value_vec.empty()) {
165     MS_LOG(ERROR) << "value vec is empty.";
166     return SIZE_MAX;
167   }
168   return IntToSize(value_vec.front());
169 }
GetOutputShapesFromCNode(const api::CNodePtr & cnode,std::vector<ShapeVector> * output_shapes)170 STATUS GetOutputShapesFromCNode(const api::CNodePtr &cnode, std::vector<ShapeVector> *output_shapes) {
171   api::AbstractBasePtr abstract = nullptr;
172   if (output_shapes == nullptr) {
173     MS_LOG(ERROR) << "output_shapes is nullptr.";
174     return RET_ERROR;
175   }
176   if (CheckPrimitiveType(cnode, api::MakeShared<ops::TupleGetItem>())) {
177     auto tuple_inputs = cnode->inputs();
178     MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
179     auto get_item_input_cnode = tuple_inputs.at(1);
180     MS_ASSERT(get_item_input_cnode != nullptr);
181     auto idx = GetTupleGetItemOutIndex(cnode);
182     if (!api::utils::isa<api::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
183       MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
184       return RET_ERROR;
185     }
186     auto abstract_tuple = api::utils::cast<api::AbstractTuplePtr>(get_item_input_cnode->abstract());
187     auto abstract_list = abstract_tuple->elements();
188     if (abstract_list.size() <= idx) {
189       MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
190       return RET_ERROR;
191     }
192     abstract = abstract_list[idx];
193   } else {
194     abstract = cnode->abstract();
195   }
196   if (abstract == nullptr) {
197     MS_LOG(ERROR) << "abstract cnode is nullptr. " << cnode->fullname_with_scope();
198     return RET_ERROR;
199   }
200   if (api::utils::isa<api::AbstractTuplePtr>(abstract)) {
201     auto abstract_tuple = api::utils::cast<api::AbstractTuplePtr>(abstract);
202     auto abstract_list = abstract_tuple->elements();
203     for (const auto &elem : abstract_list) {
204       ShapeVector shape_vector;
205       if (FetchShapeFromAbstract(elem, &shape_vector) != RET_OK) {
206         MS_LOG(ERROR) << "fetch shape from abstract tuple elem failed. " << cnode->fullname_with_scope();
207         return RET_ERROR;
208       }
209       if (shape_vector.empty()) {
210         MS_LOG(ERROR) << "shape_vector is empty." << cnode->fullname_with_scope();
211         return RET_ERROR;
212       }
213       (void)output_shapes->emplace_back(shape_vector);
214     }
215     return RET_OK;
216   } else {
217     ShapeVector shape_vector;
218     if (FetchShapeFromAbstract(abstract, &shape_vector) != RET_OK) {
219       MS_LOG(ERROR) << "fetch shape from abstract failed. " << cnode->fullname_with_scope();
220       return RET_ERROR;
221     }
222     if (shape_vector.empty()) {
223       MS_LOG(ERROR) << "shape_vector is empty." << cnode->fullname_with_scope();
224       return RET_ERROR;
225     }
226     (void)output_shapes->emplace_back(shape_vector);
227   }
228   return RET_OK;
229 }
230 
GetInputShapeFromCNode(const api::CNodePtr & cnode,size_t input_idx,ShapeVector * shape)231 STATUS GetInputShapeFromCNode(const api::CNodePtr &cnode, size_t input_idx, ShapeVector *shape) {
232   if (shape == nullptr) {
233     MS_LOG(ERROR) << "shape is nullptr.";
234     return RET_ERROR;
235   }
236   auto input_abstract = GetCNodeInputAbstract(cnode, input_idx);
237   if (input_abstract == nullptr) {
238     MS_LOG(ERROR) << "input_abstract is nullptr.";
239     return RET_ERROR;
240   }
241   if (FetchShapeFromAbstract(input_abstract, shape) != RET_OK) {
242     MS_LOG(ERROR) << "fetch shape from abstract failed.";
243     return RET_ERROR;
244   }
245   return RET_OK;
246 }
247 
FetchShapeFromAbstract(const api::AbstractBasePtr & abstract,ShapeVector * shape)248 STATUS FetchShapeFromAbstract(const api::AbstractBasePtr &abstract, ShapeVector *shape) {
249   if (shape == nullptr) {
250     MS_LOG(ERROR) << "shape is nullptr.";
251     return RET_ERROR;
252   }
253   if (abstract == nullptr) {
254     MS_LOG(ERROR) << "abstract of cnode is invalid.";
255     return RET_ERROR;
256   }
257   if (!api::utils::isa<api::AbstractTensor>(abstract)) {
258     MS_LOG(ERROR) << "abstract of cnode is invalid.";
259     return RET_ERROR;
260   }
261   auto abstract_tensor = abstract->cast<api::AbstractTensorPtr>();
262   if (!api::utils::isa<api::ShapePtr>(abstract_tensor->shape())) {
263     MS_LOG(ERROR) << "shape of cnode's output is invalid.";
264     return RET_ERROR;
265   }
266   *shape = api::utils::cast<api::ShapePtr>(abstract_tensor->shape())->shape();
267   return RET_OK;
268 }
FetchTypeIdFromAbstract(const api::AbstractBasePtr & abstract,TypeId * type_id)269 STATUS FetchTypeIdFromAbstract(const api::AbstractBasePtr &abstract, TypeId *type_id) {
270   if (type_id == nullptr) {
271     MS_LOG(ERROR) << "type id is nullptr.";
272     return RET_ERROR;
273   }
274   if (abstract == nullptr) {
275     MS_LOG(ERROR) << "abstract of cnode is invalid.";
276     return RET_ERROR;
277   }
278   if (!api::utils::isa<api::AbstractTensor>(abstract)) {
279     MS_LOG(ERROR) << "abstract of cnode is invalid.";
280     return RET_ERROR;
281   }
282   auto abstract_tensor = abstract->cast<api::AbstractTensorPtr>();
283   if (abstract_tensor->element() == nullptr) {
284     MS_LOG(ERROR) << "element of abstract_tensor is nullptr.";
285     return RET_ERROR;
286   }
287   auto type_ptr = abstract_tensor->element()->type();
288   if (type_ptr == nullptr) {
289     MS_LOG(ERROR) << "type_ptr of abstract_tensor is nullptr.";
290     return RET_ERROR;
291   }
292   *type_id = type_ptr->type_id();
293   return RET_OK;
294 }
295 
GetAnfNodeOutputShape(const api::AnfNodePtr & input,ShapeVector * shape_vector)296 int GetAnfNodeOutputShape(const api::AnfNodePtr &input, ShapeVector *shape_vector) {
297   if (shape_vector == nullptr) {
298     MS_LOG(ERROR) << "shape vector is nullptr." << input->fullname_with_scope();
299     return RET_ERROR;
300   }
301   if (api::utils::isa<api::ParameterPtr>(input)) {
302     if (GetShapeVectorFromParameter(input, shape_vector) != RET_OK) {
303       MS_LOG(ERROR) << "get output shape for preprocessor failed. " << input->fullname_with_scope();
304       return RET_ERROR;
305     }
306   } else if (api::utils::isa<api::CNodePtr>(input)) {
307     auto input_cnode = input->cast<api::CNodePtr>();
308     std::vector<ShapeVector> output_shapes;
309     if (GetOutputShapesFromCNode(input_cnode, &output_shapes) != RET_OK) {
310       MS_LOG(ERROR) << "get output shapes from cnode failed. " << input_cnode->fullname_with_scope();
311       return RET_ERROR;
312     }
313     if (output_shapes.size() == 1) {
314       *shape_vector = output_shapes.at(0);
315     } else {
316       MS_LOG(ERROR) << input_cnode->fullname_with_scope() << " has " << output_shapes.size()
317                     << " output, which should be 1.";
318       return RET_ERROR;
319     }
320   }
321   if (shape_vector->empty()) {
322     MS_LOG(ERROR) << "subgraph input shape shouldn't be empty. " << input->fullname_with_scope();
323     return RET_ERROR;
324   } else if (shape_vector->at(0) < 0) {
325     MS_LOG(WARNING) << " the N axis of " << input->fullname_with_scope() << "'s output shape is " << shape_vector->at(0)
326                     << ", which will be set to 1.";
327     shape_vector->at(0) = 1;
328   }
329   return RET_OK;
330 }
331 
TypeIdToString(TypeId type_id)332 std::string TypeIdToString(TypeId type_id) {
333   const std::unordered_map<int, std::string> kTypeIdMap{
334     {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"},    {kNumberTypeFloat32, "Float32"},
335     {kNumberTypeInt8, "Int8"},       {kNumberTypeInt16, "Int16"},      {kNumberTypeInt, "Int32"},
336     {kNumberTypeInt32, "Int32"},     {kNumberTypeUInt8, "UInt8"},      {kNumberTypeUInt16, "UInt16"},
337     {kNumberTypeUInt, "UInt32"},     {kNumberTypeUInt32, "UInt32"},    {kObjectTypeString, "String"},
338     {kNumberTypeBool, "Bool"},       {kObjectTypeTensorType, "Tensor"}};
339   std::string type_str = "Unknown";
340   if (kTypeIdMap.find(static_cast<int>(type_id)) != kTypeIdMap.end()) {
341     type_str = kTypeIdMap.at(static_cast<int>(type_id));
342   }
343   return type_str;
344 }
345 
CheckInputs(const api::CNodePtr & cnode)346 bool CheckInputs(const api::CNodePtr &cnode) {
347   if (cnode == nullptr) {
348     MS_LOG(ERROR) << "cnode is nullptr.";
349     return false;
350   }
351   auto inputs = cnode->inputs();
352   if (std::any_of(inputs.begin(), inputs.end(), [](const api::AnfNodePtr &anf_node) { return anf_node == nullptr; })) {
353     MS_LOG(ERROR) << "input is nullptr.";
354     return false;
355   }
356   return true;
357 }
GetCustomOutputName(const api::AnfNodePtr & node)358 std::string GetCustomOutputName(const api::AnfNodePtr &node) {
359   std::string output_name;
360   auto input_cnode = node->cast<api::CNodePtr>();
361   if (input_cnode == nullptr) {
362     MS_LOG(ERROR) << "custom node should be cnode. " << node->fullname_with_scope();
363     return "";
364   }
365   if (input_cnode->GetAttr(kOutputsNames) != nullptr) {
366     auto output_names = api::GetValue<std::vector<std::string>>(input_cnode->GetAttr(kOutputsNames));
367     if (output_names.size() == 1) {
368       output_name = output_names.at(0);
369     } else {
370       MS_LOG(ERROR) << "multi-output's custom node shouldn't be a subgraph's input cnode. "
371                     << node->fullname_with_scope();
372       return "";
373     }
374   }
375   return output_name;
376 }
CreateTensorInfo(const void * data,size_t data_size,const std::vector<int64_t> & shape,TypeId data_type)377 api::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape,
378                                 TypeId data_type) {
379   api::TensorPtr tensor_info = nullptr;
380   if (shape.empty() && data_size == TypeIdSize(data_type)) {
381     ShapeVector scalar_shape = {1};
382     tensor_info = api::MakeShared<api::Tensor>(data_type, scalar_shape);
383     if (tensor_info == nullptr) {
384       MS_LOG(ERROR) << "new tensor init failed";
385       return nullptr;
386     }
387     tensor_info->set_shape({});
388   } else {
389     tensor_info = api::MakeShared<api::Tensor>(data_type, shape);
390     if (tensor_info == nullptr) {
391       MS_LOG(ERROR) << "new tensor init failed";
392       return nullptr;
393     }
394   }
395   if (data_size == 0) {
396     return tensor_info;
397   }
398   if (data == nullptr) {
399     MS_LOG(ERROR) << "input tensor data is nullptr";
400     return nullptr;
401   }
402   auto ret = memcpy_s(tensor_info->data(), tensor_info->Size(), data, data_size);
403   if (ret != EOK) {
404     MS_LOG(ERROR) << "memcpy_s error : " << ret;
405     return nullptr;
406   }
407   return tensor_info;
408 }
409 
CreateTensorAbstract(const std::vector<int64_t> & shape,TypeId data_type)410 api::AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type) {
411   auto tensor_info = dpico::CreateTensorInfo(nullptr, 0, shape, data_type);
412   if (tensor_info == nullptr) {
413     MS_LOG(ERROR) << "Create tensor info failed";
414     return nullptr;
415   }
416   auto abstract = tensor_info->ToAbstract();
417   if (abstract == nullptr) {
418     MS_LOG(ERROR) << "Create tensor abstarct failed";
419     return nullptr;
420   }
421   return abstract;
422 }
423 
InitParameterFromTensorInfo(const api::ParameterPtr & param_node,const api::TensorPtr & tensor_info)424 int InitParameterFromTensorInfo(const api::ParameterPtr &param_node, const api::TensorPtr &tensor_info) {
425   if (tensor_info == nullptr) {
426     MS_LOG(ERROR) << "tensor info is nullptr.";
427     return RET_ERROR;
428   }
429   auto abstract_tensor = tensor_info->ToAbstract();
430   if (abstract_tensor == nullptr) {
431     MS_LOG(ERROR) << "Create abstract tensor failed.";
432     return RET_ERROR;
433   }
434   param_node->set_abstract(abstract_tensor);
435   param_node->set_default_param(tensor_info);
436   return RET_OK;
437 }
438 
GetCNodeInputAbstract(const api::CNodePtr & cnode,size_t index)439 api::AbstractBasePtr GetCNodeInputAbstract(const api::CNodePtr &cnode, size_t index) {
440   if (cnode == nullptr) {
441     MS_LOG(ERROR) << "CNodePtr is nullptr";
442     return nullptr;
443   }
444   auto inputs = cnode->inputs();
445   if (index >= inputs.size()) {
446     MS_LOG(ERROR) << "index: " << index << " is greater than inputs size " << inputs.size();
447     return nullptr;
448   }
449   auto input = inputs[index];
450   if (input == nullptr) {
451     MS_LOG(ERROR) << "CNode input is nullptr";
452     return nullptr;
453   }
454 
455   api::AbstractBasePtr abstract = nullptr;
456   if (api::utils::isa<api::ParameterPtr>(input)) {
457     auto parameter = input->cast<api::ParameterPtr>();
458     abstract = parameter->abstract();
459   } else if (api::utils::isa<api::CNodePtr>(input)) {
460     auto input_cnode = input->cast<api::CNodePtr>();
461     if (CheckPrimitiveType(input_cnode, api::MakeShared<ops::TupleGetItem>())) {
462       auto tuple_inputs = input_cnode->inputs();
463       MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
464       auto get_item_input_cnode = tuple_inputs.at(1);
465       MS_ASSERT(get_item_input_cnode != nullptr);
466       auto idx = GetTupleGetItemOutIndex(input_cnode);
467       if (!api::utils::isa<api::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
468         MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
469         return nullptr;
470       }
471       auto abstract_tuple = api::utils::cast<api::AbstractTuplePtr>(get_item_input_cnode->abstract());
472       auto abstract_list = abstract_tuple->elements();
473       if (abstract_list.size() <= idx) {
474         MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
475         return nullptr;
476       }
477       abstract = abstract_list[idx];
478     } else {
479       abstract = input_cnode->abstract();
480     }
481   } else {
482     MS_LOG(ERROR) << "unsupported input node type";
483     return nullptr;
484   }
485   return abstract;
486 }
487 
GetAbstractFromAnfNode(const api::AnfNodePtr & node)488 api::AbstractBasePtr GetAbstractFromAnfNode(const api::AnfNodePtr &node) {
489   api::AbstractBasePtr abstract = nullptr;
490   if (api::utils::isa<api::ParameterPtr>(node)) {
491     auto parameter = node->cast<api::ParameterPtr>();
492     abstract = parameter->abstract();
493   } else if (api::utils::isa<api::CNodePtr>(node)) {
494     auto cnode = node->cast<api::CNodePtr>();
495     if (CheckPrimitiveType(cnode, api::MakeShared<ops::TupleGetItem>())) {
496       auto tuple_inputs = cnode->inputs();
497       MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
498       auto get_item_input_cnode = tuple_inputs.at(1);
499       MS_ASSERT(get_item_input_cnode != nullptr);
500       auto idx = GetTupleGetItemOutIndex(cnode);
501       if (!api::utils::isa<api::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
502         MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
503         return nullptr;
504       }
505       auto abstract_tuple = api::utils::cast<api::AbstractTuplePtr>(get_item_input_cnode->abstract());
506       auto abstract_list = abstract_tuple->elements();
507       if (abstract_list.size() <= idx) {
508         MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
509         return nullptr;
510       }
511       abstract = abstract_list[idx];
512     } else {
513       abstract = cnode->abstract();
514     }
515   }
516   return abstract;
517 }
518 
BuildIntValueParameterNode(const api::FuncGraphPtr & func_graph,const int32_t & data,const std::string & node_name)519 api::ParameterPtr BuildIntValueParameterNode(const api::FuncGraphPtr &func_graph, const int32_t &data,
520                                              const std::string &node_name) {
521   MS_ASSERT(func_graph != nullptr);
522   auto param_node = func_graph->add_parameter();
523   param_node->set_name(node_name);
524 
525   auto tensor_info = CreateTensorInfo(&data, sizeof(int32_t), {1}, kNumberTypeInt32);
526   if (tensor_info == nullptr) {
527     MS_LOG(ERROR) << "Create tensor info failed";
528     return nullptr;
529   }
530 
531   auto status = InitParameterFromTensorInfo(param_node, tensor_info);
532   if (status != RET_OK) {
533     MS_LOG(ERROR) << "init parameter from tensor info failed";
534     return nullptr;
535   }
536   return param_node;
537 }
538 
BuildIntVecParameterNode(const api::FuncGraphPtr & func_graph,const std::vector<int32_t> & data,const std::string & node_name)539 api::ParameterPtr BuildIntVecParameterNode(const api::FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
540                                            const std::string &node_name) {
541   MS_ASSERT(func_graph != nullptr);
542   MS_CHECK_TRUE_MSG(data.size() != 0, nullptr, "Data size is 0");
543   auto param_node = func_graph->add_parameter();
544   param_node->set_name(node_name);
545 
546   std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
547   auto tensor_info = CreateTensorInfo(data.data(), data.size() * sizeof(int32_t), shape_vector, kNumberTypeInt32);
548   if (tensor_info == nullptr) {
549     MS_LOG(ERROR) << "Create tensor info failed";
550     return nullptr;
551   }
552 
553   auto status = InitParameterFromTensorInfo(param_node, tensor_info);
554   if (status != RET_OK) {
555     MS_LOG(ERROR) << "init parameter from tensor info failed";
556     return nullptr;
557   }
558 
559   return param_node;
560 }
561 
BuildIntVec2DParameterNode(const api::FuncGraphPtr & func_graph,const std::vector<std::vector<int32_t>> & data,const std::string & node_name)562 api::ParameterPtr BuildIntVec2DParameterNode(const api::FuncGraphPtr &func_graph,
563                                              const std::vector<std::vector<int32_t>> &data,
564                                              const std::string &node_name) {
565   MS_ASSERT(func_graph != nullptr);
566   MS_CHECK_TRUE_MSG(data.size() != 0, nullptr, "Data size is 0");
567   auto param_node = func_graph->add_parameter();
568   param_node->set_name(node_name);
569 
570   std::vector<int64_t> shape_vector;
571   shape_vector.push_back(data.size());
572   shape_vector.push_back(kDims2);
573 
574   std::vector<int32_t> data_1d;
575   for (auto pair : data) {
576     (void)data_1d.insert(data_1d.end(), pair.begin(), pair.end());
577   }
578 
579   auto size = data_1d.size() * sizeof(int32_t);
580   auto tensor_info = CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeInt32);
581   if (tensor_info == nullptr) {
582     MS_LOG(ERROR) << "Create tensor info failed";
583     return nullptr;
584   }
585   auto status = InitParameterFromTensorInfo(param_node, tensor_info);
586   if (status != RET_OK) {
587     MS_LOG(ERROR) << "init parameter from tensor info failed";
588     return nullptr;
589   }
590   return param_node;
591 }
592 
BuildFloatValueParameterNode(const api::FuncGraphPtr & func_graph,const float & data,const std::string & node_name)593 api::ParameterPtr BuildFloatValueParameterNode(const api::FuncGraphPtr &func_graph, const float &data,
594                                                const std::string &node_name) {
595   MS_ASSERT(func_graph != nullptr);
596   auto param_node = func_graph->add_parameter();
597   param_node->set_name(node_name);
598 
599   auto tensor_info = CreateTensorInfo(&data, sizeof(float), {1}, kNumberTypeFloat32);
600   if (tensor_info == nullptr) {
601     MS_LOG(ERROR) << "Create tensor info failed";
602     return nullptr;
603   }
604   auto status = InitParameterFromTensorInfo(param_node, tensor_info);
605   if (status != RET_OK) {
606     MS_LOG(ERROR) << "init parameter from tensor info failed";
607     return nullptr;
608   }
609   return param_node;
610 }
611 
GenTransposeNode(const api::FuncGraphPtr & func_graph,const api::AnfNodePtr & input_node,const std::vector<int> & perm,const std::string & cnode_name)612 api::CNodePtr GenTransposeNode(const api::FuncGraphPtr &func_graph, const api::AnfNodePtr &input_node,
613                                const std::vector<int> &perm, const std::string &cnode_name) {
614   MS_ASSERT(func_graph != nullptr && input_node != nullptr);
615   auto perm_node = BuildIntVecParameterNode(func_graph, perm, cnode_name + "_perm");
616   if (perm_node == nullptr) {
617     MS_LOG(ERROR) << "new perm_node error";
618     return nullptr;
619   }
620   auto trans_prim = api::MakeShared<ops::Transpose>();
621   if (trans_prim == nullptr) {
622     MS_LOG(ERROR) << "new trans_prim failed";
623     return nullptr;
624   }
625   auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node});
626   if (cnode == nullptr) {
627     MS_LOG(ERROR) << "new cnode error";
628     return nullptr;
629   }
630   auto manager = api::FuncGraphManager::Manage(func_graph);
631   if (manager == nullptr) {
632     MS_LOG(ERROR) << "manager is nullptr.";
633     return nullptr;
634   }
635   manager->SetEdge(cnode, 1, input_node);
636   manager->SetEdge(cnode, kInputIndex2, perm_node);
637   cnode->set_fullname_with_scope(cnode_name);
638   return cnode;
639 }
640 
GetTensorInfo(const api::AnfNodePtr & node)641 api::TensorPtr GetTensorInfo(const api::AnfNodePtr &node) {
642   MS_ASSERT(node != nullptr);
643   if (!api::utils::isa<api::ParameterPtr>(node)) {
644     if (api::utils::isa<api::ValueNodePtr>(node)) {
645       auto valueNode = node->cast<api::ValueNodePtr>();
646       auto value = valueNode->value()->cast<api::TensorPtr>();
647       if (value != nullptr) {
648         return value;
649       }
650     }
651     MS_LOG(DEBUG) << "get lite param value node neither parameter node or value node";
652     return nullptr;
653   }
654   auto param = node->cast<api::ParameterPtr>();
655   if (param == nullptr) {
656     MS_LOG(ERROR) << "param is nullptr.";
657     return nullptr;
658   }
659   auto tensor_info = param->default_param()->cast<api::TensorPtr>();
660   return tensor_info;
661 }
662 
CastToVec2DInt(const api::ValuePtr & value)663 std::vector<std::vector<int>> CastToVec2DInt(const api::ValuePtr &value) {
664   if (value == nullptr) {
665     MS_LOG(WARNING) << "valueptr is nullptr.";
666     return {};
667   }
668 
669   std::vector<std::vector<int>> result_value;
670   if (api::utils::isa<api::ValueSequencePtr>(value)) {
671     auto origin_value = api::GetValue<std::vector<std::vector<int64_t>>>(value);
672     for (auto &vec : origin_value) {
673       std::vector<int> cur_value;
674       for (size_t j = 0; j < vec.size(); ++j) {
675         cur_value.push_back(static_cast<int>(vec[j]));
676       }
677       result_value.push_back(cur_value);
678     }
679   }
680   return result_value;
681 }
682 
GetBoolAttr(const api::AnfNodePtr & node,const std::string & attr_name)683 bool GetBoolAttr(const api::AnfNodePtr &node, const std::string &attr_name) {
684   auto cnode = node->cast<api::CNodePtr>();
685   if (cnode == nullptr) {
686     MS_LOG(ERROR) << "cur node is not a cnode. " << node->fullname_with_scope();
687     return false;
688   }
689   auto primitive = api::GetValueNode<api::PrimitivePtr>(cnode->input(0));
690   if (primitive == nullptr) {
691     MS_LOG(ERROR) << "primitive is nullptr:" << cnode->fullname_with_scope();
692     return false;
693   }
694   auto value_ptr = primitive->GetAttr(attr_name);
695   if (value_ptr == nullptr) {
696     MS_LOG(ERROR) << "There is no attr named " << attr_name << " for node " << cnode->fullname_with_scope();
697     return false;
698   }
699   return api::GetValue<bool>(value_ptr);
700 }
701 
GetDataTypeAndShape(const api::ParameterPtr & param_node,TypeId * data_type,ShapeVector * shape_vector)702 STATUS GetDataTypeAndShape(const api::ParameterPtr &param_node, TypeId *data_type, ShapeVector *shape_vector) {
703   if (param_node == nullptr) {
704     MS_LOG(ERROR) << "param node is nullptr.";
705     return RET_ERROR;
706   }
707   if (data_type == nullptr) {
708     MS_LOG(ERROR) << "data type is nullptr.";
709     return RET_ERROR;
710   }
711   if (shape_vector == nullptr) {
712     MS_LOG(ERROR) << "shape vector is nullptr.";
713     return RET_ERROR;
714   }
715   auto abstract_base = param_node->abstract();
716   if (abstract_base == nullptr) {
717     MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
718     return RET_ERROR;
719   }
720   if (!api::utils::isa<api::AbstractTensorPtr>(abstract_base)) {
721     MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name();
722     return RET_ERROR;
723   }
724   auto abstract_tensor = api::utils::cast<api::AbstractTensorPtr>(abstract_base);
725   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
726   auto typePtr = abstract_tensor->element()->type();
727   MS_ASSERT(typePtr != nullptr);
728   *data_type = typePtr->type_id();
729   if (!api::utils::isa<api::ShapePtr>(abstract_tensor->shape())) {
730     MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
731     return RET_ERROR;
732   }
733   *shape_vector = api::utils::cast<api::ShapePtr>(abstract_tensor->shape())->shape();
734   return RET_OK;
735 }
736 
GetShapeVectorFromStringTensor(const api::TensorPtr & tensor_info,ShapeVector * shape_vector,size_t * offset)737 STATUS GetShapeVectorFromStringTensor(const api::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) {
738   if (tensor_info == nullptr) {
739     MS_LOG(ERROR) << "tensor info is nullptr.";
740     return RET_ERROR;
741   }
742   if (shape_vector == nullptr) {
743     MS_LOG(ERROR) << "shape vector is nullptr.";
744     return RET_ERROR;
745   }
746   if (offset == nullptr) {
747     MS_LOG(ERROR) << "offset is nullptr.";
748     return RET_ERROR;
749   }
750   auto data_type = tensor_info->data_type();
751   if (data_type != kObjectTypeString) {
752     MS_LOG(ERROR) << "This function only used for string tensor.";
753     return RET_ERROR;
754   }
755   shape_vector->clear();
756   auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data());
757   std::string shape_str;
758   std::string shape_size_str;
759   *offset = 0;
760   size_t cnt = 0;
761   for (; *offset < tensor_info->Size(); (*offset)++) {
762     if (tensor_data[*offset] == ',') {
763       (*offset)++;
764       break;
765     }
766     shape_size_str.push_back(static_cast<char>(tensor_data[*offset]));
767   }
768   if (*offset == 0) {
769     MS_LOG(ERROR) << "string tensor's dim size not found.";
770     return RET_ERROR;
771   }
772   if (!IsValidUnsignedNum(shape_size_str)) {
773     MS_LOG(ERROR) << "shape_size str must an unsigned int.";
774     return RET_ERROR;
775   }
776   size_t shape_size = std::stoi(shape_size_str);
777   for (; *offset < tensor_info->Size(); (*offset)++) {
778     if (tensor_data[*offset] == ',') {
779       cnt++;
780       if (!IsValidUnsignedNum(shape_str)) {
781         MS_LOG(ERROR) << "shape str must an unsigned int.";
782         return RET_ERROR;
783       }
784       shape_vector->push_back(std::stoi(shape_str));
785       shape_str.clear();
786     } else {
787       shape_str.push_back(static_cast<char>(tensor_data[*offset]));
788     }
789     if (cnt == shape_size) {
790       (*offset)++;
791       break;
792     }
793   }
794   if (shape_vector->empty()) {
795     MS_LOG(ERROR) << "string tensor's shape shouldn't be empty.";
796     return RET_ERROR;
797   }
798   return RET_OK;
799 }
800 }  // namespace dpico
801 }  // namespace mindspore
802