1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "common/anf_util.h"
18 #include <memory>
19 #include <vector>
20 #include <unordered_map>
21 #include <string>
22 #include <map>
23 #include <algorithm>
24 #include "third_party/securec/include/securec.h"
25 #include "common/op_enum.h"
26 #include "common/op_attr.h"
27 #include "common/string_util.h"
28 #include "ops/custom.h"
29 #include "ops/tuple_get_item.h"
30 #include "ops/auto_generate/gen_lite_ops.h"
31 #include "common/check_base.h"
32 namespace mindspore {
33 namespace ops {
34 class PrimitiveC;
35 }
36 } // namespace mindspore
37 namespace mindspore {
38 namespace dpico {
39 namespace {
40 const std::map<TypeId, size_t> kTypeMap = {
41 {kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2},
42 {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1},
43 {kNumberTypeUInt16, 2}, {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
44 {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}, {kNumberTypeComplex64, 8},
45 {kNumberTypeComplex128, 16}};
46 constexpr size_t kTupleGetItemInputSize = 3;
47 constexpr size_t kInputNodeOutputIndexInTupleGetItem = 2;
48 using PrimitiveCPtr = std::shared_ptr<ops::PrimitiveC>;
TypeIdSize(const TypeId data_type)49 size_t TypeIdSize(const TypeId data_type) {
50 const size_t unsupported_type_error = 0;
51 auto iter = kTypeMap.find(data_type);
52 if (iter != kTypeMap.end()) {
53 return iter->second;
54 }
55 return unsupported_type_error;
56 }
57 } // namespace
CheckPrimitiveType(const api::AnfNodePtr & node,const api::PrimitivePtr & primitive_type)58 bool CheckPrimitiveType(const api::AnfNodePtr &node, const api::PrimitivePtr &primitive_type) {
59 if (node == nullptr) {
60 return false;
61 }
62 if (node->isa<api::CNode>()) {
63 auto cnode = node->cast<api::CNodePtr>();
64 return IsPrimitive(cnode->input(0), primitive_type);
65 } else if (node->isa<api::ValueNode>()) {
66 return IsPrimitive(node, primitive_type);
67 }
68 return false;
69 }
70
GetPrimitiveType(const api::AnfNodePtr & node,std::string * name)71 STATUS GetPrimitiveType(const api::AnfNodePtr &node, std::string *name) {
72 if (name == nullptr) {
73 MS_LOG(ERROR) << "name is nulltr.";
74 return RET_ERROR;
75 }
76 if (node == nullptr) {
77 MS_LOG(ERROR) << "node is nullptr.";
78 return RET_ERROR;
79 }
80 if (node->isa<api::CNode>()) {
81 auto cnode = node->cast<api::CNodePtr>();
82 auto primitive = api::GetValueNode<api::PrimitivePtr>(cnode->input(0));
83 if (primitive == nullptr) {
84 MS_LOG(ERROR) << "primitive is nullptr. " << cnode->fullname_with_scope();
85 return RET_ERROR;
86 }
87 if (CheckPrimitiveType(node, api::MakeShared<ops::Custom>())) {
88 auto custom_prim = api::utils::cast<api::SharedPtr<ops::Custom>>(primitive);
89 MS_CHECK_TRUE_MSG(custom_prim != nullptr, RET_ERROR, "custom op is nullptr.");
90 *name = custom_prim->get_type();
91 return RET_OK;
92 } else {
93 *name = primitive->name();
94 return RET_OK;
95 }
96 } else if (node->isa<api::ValueNode>()) {
97 auto fn_value = api::GetValueNode<api::PrimitivePtr>(node);
98 if (fn_value == nullptr) {
99 MS_LOG(ERROR) << "fn_value is nullptr.";
100 return RET_ERROR;
101 }
102 *name = fn_value->name();
103 return RET_OK;
104 }
105 MS_LOG(ERROR) << "There is no name for this node";
106 return RET_ERROR;
107 }
GetShapeVectorFromParameter(const api::AnfNodePtr & anode,ShapeVector * shape_vector)108 STATUS GetShapeVectorFromParameter(const api::AnfNodePtr &anode, ShapeVector *shape_vector) {
109 if (shape_vector == nullptr) {
110 MS_LOG(ERROR) << "shape vector is nullptr.";
111 return RET_ERROR;
112 }
113 if (!api::utils::isa<api::Parameter>(anode)) {
114 MS_LOG(ERROR) << "anode should be parameter node. ";
115 return RET_ERROR;
116 }
117 auto param_node = anode->cast<api::ParameterPtr>();
118 auto abstract_base = param_node->abstract();
119 if (abstract_base == nullptr) {
120 MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
121 return lite::RET_PARAM_INVALID;
122 }
123 if (!api::utils::isa<api::AbstractTensorPtr>(abstract_base)) {
124 MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name();
125 return lite::RET_INPUT_TENSOR_ERROR;
126 }
127 auto abstract_tensor = api::utils::cast<api::AbstractTensorPtr>(abstract_base);
128 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
129 if (!api::utils::isa<api::ShapePtr>(abstract_tensor->shape())) {
130 MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
131 return lite::RET_PARAM_INVALID;
132 }
133 *shape_vector = api::utils::cast<api::ShapePtr>(abstract_tensor->shape())->shape();
134 return RET_OK;
135 }
CastToInt(const api::ValuePtr & value)136 std::vector<int> CastToInt(const api::ValuePtr &value) {
137 if (value == nullptr) {
138 MS_LOG(WARNING) << "valueptr is nullptr.";
139 return {};
140 }
141 std::vector<int> cur_value = {};
142 if (api::utils::isa<api::ValueSequencePtr>(value)) {
143 if (!value->cast<api::ValueSequencePtr>()->value().empty()) {
144 auto origin_value = api::GetValue<std::vector<int64_t>>(value);
145 (void)std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
146 [](int64_t index) { return static_cast<int32_t>(index); });
147 }
148 } else {
149 cur_value.push_back(static_cast<int>(api::GetValue<int64_t>(value)));
150 }
151 return cur_value;
152 }
GetTupleGetItemOutIndex(const api::CNodePtr & tuple_get_item)153 size_t GetTupleGetItemOutIndex(const api::CNodePtr &tuple_get_item) {
154 MS_ASSERT(tuple_get_item != nullptr);
155 if (tuple_get_item->size() != kTupleGetItemInputSize) {
156 MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
157 return SIZE_MAX;
158 }
159 auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
160 MS_ASSERT(output_index_value_node != nullptr);
161 auto value_node = output_index_value_node->cast<api::ValueNodePtr>();
162 MS_ASSERT(value_node != nullptr);
163 auto value_vec = CastToInt(value_node->value());
164 if (value_vec.empty()) {
165 MS_LOG(ERROR) << "value vec is empty.";
166 return SIZE_MAX;
167 }
168 return IntToSize(value_vec.front());
169 }
GetOutputShapesFromCNode(const api::CNodePtr & cnode,std::vector<ShapeVector> * output_shapes)170 STATUS GetOutputShapesFromCNode(const api::CNodePtr &cnode, std::vector<ShapeVector> *output_shapes) {
171 api::AbstractBasePtr abstract = nullptr;
172 if (output_shapes == nullptr) {
173 MS_LOG(ERROR) << "output_shapes is nullptr.";
174 return RET_ERROR;
175 }
176 if (CheckPrimitiveType(cnode, api::MakeShared<ops::TupleGetItem>())) {
177 auto tuple_inputs = cnode->inputs();
178 MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
179 auto get_item_input_cnode = tuple_inputs.at(1);
180 MS_ASSERT(get_item_input_cnode != nullptr);
181 auto idx = GetTupleGetItemOutIndex(cnode);
182 if (!api::utils::isa<api::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
183 MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
184 return RET_ERROR;
185 }
186 auto abstract_tuple = api::utils::cast<api::AbstractTuplePtr>(get_item_input_cnode->abstract());
187 auto abstract_list = abstract_tuple->elements();
188 if (abstract_list.size() <= idx) {
189 MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
190 return RET_ERROR;
191 }
192 abstract = abstract_list[idx];
193 } else {
194 abstract = cnode->abstract();
195 }
196 if (abstract == nullptr) {
197 MS_LOG(ERROR) << "abstract cnode is nullptr. " << cnode->fullname_with_scope();
198 return RET_ERROR;
199 }
200 if (api::utils::isa<api::AbstractTuplePtr>(abstract)) {
201 auto abstract_tuple = api::utils::cast<api::AbstractTuplePtr>(abstract);
202 auto abstract_list = abstract_tuple->elements();
203 for (const auto &elem : abstract_list) {
204 ShapeVector shape_vector;
205 if (FetchShapeFromAbstract(elem, &shape_vector) != RET_OK) {
206 MS_LOG(ERROR) << "fetch shape from abstract tuple elem failed. " << cnode->fullname_with_scope();
207 return RET_ERROR;
208 }
209 if (shape_vector.empty()) {
210 MS_LOG(ERROR) << "shape_vector is empty." << cnode->fullname_with_scope();
211 return RET_ERROR;
212 }
213 (void)output_shapes->emplace_back(shape_vector);
214 }
215 return RET_OK;
216 } else {
217 ShapeVector shape_vector;
218 if (FetchShapeFromAbstract(abstract, &shape_vector) != RET_OK) {
219 MS_LOG(ERROR) << "fetch shape from abstract failed. " << cnode->fullname_with_scope();
220 return RET_ERROR;
221 }
222 if (shape_vector.empty()) {
223 MS_LOG(ERROR) << "shape_vector is empty." << cnode->fullname_with_scope();
224 return RET_ERROR;
225 }
226 (void)output_shapes->emplace_back(shape_vector);
227 }
228 return RET_OK;
229 }
230
GetInputShapeFromCNode(const api::CNodePtr & cnode,size_t input_idx,ShapeVector * shape)231 STATUS GetInputShapeFromCNode(const api::CNodePtr &cnode, size_t input_idx, ShapeVector *shape) {
232 if (shape == nullptr) {
233 MS_LOG(ERROR) << "shape is nullptr.";
234 return RET_ERROR;
235 }
236 auto input_abstract = GetCNodeInputAbstract(cnode, input_idx);
237 if (input_abstract == nullptr) {
238 MS_LOG(ERROR) << "input_abstract is nullptr.";
239 return RET_ERROR;
240 }
241 if (FetchShapeFromAbstract(input_abstract, shape) != RET_OK) {
242 MS_LOG(ERROR) << "fetch shape from abstract failed.";
243 return RET_ERROR;
244 }
245 return RET_OK;
246 }
247
FetchShapeFromAbstract(const api::AbstractBasePtr & abstract,ShapeVector * shape)248 STATUS FetchShapeFromAbstract(const api::AbstractBasePtr &abstract, ShapeVector *shape) {
249 if (shape == nullptr) {
250 MS_LOG(ERROR) << "shape is nullptr.";
251 return RET_ERROR;
252 }
253 if (abstract == nullptr) {
254 MS_LOG(ERROR) << "abstract of cnode is invalid.";
255 return RET_ERROR;
256 }
257 if (!api::utils::isa<api::AbstractTensor>(abstract)) {
258 MS_LOG(ERROR) << "abstract of cnode is invalid.";
259 return RET_ERROR;
260 }
261 auto abstract_tensor = abstract->cast<api::AbstractTensorPtr>();
262 if (!api::utils::isa<api::ShapePtr>(abstract_tensor->shape())) {
263 MS_LOG(ERROR) << "shape of cnode's output is invalid.";
264 return RET_ERROR;
265 }
266 *shape = api::utils::cast<api::ShapePtr>(abstract_tensor->shape())->shape();
267 return RET_OK;
268 }
FetchTypeIdFromAbstract(const api::AbstractBasePtr & abstract,TypeId * type_id)269 STATUS FetchTypeIdFromAbstract(const api::AbstractBasePtr &abstract, TypeId *type_id) {
270 if (type_id == nullptr) {
271 MS_LOG(ERROR) << "type id is nullptr.";
272 return RET_ERROR;
273 }
274 if (abstract == nullptr) {
275 MS_LOG(ERROR) << "abstract of cnode is invalid.";
276 return RET_ERROR;
277 }
278 if (!api::utils::isa<api::AbstractTensor>(abstract)) {
279 MS_LOG(ERROR) << "abstract of cnode is invalid.";
280 return RET_ERROR;
281 }
282 auto abstract_tensor = abstract->cast<api::AbstractTensorPtr>();
283 if (abstract_tensor->element() == nullptr) {
284 MS_LOG(ERROR) << "element of abstract_tensor is nullptr.";
285 return RET_ERROR;
286 }
287 auto type_ptr = abstract_tensor->element()->type();
288 if (type_ptr == nullptr) {
289 MS_LOG(ERROR) << "type_ptr of abstract_tensor is nullptr.";
290 return RET_ERROR;
291 }
292 *type_id = type_ptr->type_id();
293 return RET_OK;
294 }
295
GetAnfNodeOutputShape(const api::AnfNodePtr & input,ShapeVector * shape_vector)296 int GetAnfNodeOutputShape(const api::AnfNodePtr &input, ShapeVector *shape_vector) {
297 if (shape_vector == nullptr) {
298 MS_LOG(ERROR) << "shape vector is nullptr." << input->fullname_with_scope();
299 return RET_ERROR;
300 }
301 if (api::utils::isa<api::ParameterPtr>(input)) {
302 if (GetShapeVectorFromParameter(input, shape_vector) != RET_OK) {
303 MS_LOG(ERROR) << "get output shape for preprocessor failed. " << input->fullname_with_scope();
304 return RET_ERROR;
305 }
306 } else if (api::utils::isa<api::CNodePtr>(input)) {
307 auto input_cnode = input->cast<api::CNodePtr>();
308 std::vector<ShapeVector> output_shapes;
309 if (GetOutputShapesFromCNode(input_cnode, &output_shapes) != RET_OK) {
310 MS_LOG(ERROR) << "get output shapes from cnode failed. " << input_cnode->fullname_with_scope();
311 return RET_ERROR;
312 }
313 if (output_shapes.size() == 1) {
314 *shape_vector = output_shapes.at(0);
315 } else {
316 MS_LOG(ERROR) << input_cnode->fullname_with_scope() << " has " << output_shapes.size()
317 << " output, which should be 1.";
318 return RET_ERROR;
319 }
320 }
321 if (shape_vector->empty()) {
322 MS_LOG(ERROR) << "subgraph input shape shouldn't be empty. " << input->fullname_with_scope();
323 return RET_ERROR;
324 } else if (shape_vector->at(0) < 0) {
325 MS_LOG(WARNING) << " the N axis of " << input->fullname_with_scope() << "'s output shape is " << shape_vector->at(0)
326 << ", which will be set to 1.";
327 shape_vector->at(0) = 1;
328 }
329 return RET_OK;
330 }
331
TypeIdToString(TypeId type_id)332 std::string TypeIdToString(TypeId type_id) {
333 const std::unordered_map<int, std::string> kTypeIdMap{
334 {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"}, {kNumberTypeFloat32, "Float32"},
335 {kNumberTypeInt8, "Int8"}, {kNumberTypeInt16, "Int16"}, {kNumberTypeInt, "Int32"},
336 {kNumberTypeInt32, "Int32"}, {kNumberTypeUInt8, "UInt8"}, {kNumberTypeUInt16, "UInt16"},
337 {kNumberTypeUInt, "UInt32"}, {kNumberTypeUInt32, "UInt32"}, {kObjectTypeString, "String"},
338 {kNumberTypeBool, "Bool"}, {kObjectTypeTensorType, "Tensor"}};
339 std::string type_str = "Unknown";
340 if (kTypeIdMap.find(static_cast<int>(type_id)) != kTypeIdMap.end()) {
341 type_str = kTypeIdMap.at(static_cast<int>(type_id));
342 }
343 return type_str;
344 }
345
CheckInputs(const api::CNodePtr & cnode)346 bool CheckInputs(const api::CNodePtr &cnode) {
347 if (cnode == nullptr) {
348 MS_LOG(ERROR) << "cnode is nullptr.";
349 return false;
350 }
351 auto inputs = cnode->inputs();
352 if (std::any_of(inputs.begin(), inputs.end(), [](const api::AnfNodePtr &anf_node) { return anf_node == nullptr; })) {
353 MS_LOG(ERROR) << "input is nullptr.";
354 return false;
355 }
356 return true;
357 }
GetCustomOutputName(const api::AnfNodePtr & node)358 std::string GetCustomOutputName(const api::AnfNodePtr &node) {
359 std::string output_name;
360 auto input_cnode = node->cast<api::CNodePtr>();
361 if (input_cnode == nullptr) {
362 MS_LOG(ERROR) << "custom node should be cnode. " << node->fullname_with_scope();
363 return "";
364 }
365 if (input_cnode->GetAttr(kOutputsNames) != nullptr) {
366 auto output_names = api::GetValue<std::vector<std::string>>(input_cnode->GetAttr(kOutputsNames));
367 if (output_names.size() == 1) {
368 output_name = output_names.at(0);
369 } else {
370 MS_LOG(ERROR) << "multi-output's custom node shouldn't be a subgraph's input cnode. "
371 << node->fullname_with_scope();
372 return "";
373 }
374 }
375 return output_name;
376 }
CreateTensorInfo(const void * data,size_t data_size,const std::vector<int64_t> & shape,TypeId data_type)377 api::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape,
378 TypeId data_type) {
379 api::TensorPtr tensor_info = nullptr;
380 if (shape.empty() && data_size == TypeIdSize(data_type)) {
381 ShapeVector scalar_shape = {1};
382 tensor_info = api::MakeShared<api::Tensor>(data_type, scalar_shape);
383 if (tensor_info == nullptr) {
384 MS_LOG(ERROR) << "new tensor init failed";
385 return nullptr;
386 }
387 tensor_info->set_shape({});
388 } else {
389 tensor_info = api::MakeShared<api::Tensor>(data_type, shape);
390 if (tensor_info == nullptr) {
391 MS_LOG(ERROR) << "new tensor init failed";
392 return nullptr;
393 }
394 }
395 if (data_size == 0) {
396 return tensor_info;
397 }
398 if (data == nullptr) {
399 MS_LOG(ERROR) << "input tensor data is nullptr";
400 return nullptr;
401 }
402 auto ret = memcpy_s(tensor_info->data(), tensor_info->Size(), data, data_size);
403 if (ret != EOK) {
404 MS_LOG(ERROR) << "memcpy_s error : " << ret;
405 return nullptr;
406 }
407 return tensor_info;
408 }
409
CreateTensorAbstract(const std::vector<int64_t> & shape,TypeId data_type)410 api::AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type) {
411 auto tensor_info = dpico::CreateTensorInfo(nullptr, 0, shape, data_type);
412 if (tensor_info == nullptr) {
413 MS_LOG(ERROR) << "Create tensor info failed";
414 return nullptr;
415 }
416 auto abstract = tensor_info->ToAbstract();
417 if (abstract == nullptr) {
418 MS_LOG(ERROR) << "Create tensor abstarct failed";
419 return nullptr;
420 }
421 return abstract;
422 }
423
InitParameterFromTensorInfo(const api::ParameterPtr & param_node,const api::TensorPtr & tensor_info)424 int InitParameterFromTensorInfo(const api::ParameterPtr ¶m_node, const api::TensorPtr &tensor_info) {
425 if (tensor_info == nullptr) {
426 MS_LOG(ERROR) << "tensor info is nullptr.";
427 return RET_ERROR;
428 }
429 auto abstract_tensor = tensor_info->ToAbstract();
430 if (abstract_tensor == nullptr) {
431 MS_LOG(ERROR) << "Create abstract tensor failed.";
432 return RET_ERROR;
433 }
434 param_node->set_abstract(abstract_tensor);
435 param_node->set_default_param(tensor_info);
436 return RET_OK;
437 }
438
GetCNodeInputAbstract(const api::CNodePtr & cnode,size_t index)439 api::AbstractBasePtr GetCNodeInputAbstract(const api::CNodePtr &cnode, size_t index) {
440 if (cnode == nullptr) {
441 MS_LOG(ERROR) << "CNodePtr is nullptr";
442 return nullptr;
443 }
444 auto inputs = cnode->inputs();
445 if (index >= inputs.size()) {
446 MS_LOG(ERROR) << "index: " << index << " is greater than inputs size " << inputs.size();
447 return nullptr;
448 }
449 auto input = inputs[index];
450 if (input == nullptr) {
451 MS_LOG(ERROR) << "CNode input is nullptr";
452 return nullptr;
453 }
454
455 api::AbstractBasePtr abstract = nullptr;
456 if (api::utils::isa<api::ParameterPtr>(input)) {
457 auto parameter = input->cast<api::ParameterPtr>();
458 abstract = parameter->abstract();
459 } else if (api::utils::isa<api::CNodePtr>(input)) {
460 auto input_cnode = input->cast<api::CNodePtr>();
461 if (CheckPrimitiveType(input_cnode, api::MakeShared<ops::TupleGetItem>())) {
462 auto tuple_inputs = input_cnode->inputs();
463 MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
464 auto get_item_input_cnode = tuple_inputs.at(1);
465 MS_ASSERT(get_item_input_cnode != nullptr);
466 auto idx = GetTupleGetItemOutIndex(input_cnode);
467 if (!api::utils::isa<api::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
468 MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
469 return nullptr;
470 }
471 auto abstract_tuple = api::utils::cast<api::AbstractTuplePtr>(get_item_input_cnode->abstract());
472 auto abstract_list = abstract_tuple->elements();
473 if (abstract_list.size() <= idx) {
474 MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
475 return nullptr;
476 }
477 abstract = abstract_list[idx];
478 } else {
479 abstract = input_cnode->abstract();
480 }
481 } else {
482 MS_LOG(ERROR) << "unsupported input node type";
483 return nullptr;
484 }
485 return abstract;
486 }
487
GetAbstractFromAnfNode(const api::AnfNodePtr & node)488 api::AbstractBasePtr GetAbstractFromAnfNode(const api::AnfNodePtr &node) {
489 api::AbstractBasePtr abstract = nullptr;
490 if (api::utils::isa<api::ParameterPtr>(node)) {
491 auto parameter = node->cast<api::ParameterPtr>();
492 abstract = parameter->abstract();
493 } else if (api::utils::isa<api::CNodePtr>(node)) {
494 auto cnode = node->cast<api::CNodePtr>();
495 if (CheckPrimitiveType(cnode, api::MakeShared<ops::TupleGetItem>())) {
496 auto tuple_inputs = cnode->inputs();
497 MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
498 auto get_item_input_cnode = tuple_inputs.at(1);
499 MS_ASSERT(get_item_input_cnode != nullptr);
500 auto idx = GetTupleGetItemOutIndex(cnode);
501 if (!api::utils::isa<api::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
502 MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
503 return nullptr;
504 }
505 auto abstract_tuple = api::utils::cast<api::AbstractTuplePtr>(get_item_input_cnode->abstract());
506 auto abstract_list = abstract_tuple->elements();
507 if (abstract_list.size() <= idx) {
508 MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
509 return nullptr;
510 }
511 abstract = abstract_list[idx];
512 } else {
513 abstract = cnode->abstract();
514 }
515 }
516 return abstract;
517 }
518
BuildIntValueParameterNode(const api::FuncGraphPtr & func_graph,const int32_t & data,const std::string & node_name)519 api::ParameterPtr BuildIntValueParameterNode(const api::FuncGraphPtr &func_graph, const int32_t &data,
520 const std::string &node_name) {
521 MS_ASSERT(func_graph != nullptr);
522 auto param_node = func_graph->add_parameter();
523 param_node->set_name(node_name);
524
525 auto tensor_info = CreateTensorInfo(&data, sizeof(int32_t), {1}, kNumberTypeInt32);
526 if (tensor_info == nullptr) {
527 MS_LOG(ERROR) << "Create tensor info failed";
528 return nullptr;
529 }
530
531 auto status = InitParameterFromTensorInfo(param_node, tensor_info);
532 if (status != RET_OK) {
533 MS_LOG(ERROR) << "init parameter from tensor info failed";
534 return nullptr;
535 }
536 return param_node;
537 }
538
BuildIntVecParameterNode(const api::FuncGraphPtr & func_graph,const std::vector<int32_t> & data,const std::string & node_name)539 api::ParameterPtr BuildIntVecParameterNode(const api::FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
540 const std::string &node_name) {
541 MS_ASSERT(func_graph != nullptr);
542 MS_CHECK_TRUE_MSG(data.size() != 0, nullptr, "Data size is 0");
543 auto param_node = func_graph->add_parameter();
544 param_node->set_name(node_name);
545
546 std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
547 auto tensor_info = CreateTensorInfo(data.data(), data.size() * sizeof(int32_t), shape_vector, kNumberTypeInt32);
548 if (tensor_info == nullptr) {
549 MS_LOG(ERROR) << "Create tensor info failed";
550 return nullptr;
551 }
552
553 auto status = InitParameterFromTensorInfo(param_node, tensor_info);
554 if (status != RET_OK) {
555 MS_LOG(ERROR) << "init parameter from tensor info failed";
556 return nullptr;
557 }
558
559 return param_node;
560 }
561
BuildIntVec2DParameterNode(const api::FuncGraphPtr & func_graph,const std::vector<std::vector<int32_t>> & data,const std::string & node_name)562 api::ParameterPtr BuildIntVec2DParameterNode(const api::FuncGraphPtr &func_graph,
563 const std::vector<std::vector<int32_t>> &data,
564 const std::string &node_name) {
565 MS_ASSERT(func_graph != nullptr);
566 MS_CHECK_TRUE_MSG(data.size() != 0, nullptr, "Data size is 0");
567 auto param_node = func_graph->add_parameter();
568 param_node->set_name(node_name);
569
570 std::vector<int64_t> shape_vector;
571 shape_vector.push_back(data.size());
572 shape_vector.push_back(kDims2);
573
574 std::vector<int32_t> data_1d;
575 for (auto pair : data) {
576 (void)data_1d.insert(data_1d.end(), pair.begin(), pair.end());
577 }
578
579 auto size = data_1d.size() * sizeof(int32_t);
580 auto tensor_info = CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeInt32);
581 if (tensor_info == nullptr) {
582 MS_LOG(ERROR) << "Create tensor info failed";
583 return nullptr;
584 }
585 auto status = InitParameterFromTensorInfo(param_node, tensor_info);
586 if (status != RET_OK) {
587 MS_LOG(ERROR) << "init parameter from tensor info failed";
588 return nullptr;
589 }
590 return param_node;
591 }
592
BuildFloatValueParameterNode(const api::FuncGraphPtr & func_graph,const float & data,const std::string & node_name)593 api::ParameterPtr BuildFloatValueParameterNode(const api::FuncGraphPtr &func_graph, const float &data,
594 const std::string &node_name) {
595 MS_ASSERT(func_graph != nullptr);
596 auto param_node = func_graph->add_parameter();
597 param_node->set_name(node_name);
598
599 auto tensor_info = CreateTensorInfo(&data, sizeof(float), {1}, kNumberTypeFloat32);
600 if (tensor_info == nullptr) {
601 MS_LOG(ERROR) << "Create tensor info failed";
602 return nullptr;
603 }
604 auto status = InitParameterFromTensorInfo(param_node, tensor_info);
605 if (status != RET_OK) {
606 MS_LOG(ERROR) << "init parameter from tensor info failed";
607 return nullptr;
608 }
609 return param_node;
610 }
611
GenTransposeNode(const api::FuncGraphPtr & func_graph,const api::AnfNodePtr & input_node,const std::vector<int> & perm,const std::string & cnode_name)612 api::CNodePtr GenTransposeNode(const api::FuncGraphPtr &func_graph, const api::AnfNodePtr &input_node,
613 const std::vector<int> &perm, const std::string &cnode_name) {
614 MS_ASSERT(func_graph != nullptr && input_node != nullptr);
615 auto perm_node = BuildIntVecParameterNode(func_graph, perm, cnode_name + "_perm");
616 if (perm_node == nullptr) {
617 MS_LOG(ERROR) << "new perm_node error";
618 return nullptr;
619 }
620 auto trans_prim = api::MakeShared<ops::Transpose>();
621 if (trans_prim == nullptr) {
622 MS_LOG(ERROR) << "new trans_prim failed";
623 return nullptr;
624 }
625 auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node});
626 if (cnode == nullptr) {
627 MS_LOG(ERROR) << "new cnode error";
628 return nullptr;
629 }
630 auto manager = api::FuncGraphManager::Manage(func_graph);
631 if (manager == nullptr) {
632 MS_LOG(ERROR) << "manager is nullptr.";
633 return nullptr;
634 }
635 manager->SetEdge(cnode, 1, input_node);
636 manager->SetEdge(cnode, kInputIndex2, perm_node);
637 cnode->set_fullname_with_scope(cnode_name);
638 return cnode;
639 }
640
GetTensorInfo(const api::AnfNodePtr & node)641 api::TensorPtr GetTensorInfo(const api::AnfNodePtr &node) {
642 MS_ASSERT(node != nullptr);
643 if (!api::utils::isa<api::ParameterPtr>(node)) {
644 if (api::utils::isa<api::ValueNodePtr>(node)) {
645 auto valueNode = node->cast<api::ValueNodePtr>();
646 auto value = valueNode->value()->cast<api::TensorPtr>();
647 if (value != nullptr) {
648 return value;
649 }
650 }
651 MS_LOG(DEBUG) << "get lite param value node neither parameter node or value node";
652 return nullptr;
653 }
654 auto param = node->cast<api::ParameterPtr>();
655 if (param == nullptr) {
656 MS_LOG(ERROR) << "param is nullptr.";
657 return nullptr;
658 }
659 auto tensor_info = param->default_param()->cast<api::TensorPtr>();
660 return tensor_info;
661 }
662
CastToVec2DInt(const api::ValuePtr & value)663 std::vector<std::vector<int>> CastToVec2DInt(const api::ValuePtr &value) {
664 if (value == nullptr) {
665 MS_LOG(WARNING) << "valueptr is nullptr.";
666 return {};
667 }
668
669 std::vector<std::vector<int>> result_value;
670 if (api::utils::isa<api::ValueSequencePtr>(value)) {
671 auto origin_value = api::GetValue<std::vector<std::vector<int64_t>>>(value);
672 for (auto &vec : origin_value) {
673 std::vector<int> cur_value;
674 for (size_t j = 0; j < vec.size(); ++j) {
675 cur_value.push_back(static_cast<int>(vec[j]));
676 }
677 result_value.push_back(cur_value);
678 }
679 }
680 return result_value;
681 }
682
GetBoolAttr(const api::AnfNodePtr & node,const std::string & attr_name)683 bool GetBoolAttr(const api::AnfNodePtr &node, const std::string &attr_name) {
684 auto cnode = node->cast<api::CNodePtr>();
685 if (cnode == nullptr) {
686 MS_LOG(ERROR) << "cur node is not a cnode. " << node->fullname_with_scope();
687 return false;
688 }
689 auto primitive = api::GetValueNode<api::PrimitivePtr>(cnode->input(0));
690 if (primitive == nullptr) {
691 MS_LOG(ERROR) << "primitive is nullptr:" << cnode->fullname_with_scope();
692 return false;
693 }
694 auto value_ptr = primitive->GetAttr(attr_name);
695 if (value_ptr == nullptr) {
696 MS_LOG(ERROR) << "There is no attr named " << attr_name << " for node " << cnode->fullname_with_scope();
697 return false;
698 }
699 return api::GetValue<bool>(value_ptr);
700 }
701
GetDataTypeAndShape(const api::ParameterPtr & param_node,TypeId * data_type,ShapeVector * shape_vector)702 STATUS GetDataTypeAndShape(const api::ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector) {
703 if (param_node == nullptr) {
704 MS_LOG(ERROR) << "param node is nullptr.";
705 return RET_ERROR;
706 }
707 if (data_type == nullptr) {
708 MS_LOG(ERROR) << "data type is nullptr.";
709 return RET_ERROR;
710 }
711 if (shape_vector == nullptr) {
712 MS_LOG(ERROR) << "shape vector is nullptr.";
713 return RET_ERROR;
714 }
715 auto abstract_base = param_node->abstract();
716 if (abstract_base == nullptr) {
717 MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
718 return RET_ERROR;
719 }
720 if (!api::utils::isa<api::AbstractTensorPtr>(abstract_base)) {
721 MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name();
722 return RET_ERROR;
723 }
724 auto abstract_tensor = api::utils::cast<api::AbstractTensorPtr>(abstract_base);
725 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
726 auto typePtr = abstract_tensor->element()->type();
727 MS_ASSERT(typePtr != nullptr);
728 *data_type = typePtr->type_id();
729 if (!api::utils::isa<api::ShapePtr>(abstract_tensor->shape())) {
730 MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
731 return RET_ERROR;
732 }
733 *shape_vector = api::utils::cast<api::ShapePtr>(abstract_tensor->shape())->shape();
734 return RET_OK;
735 }
736
GetShapeVectorFromStringTensor(const api::TensorPtr & tensor_info,ShapeVector * shape_vector,size_t * offset)737 STATUS GetShapeVectorFromStringTensor(const api::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) {
738 if (tensor_info == nullptr) {
739 MS_LOG(ERROR) << "tensor info is nullptr.";
740 return RET_ERROR;
741 }
742 if (shape_vector == nullptr) {
743 MS_LOG(ERROR) << "shape vector is nullptr.";
744 return RET_ERROR;
745 }
746 if (offset == nullptr) {
747 MS_LOG(ERROR) << "offset is nullptr.";
748 return RET_ERROR;
749 }
750 auto data_type = tensor_info->data_type();
751 if (data_type != kObjectTypeString) {
752 MS_LOG(ERROR) << "This function only used for string tensor.";
753 return RET_ERROR;
754 }
755 shape_vector->clear();
756 auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data());
757 std::string shape_str;
758 std::string shape_size_str;
759 *offset = 0;
760 size_t cnt = 0;
761 for (; *offset < tensor_info->Size(); (*offset)++) {
762 if (tensor_data[*offset] == ',') {
763 (*offset)++;
764 break;
765 }
766 shape_size_str.push_back(static_cast<char>(tensor_data[*offset]));
767 }
768 if (*offset == 0) {
769 MS_LOG(ERROR) << "string tensor's dim size not found.";
770 return RET_ERROR;
771 }
772 if (!IsValidUnsignedNum(shape_size_str)) {
773 MS_LOG(ERROR) << "shape_size str must an unsigned int.";
774 return RET_ERROR;
775 }
776 size_t shape_size = std::stoi(shape_size_str);
777 for (; *offset < tensor_info->Size(); (*offset)++) {
778 if (tensor_data[*offset] == ',') {
779 cnt++;
780 if (!IsValidUnsignedNum(shape_str)) {
781 MS_LOG(ERROR) << "shape str must an unsigned int.";
782 return RET_ERROR;
783 }
784 shape_vector->push_back(std::stoi(shape_str));
785 shape_str.clear();
786 } else {
787 shape_str.push_back(static_cast<char>(tensor_data[*offset]));
788 }
789 if (cnt == shape_size) {
790 (*offset)++;
791 break;
792 }
793 }
794 if (shape_vector->empty()) {
795 MS_LOG(ERROR) << "string tensor's shape shouldn't be empty.";
796 return RET_ERROR;
797 }
798 return RET_OK;
799 }
800 } // namespace dpico
801 } // namespace mindspore
802