• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/common/node_util.h"
18 #include <memory>
19 #include <set>
20 #include <vector>
21 #include "src/ops/populate/populate_register.h"
22 #include "src/common/common.h"
23 #include "src/common/log_adapter.h"
24 #include "tools/common/graph_util.h"
25 #include "tools/common/tensor_util.h"
26 #include "src/runtime/infer_manager.h"
27 #include "mindspore/core/ops/switch.h"
28 #include "mindspore/core/ops/call.h"
29 #include "mindspore/core/ops/fusion/partial_fusion.h"
30 #include "nnacl/op_base.h"
31 
32 namespace mindspore {
33 namespace lite {
34 constexpr size_t kInitialSize = 1024;
GetInputCNode(const CNodePtr & cnode)35 std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode) {
36   if (cnode == nullptr) {
37     return {};
38   }
39   std::vector<CNodePtr> inputs;
40   for (const auto &input : cnode->inputs()) {
41     if (input == nullptr || !utils::isa<CNodePtr>(input)) {
42       continue;
43     }
44     inputs.emplace_back(utils::cast<CNodePtr>(input));
45   }
46   return inputs;
47 }
48 
ConvertToPrimitive(schema::PrimitiveT * primitive_t,flatbuffers::FlatBufferBuilder * fbb)49 const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
50   if (primitive_t == nullptr || fbb == nullptr) {
51     MS_LOG(ERROR) << "primitiveT or fbb is nullptr.";
52     return nullptr;
53   }
54   auto prim_offset = schema::CreatePrimitive(*fbb, primitive_t);
55   fbb->Finish(prim_offset);
56   auto prim_buf = fbb->GetBufferPointer();
57   return flatbuffers::GetRoot<schema::Primitive>(prim_buf);
58 }
59 
ConvertDims(mindspore::schema::Format src_format,const std::vector<int32_t> & src_dims,mindspore::schema::Format dst_format,std::vector<int32_t> * dst_dims)60 STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
61                               mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {
62   MS_ASSERT(dst_dims != nullptr);
63   if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) {
64     MS_LOG(ERROR) << "Convert format , src size " << src_dims.size()
65                   << " <3 or src format is equal to dst format,not need convert";
66     *dst_dims = src_dims;
67     return RET_PARAM_INVALID;
68   }
69 
70   std::vector<int32_t> nchw_dim;
71   switch (src_format) {
72     case schema::Format::Format_NCHW:
73       nchw_dim = src_dims;
74       break;
75     case schema::Format::Format_NHWC:
76       if (src_dims.size() == DIM_DEFAULT_SIZE) {
77         nchw_dim.push_back(src_dims[NHWC_N]);
78         nchw_dim.push_back(src_dims[NHWC_C]);
79         nchw_dim.push_back(src_dims[NHWC_H]);
80         nchw_dim.push_back(src_dims[NHWC_W]);
81       } else {
82         nchw_dim.push_back(src_dims[HWC_C]);
83         nchw_dim.push_back(src_dims[HWC_H]);
84         nchw_dim.push_back(src_dims[HWC_W]);
85       }
86       break;
87     default:
88       MS_LOG(ERROR) << "Not support src format: " << EnumNameFormat(src_format);
89       return RET_ERROR;
90   }
91 
92   if (nchw_dim.empty()) {
93     MS_LOG(ERROR) << "Param nchw_dim is empty!";
94     return RET_ERROR;
95   }
96 
97   switch (dst_format) {
98     case schema::Format::Format_NCHW:
99       *dst_dims = nchw_dim;
100       break;
101     case schema::Format::Format_NHWC:
102       if (src_dims.size() == DIM_DEFAULT_SIZE) {
103         dst_dims->push_back(nchw_dim[NCHW_N]);
104         dst_dims->push_back(nchw_dim[NCHW_H]);
105         dst_dims->push_back(nchw_dim[NCHW_W]);
106         dst_dims->push_back(nchw_dim[NCHW_C]);
107       }
108       break;
109     default:
110       MS_LOG(ERROR) << "Not support dst format: " << dst_format;
111       return RET_ERROR;
112   }
113   return RET_OK;
114 }
115 
IsKCHWSource(kTransFilterType type)116 static bool IsKCHWSource(kTransFilterType type) {
117   return (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW);
118 }
119 
IsCKHWSource(kTransFilterType type)120 static bool IsCKHWSource(kTransFilterType type) {
121   return (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC);
122 }
123 
IsHWCKSource(kTransFilterType type)124 static bool IsHWCKSource(kTransFilterType type) { return (type == kHWCK2KCHW || type == kHWCK2CKHW); }
125 
IsHWKCSource(kTransFilterType type)126 static bool IsHWKCSource(kTransFilterType type) { return (type == kHWKC2KCHW || type == kHWKC2CKHW); }
127 
IsNHWCSource(kTransFilterType type)128 static bool IsNHWCSource(kTransFilterType type) {
129   return (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW);
130 }
131 
IsCHWKSource(kTransFilterType type)132 static bool IsCHWKSource(kTransFilterType type) { return (type == kCHWK2HWCK || type == kCHWK2KHWC); }
133 
IsKHWCSource(kTransFilterType type)134 static bool IsKHWCSource(kTransFilterType type) { return (type == kKHWC2HWCK || type == kKHWC2CHWK); }
135 
GetFilterDim(const std::vector<int32_t> & oriDims,kTransFilterType type,int32_t * filterK,int32_t * filterC,int32_t * filterH,int32_t * filterW)136 STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
137                     int32_t *filterH, int32_t *filterW) {
138   if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) {
139     MS_LOG(ERROR) << "null input";
140     return RET_NULL_PTR;
141   }
142   MS_ASSERT(oriDims.size() == 4);
143   if (IsKCHWSource(type)) {
144     *filterK = oriDims.at(KCHW_K);
145     *filterC = oriDims.at(KCHW_C);
146     *filterH = oriDims.at(KCHW_H);
147     *filterW = oriDims.at(KCHW_W);
148   } else if (IsCKHWSource(type)) {
149     *filterC = oriDims.at(CKHW_C);
150     *filterK = oriDims.at(CKHW_K);
151     *filterH = oriDims.at(CKHW_H);
152     *filterW = oriDims.at(CKHW_W);
153   } else if (IsHWCKSource(type)) {
154     *filterH = oriDims.at(HWCK_H);
155     *filterW = oriDims.at(HWCK_W);
156     *filterC = oriDims.at(HWCK_C);
157     *filterK = oriDims.at(HWCK_K);
158   } else if (IsHWKCSource(type)) {
159     *filterH = oriDims.at(HWKC_H);
160     *filterW = oriDims.at(HWKC_W);
161     *filterK = oriDims.at(HWKC_K);
162     *filterC = oriDims.at(HWKC_C);
163   } else if (IsNHWCSource(type)) {
164     *filterK = oriDims.at(NHWC_N);
165     *filterH = oriDims.at(NHWC_H);
166     *filterW = oriDims.at(NHWC_W);
167     *filterC = oriDims.at(NHWC_C);
168   } else if (IsCHWKSource(type)) {
169     *filterC = oriDims.at(CHWK_C);
170     *filterH = oriDims.at(CHWK_H);
171     *filterW = oriDims.at(CHWK_W);
172     *filterK = oriDims.at(CHWK_K);
173   } else if (IsKHWCSource(type)) {
174     *filterK = oriDims.at(KHWC_K);
175     *filterH = oriDims.at(KHWC_H);
176     *filterW = oriDims.at(KHWC_W);
177     *filterC = oriDims.at(KHWC_C);
178   } else {
179     MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
180     return RET_ERROR;
181   }
182   return RET_OK;
183 }
184 
SetFilterDim(schema::TensorT * tensor,kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW)185 STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
186                     int32_t filterW) {
187   MS_ASSERT(tensor != nullptr);
188   if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
189     tensor->dims = {filterH, filterW, filterC, filterK};
190   } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
191     tensor->dims = {filterH, filterW, filterK, filterC};
192   } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
193     tensor->dims = {filterK, filterC, filterH, filterW};
194   } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
195     tensor->dims = {filterC, filterK, filterH, filterW};
196   } else if (type == kKHWC2CHWK) {
197     tensor->dims = {filterC, filterH, filterW, filterK};
198   } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
199     tensor->dims = {filterK, filterH, filterW, filterC};
200   } else {
201     MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
202     return RET_ERROR;
203   }
204   return RET_OK;
205 }
206 
Convert2KHWC(int srcFormat)207 static int Convert2KHWC(int srcFormat) {
208   if (srcFormat == schema::Format::Format_KCHW) return kKCHW2KHWC;
209   if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KHWC;
210   if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KHWC;
211   return -1;
212 }
213 
Convert2HWCK(int srcFormat)214 static int Convert2HWCK(int srcFormat) {
215   if (srcFormat == schema::Format::Format_KCHW) return kKCHW2HWCK;
216   if (srcFormat == schema::Format::Format_KHWC) return kKHWC2HWCK;
217   if (srcFormat == schema::Format::Format_CKHW) return kCKHW2HWCK;
218   if (srcFormat == schema::Format::Format_CHWK) return kCHWK2HWCK;
219   return -1;
220 }
221 
Convert2KCHW(int srcFormat)222 static int Convert2KCHW(int srcFormat) {
223   if (srcFormat == schema::Format::Format_HWCK) return kHWCK2KCHW;
224   if (srcFormat == schema::Format::Format_HWKC) return kHWKC2KCHW;
225   if (srcFormat == schema::Format::Format_KHWC) return kKHWC2KCHW;
226   if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KCHW;
227   if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KCHW;
228   return -1;
229 }
230 
Convert2CKHW(int srcFormat)231 static int Convert2CKHW(int srcFormat) {
232   if (srcFormat == schema::Format::Format_HWCK) return kHWCK2CKHW;
233   if (srcFormat == schema::Format::Format_HWKC) return kHWKC2CKHW;
234   if (srcFormat == schema::Format::Format_KCHW) return kKCHW2CKHW;
235   return -1;
236 }
237 
NodeInferShpae(const schema::CNodeT & node,const std::vector<Tensor * > & inputs,std::vector<Tensor * > * outputs)238 STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs) {
239   flatbuffers::FlatBufferBuilder fbb(kInitialSize);
240   auto prim = ConvertToPrimitive(node.primitive.get(), &fbb);
241   if (prim == nullptr) {
242     MS_LOG(ERROR) << "get primitive failed.";
243     fbb.Clear();
244     return RET_ERROR;
245   }
246   auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR);
247   if (parameter_gen == nullptr) {
248     fbb.Clear();
249     MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
250     return RET_ERROR;
251   }
252   auto parameter = parameter_gen(prim);
253   if (parameter == nullptr) {
254     fbb.Clear();
255     MS_LOG(ERROR) << "parameter is nullptr.";
256     return RET_ERROR;
257   }
258   auto ret = KernelInferShape(inputs, *outputs, parameter);
259   fbb.Clear();
260   if (parameter->destroy_func_ != nullptr) {
261     parameter->destroy_func_(parameter);
262   }
263   free(parameter);
264   parameter = nullptr;
265   return ret;
266 }
267 
GetTensorInputIndexInCNode(const uint32_t & tensor_index,const schema::CNodeT & cnode)268 size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode) {
269   size_t ret = -1;
270   for (size_t i = 0; i < cnode.inputIndex.size(); i++) {
271     if (cnode.inputIndex.at(i) == tensor_index) {
272       ret = i;
273     }
274   }
275   return ret;
276 }
277 
TransFilterFormat(schema::TensorT * tensor,schema::Format dstFormat)278 STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
279   if (tensor == nullptr) {
280     MS_LOG(ERROR) << "tensor is null";
281     return RET_NULL_PTR;
282   }
283   std::vector<int32_t> oriDims = tensor->dims;
284   if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
285     MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
286     return RET_ERROR;
287   }
288   auto srcFormat = tensor->format;
289   auto dataType = tensor->dataType;
290   STATUS status;
291   int convert = -1;
292 
293   if (dstFormat == srcFormat) return RET_OK;
294 
295   switch (dstFormat) {
296     case schema::Format::Format_KHWC:
297       convert = Convert2KHWC(srcFormat);
298       break;
299     case schema::Format::Format_HWCK:
300       convert = Convert2HWCK(srcFormat);
301       break;
302     case schema::Format::Format_KCHW:
303       convert = Convert2KCHW(srcFormat);
304       break;
305     case schema::Format::Format_CKHW:
306       convert = Convert2CKHW(srcFormat);
307       break;
308     default:
309       convert = -1;
310   }
311   if (convert == -1) {
312     MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat);
313     return RET_ERROR;
314   }
315 
316   if (dataType == kNumberTypeFloat32) {
317     status = TransFilterFormat<float>(tensor, static_cast<kTransFilterType>(convert));
318   } else if (dataType == kNumberTypeUInt8) {
319     status = TransFilterFormat<uint8_t>(tensor, static_cast<kTransFilterType>(convert));
320   } else if (dataType == kNumberTypeInt8) {
321     status = TransFilterFormat<int8_t>(tensor, static_cast<kTransFilterType>(convert));
322   } else {
323     MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
324     return RET_ERROR;
325   }
326   if (status != RET_OK) {
327     MS_LOG(ERROR) << "TransFilterData failed: " << status;
328     return status;
329   }
330   return RET_OK;
331 }
332 
GetCNodeOutputsSize(const std::shared_ptr<AnfNode> & anf_node,bool train_flag)333 size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag) {
334   MS_ASSERT(anf_node != nullptr);
335   auto cnode = anf_node->cast<CNodePtr>();
336   MS_ASSERT(cnode != nullptr);
337   if (train_flag &&
338       (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) {
339     return 1;
340   }
341   if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
342     auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
343     return tuple->elements().size();
344   } else {
345     return 1;
346   }
347 }
348 
IsPartialFusion(const AnfNodePtr & node)349 bool IsPartialFusion(const AnfNodePtr &node) {
350   if (node == nullptr) {
351     return false;
352   }
353   if (node->isa<mindspore::CNode>()) {
354     auto cnode = node->cast<CNodePtr>();
355     MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
356     auto vnode_value = cnode->input(0)->cast<ValueNodePtr>()->value();
357     return GetValue<NamedPtr>(vnode_value)->name() == "PartialFusion";
358   }
359   return false;
360 }
361 
IsCall(const AnfNodePtr & node)362 bool IsCall(const AnfNodePtr &node) {
363   if (node == nullptr) {
364     return false;
365   }
366   if (!utils::isa<CNodePtr>(node)) {
367     return false;
368   }
369   auto cnode = node->cast<CNodePtr>();
370   MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
371   if (cnode->inputs().empty()) {
372     return false;
373   }
374   auto cnode_first_input = cnode->input(0);
375   if (utils::isa<CNodePtr>(cnode_first_input)) {
376     return true;
377   }
378   if (utils::isa<ValueNode>(cnode_first_input)) {
379     auto vnode = cnode_first_input->cast<ValueNodePtr>();
380     return GetValueNode<FuncGraphPtr>(vnode) != nullptr;
381   }
382   return false;
383 }
384 
IsSwitch(const AnfNodePtr & node)385 bool IsSwitch(const AnfNodePtr &node) {
386   if (node == nullptr) {
387     return false;
388   }
389   if (!utils::isa<CNodePtr>(node)) {
390     return false;
391   }
392   return opt::CheckPrimitiveType(node, prim::kPrimSwitch);
393 }
394 
IsMakeTuple(const AnfNodePtr & node)395 bool IsMakeTuple(const AnfNodePtr &node) {
396   if (node == nullptr) {
397     return false;
398   }
399   if (!utils::isa<CNodePtr>(node)) {
400     return false;
401   }
402   return opt::CheckPrimitiveType(node, prim::kPrimMakeTuple);
403 }
404 
GetPartialFusionPrim()405 ValueNodePtr GetPartialFusionPrim() {
406   auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
407   MS_CHECK_TRUE_MSG(partial_prim != nullptr, nullptr, "partial_prim is nullptr");
408   ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
409   MS_CHECK_TRUE_MSG(partial_anf_prim != nullptr, nullptr, "partial_anf_prim is nullptr");
410   return partial_anf_prim;
411 }
412 
GetSwitchAnfPrim()413 ValueNodePtr GetSwitchAnfPrim() {
414   auto switch_prim = std::make_shared<mindspore::ops::Switch>();
415   MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
416   ValueNodePtr switch_anf_prim = NewValueNode(switch_prim);
417   MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
418   return switch_anf_prim;
419 }
420 
GetCallAnfPrim()421 ValueNodePtr GetCallAnfPrim() {
422   auto call_prim = std::make_shared<mindspore::ops::Call>();
423   MS_CHECK_TRUE_MSG(call_prim != nullptr, nullptr, "call_prim is nullptr");
424   ValueNodePtr call_anf_prim = NewValueNode(call_prim);
425   MS_CHECK_TRUE_MSG(call_anf_prim != nullptr, nullptr, "call_anf_prim is nullptr");
426   return call_anf_prim;
427 }
428 }  // namespace lite
429 }  // namespace mindspore
430