1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H_
18 #define MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H_
19
20 #include <memory>
21 #include <vector>
22 #include <string>
23 #include <unordered_map>
24 #include "schema/inner/model_generated.h"
25 #include "src/common/common.h"
26 #include "src/common/log_adapter.h"
27 #include "src/tensor.h"
28 #include "include/errorcode.h"
29 #include "securec/include/securec.h"
30 #include "ops/primitive_c.h"
31 #include "tools/optimizer/common/gllo_utils.h"
32
33 namespace mindspore {
34 namespace lite {
35 std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode);
36
37 template <typename T>
CreateOperator(const std::unique_ptr<schema::PrimitiveT> & primitive,schema::PrimitiveType type)38 int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) {
39 auto attr = std::make_unique<T>();
40 if (attr == nullptr) {
41 MS_LOG(ERROR) << "new attr failed";
42 return RET_NULL_PTR;
43 }
44 primitive->value.type = type;
45 primitive->value.value = attr.release();
46 return RET_OK;
47 }
48
49 STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
50
GetCNodeTType(const schema::CNodeT & cNodeT)51 inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) {
52 if (cNodeT.primitive != nullptr) {
53 return cNodeT.primitive->value.type;
54 } else {
55 return schema::PrimitiveType_NONE;
56 }
57 }
58
GetCNodeTTypeName(const schema::CNodeT & cNodeT)59 inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) {
60 return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT));
61 }
62
GetOpType(const schema::CNode & opDef)63 inline schema::PrimitiveType GetOpType(const schema::CNode &opDef) { return opDef.primitive()->value_type(); }
64
GetOpTypeName(const schema::CNode & opDef)65 inline std::string GetOpTypeName(const schema::CNode &opDef) { return schema::EnumNamePrimitiveType(GetOpType(opDef)); }
66
67 std::unordered_map<int, int> GetNc2NhAxisMap();
68
69 std::vector<schema::PrimitiveType> GetInsertOpList();
70
71 std::vector<schema::PrimitiveType> GetNhwcOpList();
72
73 std::vector<schema::PrimitiveType> GetNchwOpList();
74
75 std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
76
77 std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes();
78
79 std::vector<schema::PrimitiveType> Getfp32FullOpList();
80
81 std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
82
83 size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode);
84
85 class NodeUtils {
86 public:
87 static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format,
88 std::vector<int32_t> *dst_dims);
89 };
90
91 enum kTransFilterType {
92 kKCHW2HWCK, // 0
93 kKCHW2KHWC,
94 kCKHW2KHWC,
95 kCKHW2HWCK,
96 kKCHW2HWKC,
97 kCKHW2HWKC,
98 kHWCK2KCHW,
99 kHWCK2CKHW,
100 kHWKC2KCHW,
101 kHWKC2CKHW,
102 kNHWC2KCHW, // 10
103 kNHWC2CKHW,
104 kNHWC2HWCK,
105 kKHWC2HWCK,
106 kCHWK2HWCK,
107 kKHWC2CHWK,
108 kCHWK2KHWC,
109 kKHWC2KCHW,
110 kCKHW2KCHW,
111 kCHWK2KCHW,
112 kKCHW2CKHW // 20
113 };
114
115 STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
116 int32_t *filterH, int32_t *filterW);
117 STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
118 int32_t filterW);
119
120 template <typename T>
TransKHWC2CHWK(int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)121 static void TransKHWC2CHWK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
122 T *p1Buff = nullptr;
123 T *p2Buff = nullptr;
124 for (int k = 0; k < filterK; ++k) {
125 for (int h = 0; h < filterH; ++h) {
126 for (int w = 0; w < filterW; ++w) {
127 for (int c = 0; c < filterC; ++c) {
128 p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
129 p2Buff = dstData + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
130 *p2Buff = *p1Buff;
131 }
132 }
133 }
134 }
135 }
136
137 template <typename T>
TransKHWC2HWCK(int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)138 static void TransKHWC2HWCK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
139 T *p1Buff = nullptr;
140 T *p2Buff = nullptr;
141 for (int k = 0; k < filterK; ++k) {
142 for (int h = 0; h < filterH; ++h) {
143 for (int w = 0; w < filterW; ++w) {
144 for (int c = 0; c < filterC; ++c) {
145 p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
146 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
147 *p2Buff = *p1Buff;
148 }
149 }
150 }
151 }
152 }
153
154 template <typename T>
TransCKHW(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)155 static void TransCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
156 T *srcData, T *dstData) {
157 T *p1Buff = nullptr;
158 T *p2Buff = nullptr;
159 for (int c = 0; c < filterC; ++c) {
160 for (int k = 0; k < filterK; ++k) {
161 for (int h = 0; h < filterH; ++h) {
162 for (int w = 0; w < filterW; ++w) {
163 p1Buff = srcData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
164 if (type == kCKHW2HWCK) {
165 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
166 } else if (type == kCKHW2KHWC) {
167 p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
168 } else {
169 p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
170 }
171 *p2Buff = *p1Buff;
172 }
173 }
174 }
175 }
176 }
177
178 template <typename T>
TransKCHW(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)179 static void TransKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
180 T *srcData, T *dstData) {
181 T *p1Buff = nullptr;
182 T *p2Buff = nullptr;
183 for (int k = 0; k < filterK; ++k) {
184 for (int c = 0; c < filterC; ++c) {
185 for (int h = 0; h < filterH; ++h) {
186 for (int w = 0; w < filterW; ++w) {
187 p1Buff = srcData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
188 if (type == kKCHW2HWCK) {
189 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
190 } else if (type == kKCHW2KHWC) {
191 p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
192 } else if (type == kKCHW2CKHW) {
193 p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
194 } else {
195 p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
196 }
197 *p2Buff = *p1Buff;
198 }
199 }
200 }
201 }
202 }
203
204 template <typename T>
TransCHWK(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)205 static void TransCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
206 T *srcData, T *dstData) {
207 T *p1Buff = nullptr;
208 T *p2Buff = nullptr;
209 for (int c = 0; c < filterC; ++c) {
210 for (int h = 0; h < filterH; ++h) {
211 for (int w = 0; w < filterW; ++w) {
212 for (int k = 0; k < filterK; ++k) {
213 p1Buff = srcData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
214 if (type == kCHWK2HWCK) {
215 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
216 } else {
217 p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
218 }
219 *p2Buff = *p1Buff;
220 }
221 }
222 }
223 }
224 }
225
226 template <typename T>
TransHWCK(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)227 static void TransHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
228 T *srcData, T *dstData) {
229 T *p1Buff = nullptr;
230 T *p2Buff = nullptr;
231 for (int h = 0; h < filterH; ++h) {
232 for (int w = 0; w < filterW; ++w) {
233 for (int c = 0; c < filterC; ++c) {
234 for (int k = 0; k < filterK; ++k) {
235 p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
236 if (type == kHWCK2KCHW) {
237 p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
238 } else {
239 p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
240 }
241 *p2Buff = *p1Buff;
242 }
243 }
244 }
245 }
246 }
247
248 template <typename T>
TransHWKC(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)249 static void TransHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
250 T *srcData, T *dstData) {
251 T *p1Buff = nullptr;
252 T *p2Buff = nullptr;
253 for (int h = 0; h < filterH; ++h) {
254 for (int w = 0; w < filterW; ++w) {
255 for (int c = 0; c < filterC; ++c) {
256 for (int k = 0; k < filterK; ++k) {
257 p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
258 if (type == kHWKC2KCHW) {
259 p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
260 } else {
261 p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
262 }
263 *p2Buff = *p1Buff;
264 }
265 }
266 }
267 }
268 }
269
270 template <typename T>
TransNHWC(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)271 static void TransNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
272 T *srcData, T *dstData) {
273 T *p1Buff = nullptr;
274 T *p2Buff = nullptr;
275 for (int k = 0; k < filterK; ++k) {
276 for (int h = 0; h < filterH; ++h) {
277 for (int w = 0; w < filterW; ++w) {
278 for (int c = 0; c < filterC; ++c) {
279 p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
280 if (type == kNHWC2HWCK) {
281 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
282 } else if (type == kNHWC2CKHW) {
283 p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
284 } else {
285 p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
286 }
287 *p2Buff = *p1Buff;
288 }
289 }
290 }
291 }
292 }
293
294 template <typename T>
TransFilterData(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)295 static STATUS TransFilterData(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
296 T *srcData, T *dstData) {
297 switch (type) {
298 case kCHWK2HWCK:
299 case kCHWK2KHWC: {
300 TransCHWK(type, filterK, filterC, filterH, filterW, srcData, dstData);
301 } break;
302 case kKHWC2HWCK: {
303 TransKHWC2HWCK(filterK, filterC, filterH, filterW, srcData, dstData);
304 } break;
305 case kKCHW2HWCK:
306 case kKCHW2CKHW:
307 case kKCHW2KHWC:
308 case kKCHW2HWKC: {
309 TransKCHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
310 } break;
311 case kCKHW2HWCK:
312 case kCKHW2KHWC:
313 case kCKHW2HWKC: {
314 TransCKHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
315 } break;
316 case kHWCK2KCHW:
317 case kHWCK2CKHW: {
318 TransHWCK(type, filterK, filterC, filterH, filterW, srcData, dstData);
319 } break;
320 case kHWKC2KCHW:
321 case kHWKC2CKHW: {
322 TransHWKC(type, filterK, filterC, filterH, filterW, srcData, dstData);
323 } break;
324 case kNHWC2HWCK:
325 case kNHWC2KCHW:
326 case kNHWC2CKHW: {
327 TransNHWC(type, filterK, filterC, filterH, filterW, srcData, dstData);
328 } break;
329 case kKHWC2CHWK: {
330 TransKHWC2CHWK(filterK, filterC, filterH, filterW, srcData, dstData);
331 } break;
332 default: {
333 MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
334 return RET_ERROR;
335 }
336 }
337 return RET_OK;
338 }
339
340 template <typename T>
TransFilterData(schema::TensorT * tensor,kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW)341 static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
342 int32_t filterH, int32_t filterW) {
343 MS_ASSERT(tensor != nullptr);
344 int count = filterH * filterW * filterC * filterK;
345 if (count <= 0) {
346 MS_LOG(ERROR) << "Dim size invalid";
347 return RET_ERROR;
348 }
349 std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
350 if (buf == nullptr) {
351 MS_LOG(ERROR) << "new buf failed";
352 return RET_ERROR;
353 }
354
355 void *originWeightDate = tensor->data.data();
356 T *weightData = static_cast<T *>(originWeightDate);
357
358 if (weightData == nullptr) {
359 MS_LOG(ERROR) << "weightData is nullptr";
360 return RET_ERROR;
361 }
362
363 if (TransFilterData(type, filterK, filterC, filterH, filterW, weightData, buf.get()) != RET_OK) {
364 MS_LOG(ERROR) << "TransFilterData failed";
365 return RET_ERROR;
366 }
367
368 auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T));
369 if (ret != EOK) {
370 MS_LOG(ERROR) << "memcpy_s failed: " << ret;
371 return RET_ERROR;
372 }
373 return RET_OK;
374 }
375
376 template <typename T>
TransFilterFormat(schema::TensorT * tensor,kTransFilterType type)377 static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) {
378 MS_ASSERT(tensor != nullptr);
379 std::vector<int32_t> oriDims = tensor->dims;
380 if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
381 MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
382 return RET_ERROR;
383 }
384
385 int32_t filterH;
386 int32_t filterW;
387 int32_t filterC;
388 int32_t filterK;
389 auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
390 if (status != RET_OK) {
391 MS_LOG(ERROR) << "GetFilterDim failed: " << status;
392 return status;
393 }
394 status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
395 if (status != RET_OK) {
396 MS_LOG(ERROR) << "SetFilterDim failed: " << status;
397 return status;
398 }
399 status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
400 if (status != RET_OK) {
401 MS_LOG(ERROR) << "TransFilterData failed: " << status;
402 return status;
403 }
404
405 return RET_OK;
406 }
407
408 STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
409
410 size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag = false);
411
412 bool IsPartialFusion(const AnfNodePtr &node);
413
414 bool IsCall(const AnfNodePtr &node);
415
416 bool IsSwitch(const AnfNodePtr &node);
417
418 bool IsSwitchLayer(const AnfNodePtr &node);
419
420 bool IsControlFlowOp(const AnfNodePtr &node);
421
422 bool IsMakeTuple(const AnfNodePtr &node);
423
424 ValueNodePtr GetPartialFusionPrim();
425
426 ValueNodePtr GetSwitchAnfPrim();
427
428 ValueNodePtr GetCallAnfPrim();
429
IsGraphInput(const AnfNodePtr & cnode)430 inline bool IsGraphInput(const AnfNodePtr &cnode) {
431 return cnode->isa<Parameter>() && !cnode->cast<ParameterPtr>()->has_default();
432 }
433
434 int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type);
435 } // namespace lite
436 } // namespace mindspore
437 #endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H_
438