• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H_
18 #define MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <string>
23 #include <unordered_map>
24 #include "schema/inner/model_generated.h"
25 #include "src/common/common.h"
26 #include "src/common/log_adapter.h"
27 #include "src/tensor.h"
28 #include "include/errorcode.h"
29 #include "securec/include/securec.h"
30 #include "ops/primitive_c.h"
31 #include "tools/optimizer/common/gllo_utils.h"
32 
33 namespace mindspore {
34 namespace lite {
35 std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode);
36 
37 template <typename T>
CreateOperator(const std::unique_ptr<schema::PrimitiveT> & primitive,schema::PrimitiveType type)38 int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) {
39   auto attr = std::make_unique<T>();
40   if (attr == nullptr) {
41     MS_LOG(ERROR) << "new attr failed";
42     return RET_NULL_PTR;
43   }
44   primitive->value.type = type;
45   primitive->value.value = attr.release();
46   return RET_OK;
47 }
48 
49 STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
50 
GetCNodeTType(const schema::CNodeT & cNodeT)51 inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) {
52   if (cNodeT.primitive != nullptr) {
53     return cNodeT.primitive->value.type;
54   } else {
55     return schema::PrimitiveType_NONE;
56   }
57 }
58 
GetCNodeTTypeName(const schema::CNodeT & cNodeT)59 inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) {
60   return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT));
61 }
62 
GetOpType(const schema::CNode & opDef)63 inline schema::PrimitiveType GetOpType(const schema::CNode &opDef) { return opDef.primitive()->value_type(); }
64 
GetOpTypeName(const schema::CNode & opDef)65 inline std::string GetOpTypeName(const schema::CNode &opDef) { return schema::EnumNamePrimitiveType(GetOpType(opDef)); }
66 
67 std::unordered_map<int, int> GetNc2NhAxisMap();
68 
69 std::vector<schema::PrimitiveType> GetInsertOpList();
70 
71 std::vector<schema::PrimitiveType> GetNhwcOpList();
72 
73 std::vector<schema::PrimitiveType> GetNchwOpList();
74 
75 std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
76 
77 std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes();
78 
79 std::vector<schema::PrimitiveType> Getfp32FullOpList();
80 
81 std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
82 
83 size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode);
84 
85 class NodeUtils {
86  public:
87   static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format,
88                             std::vector<int32_t> *dst_dims);
89 };
90 
91 enum kTransFilterType {
92   kKCHW2HWCK,  // 0
93   kKCHW2KHWC,
94   kCKHW2KHWC,
95   kCKHW2HWCK,
96   kKCHW2HWKC,
97   kCKHW2HWKC,
98   kHWCK2KCHW,
99   kHWCK2CKHW,
100   kHWKC2KCHW,
101   kHWKC2CKHW,
102   kNHWC2KCHW,  // 10
103   kNHWC2CKHW,
104   kNHWC2HWCK,
105   kKHWC2HWCK,
106   kCHWK2HWCK,
107   kKHWC2CHWK,
108   kCHWK2KHWC,
109   kKHWC2KCHW,
110   kCKHW2KCHW,
111   kCHWK2KCHW,
112   kKCHW2CKHW  // 20
113 };
114 
115 STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
116                     int32_t *filterH, int32_t *filterW);
117 STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
118                     int32_t filterW);
119 
120 template <typename T>
TransKHWC2CHWK(int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)121 static void TransKHWC2CHWK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
122   T *p1Buff = nullptr;
123   T *p2Buff = nullptr;
124   for (int k = 0; k < filterK; ++k) {
125     for (int h = 0; h < filterH; ++h) {
126       for (int w = 0; w < filterW; ++w) {
127         for (int c = 0; c < filterC; ++c) {
128           p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
129           p2Buff = dstData + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
130           *p2Buff = *p1Buff;
131         }
132       }
133     }
134   }
135 }
136 
137 template <typename T>
TransKHWC2HWCK(int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)138 static void TransKHWC2HWCK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
139   T *p1Buff = nullptr;
140   T *p2Buff = nullptr;
141   for (int k = 0; k < filterK; ++k) {
142     for (int h = 0; h < filterH; ++h) {
143       for (int w = 0; w < filterW; ++w) {
144         for (int c = 0; c < filterC; ++c) {
145           p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
146           p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
147           *p2Buff = *p1Buff;
148         }
149       }
150     }
151   }
152 }
153 
154 template <typename T>
TransCKHW(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)155 static void TransCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
156                       T *srcData, T *dstData) {
157   T *p1Buff = nullptr;
158   T *p2Buff = nullptr;
159   for (int c = 0; c < filterC; ++c) {
160     for (int k = 0; k < filterK; ++k) {
161       for (int h = 0; h < filterH; ++h) {
162         for (int w = 0; w < filterW; ++w) {
163           p1Buff = srcData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
164           if (type == kCKHW2HWCK) {
165             p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
166           } else if (type == kCKHW2KHWC) {
167             p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
168           } else {
169             p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
170           }
171           *p2Buff = *p1Buff;
172         }
173       }
174     }
175   }
176 }
177 
178 template <typename T>
TransKCHW(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)179 static void TransKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
180                       T *srcData, T *dstData) {
181   T *p1Buff = nullptr;
182   T *p2Buff = nullptr;
183   for (int k = 0; k < filterK; ++k) {
184     for (int c = 0; c < filterC; ++c) {
185       for (int h = 0; h < filterH; ++h) {
186         for (int w = 0; w < filterW; ++w) {
187           p1Buff = srcData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
188           if (type == kKCHW2HWCK) {
189             p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
190           } else if (type == kKCHW2KHWC) {
191             p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
192           } else if (type == kKCHW2CKHW) {
193             p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
194           } else {
195             p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
196           }
197           *p2Buff = *p1Buff;
198         }
199       }
200     }
201   }
202 }
203 
204 template <typename T>
TransCHWK(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)205 static void TransCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
206                       T *srcData, T *dstData) {
207   T *p1Buff = nullptr;
208   T *p2Buff = nullptr;
209   for (int c = 0; c < filterC; ++c) {
210     for (int h = 0; h < filterH; ++h) {
211       for (int w = 0; w < filterW; ++w) {
212         for (int k = 0; k < filterK; ++k) {
213           p1Buff = srcData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
214           if (type == kCHWK2HWCK) {
215             p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
216           } else {
217             p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
218           }
219           *p2Buff = *p1Buff;
220         }
221       }
222     }
223   }
224 }
225 
226 template <typename T>
TransHWCK(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)227 static void TransHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
228                       T *srcData, T *dstData) {
229   T *p1Buff = nullptr;
230   T *p2Buff = nullptr;
231   for (int h = 0; h < filterH; ++h) {
232     for (int w = 0; w < filterW; ++w) {
233       for (int c = 0; c < filterC; ++c) {
234         for (int k = 0; k < filterK; ++k) {
235           p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
236           if (type == kHWCK2KCHW) {
237             p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
238           } else {
239             p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
240           }
241           *p2Buff = *p1Buff;
242         }
243       }
244     }
245   }
246 }
247 
248 template <typename T>
TransHWKC(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)249 static void TransHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
250                       T *srcData, T *dstData) {
251   T *p1Buff = nullptr;
252   T *p2Buff = nullptr;
253   for (int h = 0; h < filterH; ++h) {
254     for (int w = 0; w < filterW; ++w) {
255       for (int c = 0; c < filterC; ++c) {
256         for (int k = 0; k < filterK; ++k) {
257           p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
258           if (type == kHWKC2KCHW) {
259             p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
260           } else {
261             p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
262           }
263           *p2Buff = *p1Buff;
264         }
265       }
266     }
267   }
268 }
269 
270 template <typename T>
TransNHWC(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)271 static void TransNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
272                       T *srcData, T *dstData) {
273   T *p1Buff = nullptr;
274   T *p2Buff = nullptr;
275   for (int k = 0; k < filterK; ++k) {
276     for (int h = 0; h < filterH; ++h) {
277       for (int w = 0; w < filterW; ++w) {
278         for (int c = 0; c < filterC; ++c) {
279           p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
280           if (type == kNHWC2HWCK) {
281             p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
282           } else if (type == kNHWC2CKHW) {
283             p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
284           } else {
285             p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
286           }
287           *p2Buff = *p1Buff;
288         }
289       }
290     }
291   }
292 }
293 
294 template <typename T>
TransFilterData(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)295 static STATUS TransFilterData(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
296                               T *srcData, T *dstData) {
297   switch (type) {
298     case kCHWK2HWCK:
299     case kCHWK2KHWC: {
300       TransCHWK(type, filterK, filterC, filterH, filterW, srcData, dstData);
301     } break;
302     case kKHWC2HWCK: {
303       TransKHWC2HWCK(filterK, filterC, filterH, filterW, srcData, dstData);
304     } break;
305     case kKCHW2HWCK:
306     case kKCHW2CKHW:
307     case kKCHW2KHWC:
308     case kKCHW2HWKC: {
309       TransKCHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
310     } break;
311     case kCKHW2HWCK:
312     case kCKHW2KHWC:
313     case kCKHW2HWKC: {
314       TransCKHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
315     } break;
316     case kHWCK2KCHW:
317     case kHWCK2CKHW: {
318       TransHWCK(type, filterK, filterC, filterH, filterW, srcData, dstData);
319     } break;
320     case kHWKC2KCHW:
321     case kHWKC2CKHW: {
322       TransHWKC(type, filterK, filterC, filterH, filterW, srcData, dstData);
323     } break;
324     case kNHWC2HWCK:
325     case kNHWC2KCHW:
326     case kNHWC2CKHW: {
327       TransNHWC(type, filterK, filterC, filterH, filterW, srcData, dstData);
328     } break;
329     case kKHWC2CHWK: {
330       TransKHWC2CHWK(filterK, filterC, filterH, filterW, srcData, dstData);
331     } break;
332     default: {
333       MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
334       return RET_ERROR;
335     }
336   }
337   return RET_OK;
338 }
339 
340 template <typename T>
TransFilterData(schema::TensorT * tensor,kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW)341 static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
342                               int32_t filterH, int32_t filterW) {
343   MS_ASSERT(tensor != nullptr);
344   int count = filterH * filterW * filterC * filterK;
345   if (count <= 0) {
346     MS_LOG(ERROR) << "Dim size invalid";
347     return RET_ERROR;
348   }
349   std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
350   if (buf == nullptr) {
351     MS_LOG(ERROR) << "new buf failed";
352     return RET_ERROR;
353   }
354 
355   void *originWeightDate = tensor->data.data();
356   T *weightData = static_cast<T *>(originWeightDate);
357 
358   if (weightData == nullptr) {
359     MS_LOG(ERROR) << "weightData is nullptr";
360     return RET_ERROR;
361   }
362 
363   if (TransFilterData(type, filterK, filterC, filterH, filterW, weightData, buf.get()) != RET_OK) {
364     MS_LOG(ERROR) << "TransFilterData failed";
365     return RET_ERROR;
366   }
367 
368   auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T));
369   if (ret != EOK) {
370     MS_LOG(ERROR) << "memcpy_s failed: " << ret;
371     return RET_ERROR;
372   }
373   return RET_OK;
374 }
375 
376 template <typename T>
TransFilterFormat(schema::TensorT * tensor,kTransFilterType type)377 static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) {
378   MS_ASSERT(tensor != nullptr);
379   std::vector<int32_t> oriDims = tensor->dims;
380   if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
381     MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
382     return RET_ERROR;
383   }
384 
385   int32_t filterH;
386   int32_t filterW;
387   int32_t filterC;
388   int32_t filterK;
389   auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
390   if (status != RET_OK) {
391     MS_LOG(ERROR) << "GetFilterDim failed: " << status;
392     return status;
393   }
394   status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
395   if (status != RET_OK) {
396     MS_LOG(ERROR) << "SetFilterDim failed: " << status;
397     return status;
398   }
399   status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
400   if (status != RET_OK) {
401     MS_LOG(ERROR) << "TransFilterData failed: " << status;
402     return status;
403   }
404 
405   return RET_OK;
406 }
407 
408 STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
409 
410 size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag = false);
411 
412 bool IsPartialFusion(const AnfNodePtr &node);
413 
414 bool IsCall(const AnfNodePtr &node);
415 
416 bool IsSwitch(const AnfNodePtr &node);
417 
418 bool IsSwitchLayer(const AnfNodePtr &node);
419 
420 bool IsControlFlowOp(const AnfNodePtr &node);
421 
422 bool IsMakeTuple(const AnfNodePtr &node);
423 
424 ValueNodePtr GetPartialFusionPrim();
425 
426 ValueNodePtr GetSwitchAnfPrim();
427 
428 ValueNodePtr GetCallAnfPrim();
429 
IsGraphInput(const AnfNodePtr & cnode)430 inline bool IsGraphInput(const AnfNodePtr &cnode) {
431   return cnode->isa<Parameter>() && !cnode->cast<ParameterPtr>()->has_default();
432 }
433 
434 int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type);
435 }  // namespace lite
436 }  // namespace mindspore
437 #endif  // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H_
438