• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/adapter/acl/common/utils.h"
18 #include <functional>
19 #include "mindspore/core/ops/sequence_ops.h"
20 #include "tools/optimizer/common/gllo_utils.h"
21 #include "tools/common/node_util.h"
22 #include "base/base_ref.h"
23 #include "abstract/dshape.h"
24 #include "abstract/abstract_value.h"
25 #include "include/common/utils/utils.h"
26 #include "src/common/log_util.h"
27 #include "ir/func_graph.h"
28 #include "nnacl/op_base.h"
29 
30 namespace mindspore {
31 namespace lite {
32 namespace acl {
33 namespace {
34 constexpr size_t kTupleGetItemInputSize = 3;
35 constexpr size_t kSecondIndex = 1;
36 constexpr size_t kInvalidSize = SIZE_MAX;
37 }  // namespace
38 
GetTupleGetItemOutIndex(const mindspore::CNodePtr & tuple_get_item)39 static size_t GetTupleGetItemOutIndex(const mindspore::CNodePtr &tuple_get_item) {
40   MS_CHECK_TRUE_MSG(tuple_get_item != nullptr, kInvalidSize, "tuple_get_item is nullptr.");
41   MS_CHECK_TRUE_MSG(tuple_get_item->size() == mindspore::kTupleGetItemInputSize, kInvalidSize,
42                     "The node tuple_get_item must have 3 inputs!");
43   auto output_index_value_node = tuple_get_item->input(mindspore::kInputNodeOutputIndexInTupleGetItem);
44   MS_CHECK_TRUE_MSG(output_index_value_node != nullptr, kInvalidSize, "output_index_value_node is nullptr.");
45   auto value_node = output_index_value_node->cast<mindspore::ValueNodePtr>();
46   MS_CHECK_TRUE_MSG(value_node != nullptr, kInvalidSize, "value_node is nullptr.");
47   auto values = opt::CastToInt(value_node->value());
48   MS_CHECK_TRUE_MSG(values.size() > 0, kInvalidSize, "value_node has no value.");
49   return IntToSize(values.front());
50 }
51 
CheckPrimitiveType(const mindspore::AnfNodePtr & node,const mindspore::PrimitivePtr & primitive_type)52 static bool CheckPrimitiveType(const mindspore::AnfNodePtr &node, const mindspore::PrimitivePtr &primitive_type) {
53   MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
54   if (node->isa<mindspore::CNode>()) {
55     auto cnode = node->cast<mindspore::CNodePtr>();
56     MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr.");
57     return IsPrimitive(cnode->input(0), primitive_type);
58   } else if (node->isa<mindspore::ValueNode>()) {
59     return IsPrimitive(node, primitive_type);
60   }
61   return false;
62 }
63 
GetShapeVectorFromCNode(const mindspore::CNodePtr & cnode,std::vector<int64_t> * shape_vector)64 STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int64_t> *shape_vector) {
65   mindspore::AbstractBasePtr cnode_abstract;
66   if (CheckPrimitiveType(cnode, mindspore::prim::kPrimTupleGetItem)) {
67     auto tuple_inputs = cnode->inputs();
68     MS_CHECK_TRUE_MSG(tuple_inputs.size() == kTupleGetItemInputSize, lite::RET_ERROR, "The node must have 3 inputs!");
69     auto get_item_input_cnode = tuple_inputs.at(kSecondIndex);
70     MS_CHECK_TRUE_MSG(get_item_input_cnode != nullptr, lite::RET_ERROR, "input node is nullptr.");
71     auto idx = GetTupleGetItemOutIndex(cnode);
72     if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
73       MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple, cnode name: "
74                     << get_item_input_cnode->fullname_with_scope();
75       return lite::RET_ERROR;
76     }
77     auto abstract_tuple =
78       mindspore::utils::cast<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
79     auto abstract_list = abstract_tuple->elements();
80     if (abstract_list.size() <= idx) {
81       MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
82       return lite::RET_ERROR;
83     }
84     cnode_abstract = abstract_list[idx];
85   } else {
86     cnode_abstract = cnode->abstract();
87   }
88   CHECK_NULL_RETURN(cnode_abstract);
89   if (cnode_abstract->BuildShape() == mindspore::abstract::kNoShape) {
90     *shape_vector = std::vector<int64_t>();
91     return lite::RET_OK;
92   }
93   if (!mindspore::utils::isa<mindspore::abstract::AbstractTensorPtr>(cnode_abstract)) {
94     MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope();
95     return lite::RET_ERROR;
96   }
97   auto cnode_abstract_tensor = cnode_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
98   CHECK_NULL_RETURN(cnode_abstract_tensor);
99   if (!mindspore::utils::isa<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
100     MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope();
101     return lite::RET_ERROR;
102   }
103   auto shape_ptr = mindspore::utils::cast<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
104   CHECK_NULL_RETURN(shape_ptr);
105   if (shape_ptr->shape().empty()) {
106     MS_LOG(INFO) << "Shape is empty " << cnode->fullname_with_scope();
107   }
108 
109   *shape_vector = shape_ptr->shape();
110   return lite::RET_OK;
111 }
112 
GetTypeFromNode(const AnfNodePtr & node,const size_t tuple_idx)113 TypeId GetTypeFromNode(const AnfNodePtr &node, const size_t tuple_idx) {
114   TypeId type = kNumberTypeFloat32;
115   MS_CHECK_TRUE_MSG(node != nullptr, type, "node is nullptr.");
116   if (utils::isa<CNodePtr>(node)) {
117     auto cnode = node->cast<CNodePtr>();
118     MS_CHECK_TRUE_MSG(cnode != nullptr, type, "cnode is nullptr.");
119     if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
120       auto abstract_tensor = cnode->abstract()->cast<abstract::AbstractTensorPtr>();
121       if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
122         MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
123         return type;
124       }
125       auto type_ptr = abstract_tensor->element()->GetTypeTrack();
126       MS_CHECK_TRUE_MSG(type_ptr != nullptr, type, "type_ptr is nullptr.");
127       type = type_ptr->type_id();
128     } else if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
129       auto abstract_tuple = cnode->abstract()->cast<abstract::AbstractTuplePtr>();
130       if (abstract_tuple->elements().empty()) {
131         MS_LOG(ERROR) << "abstract_tuple elements is empty.";
132         return type;
133       }
134       if (tuple_idx >= abstract_tuple->size()) {
135         MS_LOG(ERROR) << "tuple_idx out of range.";
136         return type;
137       }
138       auto abstract_base = abstract_tuple->elements()[tuple_idx];
139       if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
140         auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
141         if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
142           MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr";
143           return type;
144         }
145         auto type_ptr = abstract_tensor->element()->GetTypeTrack();
146         MS_CHECK_TRUE_MSG(type_ptr != nullptr, type, "type_ptr is nullptr");
147         type = type_ptr->type_id();
148       }
149     }
150     MS_LOG(INFO) << "node type id is " << type;
151   }
152   return type;
153 }
154 
GetIntParameterData(const ParameterPtr & param_ptr)155 std::vector<int> GetIntParameterData(const ParameterPtr &param_ptr) {
156   std::vector<int> result;
157   MS_CHECK_TRUE_MSG(param_ptr != nullptr, result, "Param is nullptr.");
158 
159   if (!param_ptr->has_default()) {
160     MS_LOG(DEBUG) << "Param has not default.";
161     return result;
162   }
163   auto default_param = param_ptr->default_param();
164   MS_CHECK_TRUE_MSG(default_param != nullptr, result, "default_param is nullptr.");
165   if (!utils::isa<tensor::TensorPtr>(default_param)) {
166     MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
167     return result;
168   }
169   auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
170   MS_CHECK_TRUE_MSG(default_param_ptr != nullptr, result, "default_param_ptr is nullptr.");
171   if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
172     MS_LOG(DEBUG) << "Default param is not int.";
173     return result;
174   }
175 
176   auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
177   MS_CHECK_TRUE_MSG(ptr != nullptr, result, "ptr is nullptr.");
178   int shape_size =
179     std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
180   for (int i = 0; i < shape_size; i++) {
181     result.emplace_back(ptr[i]);
182   }
183   return result;
184 }
185 
GetInt64ParameterData(const ParameterPtr & param_ptr)186 std::vector<int64_t> GetInt64ParameterData(const ParameterPtr &param_ptr) {
187   std::vector<int64_t> result;
188   MS_CHECK_TRUE_MSG(param_ptr != nullptr, result, "Param is nullptr.");
189 
190   if (!param_ptr->has_default()) {
191     MS_LOG(DEBUG) << "Param has not default.";
192     return result;
193   }
194   auto default_param = param_ptr->default_param();
195   MS_CHECK_TRUE_MSG(default_param != nullptr, result, "default_param is nullptr.");
196   if (!utils::isa<tensor::TensorPtr>(default_param)) {
197     MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
198     return result;
199   }
200   auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
201   MS_CHECK_TRUE_MSG(default_param_ptr != nullptr, result, "default_param_ptr is nullptr.");
202   if (default_param_ptr->data_type() != kNumberTypeInt64) {
203     MS_LOG(DEBUG) << "Default param is not int64.";
204     return result;
205   }
206 
207   auto ptr = reinterpret_cast<int64_t *>(default_param_ptr->data_c());
208   MS_CHECK_TRUE_MSG(ptr != nullptr, result, "ptr is nullptr.");
209   int shape_size =
210     std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
211   for (int i = 0; i < shape_size; i++) {
212     result.emplace_back(ptr[i]);
213   }
214   return result;
215 }
216 
GetFloatParameterData(const ParameterPtr & param_ptr)217 std::vector<float> GetFloatParameterData(const ParameterPtr &param_ptr) {
218   std::vector<float> result;
219   MS_CHECK_TRUE_MSG(param_ptr != nullptr, result, "Param is nullptr.");
220 
221   if (!param_ptr->has_default()) {
222     MS_LOG(DEBUG) << "Param has not default.";
223     return result;
224   }
225   auto default_param = param_ptr->default_param();
226   MS_CHECK_TRUE_MSG(default_param != nullptr, result, "default_param is nullptr.");
227   if (!utils::isa<tensor::TensorPtr>(default_param)) {
228     MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
229     return result;
230   }
231   auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
232   MS_CHECK_TRUE_MSG(default_param_ptr != nullptr, result, "default_param_ptr is nullptr.");
233   if (default_param_ptr->data_type() != kNumberTypeFloat32 && default_param_ptr->data_type() != kNumberTypeFloat) {
234     MS_LOG(DEBUG) << "Default param is not int.";
235     return result;
236   }
237 
238   auto ptr = reinterpret_cast<float *>(default_param_ptr->data_c());
239   MS_CHECK_TRUE_MSG(ptr != nullptr, result, "ptr is nullptr.");
240   int shape_size =
241     std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<float>());
242   for (int i = 0; i < shape_size; i++) {
243     result.emplace_back(ptr[i]);
244   }
245   return result;
246 }
247 
IsCaseNode(const CNodePtr node)248 bool IsCaseNode(const CNodePtr node) {
249   MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
250   if (node->input(0) == nullptr) {
251     MS_LOG(WARNING) << "The input of node is nullptr.";
252     return false;
253   }
254   if (!node->inputs().empty() && node->input(0)->isa<CNode>() &&
255       GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") {
256     return true;
257   }
258   return false;
259 }
260 
GetCNodeTargetFuncName(const CNodePtr & cnode)261 std::string GetCNodeTargetFuncName(const CNodePtr &cnode) {
262   if (IsCaseNode(cnode)) {
263     return string("Case");
264   }
265   auto name = GetCNodeFuncName(cnode);
266   if (name == "switch_layer") {
267     name = "";
268   }
269   return name;
270 }
271 
DelRedundantParameter(const FuncGraphPtr & func_graph)272 STATUS DelRedundantParameter(const FuncGraphPtr &func_graph) {
273   CHECK_NULL_RETURN(func_graph);
274   auto nodes = TopoSort(func_graph->get_return());
275   auto parameters = func_graph->parameters();
276   for (auto &parameter : parameters) {
277     CHECK_NULL_RETURN(parameter);
278     if (std::find(nodes.begin(), nodes.end(), parameter) == nodes.end() && !lite::IsGraphInput(parameter)) {
279       func_graph->DropNode(parameter);
280     }
281   }
282   return lite::RET_OK;
283 }
284 }  // namespace acl
285 }  // namespace lite
286 }  // namespace mindspore
287