1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/acl/common/utils.h"
18 #include <functional>
19 #include "tools/optimizer/common/gllo_utils.h"
20 #include "base/base_ref.h"
21 #include "base/core_ops.h"
22 #include "abstract/dshape.h"
23 #include "abstract/abstract_value.h"
24 #include "utils/utils.h"
25 #include "src/common/log_util.h"
26
27 namespace mindspore {
28 namespace lite {
29 namespace acl {
30 namespace {
31 constexpr size_t kTupleGetItemInputSize = 3;
32 constexpr size_t kSecondIndex = 1;
33 constexpr size_t kInvalidSize = SIZE_MAX;
34 } // namespace
35
GetTupleGetItemOutIndex(const mindspore::CNodePtr & tuple_get_item)36 static size_t GetTupleGetItemOutIndex(const mindspore::CNodePtr &tuple_get_item) {
37 MS_ASSERT(tuple_get_item != nullptr);
38 if (tuple_get_item->size() != mindspore::kTupleGetItemInputSize) {
39 MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
40 return kInvalidSize;
41 }
42 auto output_index_value_node = tuple_get_item->input(mindspore::kInputNodeOutputIndexInTupleGetItem);
43 MS_ASSERT(output_index_value_node != nullptr);
44 auto value_node = output_index_value_node->cast<mindspore::ValueNodePtr>();
45 MS_ASSERT(value_node != nullptr);
46 return IntToSize(opt::CastToInt(value_node->value()).front());
47 }
48
CheckPrimitiveType(const mindspore::AnfNodePtr & node,const mindspore::PrimitivePtr & primitive_type)49 static bool CheckPrimitiveType(const mindspore::AnfNodePtr &node, const mindspore::PrimitivePtr &primitive_type) {
50 if (node == nullptr) {
51 return false;
52 }
53 if (node->isa<mindspore::CNode>()) {
54 auto cnode = node->cast<mindspore::CNodePtr>();
55 return IsPrimitive(cnode->input(0), primitive_type);
56 } else if (node->isa<mindspore::ValueNode>()) {
57 return IsPrimitive(node, primitive_type);
58 }
59 return false;
60 }
61
GetShapeVectorFromCNode(const mindspore::CNodePtr & cnode,std::vector<int64_t> * shape_vector)62 STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int64_t> *shape_vector) {
63 mindspore::AbstractBasePtr cnode_abstract;
64 if (CheckPrimitiveType(cnode, mindspore::prim::kPrimTupleGetItem)) {
65 auto tuple_inputs = cnode->inputs();
66 MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
67 auto get_item_input_cnode = tuple_inputs.at(kSecondIndex);
68 MS_ASSERT(get_item_input_cnode != nullptr);
69 auto idx = GetTupleGetItemOutIndex(cnode);
70 if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
71 MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
72 return lite::RET_ERROR;
73 }
74 auto abstract_tuple =
75 mindspore::utils::cast<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
76 auto abstract_list = abstract_tuple->elements();
77 if (abstract_list.size() <= idx) {
78 MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
79 return lite::RET_ERROR;
80 }
81 cnode_abstract = abstract_list[idx];
82 } else {
83 cnode_abstract = cnode->abstract();
84 }
85 CHECK_NULL_RETURN(cnode_abstract);
86 if (!mindspore::utils::isa<mindspore::abstract::AbstractTensorPtr>(cnode_abstract)) {
87 MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope();
88 return lite::RET_ERROR;
89 }
90 auto cnode_abstract_tensor = cnode_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
91 CHECK_NULL_RETURN(cnode_abstract_tensor);
92 if (!mindspore::utils::isa<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
93 MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope();
94 return lite::RET_ERROR;
95 }
96 auto shape_ptr = mindspore::utils::cast<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
97 CHECK_NULL_RETURN(shape_ptr);
98 if (shape_ptr->shape().empty()) {
99 MS_LOG(WARNING) << "Shape is empty " << cnode->fullname_with_scope();
100 }
101
102 *shape_vector = shape_ptr->shape();
103 return lite::RET_OK;
104 }
105
GetTypeFromNode(const AnfNodePtr & node)106 TypeId GetTypeFromNode(const AnfNodePtr &node) {
107 TypeId type = kNumberTypeFloat32;
108 if (utils::isa<CNodePtr>(node)) {
109 auto cnode = node->cast<CNodePtr>();
110 if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
111 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
112 if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
113 MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
114 return type;
115 }
116 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
117 type = type_ptr->type_id();
118 }
119 MS_LOG(INFO) << "node type id is " << type;
120 }
121 return type;
122 }
123
GetIntParameterData(const ParameterPtr & param_ptr)124 std::vector<int> GetIntParameterData(const ParameterPtr ¶m_ptr) {
125 std::vector<int> result;
126 if (param_ptr == nullptr) {
127 MS_LOG(DEBUG) << "Param is nullptr.";
128 return result;
129 }
130
131 if (!param_ptr->has_default()) {
132 MS_LOG(DEBUG) << "Param has not default.";
133 return result;
134 }
135 auto default_param = param_ptr->default_param();
136 if (!utils::isa<tensor::TensorPtr>(default_param)) {
137 MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
138 return result;
139 }
140 auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
141 if (default_param_ptr == nullptr) {
142 MS_LOG(DEBUG) << "Default param ptr is nullptr.";
143 return result;
144 }
145 if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
146 MS_LOG(DEBUG) << "Default param is not int.";
147 return result;
148 }
149
150 auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
151 int shape_size =
152 std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
153 for (int i = 0; i < shape_size; i++) {
154 result.emplace_back(ptr[i]);
155 }
156 return result;
157 }
158
IsCaseNode(const CNodePtr node)159 bool IsCaseNode(const CNodePtr node) {
160 if (node->input(0) == nullptr) {
161 MS_LOG(WARNING) << "The input of node is nullptr.";
162 return false;
163 }
164 if (!node->inputs().empty() && node->input(0)->isa<CNode>() &&
165 GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") {
166 return true;
167 }
168 return false;
169 }
170
GetCNodeTargetFuncName(const CNodePtr & cnode)171 std::string GetCNodeTargetFuncName(const CNodePtr &cnode) {
172 if (IsCaseNode(cnode)) {
173 return string("Case");
174 }
175 auto name = GetCNodeFuncName(cnode);
176 if (name == "switch_layer") {
177 name = "";
178 }
179 return name;
180 }
181 } // namespace acl
182 } // namespace lite
183 } // namespace mindspore
184