• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H
18 #define MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H
19 
20 #include <memory>
21 #include <vector>
22 #include <string>
23 #include <unordered_map>
24 #include "schema/inner/model_generated.h"
25 #include "src/common/common.h"
26 #include "src/common/log_adapter.h"
27 #include "src/tensor.h"
28 #include "include/errorcode.h"
29 #include "securec/include/securec.h"
30 #include "tools/optimizer/common/gllo_utils.h"
31 
32 namespace mindspore {
33 namespace lite {
34 std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode);
35 
36 template <typename T>
CreateOperator(const std::unique_ptr<schema::PrimitiveT> & primitive,schema::PrimitiveType type)37 int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) {
38   auto attr = std::make_unique<T>();
39   if (attr == nullptr) {
40     MS_LOG(ERROR) << "new attr failed";
41     return RET_NULL_PTR;
42   }
43   primitive->value.type = type;
44   primitive->value.value = attr.release();
45   return RET_OK;
46 }
47 
48 using STATUS = int;
49 STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
50 
51 STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs);
52 
GetCNodeTType(const schema::CNodeT & cNodeT)53 inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) {
54   if (cNodeT.primitive != nullptr) {
55     return cNodeT.primitive->value.type;
56   } else {
57     return schema::PrimitiveType_NONE;
58   }
59 }
60 
GetCNodeTTypeName(const schema::CNodeT & cNodeT)61 inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) {
62   return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT));
63 }
64 
GetOpType(const schema::CNode & opDef)65 inline schema::PrimitiveType GetOpType(const schema::CNode &opDef) { return opDef.primitive()->value_type(); }
66 
GetOpTypeName(const schema::CNode & opDef)67 inline std::string GetOpTypeName(const schema::CNode &opDef) { return schema::EnumNamePrimitiveType(GetOpType(opDef)); }
68 
69 std::unordered_map<int, int> GetNc2NhAxisMap();
70 
71 std::vector<schema::PrimitiveType> GetInsertOpList();
72 
73 std::vector<schema::PrimitiveType> GetNhwcOpList();
74 
75 std::vector<schema::PrimitiveType> GetNchwOpList();
76 
77 std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
78 
79 std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes();
80 
81 std::vector<schema::PrimitiveType> Getfp32FullOpList();
82 
83 std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
84 
85 const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
86 
87 size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode);
88 
89 class NodeUtils {
90  public:
91   static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format,
92                             std::vector<int32_t> *dst_dims);
93 };
94 
95 enum kTransFilterType {
96   kKCHW2HWCK,  // 0
97   kKCHW2KHWC,
98   kCKHW2KHWC,
99   kCKHW2HWCK,
100   kKCHW2HWKC,
101   kCKHW2HWKC,
102   kHWCK2KCHW,
103   kHWCK2CKHW,
104   kHWKC2KCHW,
105   kHWKC2CKHW,
106   kNHWC2KCHW,  // 10
107   kNHWC2CKHW,
108   kNHWC2HWCK,
109   kKHWC2HWCK,
110   kCHWK2HWCK,
111   kKHWC2CHWK,
112   kCHWK2KHWC,
113   kKHWC2KCHW,
114   kCKHW2KCHW,
115   kCHWK2KCHW,
116   kKCHW2CKHW  // 20
117 };
118 
119 STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
120                     int32_t *filterH, int32_t *filterW);
121 STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
122                     int32_t filterW);
123 
124 template <typename T>
TransKHWC2CHWK(int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)125 static void TransKHWC2CHWK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
126   T *p1Buff = nullptr;
127   T *p2Buff = nullptr;
128   for (int k = 0; k < filterK; ++k) {
129     for (int h = 0; h < filterH; ++h) {
130       for (int w = 0; w < filterW; ++w) {
131         for (int c = 0; c < filterC; ++c) {
132           p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
133           p2Buff = dstData + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
134           *p2Buff = *p1Buff;
135         }
136       }
137     }
138   }
139 }
140 
141 template <typename T>
TransKHWC2HWCK(int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)142 static void TransKHWC2HWCK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
143   T *p1Buff = nullptr;
144   T *p2Buff = nullptr;
145   for (int k = 0; k < filterK; ++k) {
146     for (int h = 0; h < filterH; ++h) {
147       for (int w = 0; w < filterW; ++w) {
148         for (int c = 0; c < filterC; ++c) {
149           p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
150           p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
151           *p2Buff = *p1Buff;
152         }
153       }
154     }
155   }
156 }
157 
158 template <typename T>
TransCKHW(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)159 static void TransCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
160                       T *srcData, T *dstData) {
161   T *p1Buff = nullptr;
162   T *p2Buff = nullptr;
163   for (int c = 0; c < filterC; ++c) {
164     for (int k = 0; k < filterK; ++k) {
165       for (int h = 0; h < filterH; ++h) {
166         for (int w = 0; w < filterW; ++w) {
167           p1Buff = srcData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
168           if (type == kCKHW2HWCK) {
169             p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
170           } else if (type == kCKHW2KHWC) {
171             p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
172           } else {
173             p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
174           }
175           *p2Buff = *p1Buff;
176         }
177       }
178     }
179   }
180 }
181 
182 template <typename T>
TransKCHW(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)183 static void TransKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
184                       T *srcData, T *dstData) {
185   T *p1Buff = nullptr;
186   T *p2Buff = nullptr;
187   for (int k = 0; k < filterK; ++k) {
188     for (int c = 0; c < filterC; ++c) {
189       for (int h = 0; h < filterH; ++h) {
190         for (int w = 0; w < filterW; ++w) {
191           p1Buff = srcData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
192           if (type == kKCHW2HWCK) {
193             p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
194           } else if (type == kKCHW2KHWC) {
195             p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
196           } else if (type == kKCHW2CKHW) {
197             p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
198           } else {
199             p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
200           }
201           *p2Buff = *p1Buff;
202         }
203       }
204     }
205   }
206 }
207 
208 template <typename T>
TransCHWK(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)209 static void TransCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
210                       T *srcData, T *dstData) {
211   T *p1Buff = nullptr;
212   T *p2Buff = nullptr;
213   for (int c = 0; c < filterC; ++c) {
214     for (int h = 0; h < filterH; ++h) {
215       for (int w = 0; w < filterW; ++w) {
216         for (int k = 0; k < filterK; ++k) {
217           p1Buff = srcData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
218           if (type == kCHWK2HWCK) {
219             p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
220           } else {
221             p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
222           }
223           *p2Buff = *p1Buff;
224         }
225       }
226     }
227   }
228 }
229 
230 template <typename T>
TransHWCK(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)231 static void TransHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
232                       T *srcData, T *dstData) {
233   T *p1Buff = nullptr;
234   T *p2Buff = nullptr;
235   for (int h = 0; h < filterH; ++h) {
236     for (int w = 0; w < filterW; ++w) {
237       for (int c = 0; c < filterC; ++c) {
238         for (int k = 0; k < filterK; ++k) {
239           p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
240           if (type == kHWCK2KCHW) {
241             p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
242           } else {
243             p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
244           }
245           *p2Buff = *p1Buff;
246         }
247       }
248     }
249   }
250 }
251 
252 template <typename T>
TransHWKC(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)253 static void TransHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
254                       T *srcData, T *dstData) {
255   T *p1Buff = nullptr;
256   T *p2Buff = nullptr;
257   for (int h = 0; h < filterH; ++h) {
258     for (int w = 0; w < filterW; ++w) {
259       for (int c = 0; c < filterC; ++c) {
260         for (int k = 0; k < filterK; ++k) {
261           p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
262           if (type == kHWKC2KCHW) {
263             p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
264           } else {
265             p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
266           }
267           *p2Buff = *p1Buff;
268         }
269       }
270     }
271   }
272 }
273 
274 template <typename T>
TransNHWC(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)275 static void TransNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
276                       T *srcData, T *dstData) {
277   T *p1Buff = nullptr;
278   T *p2Buff = nullptr;
279   for (int k = 0; k < filterK; ++k) {
280     for (int h = 0; h < filterH; ++h) {
281       for (int w = 0; w < filterW; ++w) {
282         for (int c = 0; c < filterC; ++c) {
283           p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
284           if (type == kNHWC2HWCK) {
285             p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
286           } else if (type == kNHWC2CKHW) {
287             p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
288           } else {
289             p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
290           }
291           *p2Buff = *p1Buff;
292         }
293       }
294     }
295   }
296 }
297 
298 template <typename T>
TransFilterData(kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW,T * srcData,T * dstData)299 static STATUS TransFilterData(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
300                               T *srcData, T *dstData) {
301   switch (type) {
302     case kCHWK2HWCK:
303     case kCHWK2KHWC: {
304       TransCHWK(type, filterK, filterC, filterH, filterW, srcData, dstData);
305     } break;
306     case kKHWC2HWCK: {
307       TransKHWC2HWCK(filterK, filterC, filterH, filterW, srcData, dstData);
308     } break;
309     case kKCHW2HWCK:
310     case kKCHW2CKHW:
311     case kKCHW2KHWC:
312     case kKCHW2HWKC: {
313       TransKCHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
314     } break;
315     case kCKHW2HWCK:
316     case kCKHW2KHWC:
317     case kCKHW2HWKC: {
318       TransCKHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
319     } break;
320     case kHWCK2KCHW:
321     case kHWCK2CKHW: {
322       TransHWCK(type, filterK, filterC, filterH, filterW, srcData, dstData);
323     } break;
324     case kHWKC2KCHW:
325     case kHWKC2CKHW: {
326       TransHWKC(type, filterK, filterC, filterH, filterW, srcData, dstData);
327     } break;
328     case kNHWC2HWCK:
329     case kNHWC2KCHW:
330     case kNHWC2CKHW: {
331       TransNHWC(type, filterK, filterC, filterH, filterW, srcData, dstData);
332     } break;
333     case kKHWC2CHWK: {
334       TransKHWC2CHWK(filterK, filterC, filterH, filterW, srcData, dstData);
335     } break;
336     default: {
337       MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
338       return RET_ERROR;
339     }
340   }
341   return RET_OK;
342 }
343 
344 template <typename T>
TransFilterData(schema::TensorT * tensor,kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW)345 static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
346                               int32_t filterH, int32_t filterW) {
347   MS_ASSERT(tensor != nullptr);
348   int count = filterH * filterW * filterC * filterK;
349   if (count <= 0) {
350     MS_LOG(ERROR) << "Dim size invalid";
351     return RET_ERROR;
352   }
353   std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
354   if (buf == nullptr) {
355     MS_LOG(ERROR) << "new buf failed";
356     return RET_ERROR;
357   }
358 
359   void *originWeightDate = tensor->data.data();
360   T *weightData = static_cast<T *>(originWeightDate);
361 
362   if (weightData == nullptr) {
363     MS_LOG(ERROR) << "weightData is nullptr";
364     return RET_ERROR;
365   }
366 
367   if (TransFilterData(type, filterK, filterC, filterH, filterW, weightData, buf.get()) != RET_OK) {
368     MS_LOG(ERROR) << "TransFilterData failed";
369     return RET_ERROR;
370   }
371 
372   auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T));
373   if (ret != EOK) {
374     MS_LOG(ERROR) << "memcpy_s failed: " << ret;
375     return RET_ERROR;
376   }
377   return RET_OK;
378 }
379 
380 template <typename T>
TransFilterFormat(schema::TensorT * tensor,kTransFilterType type)381 static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) {
382   MS_ASSERT(tensor != nullptr);
383   std::vector<int32_t> oriDims = tensor->dims;
384   if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
385     MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
386     return RET_ERROR;
387   }
388 
389   int32_t filterH;
390   int32_t filterW;
391   int32_t filterC;
392   int32_t filterK;
393   auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
394   if (status != RET_OK) {
395     MS_LOG(ERROR) << "GetFilterDim failed: " << status;
396     return status;
397   }
398   status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
399   if (status != RET_OK) {
400     MS_LOG(ERROR) << "SetFilterDim failed: " << status;
401     return status;
402   }
403   status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
404   if (status != RET_OK) {
405     MS_LOG(ERROR) << "TransFilterData failed: " << status;
406     return status;
407   }
408 
409   return RET_OK;
410 }
411 
412 STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
413 
414 size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag = false);
415 
416 bool IsPartialFusion(const AnfNodePtr &node);
417 
418 bool IsCall(const AnfNodePtr &node);
419 
420 bool IsSwitch(const AnfNodePtr &node);
421 
422 bool IsMakeTuple(const AnfNodePtr &node);
423 
424 ValueNodePtr GetPartialFusionPrim();
425 
426 ValueNodePtr GetSwitchAnfPrim();
427 
428 ValueNodePtr GetCallAnfPrim();
429 
IsGraphInput(const AnfNodePtr & cnode)430 inline bool IsGraphInput(const AnfNodePtr &cnode) {
431   return cnode->isa<Parameter>() && !cnode->cast<ParameterPtr>()->has_default();
432 }
433 }  // namespace lite
434 }  // namespace mindspore
435 #endif  // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H
436