1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/common/node_util.h"
18 #include <memory>
19 #include <set>
20 #include <vector>
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/nn_optimizer_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "src/common/ops/populate/populate_register.h"
26 #include "src/common/common.h"
27 #include "src/common/log_adapter.h"
28 #include "tools/common/graph_util.h"
29 #include "tools/common/tensor_util.h"
30 #include "src/litert/infer_manager.h"
31 #include "mindspore/core/ops/switch.h"
32 #include "mindspore/core/ops/call.h"
33 #include "mindspore/core/ops/fusion/partial_fusion.h"
34 #include "nnacl/op_base.h"
35
36 namespace mindspore {
37 namespace lite {
38 constexpr size_t kInitialSize = 1024;
GetInputCNode(const CNodePtr & cnode)39 std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode) {
40 if (cnode == nullptr) {
41 return {};
42 }
43 std::vector<CNodePtr> inputs;
44 for (const auto &input : cnode->inputs()) {
45 if (input == nullptr || !utils::isa<CNodePtr>(input)) {
46 continue;
47 }
48 inputs.emplace_back(utils::cast<CNodePtr>(input));
49 }
50 return inputs;
51 }
52
ConvertDims(mindspore::schema::Format src_format,const std::vector<int32_t> & src_dims,mindspore::schema::Format dst_format,std::vector<int32_t> * dst_dims)53 STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
54 mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {
55 MS_ASSERT(dst_dims != nullptr);
56 if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) {
57 MS_LOG(ERROR) << "Convert format , src size " << src_dims.size()
58 << " <3 or src format is equal to dst format,not need convert";
59 *dst_dims = src_dims;
60 return RET_PARAM_INVALID;
61 }
62
63 std::vector<int32_t> nchw_dim;
64 switch (src_format) {
65 case schema::Format::Format_NCHW:
66 nchw_dim = src_dims;
67 break;
68 case schema::Format::Format_NHWC:
69 if (src_dims.size() == DIM_DEFAULT_SIZE) {
70 nchw_dim.push_back(src_dims[NHWC_N]);
71 nchw_dim.push_back(src_dims[NHWC_C]);
72 nchw_dim.push_back(src_dims[NHWC_H]);
73 nchw_dim.push_back(src_dims[NHWC_W]);
74 } else {
75 nchw_dim.push_back(src_dims[HWC_C]);
76 nchw_dim.push_back(src_dims[HWC_H]);
77 nchw_dim.push_back(src_dims[HWC_W]);
78 }
79 break;
80 default:
81 MS_LOG(ERROR) << "Not support src format: " << EnumNameFormat(src_format);
82 return RET_ERROR;
83 }
84
85 if (nchw_dim.empty()) {
86 MS_LOG(ERROR) << "Param nchw_dim is empty!";
87 return RET_ERROR;
88 }
89
90 switch (dst_format) {
91 case schema::Format::Format_NCHW:
92 *dst_dims = nchw_dim;
93 break;
94 case schema::Format::Format_NHWC:
95 if (src_dims.size() == DIM_DEFAULT_SIZE) {
96 dst_dims->push_back(nchw_dim[NCHW_N]);
97 dst_dims->push_back(nchw_dim[NCHW_H]);
98 dst_dims->push_back(nchw_dim[NCHW_W]);
99 dst_dims->push_back(nchw_dim[NCHW_C]);
100 }
101 break;
102 default:
103 MS_LOG(ERROR) << "Not support dst format: " << dst_format;
104 return RET_ERROR;
105 }
106 return RET_OK;
107 }
108
IsKCHWSource(kTransFilterType type)109 static bool IsKCHWSource(kTransFilterType type) {
110 return (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW);
111 }
112
IsCKHWSource(kTransFilterType type)113 static bool IsCKHWSource(kTransFilterType type) {
114 return (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC);
115 }
116
IsHWCKSource(kTransFilterType type)117 static bool IsHWCKSource(kTransFilterType type) { return (type == kHWCK2KCHW || type == kHWCK2CKHW); }
118
IsHWKCSource(kTransFilterType type)119 static bool IsHWKCSource(kTransFilterType type) { return (type == kHWKC2KCHW || type == kHWKC2CKHW); }
120
IsNHWCSource(kTransFilterType type)121 static bool IsNHWCSource(kTransFilterType type) {
122 return (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW);
123 }
124
IsCHWKSource(kTransFilterType type)125 static bool IsCHWKSource(kTransFilterType type) { return (type == kCHWK2HWCK || type == kCHWK2KHWC); }
126
IsKHWCSource(kTransFilterType type)127 static bool IsKHWCSource(kTransFilterType type) { return (type == kKHWC2HWCK || type == kKHWC2CHWK); }
128
GetFilterDim(const std::vector<int32_t> & oriDims,kTransFilterType type,int32_t * filterK,int32_t * filterC,int32_t * filterH,int32_t * filterW)129 STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
130 int32_t *filterH, int32_t *filterW) {
131 if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) {
132 MS_LOG(ERROR) << "null input";
133 return RET_NULL_PTR;
134 }
135 MS_ASSERT(oriDims.size() == 4);
136 if (IsKCHWSource(type)) {
137 *filterK = oriDims.at(KCHW_K);
138 *filterC = oriDims.at(KCHW_C);
139 *filterH = oriDims.at(KCHW_H);
140 *filterW = oriDims.at(KCHW_W);
141 } else if (IsCKHWSource(type)) {
142 *filterC = oriDims.at(CKHW_C);
143 *filterK = oriDims.at(CKHW_K);
144 *filterH = oriDims.at(CKHW_H);
145 *filterW = oriDims.at(CKHW_W);
146 } else if (IsHWCKSource(type)) {
147 *filterH = oriDims.at(HWCK_H);
148 *filterW = oriDims.at(HWCK_W);
149 *filterC = oriDims.at(HWCK_C);
150 *filterK = oriDims.at(HWCK_K);
151 } else if (IsHWKCSource(type)) {
152 *filterH = oriDims.at(HWKC_H);
153 *filterW = oriDims.at(HWKC_W);
154 *filterK = oriDims.at(HWKC_K);
155 *filterC = oriDims.at(HWKC_C);
156 } else if (IsNHWCSource(type)) {
157 *filterK = oriDims.at(NHWC_N);
158 *filterH = oriDims.at(NHWC_H);
159 *filterW = oriDims.at(NHWC_W);
160 *filterC = oriDims.at(NHWC_C);
161 } else if (IsCHWKSource(type)) {
162 *filterC = oriDims.at(CHWK_C);
163 *filterH = oriDims.at(CHWK_H);
164 *filterW = oriDims.at(CHWK_W);
165 *filterK = oriDims.at(CHWK_K);
166 } else if (IsKHWCSource(type)) {
167 *filterK = oriDims.at(KHWC_K);
168 *filterH = oriDims.at(KHWC_H);
169 *filterW = oriDims.at(KHWC_W);
170 *filterC = oriDims.at(KHWC_C);
171 } else {
172 MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
173 return RET_ERROR;
174 }
175 return RET_OK;
176 }
177
SetFilterDim(schema::TensorT * tensor,kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW)178 STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
179 int32_t filterW) {
180 MS_CHECK_TRUE_MSG(tensor != nullptr, RET_NULL_PTR, "tensor is nullptr");
181 if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
182 tensor->dims = {filterH, filterW, filterC, filterK};
183 } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
184 tensor->dims = {filterH, filterW, filterK, filterC};
185 } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
186 tensor->dims = {filterK, filterC, filterH, filterW};
187 } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
188 tensor->dims = {filterC, filterK, filterH, filterW};
189 } else if (type == kKHWC2CHWK) {
190 tensor->dims = {filterC, filterH, filterW, filterK};
191 } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
192 tensor->dims = {filterK, filterH, filterW, filterC};
193 } else {
194 MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
195 return RET_ERROR;
196 }
197 return RET_OK;
198 }
199
Convert2KHWC(int srcFormat)200 static int Convert2KHWC(int srcFormat) {
201 if (srcFormat == schema::Format::Format_KCHW) return kKCHW2KHWC;
202 if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KHWC;
203 if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KHWC;
204 return -1;
205 }
206
Convert2HWCK(int srcFormat)207 static int Convert2HWCK(int srcFormat) {
208 if (srcFormat == schema::Format::Format_KCHW) return kKCHW2HWCK;
209 if (srcFormat == schema::Format::Format_KHWC) return kKHWC2HWCK;
210 if (srcFormat == schema::Format::Format_CKHW) return kCKHW2HWCK;
211 if (srcFormat == schema::Format::Format_CHWK) return kCHWK2HWCK;
212 return -1;
213 }
214
Convert2KCHW(int srcFormat)215 static int Convert2KCHW(int srcFormat) {
216 if (srcFormat == schema::Format::Format_HWCK) return kHWCK2KCHW;
217 if (srcFormat == schema::Format::Format_HWKC) return kHWKC2KCHW;
218 if (srcFormat == schema::Format::Format_KHWC) return kKHWC2KCHW;
219 if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KCHW;
220 if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KCHW;
221 return -1;
222 }
223
Convert2CKHW(int srcFormat)224 static int Convert2CKHW(int srcFormat) {
225 if (srcFormat == schema::Format::Format_HWCK) return kHWCK2CKHW;
226 if (srcFormat == schema::Format::Format_HWKC) return kHWKC2CKHW;
227 if (srcFormat == schema::Format::Format_KCHW) return kKCHW2CKHW;
228 return -1;
229 }
230
GetTensorInputIndexInCNode(const uint32_t & tensor_index,const schema::CNodeT & cnode)231 size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode) {
232 size_t ret = -1;
233 for (size_t i = 0; i < cnode.inputIndex.size(); i++) {
234 if (cnode.inputIndex.at(i) == tensor_index) {
235 ret = i;
236 }
237 }
238 return ret;
239 }
240
TransFilterFormat(schema::TensorT * tensor,schema::Format dstFormat)241 STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
242 if (tensor == nullptr) {
243 MS_LOG(ERROR) << "tensor is null";
244 return RET_NULL_PTR;
245 }
246 std::vector<int32_t> oriDims = tensor->dims;
247 if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
248 MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
249 return RET_ERROR;
250 }
251 auto srcFormat = tensor->format;
252 auto dataType = tensor->dataType;
253 STATUS status;
254 int convert = -1;
255
256 if (dstFormat == srcFormat) return RET_OK;
257
258 switch (dstFormat) {
259 case schema::Format::Format_KHWC:
260 convert = Convert2KHWC(srcFormat);
261 break;
262 case schema::Format::Format_HWCK:
263 convert = Convert2HWCK(srcFormat);
264 break;
265 case schema::Format::Format_KCHW:
266 convert = Convert2KCHW(srcFormat);
267 break;
268 case schema::Format::Format_CKHW:
269 convert = Convert2CKHW(srcFormat);
270 break;
271 default:
272 convert = -1;
273 }
274 if (convert == -1) {
275 MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat);
276 return RET_ERROR;
277 }
278
279 if (dataType == kNumberTypeFloat32) {
280 status = TransFilterFormat<float>(tensor, static_cast<kTransFilterType>(convert));
281 } else if (dataType == kNumberTypeUInt8) {
282 status = TransFilterFormat<uint8_t>(tensor, static_cast<kTransFilterType>(convert));
283 } else if (dataType == kNumberTypeInt8) {
284 status = TransFilterFormat<int8_t>(tensor, static_cast<kTransFilterType>(convert));
285 } else {
286 MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
287 return RET_ERROR;
288 }
289 if (status != RET_OK) {
290 MS_LOG(ERROR) << "TransFilterData failed: " << status;
291 return status;
292 }
293 return RET_OK;
294 }
295
GetCNodeOutputsSize(const std::shared_ptr<AnfNode> & anf_node,bool train_flag)296 size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag) {
297 MS_CHECK_TRUE_MSG(anf_node != nullptr, 0, "anf_node is nullptr");
298 auto cnode = anf_node->cast<CNodePtr>();
299 MS_CHECK_TRUE_MSG(cnode != nullptr, 0, "cnode is nullptr");
300 if (train_flag &&
301 (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) {
302 return 1;
303 }
304 if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
305 auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
306 return tuple->elements().size();
307 } else {
308 return 1;
309 }
310 }
311
IsPartialFusion(const AnfNodePtr & node)312 bool IsPartialFusion(const AnfNodePtr &node) {
313 if (node == nullptr) {
314 return false;
315 }
316 if (node->isa<mindspore::CNode>()) {
317 auto cnode = node->cast<CNodePtr>();
318 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
319 auto vnode_value = cnode->input(0)->cast<ValueNodePtr>()->value();
320 return GetValue<NamedPtr>(vnode_value)->name() == "PartialFusion";
321 }
322 return false;
323 }
324
IsCall(const AnfNodePtr & node)325 bool IsCall(const AnfNodePtr &node) {
326 if (node == nullptr) {
327 return false;
328 }
329 if (!utils::isa<CNodePtr>(node)) {
330 return false;
331 }
332 auto cnode = node->cast<CNodePtr>();
333 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
334 if (cnode->inputs().empty()) {
335 return false;
336 }
337 auto cnode_first_input = cnode->input(0);
338 if (utils::isa<CNodePtr>(cnode_first_input)) {
339 return true;
340 }
341 if (utils::isa<ValueNode>(cnode_first_input)) {
342 auto vnode = cnode_first_input->cast<ValueNodePtr>();
343 return GetValueNode<FuncGraphPtr>(vnode) != nullptr;
344 }
345 return false;
346 }
347
IsSwitch(const AnfNodePtr & node)348 bool IsSwitch(const AnfNodePtr &node) {
349 if (node == nullptr) {
350 return false;
351 }
352 if (!utils::isa<CNodePtr>(node)) {
353 return false;
354 }
355 return opt::CheckPrimitiveType(node, prim::kPrimSwitch);
356 }
357
IsSwitchLayer(const AnfNodePtr & node)358 bool IsSwitchLayer(const AnfNodePtr &node) {
359 if (node == nullptr) {
360 return false;
361 }
362 if (!utils::isa<CNodePtr>(node)) {
363 return false;
364 }
365 return opt::CheckPrimitiveType(node, prim::kPrimSwitchLayer);
366 }
367
IsControlFlowOp(const AnfNodePtr & node)368 bool IsControlFlowOp(const AnfNodePtr &node) {
369 return IsPartialFusion(node) || IsCall(node) || IsSwitch(node) || IsSwitchLayer(node);
370 }
371
IsMakeTuple(const AnfNodePtr & node)372 bool IsMakeTuple(const AnfNodePtr &node) {
373 if (node == nullptr) {
374 return false;
375 }
376 if (!utils::isa<CNodePtr>(node)) {
377 return false;
378 }
379 return opt::CheckPrimitiveType(node, prim::kPrimMakeTuple);
380 }
381
GetPartialFusionPrim()382 ValueNodePtr GetPartialFusionPrim() {
383 auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
384 MS_CHECK_TRUE_MSG(partial_prim != nullptr, nullptr, "partial_prim is nullptr");
385 auto partial_prim_c = partial_prim->GetPrim();
386 MS_CHECK_TRUE_MSG(partial_prim_c != nullptr, nullptr, "partial_prim_c is nullptr");
387 ValueNodePtr partial_anf_prim = NewValueNode(partial_prim_c);
388 MS_CHECK_TRUE_MSG(partial_anf_prim != nullptr, nullptr, "partial_anf_prim is nullptr");
389 return partial_anf_prim;
390 }
391
GetSwitchAnfPrim()392 ValueNodePtr GetSwitchAnfPrim() {
393 auto switch_prim = std::make_shared<mindspore::ops::Switch>();
394 MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
395 auto switch_prim_c = switch_prim->GetPrim();
396 MS_CHECK_TRUE_MSG(switch_prim_c != nullptr, nullptr, "switch_prim_c is nullptr");
397 ValueNodePtr switch_anf_prim = NewValueNode(switch_prim_c);
398 MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
399 return switch_anf_prim;
400 }
401
GetCallAnfPrim()402 ValueNodePtr GetCallAnfPrim() {
403 auto call_prim = std::make_shared<mindspore::ops::Call>();
404 MS_CHECK_TRUE_MSG(call_prim != nullptr, nullptr, "call_prim is nullptr");
405 auto call_prim_c = call_prim->GetPrim();
406 MS_CHECK_TRUE_MSG(call_prim_c != nullptr, nullptr, "call_prim_c is nullptr");
407 ValueNodePtr call_anf_prim = NewValueNode(call_prim_c);
408 MS_CHECK_TRUE_MSG(call_anf_prim != nullptr, nullptr, "call_anf_prim is nullptr");
409 return call_anf_prim;
410 }
411
UpdateDataType(const AnfNodePtr & cnode,TypeId new_data_type)412 int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type) {
413 auto abstract_base = cnode->abstract();
414 if (abstract_base == nullptr) {
415 MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope();
416 return RET_NULL_PTR;
417 }
418 if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
419 MS_LOG(ERROR) << "Abstract of node should be anstract tensor, " << cnode->fullname_with_scope();
420 return RET_ERROR;
421 }
422 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
423 CHECK_NULL_RETURN(abstract_tensor);
424 CHECK_NULL_RETURN(abstract_tensor->element());
425 abstract_tensor->element()->set_type(TypeIdToType(new_data_type));
426 return RET_OK;
427 }
428 } // namespace lite
429 } // namespace mindspore
430