1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H
18 #define MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H
19
20 #include <memory>
21 #include <vector>
22 #include <string>
23 #include <unordered_map>
24 #include "schema/inner/model_generated.h"
25 #include "src/common/common.h"
26 #include "src/common/log_adapter.h"
27 #include "src/tensor.h"
28 #include "include/errorcode.h"
29 #include "securec/include/securec.h"
30 #include "tools/optimizer/common/gllo_utils.h"
31
32 namespace mindspore {
33 namespace lite {
34 std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode);
35
36 template <typename T>
CreateOperator(const std::unique_ptr<schema::PrimitiveT> & primitive,schema::PrimitiveType type)37 int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) {
38 auto attr = std::make_unique<T>();
39 if (attr == nullptr) {
40 MS_LOG(ERROR) << "new attr failed";
41 return RET_NULL_PTR;
42 }
43 primitive->value.type = type;
44 primitive->value.value = attr.release();
45 return RET_OK;
46 }
47
48 using STATUS = int;
49 STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
50
51 STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs);
52
GetCNodeTType(const schema::CNodeT & cNodeT)53 inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) {
54 if (cNodeT.primitive != nullptr) {
55 return cNodeT.primitive->value.type;
56 } else {
57 return schema::PrimitiveType_NONE;
58 }
59 }
60
GetCNodeTTypeName(const schema::CNodeT & cNodeT)61 inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) {
62 return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT));
63 }
64
GetOpType(const schema::CNode & opDef)65 inline schema::PrimitiveType GetOpType(const schema::CNode &opDef) { return opDef.primitive()->value_type(); }
66
GetOpTypeName(const schema::CNode & opDef)67 inline std::string GetOpTypeName(const schema::CNode &opDef) { return schema::EnumNamePrimitiveType(GetOpType(opDef)); }
68
69 std::unordered_map<int, int> GetNc2NhAxisMap();
70
71 std::vector<schema::PrimitiveType> GetInsertOpList();
72
73 std::vector<schema::PrimitiveType> GetNhwcOpList();
74
75 std::vector<schema::PrimitiveType> GetNchwOpList();
76
77 std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
78
79 std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes();
80
81 std::vector<schema::PrimitiveType> Getfp32FullOpList();
82
83 std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
84
85 const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
86
87 size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode);
88
89 class NodeUtils {
90 public:
91 static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format,
92 std::vector<int32_t> *dst_dims);
93 };
94
95 enum kTransFilterType {
96 kKCHW2HWCK, // 0
97 kKCHW2KHWC,
98 kCKHW2KHWC,
99 kCKHW2HWCK,
100 kKCHW2HWKC,
101 kCKHW2HWKC,
102 kHWCK2KCHW,
103 kHWCK2CKHW,
104 kHWKC2KCHW,
105 kHWKC2CKHW,
106 kNHWC2KCHW, // 10
107 kNHWC2CKHW,
108 kNHWC2HWCK,
109 kKHWC2HWCK,
110 kCHWK2HWCK,
111 kKHWC2CHWK,
112 kCHWK2KHWC,
113 kKHWC2KCHW,
114 kCKHW2KCHW,
115 kCHWK2KCHW,
116 kKCHW2CKHW // 20
117 };
118
119 STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
120 int32_t *filterH, int32_t *filterW);
121 STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
122 int32_t filterW);
123
124 template <typename T>
TransKHWC2CHWK(int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)125 static void TransKHWC2CHWK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
126 T *p1Buff = nullptr;
127 T *p2Buff = nullptr;
128 for (int k = 0; k < filterK; ++k) {
129 for (int h = 0; h < filterH; ++h) {
130 for (int w = 0; w < filterW; ++w) {
131 for (int c = 0; c < filterC; ++c) {
132 p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
133 p2Buff = dstData + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
134 *p2Buff = *p1Buff;
135 }
136 }
137 }
138 }
139 }
140
141 template <typename T>
TransKHWC2HWCK(int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)142 static void TransKHWC2HWCK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
143 T *p1Buff = nullptr;
144 T *p2Buff = nullptr;
145 for (int k = 0; k < filterK; ++k) {
146 for (int h = 0; h < filterH; ++h) {
147 for (int w = 0; w < filterW; ++w) {
148 for (int c = 0; c < filterC; ++c) {
149 p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
150 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
151 *p2Buff = *p1Buff;
152 }
153 }
154 }
155 }
156 }
157
158 template <typename T>
TransCKHW(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)159 static void TransCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
160 T *srcData, T *dstData) {
161 T *p1Buff = nullptr;
162 T *p2Buff = nullptr;
163 for (int c = 0; c < filterC; ++c) {
164 for (int k = 0; k < filterK; ++k) {
165 for (int h = 0; h < filterH; ++h) {
166 for (int w = 0; w < filterW; ++w) {
167 p1Buff = srcData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
168 if (type == kCKHW2HWCK) {
169 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
170 } else if (type == kCKHW2KHWC) {
171 p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
172 } else {
173 p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
174 }
175 *p2Buff = *p1Buff;
176 }
177 }
178 }
179 }
180 }
181
182 template <typename T>
TransKCHW(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)183 static void TransKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
184 T *srcData, T *dstData) {
185 T *p1Buff = nullptr;
186 T *p2Buff = nullptr;
187 for (int k = 0; k < filterK; ++k) {
188 for (int c = 0; c < filterC; ++c) {
189 for (int h = 0; h < filterH; ++h) {
190 for (int w = 0; w < filterW; ++w) {
191 p1Buff = srcData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
192 if (type == kKCHW2HWCK) {
193 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
194 } else if (type == kKCHW2KHWC) {
195 p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
196 } else if (type == kKCHW2CKHW) {
197 p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
198 } else {
199 p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
200 }
201 *p2Buff = *p1Buff;
202 }
203 }
204 }
205 }
206 }
207
208 template <typename T>
TransCHWK(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)209 static void TransCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
210 T *srcData, T *dstData) {
211 T *p1Buff = nullptr;
212 T *p2Buff = nullptr;
213 for (int c = 0; c < filterC; ++c) {
214 for (int h = 0; h < filterH; ++h) {
215 for (int w = 0; w < filterW; ++w) {
216 for (int k = 0; k < filterK; ++k) {
217 p1Buff = srcData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
218 if (type == kCHWK2HWCK) {
219 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
220 } else {
221 p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
222 }
223 *p2Buff = *p1Buff;
224 }
225 }
226 }
227 }
228 }
229
230 template <typename T>
TransHWCK(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)231 static void TransHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
232 T *srcData, T *dstData) {
233 T *p1Buff = nullptr;
234 T *p2Buff = nullptr;
235 for (int h = 0; h < filterH; ++h) {
236 for (int w = 0; w < filterW; ++w) {
237 for (int c = 0; c < filterC; ++c) {
238 for (int k = 0; k < filterK; ++k) {
239 p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
240 if (type == kHWCK2KCHW) {
241 p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
242 } else {
243 p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
244 }
245 *p2Buff = *p1Buff;
246 }
247 }
248 }
249 }
250 }
251
252 template <typename T>
TransHWKC(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)253 static void TransHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
254 T *srcData, T *dstData) {
255 T *p1Buff = nullptr;
256 T *p2Buff = nullptr;
257 for (int h = 0; h < filterH; ++h) {
258 for (int w = 0; w < filterW; ++w) {
259 for (int c = 0; c < filterC; ++c) {
260 for (int k = 0; k < filterK; ++k) {
261 p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
262 if (type == kHWKC2KCHW) {
263 p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
264 } else {
265 p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
266 }
267 *p2Buff = *p1Buff;
268 }
269 }
270 }
271 }
272 }
273
274 template <typename T>
TransNHWC(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)275 static void TransNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
276 T *srcData, T *dstData) {
277 T *p1Buff = nullptr;
278 T *p2Buff = nullptr;
279 for (int k = 0; k < filterK; ++k) {
280 for (int h = 0; h < filterH; ++h) {
281 for (int w = 0; w < filterW; ++w) {
282 for (int c = 0; c < filterC; ++c) {
283 p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
284 if (type == kNHWC2HWCK) {
285 p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
286 } else if (type == kNHWC2CKHW) {
287 p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
288 } else {
289 p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
290 }
291 *p2Buff = *p1Buff;
292 }
293 }
294 }
295 }
296 }
297
298 template <typename T>
TransFilterData(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)299 static STATUS TransFilterData(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
300 T *srcData, T *dstData) {
301 switch (type) {
302 case kCHWK2HWCK:
303 case kCHWK2KHWC: {
304 TransCHWK(type, filterK, filterC, filterH, filterW, srcData, dstData);
305 } break;
306 case kKHWC2HWCK: {
307 TransKHWC2HWCK(filterK, filterC, filterH, filterW, srcData, dstData);
308 } break;
309 case kKCHW2HWCK:
310 case kKCHW2CKHW:
311 case kKCHW2KHWC:
312 case kKCHW2HWKC: {
313 TransKCHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
314 } break;
315 case kCKHW2HWCK:
316 case kCKHW2KHWC:
317 case kCKHW2HWKC: {
318 TransCKHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
319 } break;
320 case kHWCK2KCHW:
321 case kHWCK2CKHW: {
322 TransHWCK(type, filterK, filterC, filterH, filterW, srcData, dstData);
323 } break;
324 case kHWKC2KCHW:
325 case kHWKC2CKHW: {
326 TransHWKC(type, filterK, filterC, filterH, filterW, srcData, dstData);
327 } break;
328 case kNHWC2HWCK:
329 case kNHWC2KCHW:
330 case kNHWC2CKHW: {
331 TransNHWC(type, filterK, filterC, filterH, filterW, srcData, dstData);
332 } break;
333 case kKHWC2CHWK: {
334 TransKHWC2CHWK(filterK, filterC, filterH, filterW, srcData, dstData);
335 } break;
336 default: {
337 MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
338 return RET_ERROR;
339 }
340 }
341 return RET_OK;
342 }
343
344 template <typename T>
TransFilterData(schema::TensorT * tensor,kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW)345 static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
346 int32_t filterH, int32_t filterW) {
347 MS_ASSERT(tensor != nullptr);
348 int count = filterH * filterW * filterC * filterK;
349 if (count <= 0) {
350 MS_LOG(ERROR) << "Dim size invalid";
351 return RET_ERROR;
352 }
353 std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
354 if (buf == nullptr) {
355 MS_LOG(ERROR) << "new buf failed";
356 return RET_ERROR;
357 }
358
359 void *originWeightDate = tensor->data.data();
360 T *weightData = static_cast<T *>(originWeightDate);
361
362 if (weightData == nullptr) {
363 MS_LOG(ERROR) << "weightData is nullptr";
364 return RET_ERROR;
365 }
366
367 if (TransFilterData(type, filterK, filterC, filterH, filterW, weightData, buf.get()) != RET_OK) {
368 MS_LOG(ERROR) << "TransFilterData failed";
369 return RET_ERROR;
370 }
371
372 auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T));
373 if (ret != EOK) {
374 MS_LOG(ERROR) << "memcpy_s failed: " << ret;
375 return RET_ERROR;
376 }
377 return RET_OK;
378 }
379
380 template <typename T>
TransFilterFormat(schema::TensorT * tensor,kTransFilterType type)381 static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) {
382 MS_ASSERT(tensor != nullptr);
383 std::vector<int32_t> oriDims = tensor->dims;
384 if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
385 MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
386 return RET_ERROR;
387 }
388
389 int32_t filterH;
390 int32_t filterW;
391 int32_t filterC;
392 int32_t filterK;
393 auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
394 if (status != RET_OK) {
395 MS_LOG(ERROR) << "GetFilterDim failed: " << status;
396 return status;
397 }
398 status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
399 if (status != RET_OK) {
400 MS_LOG(ERROR) << "SetFilterDim failed: " << status;
401 return status;
402 }
403 status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
404 if (status != RET_OK) {
405 MS_LOG(ERROR) << "TransFilterData failed: " << status;
406 return status;
407 }
408
409 return RET_OK;
410 }
411
412 STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
413
414 size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag = false);
415
416 bool IsPartialFusion(const AnfNodePtr &node);
417
418 bool IsCall(const AnfNodePtr &node);
419
420 bool IsSwitch(const AnfNodePtr &node);
421
422 bool IsMakeTuple(const AnfNodePtr &node);
423
424 ValueNodePtr GetPartialFusionPrim();
425
426 ValueNodePtr GetSwitchAnfPrim();
427
428 ValueNodePtr GetCallAnfPrim();
429
IsGraphInput(const AnfNodePtr & cnode)430 inline bool IsGraphInput(const AnfNodePtr &cnode) {
431 return cnode->isa<Parameter>() && !cnode->cast<ParameterPtr>()->has_default();
432 }
433 } // namespace lite
434 } // namespace mindspore
435 #endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H
436