• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_
18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_
19 #ifndef USE_DEPRECATED_API
20 #define USE_DEPRECATED_API
21 #endif
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 #include "ops/primitive_c.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "ir/anf.h"
29 #include "ir/func_graph.h"
30 #include "src/common/utils.h"
31 #include "include/backend/optimizer/pattern_engine.h"
32 #include "ops/fusion/conv2d_backprop_input_fusion.h"
33 #include "schema/inner/model_generated.h"
34 #include "tools/converter/converter_context.h"
35 
36 using PrimitiveCPtr = std::shared_ptr<mindspore::ops::PrimitiveC>;
37 using mindspore::lite::RET_ERROR;
38 using mindspore::lite::RET_OK;
39 using mindspore::lite::STATUS;
40 namespace mindspore {
41 namespace opt {
42 // used for common op, which corresponding value is a boolean.
43 constexpr auto kInferDone = "infer_done";
44 // used for control_flow op(while and if), which corresponding value is a boolean vec.
45 constexpr auto kInferFlags = "infer_flags";
46 inline constexpr int kInputIndexOne = 1;
47 inline constexpr int kInputIndexTwo = 2;
48 inline constexpr int kInputIndexThree = 3;
49 inline constexpr int kInputIndexFour = 4;
50 inline constexpr int kInputIndexFive = 5;
51 inline constexpr int kInputIndexSix = 6;
52 inline constexpr int kInputIndexSeven = 7;
53 inline constexpr size_t kInputSizeTwo = 2;
54 inline constexpr size_t kInputSizeThree = 3;
55 inline constexpr size_t kInputSizeFour = 4;
56 inline constexpr size_t kInputSizeFive = 5;
57 inline const std::vector<int> kNH2NC = {0, 3, 1, 2};
58 inline const std::vector<int> kNC2NH = {0, 2, 3, 1};
59 inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity");
60 inline const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared<Primitive>("Conv2DBackpropInputFusion");
61 
62 using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
63 
64 std::vector<int> CastToInt(const ValuePtr &value);
65 
66 std::vector<std::vector<int>> CastToVec2DInt(const ValuePtr &value);
67 
68 std::vector<float> CastToFloat(const ValuePtr &value);
69 
70 bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
71 
72 int GetPrimitiveType(const AnfNodePtr &node, std::string *name);
73 
74 bool IsRealCNodeKernel(const AnfNodePtr &node);
75 
76 bool IsGraphKernel(const AnfNodePtr &node);
77 
78 bool CheckInputs(const CNodePtr &cnode);
79 
80 ParameterPtr AddNewBiasNode(const float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id);
81 
82 bool IsParamNode(const BaseRef &n);
83 
84 bool IsParamOrValueNodeWithData(const BaseRef &n);
85 
86 bool IsParallelSplitConvNode(const BaseRef &n);
87 
88 bool IsConvNode(const BaseRef &n);
89 
90 bool IsOpType(const BaseRef &n, const PrimitivePtr &prim);
91 
92 bool CheckIsAllInputsParam(const AnfNodePtr &node);
93 
94 size_t GetOutputTensorNum(const AnfNodePtr &node);
95 
96 bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node);
97 
98 AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
99 
100 size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
101 
102 size_t GetListGetItemOutIndex(const CNodePtr &list_get_item);
103 
104 tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node);
105 
106 AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index);
107 
108 STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format);
109 
110 ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::TensorPtr &tensor_info,
111                                 const std::string &node_name, bool keep_origin_dtype = false);
112 
113 ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data,
114                                         const std::string &node_name, bool empty_shape = false);
115 
116 ParameterPtr BuildInt64ValueParameterNode(const FuncGraphPtr &func_graph, const int64_t &data,
117                                           const std::string &node_name, bool empty_shape = false);
118 
119 ValueNodePtr BuildIntVecValueNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data);
120 
121 ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
122                                       const std::string &node_name);
123 
124 ParameterPtr BuildInt64VecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &data,
125                                         const std::string &node_name);
126 
127 ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<int32_t>> &data,
128                                         const std::string &node_name);
129 
130 ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
131                                           const std::string &node_name, bool empty_shape = false);
132 
133 ParameterPtr BuildFloatVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<float> &data,
134                                         const std::string &node_name);
135 
136 ParameterPtr BuildFloat16ValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
137                                             const std::string &node_name, bool empty_shape);
138 
139 ParameterPtr BuildFloat16VecParameterNode(const FuncGraphPtr &func_graph, const std::vector<float16> &data,
140                                           const std::string &node_name);
141 
142 ParameterPtr BuildFloatVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<float>> &data,
143                                           const std::string &node_name);
144 
145 CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &perm,
146                           const std::string &cnode_name);
147 
148 CNodePtr GenCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input_node, const std::string &cnode_name,
149                      const TypeId dst_type, const AbstractBasePtr &abstract);
150 
151 CNodePtr GenReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &shape,
152                         const std::string &cnode_name);
153 
154 CNodePtr GenGatherNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &indices,
155                        const std::string &cnode_name, const std::vector<int> &axis = {0});
156 
157 CNodePtr GenGatherNodeDynamicIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
158                                    const AnfNodePtr &indices_node, const std::string &cnode_name,
159                                    const std::vector<int> &axis);
160 
161 CNodePtr GenConcatNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &input_node_vec,
162                        const std::string &cnode_name, int64_t axis = 0);
163 
164 CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index);
165 
166 STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape);
167 
168 STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index);
169 
170 bool IsTrainOp(const CNodePtr &cnode);
171 
172 bool IsMarkedTrainOp(const CNodePtr &cnode);
173 
174 ShapeVector GetAnfNodeOutputShape(const AnfNodePtr &node, size_t output_idx);
175 
176 int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id);
177 
178 size_t GetOutputSize(const AnfNodePtr &anf_node);
179 
180 bool IsQuantParameterNode(const PrimitivePtr &prim);
181 
182 void UpdateManager(const FuncGraphPtr &func_graph);
183 
184 std::pair<CNodePtr, int> GetRealCertainVarInput(const CNodePtr &cnode, size_t index);
185 
186 int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ);
187 
188 bool CheckAndGetCnodeIndex(const CNodePtr &cnode, size_t *index, const PrimitivePtr &primitive_type);
189 
190 void PrintFuncGraph(const FuncGraphPtr &func_graph, const std::string &output_file);
191 
192 std::vector<KernelWithIndex> GetNodeInputs(const AnfNodePtr &anf_node);
193 
194 bool IsReduceModeMeetOutEqualIn(const PrimitivePtr &prim);
195 
196 STATUS AdjustInputToCnode(const CNodePtr &cnode, size_t input_index);
197 
198 template <const PrimitivePtr *prim = nullptr>
IsSpecifiedNode(const BaseRef & n)199 inline bool IsSpecifiedNode(const BaseRef &n) {
200   if (utils::isa<AnfNodePtr>(n)) {
201     auto anf_node = utils::cast<AnfNodePtr>(n);
202     return CheckPrimitiveType(anf_node, *prim);
203   }
204   return false;
205 }
206 
207 tensor::TensorPtr GetTensorFromParameterNode(const EquivPtr &equiv, const VarPtr &input);
208 const float GetFloatParameterValue(const EquivPtr &equiv, const VarPtr &input);
209 const int GetIntParameterValue(const EquivPtr &equiv, const VarPtr &input);
210 }  // namespace opt
211 }  // namespace mindspore
212 #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_
213