• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/common/node_util.h"
18 #include <memory>
19 #include <set>
20 #include <vector>
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/nn_optimizer_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "src/common/ops/populate/populate_register.h"
26 #include "src/common/common.h"
27 #include "src/common/log_adapter.h"
28 #include "tools/common/graph_util.h"
29 #include "tools/common/tensor_util.h"
30 #include "src/litert/infer_manager.h"
31 #include "mindspore/core/ops/switch.h"
32 #include "mindspore/core/ops/call.h"
33 #include "mindspore/core/ops/fusion/partial_fusion.h"
34 #include "nnacl/op_base.h"
35 
36 namespace mindspore {
37 namespace lite {
38 constexpr size_t kInitialSize = 1024;
GetInputCNode(const CNodePtr & cnode)39 std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode) {
40   if (cnode == nullptr) {
41     return {};
42   }
43   std::vector<CNodePtr> inputs;
44   for (const auto &input : cnode->inputs()) {
45     if (input == nullptr || !utils::isa<CNodePtr>(input)) {
46       continue;
47     }
48     inputs.emplace_back(utils::cast<CNodePtr>(input));
49   }
50   return inputs;
51 }
52 
ConvertDims(mindspore::schema::Format src_format,const std::vector<int32_t> & src_dims,mindspore::schema::Format dst_format,std::vector<int32_t> * dst_dims)53 STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
54                               mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {
55   MS_ASSERT(dst_dims != nullptr);
56   if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) {
57     MS_LOG(ERROR) << "Convert format , src size " << src_dims.size()
58                   << " <3 or src format is equal to dst format,not need convert";
59     *dst_dims = src_dims;
60     return RET_PARAM_INVALID;
61   }
62 
63   std::vector<int32_t> nchw_dim;
64   switch (src_format) {
65     case schema::Format::Format_NCHW:
66       nchw_dim = src_dims;
67       break;
68     case schema::Format::Format_NHWC:
69       if (src_dims.size() == DIM_DEFAULT_SIZE) {
70         nchw_dim.push_back(src_dims[NHWC_N]);
71         nchw_dim.push_back(src_dims[NHWC_C]);
72         nchw_dim.push_back(src_dims[NHWC_H]);
73         nchw_dim.push_back(src_dims[NHWC_W]);
74       } else {
75         nchw_dim.push_back(src_dims[HWC_C]);
76         nchw_dim.push_back(src_dims[HWC_H]);
77         nchw_dim.push_back(src_dims[HWC_W]);
78       }
79       break;
80     default:
81       MS_LOG(ERROR) << "Not support src format: " << EnumNameFormat(src_format);
82       return RET_ERROR;
83   }
84 
85   if (nchw_dim.empty()) {
86     MS_LOG(ERROR) << "Param nchw_dim is empty!";
87     return RET_ERROR;
88   }
89 
90   switch (dst_format) {
91     case schema::Format::Format_NCHW:
92       *dst_dims = nchw_dim;
93       break;
94     case schema::Format::Format_NHWC:
95       if (src_dims.size() == DIM_DEFAULT_SIZE) {
96         dst_dims->push_back(nchw_dim[NCHW_N]);
97         dst_dims->push_back(nchw_dim[NCHW_H]);
98         dst_dims->push_back(nchw_dim[NCHW_W]);
99         dst_dims->push_back(nchw_dim[NCHW_C]);
100       }
101       break;
102     default:
103       MS_LOG(ERROR) << "Not support dst format: " << dst_format;
104       return RET_ERROR;
105   }
106   return RET_OK;
107 }
108 
IsKCHWSource(kTransFilterType type)109 static bool IsKCHWSource(kTransFilterType type) {
110   return (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW);
111 }
112 
IsCKHWSource(kTransFilterType type)113 static bool IsCKHWSource(kTransFilterType type) {
114   return (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC);
115 }
116 
IsHWCKSource(kTransFilterType type)117 static bool IsHWCKSource(kTransFilterType type) { return (type == kHWCK2KCHW || type == kHWCK2CKHW); }
118 
IsHWKCSource(kTransFilterType type)119 static bool IsHWKCSource(kTransFilterType type) { return (type == kHWKC2KCHW || type == kHWKC2CKHW); }
120 
IsNHWCSource(kTransFilterType type)121 static bool IsNHWCSource(kTransFilterType type) {
122   return (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW);
123 }
124 
IsCHWKSource(kTransFilterType type)125 static bool IsCHWKSource(kTransFilterType type) { return (type == kCHWK2HWCK || type == kCHWK2KHWC); }
126 
IsKHWCSource(kTransFilterType type)127 static bool IsKHWCSource(kTransFilterType type) { return (type == kKHWC2HWCK || type == kKHWC2CHWK); }
128 
GetFilterDim(const std::vector<int32_t> & oriDims,kTransFilterType type,int32_t * filterK,int32_t * filterC,int32_t * filterH,int32_t * filterW)129 STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
130                     int32_t *filterH, int32_t *filterW) {
131   if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) {
132     MS_LOG(ERROR) << "null input";
133     return RET_NULL_PTR;
134   }
135   MS_ASSERT(oriDims.size() == 4);
136   if (IsKCHWSource(type)) {
137     *filterK = oriDims.at(KCHW_K);
138     *filterC = oriDims.at(KCHW_C);
139     *filterH = oriDims.at(KCHW_H);
140     *filterW = oriDims.at(KCHW_W);
141   } else if (IsCKHWSource(type)) {
142     *filterC = oriDims.at(CKHW_C);
143     *filterK = oriDims.at(CKHW_K);
144     *filterH = oriDims.at(CKHW_H);
145     *filterW = oriDims.at(CKHW_W);
146   } else if (IsHWCKSource(type)) {
147     *filterH = oriDims.at(HWCK_H);
148     *filterW = oriDims.at(HWCK_W);
149     *filterC = oriDims.at(HWCK_C);
150     *filterK = oriDims.at(HWCK_K);
151   } else if (IsHWKCSource(type)) {
152     *filterH = oriDims.at(HWKC_H);
153     *filterW = oriDims.at(HWKC_W);
154     *filterK = oriDims.at(HWKC_K);
155     *filterC = oriDims.at(HWKC_C);
156   } else if (IsNHWCSource(type)) {
157     *filterK = oriDims.at(NHWC_N);
158     *filterH = oriDims.at(NHWC_H);
159     *filterW = oriDims.at(NHWC_W);
160     *filterC = oriDims.at(NHWC_C);
161   } else if (IsCHWKSource(type)) {
162     *filterC = oriDims.at(CHWK_C);
163     *filterH = oriDims.at(CHWK_H);
164     *filterW = oriDims.at(CHWK_W);
165     *filterK = oriDims.at(CHWK_K);
166   } else if (IsKHWCSource(type)) {
167     *filterK = oriDims.at(KHWC_K);
168     *filterH = oriDims.at(KHWC_H);
169     *filterW = oriDims.at(KHWC_W);
170     *filterC = oriDims.at(KHWC_C);
171   } else {
172     MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
173     return RET_ERROR;
174   }
175   return RET_OK;
176 }
177 
SetFilterDim(schema::TensorT * tensor,kTransFilterType type,int32_t filterK,int32_t filterC,int32_t filterH,int32_t filterW)178 STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
179                     int32_t filterW) {
180   MS_CHECK_TRUE_MSG(tensor != nullptr, RET_NULL_PTR, "tensor is nullptr");
181   if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
182     tensor->dims = {filterH, filterW, filterC, filterK};
183   } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
184     tensor->dims = {filterH, filterW, filterK, filterC};
185   } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
186     tensor->dims = {filterK, filterC, filterH, filterW};
187   } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
188     tensor->dims = {filterC, filterK, filterH, filterW};
189   } else if (type == kKHWC2CHWK) {
190     tensor->dims = {filterC, filterH, filterW, filterK};
191   } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
192     tensor->dims = {filterK, filterH, filterW, filterC};
193   } else {
194     MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
195     return RET_ERROR;
196   }
197   return RET_OK;
198 }
199 
Convert2KHWC(int srcFormat)200 static int Convert2KHWC(int srcFormat) {
201   if (srcFormat == schema::Format::Format_KCHW) return kKCHW2KHWC;
202   if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KHWC;
203   if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KHWC;
204   return -1;
205 }
206 
Convert2HWCK(int srcFormat)207 static int Convert2HWCK(int srcFormat) {
208   if (srcFormat == schema::Format::Format_KCHW) return kKCHW2HWCK;
209   if (srcFormat == schema::Format::Format_KHWC) return kKHWC2HWCK;
210   if (srcFormat == schema::Format::Format_CKHW) return kCKHW2HWCK;
211   if (srcFormat == schema::Format::Format_CHWK) return kCHWK2HWCK;
212   return -1;
213 }
214 
Convert2KCHW(int srcFormat)215 static int Convert2KCHW(int srcFormat) {
216   if (srcFormat == schema::Format::Format_HWCK) return kHWCK2KCHW;
217   if (srcFormat == schema::Format::Format_HWKC) return kHWKC2KCHW;
218   if (srcFormat == schema::Format::Format_KHWC) return kKHWC2KCHW;
219   if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KCHW;
220   if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KCHW;
221   return -1;
222 }
223 
Convert2CKHW(int srcFormat)224 static int Convert2CKHW(int srcFormat) {
225   if (srcFormat == schema::Format::Format_HWCK) return kHWCK2CKHW;
226   if (srcFormat == schema::Format::Format_HWKC) return kHWKC2CKHW;
227   if (srcFormat == schema::Format::Format_KCHW) return kKCHW2CKHW;
228   return -1;
229 }
230 
GetTensorInputIndexInCNode(const uint32_t & tensor_index,const schema::CNodeT & cnode)231 size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode) {
232   size_t ret = -1;
233   for (size_t i = 0; i < cnode.inputIndex.size(); i++) {
234     if (cnode.inputIndex.at(i) == tensor_index) {
235       ret = i;
236     }
237   }
238   return ret;
239 }
240 
TransFilterFormat(schema::TensorT * tensor,schema::Format dstFormat)241 STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
242   if (tensor == nullptr) {
243     MS_LOG(ERROR) << "tensor is null";
244     return RET_NULL_PTR;
245   }
246   std::vector<int32_t> oriDims = tensor->dims;
247   if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
248     MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
249     return RET_ERROR;
250   }
251   auto srcFormat = tensor->format;
252   auto dataType = tensor->dataType;
253   STATUS status;
254   int convert = -1;
255 
256   if (dstFormat == srcFormat) return RET_OK;
257 
258   switch (dstFormat) {
259     case schema::Format::Format_KHWC:
260       convert = Convert2KHWC(srcFormat);
261       break;
262     case schema::Format::Format_HWCK:
263       convert = Convert2HWCK(srcFormat);
264       break;
265     case schema::Format::Format_KCHW:
266       convert = Convert2KCHW(srcFormat);
267       break;
268     case schema::Format::Format_CKHW:
269       convert = Convert2CKHW(srcFormat);
270       break;
271     default:
272       convert = -1;
273   }
274   if (convert == -1) {
275     MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat);
276     return RET_ERROR;
277   }
278 
279   if (dataType == kNumberTypeFloat32) {
280     status = TransFilterFormat<float>(tensor, static_cast<kTransFilterType>(convert));
281   } else if (dataType == kNumberTypeUInt8) {
282     status = TransFilterFormat<uint8_t>(tensor, static_cast<kTransFilterType>(convert));
283   } else if (dataType == kNumberTypeInt8) {
284     status = TransFilterFormat<int8_t>(tensor, static_cast<kTransFilterType>(convert));
285   } else {
286     MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
287     return RET_ERROR;
288   }
289   if (status != RET_OK) {
290     MS_LOG(ERROR) << "TransFilterData failed: " << status;
291     return status;
292   }
293   return RET_OK;
294 }
295 
GetCNodeOutputsSize(const std::shared_ptr<AnfNode> & anf_node,bool train_flag)296 size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag) {
297   MS_CHECK_TRUE_MSG(anf_node != nullptr, 0, "anf_node is nullptr");
298   auto cnode = anf_node->cast<CNodePtr>();
299   MS_CHECK_TRUE_MSG(cnode != nullptr, 0, "cnode is nullptr");
300   if (train_flag &&
301       (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) {
302     return 1;
303   }
304   if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
305     auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
306     return tuple->elements().size();
307   } else {
308     return 1;
309   }
310 }
311 
IsPartialFusion(const AnfNodePtr & node)312 bool IsPartialFusion(const AnfNodePtr &node) {
313   if (node == nullptr) {
314     return false;
315   }
316   if (node->isa<mindspore::CNode>()) {
317     auto cnode = node->cast<CNodePtr>();
318     MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
319     auto vnode_value = cnode->input(0)->cast<ValueNodePtr>()->value();
320     return GetValue<NamedPtr>(vnode_value)->name() == "PartialFusion";
321   }
322   return false;
323 }
324 
IsCall(const AnfNodePtr & node)325 bool IsCall(const AnfNodePtr &node) {
326   if (node == nullptr) {
327     return false;
328   }
329   if (!utils::isa<CNodePtr>(node)) {
330     return false;
331   }
332   auto cnode = node->cast<CNodePtr>();
333   MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
334   if (cnode->inputs().empty()) {
335     return false;
336   }
337   auto cnode_first_input = cnode->input(0);
338   if (utils::isa<CNodePtr>(cnode_first_input)) {
339     return true;
340   }
341   if (utils::isa<ValueNode>(cnode_first_input)) {
342     auto vnode = cnode_first_input->cast<ValueNodePtr>();
343     return GetValueNode<FuncGraphPtr>(vnode) != nullptr;
344   }
345   return false;
346 }
347 
IsSwitch(const AnfNodePtr & node)348 bool IsSwitch(const AnfNodePtr &node) {
349   if (node == nullptr) {
350     return false;
351   }
352   if (!utils::isa<CNodePtr>(node)) {
353     return false;
354   }
355   return opt::CheckPrimitiveType(node, prim::kPrimSwitch);
356 }
357 
IsSwitchLayer(const AnfNodePtr & node)358 bool IsSwitchLayer(const AnfNodePtr &node) {
359   if (node == nullptr) {
360     return false;
361   }
362   if (!utils::isa<CNodePtr>(node)) {
363     return false;
364   }
365   return opt::CheckPrimitiveType(node, prim::kPrimSwitchLayer);
366 }
367 
IsControlFlowOp(const AnfNodePtr & node)368 bool IsControlFlowOp(const AnfNodePtr &node) {
369   return IsPartialFusion(node) || IsCall(node) || IsSwitch(node) || IsSwitchLayer(node);
370 }
371 
IsMakeTuple(const AnfNodePtr & node)372 bool IsMakeTuple(const AnfNodePtr &node) {
373   if (node == nullptr) {
374     return false;
375   }
376   if (!utils::isa<CNodePtr>(node)) {
377     return false;
378   }
379   return opt::CheckPrimitiveType(node, prim::kPrimMakeTuple);
380 }
381 
GetPartialFusionPrim()382 ValueNodePtr GetPartialFusionPrim() {
383   auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
384   MS_CHECK_TRUE_MSG(partial_prim != nullptr, nullptr, "partial_prim is nullptr");
385   auto partial_prim_c = partial_prim->GetPrim();
386   MS_CHECK_TRUE_MSG(partial_prim_c != nullptr, nullptr, "partial_prim_c is nullptr");
387   ValueNodePtr partial_anf_prim = NewValueNode(partial_prim_c);
388   MS_CHECK_TRUE_MSG(partial_anf_prim != nullptr, nullptr, "partial_anf_prim is nullptr");
389   return partial_anf_prim;
390 }
391 
GetSwitchAnfPrim()392 ValueNodePtr GetSwitchAnfPrim() {
393   auto switch_prim = std::make_shared<mindspore::ops::Switch>();
394   MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
395   auto switch_prim_c = switch_prim->GetPrim();
396   MS_CHECK_TRUE_MSG(switch_prim_c != nullptr, nullptr, "switch_prim_c is nullptr");
397   ValueNodePtr switch_anf_prim = NewValueNode(switch_prim_c);
398   MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
399   return switch_anf_prim;
400 }
401 
GetCallAnfPrim()402 ValueNodePtr GetCallAnfPrim() {
403   auto call_prim = std::make_shared<mindspore::ops::Call>();
404   MS_CHECK_TRUE_MSG(call_prim != nullptr, nullptr, "call_prim is nullptr");
405   auto call_prim_c = call_prim->GetPrim();
406   MS_CHECK_TRUE_MSG(call_prim_c != nullptr, nullptr, "call_prim_c is nullptr");
407   ValueNodePtr call_anf_prim = NewValueNode(call_prim_c);
408   MS_CHECK_TRUE_MSG(call_anf_prim != nullptr, nullptr, "call_anf_prim is nullptr");
409   return call_anf_prim;
410 }
411 
UpdateDataType(const AnfNodePtr & cnode,TypeId new_data_type)412 int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type) {
413   auto abstract_base = cnode->abstract();
414   if (abstract_base == nullptr) {
415     MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope();
416     return RET_NULL_PTR;
417   }
418   if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
419     MS_LOG(ERROR) << "Abstract of node should be anstract tensor, " << cnode->fullname_with_scope();
420     return RET_ERROR;
421   }
422   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
423   CHECK_NULL_RETURN(abstract_tensor);
424   CHECK_NULL_RETURN(abstract_tensor->element());
425   abstract_tensor->element()->set_type(TypeIdToType(new_data_type));
426   return RET_OK;
427 }
428 }  // namespace lite
429 }  // namespace mindspore
430