1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/adapter/acl/common/utils.h"
18 #include <functional>
19 #include "mindspore/core/ops/sequence_ops.h"
20 #include "tools/optimizer/common/gllo_utils.h"
21 #include "tools/common/node_util.h"
22 #include "base/base_ref.h"
23 #include "abstract/dshape.h"
24 #include "abstract/abstract_value.h"
25 #include "include/common/utils/utils.h"
26 #include "src/common/log_util.h"
27 #include "ir/func_graph.h"
28 #include "nnacl/op_base.h"
29
30 namespace mindspore {
31 namespace lite {
32 namespace acl {
33 namespace {
34 constexpr size_t kTupleGetItemInputSize = 3;
35 constexpr size_t kSecondIndex = 1;
36 constexpr size_t kInvalidSize = SIZE_MAX;
37 } // namespace
38
GetTupleGetItemOutIndex(const mindspore::CNodePtr & tuple_get_item)39 static size_t GetTupleGetItemOutIndex(const mindspore::CNodePtr &tuple_get_item) {
40 MS_CHECK_TRUE_MSG(tuple_get_item != nullptr, kInvalidSize, "tuple_get_item is nullptr.");
41 MS_CHECK_TRUE_MSG(tuple_get_item->size() == mindspore::kTupleGetItemInputSize, kInvalidSize,
42 "The node tuple_get_item must have 3 inputs!");
43 auto output_index_value_node = tuple_get_item->input(mindspore::kInputNodeOutputIndexInTupleGetItem);
44 MS_CHECK_TRUE_MSG(output_index_value_node != nullptr, kInvalidSize, "output_index_value_node is nullptr.");
45 auto value_node = output_index_value_node->cast<mindspore::ValueNodePtr>();
46 MS_CHECK_TRUE_MSG(value_node != nullptr, kInvalidSize, "value_node is nullptr.");
47 auto values = opt::CastToInt(value_node->value());
48 MS_CHECK_TRUE_MSG(values.size() > 0, kInvalidSize, "value_node has no value.");
49 return IntToSize(values.front());
50 }
51
CheckPrimitiveType(const mindspore::AnfNodePtr & node,const mindspore::PrimitivePtr & primitive_type)52 static bool CheckPrimitiveType(const mindspore::AnfNodePtr &node, const mindspore::PrimitivePtr &primitive_type) {
53 MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
54 if (node->isa<mindspore::CNode>()) {
55 auto cnode = node->cast<mindspore::CNodePtr>();
56 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr.");
57 return IsPrimitive(cnode->input(0), primitive_type);
58 } else if (node->isa<mindspore::ValueNode>()) {
59 return IsPrimitive(node, primitive_type);
60 }
61 return false;
62 }
63
GetShapeVectorFromCNode(const mindspore::CNodePtr & cnode,std::vector<int64_t> * shape_vector)64 STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int64_t> *shape_vector) {
65 mindspore::AbstractBasePtr cnode_abstract;
66 if (CheckPrimitiveType(cnode, mindspore::prim::kPrimTupleGetItem)) {
67 auto tuple_inputs = cnode->inputs();
68 MS_CHECK_TRUE_MSG(tuple_inputs.size() == kTupleGetItemInputSize, lite::RET_ERROR, "The node must have 3 inputs!");
69 auto get_item_input_cnode = tuple_inputs.at(kSecondIndex);
70 MS_CHECK_TRUE_MSG(get_item_input_cnode != nullptr, lite::RET_ERROR, "input node is nullptr.");
71 auto idx = GetTupleGetItemOutIndex(cnode);
72 if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
73 MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple, cnode name: "
74 << get_item_input_cnode->fullname_with_scope();
75 return lite::RET_ERROR;
76 }
77 auto abstract_tuple =
78 mindspore::utils::cast<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
79 auto abstract_list = abstract_tuple->elements();
80 if (abstract_list.size() <= idx) {
81 MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
82 return lite::RET_ERROR;
83 }
84 cnode_abstract = abstract_list[idx];
85 } else {
86 cnode_abstract = cnode->abstract();
87 }
88 CHECK_NULL_RETURN(cnode_abstract);
89 if (cnode_abstract->BuildShape() == mindspore::abstract::kNoShape) {
90 *shape_vector = std::vector<int64_t>();
91 return lite::RET_OK;
92 }
93 if (!mindspore::utils::isa<mindspore::abstract::AbstractTensorPtr>(cnode_abstract)) {
94 MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope();
95 return lite::RET_ERROR;
96 }
97 auto cnode_abstract_tensor = cnode_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
98 CHECK_NULL_RETURN(cnode_abstract_tensor);
99 if (!mindspore::utils::isa<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
100 MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope();
101 return lite::RET_ERROR;
102 }
103 auto shape_ptr = mindspore::utils::cast<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
104 CHECK_NULL_RETURN(shape_ptr);
105 if (shape_ptr->shape().empty()) {
106 MS_LOG(INFO) << "Shape is empty " << cnode->fullname_with_scope();
107 }
108
109 *shape_vector = shape_ptr->shape();
110 return lite::RET_OK;
111 }
112
GetTypeFromNode(const AnfNodePtr & node,const size_t tuple_idx)113 TypeId GetTypeFromNode(const AnfNodePtr &node, const size_t tuple_idx) {
114 TypeId type = kNumberTypeFloat32;
115 MS_CHECK_TRUE_MSG(node != nullptr, type, "node is nullptr.");
116 if (utils::isa<CNodePtr>(node)) {
117 auto cnode = node->cast<CNodePtr>();
118 MS_CHECK_TRUE_MSG(cnode != nullptr, type, "cnode is nullptr.");
119 if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
120 auto abstract_tensor = cnode->abstract()->cast<abstract::AbstractTensorPtr>();
121 if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
122 MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
123 return type;
124 }
125 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
126 MS_CHECK_TRUE_MSG(type_ptr != nullptr, type, "type_ptr is nullptr.");
127 type = type_ptr->type_id();
128 } else if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
129 auto abstract_tuple = cnode->abstract()->cast<abstract::AbstractTuplePtr>();
130 if (abstract_tuple->elements().empty()) {
131 MS_LOG(ERROR) << "abstract_tuple elements is empty.";
132 return type;
133 }
134 if (tuple_idx >= abstract_tuple->size()) {
135 MS_LOG(ERROR) << "tuple_idx out of range.";
136 return type;
137 }
138 auto abstract_base = abstract_tuple->elements()[tuple_idx];
139 if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
140 auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
141 if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
142 MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr";
143 return type;
144 }
145 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
146 MS_CHECK_TRUE_MSG(type_ptr != nullptr, type, "type_ptr is nullptr");
147 type = type_ptr->type_id();
148 }
149 }
150 MS_LOG(INFO) << "node type id is " << type;
151 }
152 return type;
153 }
154
GetIntParameterData(const ParameterPtr & param_ptr)155 std::vector<int> GetIntParameterData(const ParameterPtr ¶m_ptr) {
156 std::vector<int> result;
157 MS_CHECK_TRUE_MSG(param_ptr != nullptr, result, "Param is nullptr.");
158
159 if (!param_ptr->has_default()) {
160 MS_LOG(DEBUG) << "Param has not default.";
161 return result;
162 }
163 auto default_param = param_ptr->default_param();
164 MS_CHECK_TRUE_MSG(default_param != nullptr, result, "default_param is nullptr.");
165 if (!utils::isa<tensor::TensorPtr>(default_param)) {
166 MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
167 return result;
168 }
169 auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
170 MS_CHECK_TRUE_MSG(default_param_ptr != nullptr, result, "default_param_ptr is nullptr.");
171 if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
172 MS_LOG(DEBUG) << "Default param is not int.";
173 return result;
174 }
175
176 auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
177 MS_CHECK_TRUE_MSG(ptr != nullptr, result, "ptr is nullptr.");
178 int shape_size =
179 std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
180 for (int i = 0; i < shape_size; i++) {
181 result.emplace_back(ptr[i]);
182 }
183 return result;
184 }
185
GetInt64ParameterData(const ParameterPtr & param_ptr)186 std::vector<int64_t> GetInt64ParameterData(const ParameterPtr ¶m_ptr) {
187 std::vector<int64_t> result;
188 MS_CHECK_TRUE_MSG(param_ptr != nullptr, result, "Param is nullptr.");
189
190 if (!param_ptr->has_default()) {
191 MS_LOG(DEBUG) << "Param has not default.";
192 return result;
193 }
194 auto default_param = param_ptr->default_param();
195 MS_CHECK_TRUE_MSG(default_param != nullptr, result, "default_param is nullptr.");
196 if (!utils::isa<tensor::TensorPtr>(default_param)) {
197 MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
198 return result;
199 }
200 auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
201 MS_CHECK_TRUE_MSG(default_param_ptr != nullptr, result, "default_param_ptr is nullptr.");
202 if (default_param_ptr->data_type() != kNumberTypeInt64) {
203 MS_LOG(DEBUG) << "Default param is not int64.";
204 return result;
205 }
206
207 auto ptr = reinterpret_cast<int64_t *>(default_param_ptr->data_c());
208 MS_CHECK_TRUE_MSG(ptr != nullptr, result, "ptr is nullptr.");
209 int shape_size =
210 std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
211 for (int i = 0; i < shape_size; i++) {
212 result.emplace_back(ptr[i]);
213 }
214 return result;
215 }
216
GetFloatParameterData(const ParameterPtr & param_ptr)217 std::vector<float> GetFloatParameterData(const ParameterPtr ¶m_ptr) {
218 std::vector<float> result;
219 MS_CHECK_TRUE_MSG(param_ptr != nullptr, result, "Param is nullptr.");
220
221 if (!param_ptr->has_default()) {
222 MS_LOG(DEBUG) << "Param has not default.";
223 return result;
224 }
225 auto default_param = param_ptr->default_param();
226 MS_CHECK_TRUE_MSG(default_param != nullptr, result, "default_param is nullptr.");
227 if (!utils::isa<tensor::TensorPtr>(default_param)) {
228 MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
229 return result;
230 }
231 auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
232 MS_CHECK_TRUE_MSG(default_param_ptr != nullptr, result, "default_param_ptr is nullptr.");
233 if (default_param_ptr->data_type() != kNumberTypeFloat32 && default_param_ptr->data_type() != kNumberTypeFloat) {
234 MS_LOG(DEBUG) << "Default param is not int.";
235 return result;
236 }
237
238 auto ptr = reinterpret_cast<float *>(default_param_ptr->data_c());
239 MS_CHECK_TRUE_MSG(ptr != nullptr, result, "ptr is nullptr.");
240 int shape_size =
241 std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<float>());
242 for (int i = 0; i < shape_size; i++) {
243 result.emplace_back(ptr[i]);
244 }
245 return result;
246 }
247
IsCaseNode(const CNodePtr node)248 bool IsCaseNode(const CNodePtr node) {
249 MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
250 if (node->input(0) == nullptr) {
251 MS_LOG(WARNING) << "The input of node is nullptr.";
252 return false;
253 }
254 if (!node->inputs().empty() && node->input(0)->isa<CNode>() &&
255 GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") {
256 return true;
257 }
258 return false;
259 }
260
GetCNodeTargetFuncName(const CNodePtr & cnode)261 std::string GetCNodeTargetFuncName(const CNodePtr &cnode) {
262 if (IsCaseNode(cnode)) {
263 return string("Case");
264 }
265 auto name = GetCNodeFuncName(cnode);
266 if (name == "switch_layer") {
267 name = "";
268 }
269 return name;
270 }
271
DelRedundantParameter(const FuncGraphPtr & func_graph)272 STATUS DelRedundantParameter(const FuncGraphPtr &func_graph) {
273 CHECK_NULL_RETURN(func_graph);
274 auto nodes = TopoSort(func_graph->get_return());
275 auto parameters = func_graph->parameters();
276 for (auto ¶meter : parameters) {
277 CHECK_NULL_RETURN(parameter);
278 if (std::find(nodes.begin(), nodes.end(), parameter) == nodes.end() && !lite::IsGraphInput(parameter)) {
279 func_graph->DropNode(parameter);
280 }
281 }
282 return lite::RET_OK;
283 }
284 } // namespace acl
285 } // namespace lite
286 } // namespace mindspore
287