1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/common/node_util.h"
18 #include <memory>
19 #include <set>
20 #include <vector>
21 #include "src/ops/populate/populate_register.h"
22 #include "src/common/common.h"
23 #include "src/common/log_adapter.h"
24 #include "tools/common/graph_util.h"
25 #include "tools/common/tensor_util.h"
26 #include "src/runtime/infer_manager.h"
27 #include "mindspore/core/ops/switch.h"
28 #include "mindspore/core/ops/call.h"
29 #include "mindspore/core/ops/fusion/partial_fusion.h"
30 #include "nnacl/op_base.h"
31
32 namespace mindspore {
33 namespace lite {
34 constexpr size_t kInitialSize = 1024;
GetInputCNode(const CNodePtr & cnode)35 std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode) {
36 if (cnode == nullptr) {
37 return {};
38 }
39 std::vector<CNodePtr> inputs;
40 for (const auto &input : cnode->inputs()) {
41 if (input == nullptr || !utils::isa<CNodePtr>(input)) {
42 continue;
43 }
44 inputs.emplace_back(utils::cast<CNodePtr>(input));
45 }
46 return inputs;
47 }
48
ConvertToPrimitive(schema::PrimitiveT * primitive_t,flatbuffers::FlatBufferBuilder * fbb)49 const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
50 if (primitive_t == nullptr || fbb == nullptr) {
51 MS_LOG(ERROR) << "primitiveT or fbb is nullptr.";
52 return nullptr;
53 }
54 auto prim_offset = schema::CreatePrimitive(*fbb, primitive_t);
55 fbb->Finish(prim_offset);
56 auto prim_buf = fbb->GetBufferPointer();
57 return flatbuffers::GetRoot<schema::Primitive>(prim_buf);
58 }
59
ConvertDims(mindspore::schema::Format src_format,const std::vector<int32_t> & src_dims,mindspore::schema::Format dst_format,std::vector<int32_t> * dst_dims)60 STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
61 mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {
62 MS_ASSERT(dst_dims != nullptr);
63 if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) {
64 MS_LOG(ERROR) << "Convert format , src size " << src_dims.size()
65 << " <3 or src format is equal to dst format,not need convert";
66 *dst_dims = src_dims;
67 return RET_PARAM_INVALID;
68 }
69
70 std::vector<int32_t> nchw_dim;
71 switch (src_format) {
72 case schema::Format::Format_NCHW:
73 nchw_dim = src_dims;
74 break;
75 case schema::Format::Format_NHWC:
76 if (src_dims.size() == DIM_DEFAULT_SIZE) {
77 nchw_dim.push_back(src_dims[NHWC_N]);
78 nchw_dim.push_back(src_dims[NHWC_C]);
79 nchw_dim.push_back(src_dims[NHWC_H]);
80 nchw_dim.push_back(src_dims[NHWC_W]);
81 } else {
82 nchw_dim.push_back(src_dims[HWC_C]);
83 nchw_dim.push_back(src_dims[HWC_H]);
84 nchw_dim.push_back(src_dims[HWC_W]);
85 }
86 break;
87 default:
88 MS_LOG(ERROR) << "Not support src format: " << EnumNameFormat(src_format);
89 return RET_ERROR;
90 }
91
92 if (nchw_dim.empty()) {
93 MS_LOG(ERROR) << "Param nchw_dim is empty!";
94 return RET_ERROR;
95 }
96
97 switch (dst_format) {
98 case schema::Format::Format_NCHW:
99 *dst_dims = nchw_dim;
100 break;
101 case schema::Format::Format_NHWC:
102 if (src_dims.size() == DIM_DEFAULT_SIZE) {
103 dst_dims->push_back(nchw_dim[NCHW_N]);
104 dst_dims->push_back(nchw_dim[NCHW_H]);
105 dst_dims->push_back(nchw_dim[NCHW_W]);
106 dst_dims->push_back(nchw_dim[NCHW_C]);
107 }
108 break;
109 default:
110 MS_LOG(ERROR) << "Not support dst format: " << dst_format;
111 return RET_ERROR;
112 }
113 return RET_OK;
114 }
115
IsKCHWSource(kTransFilterType type)116 static bool IsKCHWSource(kTransFilterType type) {
117 return (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW);
118 }
119
IsCKHWSource(kTransFilterType type)120 static bool IsCKHWSource(kTransFilterType type) {
121 return (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC);
122 }
123
IsHWCKSource(kTransFilterType type)124 static bool IsHWCKSource(kTransFilterType type) { return (type == kHWCK2KCHW || type == kHWCK2CKHW); }
125
IsHWKCSource(kTransFilterType type)126 static bool IsHWKCSource(kTransFilterType type) { return (type == kHWKC2KCHW || type == kHWKC2CKHW); }
127
IsNHWCSource(kTransFilterType type)128 static bool IsNHWCSource(kTransFilterType type) {
129 return (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW);
130 }
131
IsCHWKSource(kTransFilterType type)132 static bool IsCHWKSource(kTransFilterType type) { return (type == kCHWK2HWCK || type == kCHWK2KHWC); }
133
IsKHWCSource(kTransFilterType type)134 static bool IsKHWCSource(kTransFilterType type) { return (type == kKHWC2HWCK || type == kKHWC2CHWK); }
135
GetFilterDim(const std::vector<int32_t> & oriDims,kTransFilterType type,int32_t * filterK,int32_t * filterC,int32_t * filterH,int32_t * filterW)136 STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
137 int32_t *filterH, int32_t *filterW) {
138 if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) {
139 MS_LOG(ERROR) << "null input";
140 return RET_NULL_PTR;
141 }
142 MS_ASSERT(oriDims.size() == 4);
143 if (IsKCHWSource(type)) {
144 *filterK = oriDims.at(KCHW_K);
145 *filterC = oriDims.at(KCHW_C);
146 *filterH = oriDims.at(KCHW_H);
147 *filterW = oriDims.at(KCHW_W);
148 } else if (IsCKHWSource(type)) {
149 *filterC = oriDims.at(CKHW_C);
150 *filterK = oriDims.at(CKHW_K);
151 *filterH = oriDims.at(CKHW_H);
152 *filterW = oriDims.at(CKHW_W);
153 } else if (IsHWCKSource(type)) {
154 *filterH = oriDims.at(HWCK_H);
155 *filterW = oriDims.at(HWCK_W);
156 *filterC = oriDims.at(HWCK_C);
157 *filterK = oriDims.at(HWCK_K);
158 } else if (IsHWKCSource(type)) {
159 *filterH = oriDims.at(HWKC_H);
160 *filterW = oriDims.at(HWKC_W);
161 *filterK = oriDims.at(HWKC_K);
162 *filterC = oriDims.at(HWKC_C);
163 } else if (IsNHWCSource(type)) {
164 *filterK = oriDims.at(NHWC_N);
165 *filterH = oriDims.at(NHWC_H);
166 *filterW = oriDims.at(NHWC_W);
167 *filterC = oriDims.at(NHWC_C);
168 } else if (IsCHWKSource(type)) {
169 *filterC = oriDims.at(CHWK_C);
170 *filterH = oriDims.at(CHWK_H);
171 *filterW = oriDims.at(CHWK_W);
172 *filterK = oriDims.at(CHWK_K);
173 } else if (IsKHWCSource(type)) {
174 *filterK = oriDims.at(KHWC_K);
175 *filterH = oriDims.at(KHWC_H);
176 *filterW = oriDims.at(KHWC_W);
177 *filterC = oriDims.at(KHWC_C);
178 } else {
179 MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
180 return RET_ERROR;
181 }
182 return RET_OK;
183 }
184
SetFilterDim(schema::TensorT * tensor,kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW)185 STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
186 int32_t filterW) {
187 MS_ASSERT(tensor != nullptr);
188 if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
189 tensor->dims = {filterH, filterW, filterC, filterK};
190 } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
191 tensor->dims = {filterH, filterW, filterK, filterC};
192 } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
193 tensor->dims = {filterK, filterC, filterH, filterW};
194 } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
195 tensor->dims = {filterC, filterK, filterH, filterW};
196 } else if (type == kKHWC2CHWK) {
197 tensor->dims = {filterC, filterH, filterW, filterK};
198 } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
199 tensor->dims = {filterK, filterH, filterW, filterC};
200 } else {
201 MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
202 return RET_ERROR;
203 }
204 return RET_OK;
205 }
206
Convert2KHWC(int srcFormat)207 static int Convert2KHWC(int srcFormat) {
208 if (srcFormat == schema::Format::Format_KCHW) return kKCHW2KHWC;
209 if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KHWC;
210 if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KHWC;
211 return -1;
212 }
213
Convert2HWCK(int srcFormat)214 static int Convert2HWCK(int srcFormat) {
215 if (srcFormat == schema::Format::Format_KCHW) return kKCHW2HWCK;
216 if (srcFormat == schema::Format::Format_KHWC) return kKHWC2HWCK;
217 if (srcFormat == schema::Format::Format_CKHW) return kCKHW2HWCK;
218 if (srcFormat == schema::Format::Format_CHWK) return kCHWK2HWCK;
219 return -1;
220 }
221
Convert2KCHW(int srcFormat)222 static int Convert2KCHW(int srcFormat) {
223 if (srcFormat == schema::Format::Format_HWCK) return kHWCK2KCHW;
224 if (srcFormat == schema::Format::Format_HWKC) return kHWKC2KCHW;
225 if (srcFormat == schema::Format::Format_KHWC) return kKHWC2KCHW;
226 if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KCHW;
227 if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KCHW;
228 return -1;
229 }
230
Convert2CKHW(int srcFormat)231 static int Convert2CKHW(int srcFormat) {
232 if (srcFormat == schema::Format::Format_HWCK) return kHWCK2CKHW;
233 if (srcFormat == schema::Format::Format_HWKC) return kHWKC2CKHW;
234 if (srcFormat == schema::Format::Format_KCHW) return kKCHW2CKHW;
235 return -1;
236 }
237
NodeInferShpae(const schema::CNodeT & node,const std::vector<Tensor * > & inputs,std::vector<Tensor * > * outputs)238 STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs) {
239 flatbuffers::FlatBufferBuilder fbb(kInitialSize);
240 auto prim = ConvertToPrimitive(node.primitive.get(), &fbb);
241 if (prim == nullptr) {
242 MS_LOG(ERROR) << "get primitive failed.";
243 fbb.Clear();
244 return RET_ERROR;
245 }
246 auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR);
247 if (parameter_gen == nullptr) {
248 fbb.Clear();
249 MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
250 return RET_ERROR;
251 }
252 auto parameter = parameter_gen(prim);
253 if (parameter == nullptr) {
254 fbb.Clear();
255 MS_LOG(ERROR) << "parameter is nullptr.";
256 return RET_ERROR;
257 }
258 auto ret = KernelInferShape(inputs, *outputs, parameter);
259 fbb.Clear();
260 if (parameter->destroy_func_ != nullptr) {
261 parameter->destroy_func_(parameter);
262 }
263 free(parameter);
264 parameter = nullptr;
265 return ret;
266 }
267
GetTensorInputIndexInCNode(const uint32_t & tensor_index,const schema::CNodeT & cnode)268 size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode) {
269 size_t ret = -1;
270 for (size_t i = 0; i < cnode.inputIndex.size(); i++) {
271 if (cnode.inputIndex.at(i) == tensor_index) {
272 ret = i;
273 }
274 }
275 return ret;
276 }
277
TransFilterFormat(schema::TensorT * tensor,schema::Format dstFormat)278 STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
279 if (tensor == nullptr) {
280 MS_LOG(ERROR) << "tensor is null";
281 return RET_NULL_PTR;
282 }
283 std::vector<int32_t> oriDims = tensor->dims;
284 if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
285 MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
286 return RET_ERROR;
287 }
288 auto srcFormat = tensor->format;
289 auto dataType = tensor->dataType;
290 STATUS status;
291 int convert = -1;
292
293 if (dstFormat == srcFormat) return RET_OK;
294
295 switch (dstFormat) {
296 case schema::Format::Format_KHWC:
297 convert = Convert2KHWC(srcFormat);
298 break;
299 case schema::Format::Format_HWCK:
300 convert = Convert2HWCK(srcFormat);
301 break;
302 case schema::Format::Format_KCHW:
303 convert = Convert2KCHW(srcFormat);
304 break;
305 case schema::Format::Format_CKHW:
306 convert = Convert2CKHW(srcFormat);
307 break;
308 default:
309 convert = -1;
310 }
311 if (convert == -1) {
312 MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat);
313 return RET_ERROR;
314 }
315
316 if (dataType == kNumberTypeFloat32) {
317 status = TransFilterFormat<float>(tensor, static_cast<kTransFilterType>(convert));
318 } else if (dataType == kNumberTypeUInt8) {
319 status = TransFilterFormat<uint8_t>(tensor, static_cast<kTransFilterType>(convert));
320 } else if (dataType == kNumberTypeInt8) {
321 status = TransFilterFormat<int8_t>(tensor, static_cast<kTransFilterType>(convert));
322 } else {
323 MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
324 return RET_ERROR;
325 }
326 if (status != RET_OK) {
327 MS_LOG(ERROR) << "TransFilterData failed: " << status;
328 return status;
329 }
330 return RET_OK;
331 }
332
GetCNodeOutputsSize(const std::shared_ptr<AnfNode> & anf_node,bool train_flag)333 size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag) {
334 MS_ASSERT(anf_node != nullptr);
335 auto cnode = anf_node->cast<CNodePtr>();
336 MS_ASSERT(cnode != nullptr);
337 if (train_flag &&
338 (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) {
339 return 1;
340 }
341 if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
342 auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
343 return tuple->elements().size();
344 } else {
345 return 1;
346 }
347 }
348
IsPartialFusion(const AnfNodePtr & node)349 bool IsPartialFusion(const AnfNodePtr &node) {
350 if (node == nullptr) {
351 return false;
352 }
353 if (node->isa<mindspore::CNode>()) {
354 auto cnode = node->cast<CNodePtr>();
355 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
356 auto vnode_value = cnode->input(0)->cast<ValueNodePtr>()->value();
357 return GetValue<NamedPtr>(vnode_value)->name() == "PartialFusion";
358 }
359 return false;
360 }
361
IsCall(const AnfNodePtr & node)362 bool IsCall(const AnfNodePtr &node) {
363 if (node == nullptr) {
364 return false;
365 }
366 if (!utils::isa<CNodePtr>(node)) {
367 return false;
368 }
369 auto cnode = node->cast<CNodePtr>();
370 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
371 if (cnode->inputs().empty()) {
372 return false;
373 }
374 auto cnode_first_input = cnode->input(0);
375 if (utils::isa<CNodePtr>(cnode_first_input)) {
376 return true;
377 }
378 if (utils::isa<ValueNode>(cnode_first_input)) {
379 auto vnode = cnode_first_input->cast<ValueNodePtr>();
380 return GetValueNode<FuncGraphPtr>(vnode) != nullptr;
381 }
382 return false;
383 }
384
IsSwitch(const AnfNodePtr & node)385 bool IsSwitch(const AnfNodePtr &node) {
386 if (node == nullptr) {
387 return false;
388 }
389 if (!utils::isa<CNodePtr>(node)) {
390 return false;
391 }
392 return opt::CheckPrimitiveType(node, prim::kPrimSwitch);
393 }
394
IsMakeTuple(const AnfNodePtr & node)395 bool IsMakeTuple(const AnfNodePtr &node) {
396 if (node == nullptr) {
397 return false;
398 }
399 if (!utils::isa<CNodePtr>(node)) {
400 return false;
401 }
402 return opt::CheckPrimitiveType(node, prim::kPrimMakeTuple);
403 }
404
GetPartialFusionPrim()405 ValueNodePtr GetPartialFusionPrim() {
406 auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
407 MS_CHECK_TRUE_MSG(partial_prim != nullptr, nullptr, "partial_prim is nullptr");
408 ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
409 MS_CHECK_TRUE_MSG(partial_anf_prim != nullptr, nullptr, "partial_anf_prim is nullptr");
410 return partial_anf_prim;
411 }
412
GetSwitchAnfPrim()413 ValueNodePtr GetSwitchAnfPrim() {
414 auto switch_prim = std::make_shared<mindspore::ops::Switch>();
415 MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
416 ValueNodePtr switch_anf_prim = NewValueNode(switch_prim);
417 MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
418 return switch_anf_prim;
419 }
420
GetCallAnfPrim()421 ValueNodePtr GetCallAnfPrim() {
422 auto call_prim = std::make_shared<mindspore::ops::Call>();
423 MS_CHECK_TRUE_MSG(call_prim != nullptr, nullptr, "call_prim is nullptr");
424 ValueNodePtr call_anf_prim = NewValueNode(call_prim);
425 MS_CHECK_TRUE_MSG(call_anf_prim != nullptr, nullptr, "call_anf_prim is nullptr");
426 return call_anf_prim;
427 }
428 } // namespace lite
429 } // namespace mindspore
430