• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/acl/common/utils.h"
18 #include <functional>
19 #include "tools/optimizer/common/gllo_utils.h"
20 #include "base/base_ref.h"
21 #include "base/core_ops.h"
22 #include "abstract/dshape.h"
23 #include "abstract/abstract_value.h"
24 #include "utils/utils.h"
25 #include "src/common/log_util.h"
26 
27 namespace mindspore {
28 namespace lite {
29 namespace acl {
30 namespace {
31 constexpr size_t kTupleGetItemInputSize = 3;
32 constexpr size_t kSecondIndex = 1;
33 constexpr size_t kInvalidSize = SIZE_MAX;
34 }  // namespace
35 
GetTupleGetItemOutIndex(const mindspore::CNodePtr & tuple_get_item)36 static size_t GetTupleGetItemOutIndex(const mindspore::CNodePtr &tuple_get_item) {
37   MS_ASSERT(tuple_get_item != nullptr);
38   if (tuple_get_item->size() != mindspore::kTupleGetItemInputSize) {
39     MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
40     return kInvalidSize;
41   }
42   auto output_index_value_node = tuple_get_item->input(mindspore::kInputNodeOutputIndexInTupleGetItem);
43   MS_ASSERT(output_index_value_node != nullptr);
44   auto value_node = output_index_value_node->cast<mindspore::ValueNodePtr>();
45   MS_ASSERT(value_node != nullptr);
46   return IntToSize(opt::CastToInt(value_node->value()).front());
47 }
48 
CheckPrimitiveType(const mindspore::AnfNodePtr & node,const mindspore::PrimitivePtr & primitive_type)49 static bool CheckPrimitiveType(const mindspore::AnfNodePtr &node, const mindspore::PrimitivePtr &primitive_type) {
50   if (node == nullptr) {
51     return false;
52   }
53   if (node->isa<mindspore::CNode>()) {
54     auto cnode = node->cast<mindspore::CNodePtr>();
55     return IsPrimitive(cnode->input(0), primitive_type);
56   } else if (node->isa<mindspore::ValueNode>()) {
57     return IsPrimitive(node, primitive_type);
58   }
59   return false;
60 }
61 
GetShapeVectorFromCNode(const mindspore::CNodePtr & cnode,std::vector<int64_t> * shape_vector)62 STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int64_t> *shape_vector) {
63   mindspore::AbstractBasePtr cnode_abstract;
64   if (CheckPrimitiveType(cnode, mindspore::prim::kPrimTupleGetItem)) {
65     auto tuple_inputs = cnode->inputs();
66     MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
67     auto get_item_input_cnode = tuple_inputs.at(kSecondIndex);
68     MS_ASSERT(get_item_input_cnode != nullptr);
69     auto idx = GetTupleGetItemOutIndex(cnode);
70     if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
71       MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
72       return lite::RET_ERROR;
73     }
74     auto abstract_tuple =
75       mindspore::utils::cast<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
76     auto abstract_list = abstract_tuple->elements();
77     if (abstract_list.size() <= idx) {
78       MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
79       return lite::RET_ERROR;
80     }
81     cnode_abstract = abstract_list[idx];
82   } else {
83     cnode_abstract = cnode->abstract();
84   }
85   CHECK_NULL_RETURN(cnode_abstract);
86   if (!mindspore::utils::isa<mindspore::abstract::AbstractTensorPtr>(cnode_abstract)) {
87     MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope();
88     return lite::RET_ERROR;
89   }
90   auto cnode_abstract_tensor = cnode_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
91   CHECK_NULL_RETURN(cnode_abstract_tensor);
92   if (!mindspore::utils::isa<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
93     MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope();
94     return lite::RET_ERROR;
95   }
96   auto shape_ptr = mindspore::utils::cast<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
97   CHECK_NULL_RETURN(shape_ptr);
98   if (shape_ptr->shape().empty()) {
99     MS_LOG(WARNING) << "Shape is empty " << cnode->fullname_with_scope();
100   }
101 
102   *shape_vector = shape_ptr->shape();
103   return lite::RET_OK;
104 }
105 
GetTypeFromNode(const AnfNodePtr & node)106 TypeId GetTypeFromNode(const AnfNodePtr &node) {
107   TypeId type = kNumberTypeFloat32;
108   if (utils::isa<CNodePtr>(node)) {
109     auto cnode = node->cast<CNodePtr>();
110     if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
111       auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
112       if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
113         MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
114         return type;
115       }
116       auto type_ptr = abstract_tensor->element()->GetTypeTrack();
117       type = type_ptr->type_id();
118     }
119     MS_LOG(INFO) << "node type id is " << type;
120   }
121   return type;
122 }
123 
GetIntParameterData(const ParameterPtr & param_ptr)124 std::vector<int> GetIntParameterData(const ParameterPtr &param_ptr) {
125   std::vector<int> result;
126   if (param_ptr == nullptr) {
127     MS_LOG(DEBUG) << "Param is nullptr.";
128     return result;
129   }
130 
131   if (!param_ptr->has_default()) {
132     MS_LOG(DEBUG) << "Param has not default.";
133     return result;
134   }
135   auto default_param = param_ptr->default_param();
136   if (!utils::isa<tensor::TensorPtr>(default_param)) {
137     MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
138     return result;
139   }
140   auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
141   if (default_param_ptr == nullptr) {
142     MS_LOG(DEBUG) << "Default param ptr is nullptr.";
143     return result;
144   }
145   if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
146     MS_LOG(DEBUG) << "Default param is not int.";
147     return result;
148   }
149 
150   auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
151   int shape_size =
152     std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
153   for (int i = 0; i < shape_size; i++) {
154     result.emplace_back(ptr[i]);
155   }
156   return result;
157 }
158 
IsCaseNode(const CNodePtr node)159 bool IsCaseNode(const CNodePtr node) {
160   if (node->input(0) == nullptr) {
161     MS_LOG(WARNING) << "The input of node is nullptr.";
162     return false;
163   }
164   if (!node->inputs().empty() && node->input(0)->isa<CNode>() &&
165       GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") {
166     return true;
167   }
168   return false;
169 }
170 
GetCNodeTargetFuncName(const CNodePtr & cnode)171 std::string GetCNodeTargetFuncName(const CNodePtr &cnode) {
172   if (IsCaseNode(cnode)) {
173     return string("Case");
174   }
175   auto name = GetCNodeFuncName(cnode);
176   if (name == "switch_layer") {
177     name = "";
178   }
179   return name;
180 }
181 }  // namespace acl
182 }  // namespace lite
183 }  // namespace mindspore
184