• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/anf_exporter/fetch_content.h"
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 #include <unordered_map>
22 #include "tools/converter/quant_param_holder.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "tools/optimizer/common/format_utils.h"
26 #include "nnacl/op_base.h"
27 
28 namespace mindspore {
29 namespace lite {
30 namespace {
31 constexpr int kNumWeightIndex = 2;
32 constexpr int kNumTransposePermSize = 4;
33 constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t);
34 static const std::unordered_map<int, int> TypeToTypeMap = {
35   {kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}};
GetShapeVectorFromStringTensor(const tensor::TensorPtr & tensor_info,ShapeVector * shape_vector,size_t * offset)36 STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) {
37   MS_ASSERT(tensor_info != nullptr && shape_vector != nullptr && offset != nullptr);
38   auto data_type = tensor_info->data_type();
39   if (data_type != kObjectTypeString) {
40     MS_LOG(ERROR) << "This function only used for string tensor.";
41     return RET_ERROR;
42   }
43   shape_vector->clear();
44   MS_CHECK_TRUE_MSG(tensor_info->data_c() != nullptr, RET_ERROR, "tensor_info->data_c() is nullptr");
45   auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
46   std::string shape_str;
47   std::string shape_size_str;
48   *offset = 0;
49   size_t cnt = 0;
50   for (; *offset < tensor_info->Size(); (*offset)++) {
51     if (tensor_data[*offset] == ',') {
52       (*offset)++;
53       break;
54     }
55     shape_size_str.push_back(tensor_data[*offset]);
56   }
57   if (*offset == 0) {
58     MS_LOG(ERROR) << "string tensor's dim size not found.";
59     return RET_ERROR;
60   }
61   size_t shape_size = std::atoi(shape_size_str.c_str());
62   MS_CHECK_TRUE_RET(shape_size != 0, RET_ERROR);
63   for (; *offset < tensor_info->Size(); (*offset)++) {
64     if (tensor_data[*offset] == ',') {
65       cnt++;
66       int64_t shape = 0;
67       try {
68         shape = std::stoi(shape_str);
69       } catch (const std::exception &e) {
70         MS_LOG(ERROR) << "Get shape failed: " << e.what();
71         return RET_ERROR;
72       }
73       shape_vector->push_back(shape);
74       shape_str.clear();
75     } else {
76       shape_str.push_back(tensor_data[*offset]);
77     }
78     if (cnt == shape_size) {
79       (*offset)++;
80       break;
81     }
82   }
83   if (shape_vector->empty()) {
84     MS_LOG(ERROR) << "string tensor's shape shouldn't be empty.";
85     return RET_ERROR;
86   }
87   return RET_OK;
88 }
89 
GetDataTypeAndShape(const ParameterPtr & param_node,TypeId * data_type,ShapeVector * shape_vector)90 STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, ShapeVector *shape_vector) {
91   MS_ASSERT(param_node != nullptr && data_type != nullptr && shape_vector != nullptr);
92   auto abstract_base = param_node->abstract();
93   if (abstract_base == nullptr) {
94     MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
95     return RET_PARAM_INVALID;
96   }
97   if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
98     MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
99     return RET_INPUT_TENSOR_ERROR;
100   }
101   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
102   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
103   auto typePtr = abstract_tensor->element()->GetTypeTrack();
104   MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
105   *data_type = typePtr->type_id();
106   if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
107     MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
108     return RET_PARAM_INVALID;
109   }
110   *shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
111   return RET_OK;
112 }
113 
FetchFromTensorValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)114 int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
115                          bool train_flag, DataInfo *data_info) {
116   MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
117   auto valueAbstract = value_node->abstract();
118   MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
119   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
120   if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
121     MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
122     return RET_ERROR;
123   }
124   auto typePtr = abstract_tensor->element()->GetTypeTrack();
125   MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
126   data_info->data_type_ = typePtr->type_id();
127   auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
128   std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
129   data_info->shape_ = dims;
130   if (train_flag && dims.empty()) {
131     data_info->shape_ = {1};
132   }
133   auto value = value_node->value();
134   MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
135   auto data = value->cast<tensor::TensorPtr>();
136   MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is invalid");
137   data_info->data_.resize(data->Size());
138   if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
139     MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
140     return RET_ERROR;
141   }
142 
143   // process weight tensor
144   if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
145     MS_LOG(ERROR) << "memcpy_s error.";
146     return RET_ERROR;
147   }
148   return RET_OK;
149 }
150 
FetchFromInt32OrInt64ImmValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,DataInfo * data_info)151 int FetchFromInt32OrInt64ImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
152   MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
153   // data of int64 is converted to int32 here.
154   data_info->data_type_ = kNumberTypeInt32;
155   data_info->shape_ = {1};
156   data_info->data_.resize(sizeof(int32_t));
157   auto value = value_node->value();
158   MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
159   int real_data = opt::CastToInt(value).front();
160   if (memcpy_s(data_info->data_.data(), sizeof(int32_t), &real_data, sizeof(int32_t)) != EOK) {
161     MS_LOG(ERROR) << "memcpy_s failed";
162     return RET_MEMORY_FAILED;
163   }
164   return RET_OK;
165 }
166 
FetchFromBoolImmValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,DataInfo * data_info)167 int FetchFromBoolImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
168   MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
169   data_info->data_type_ = kNumberTypeBool;
170   data_info->shape_ = {1};
171   data_info->data_.resize(sizeof(bool));
172   auto value = value_node->value();
173   MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
174   auto data = value->cast<mindspore::BoolImmPtr>();
175   MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is nullptr");
176   auto data_value = data->value();
177   if (memcpy_s(data_info->data_.data(), sizeof(bool), &data_value, sizeof(bool)) != EOK) {
178     MS_LOG(ERROR) << "memcpy_s failed";
179     return RET_MEMORY_FAILED;
180   }
181   return RET_OK;
182 }
183 
FetchFromNumberValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,DataInfo * data_info)184 int FetchFromNumberValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
185   MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
186   data_info->data_type_ = kNumberTypeInt32;
187   data_info->shape_ = {1};
188   data_info->data_.resize(sizeof(int));
189   auto data = value_node->value()->cast<NumberPtr>();
190   MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberPtr failed");
191   int number_type = data->number_type();
192   if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) {
193     number_type = TypeToTypeMap.at(number_type);
194   }
195   if (memcpy_s(data_info->data_.data(), sizeof(int), &number_type, sizeof(int)) != EOK) {
196     MS_LOG(ERROR) << "memcpy_s failed";
197     return RET_MEMORY_FAILED;
198   }
199   return RET_OK;
200 }
201 
FetchFromSequenceValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,DataInfo * data_info)202 int FetchFromSequenceValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
203   MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
204   auto value = value_node->value();
205   MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
206   std::vector<int32_t> shape;
207   auto value_seq = value->cast<ValueSequeuePtr>();
208   MS_CHECK_TRUE_MSG(value_seq != nullptr, RET_ERROR, "value_seq is nullptr");
209   if (!value_seq->value().empty()) {
210     if (value_seq->value().front()->type()->number_type() == kNumberTypeInt32 ||
211         value_seq->value().front()->type()->number_type() == kNumberTypeInt) {
212       shape = GetValue<std::vector<int>>(value);
213     } else if (value_seq->value().front()->type()->number_type() == kNumberTypeInt64) {
214       auto origin_value = GetValue<std::vector<int64_t>>(value);
215       std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(shape),
216                      [](int64_t val) { return static_cast<int32_t>(val); });
217     } else {
218       MS_LOG(ERROR) << "Value type is ValueSequence is not integer.";
219       return RET_ERROR;
220     }
221   }
222   data_info->data_type_ = kNumberTypeInt32;
223   data_info->shape_ = {static_cast<int32_t>(shape.size())};
224   data_info->data_.resize(shape.size() * sizeof(int));
225   if (!shape.empty() && memcpy_s(data_info->data_.data(), shape.size() * sizeof(int32_t), shape.data(),
226                                  shape.size() * sizeof(int32_t)) != EOK) {
227     MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed.";
228     return RET_ERROR;
229   }
230   return RET_OK;
231 }
232 }  // namespace
233 
FetchFromDefaultParam(const ParameterPtr & param_node,const converter::FmkType & fmk_type,DataInfo * data_info)234 int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkType &fmk_type, DataInfo *data_info) {
235   MS_ASSERT(param_node != nullptr && data_info != nullptr);
236   ShapeVector shape_vector;
237   TypeId data_type = kTypeUnknown;
238   auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
239   if (status != RET_OK) {
240     MS_LOG(ERROR) << "get data type and shape from param node failed.";
241     return RET_ERROR;
242   }
243   data_info->data_type_ = data_type;
244   auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
245   size_t offset = 0;
246   if (!shape_vector.empty() && data_type == kObjectTypeString) {
247     status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
248     if (status != RET_OK) {
249       MS_LOG(ERROR) << "get shape vector from string tensor failed.";
250       return RET_ERROR;
251     }
252   }
253   std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
254   data_info->shape_ = dims;
255   if (tensor_info != nullptr && tensor_info->Size() != 0) {
256     if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
257       data_info->data_.resize(tensor_info->Size() - offset);
258       if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
259                           static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
260         MS_LOG(ERROR) << "memcpy_s failed.";
261         return RET_ERROR;
262       }
263     }
264   }
265 
266   data_info->format_ = NHWC;
267   return RET_OK;
268 }
269 
FetchDataFromParameterNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)270 int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
271                                DataInfo *data_info) {
272   MS_ASSERT(cnode != nullptr && data_info != nullptr);
273   auto param_node = cnode->input(index)->cast<ParameterPtr>();
274   MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "input node is not parameter node.");
275   if (FetchFromDefaultParam(param_node, fmk_type, data_info) != RET_OK) {
276     MS_LOG(ERROR) << "fetch information from default param failed.";
277     return RET_ERROR;
278   }
279   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
280   MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
281   if (prim->GetAttr(ops::kFormat) == nullptr && !param_node->has_default()) {
282     data_info->format_ = mindspore::NHWC;
283   }
284   if (prim->GetAttr(ops::kFormat) != nullptr) {
285     auto value = prim->GetAttr(ops::kFormat);
286     if (value->isa<mindspore::Int64Imm>()) {
287       data_info->format_ = GetValue<int64_t>(value);
288     }
289   }
290   QuantParamHolderPtr quant_param_holder =
291     prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
292   if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() &&
293       data_info->data_type_ == kNumberTypeInt8) {
294     data_info->enable_huffman_code_ = true;
295   }
296   data_info->node_type_ = NodeType_ValueNode;
297   return RET_OK;
298 }
299 
FetchDataFromValueNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)300 int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
301                            DataInfo *data_info) {
302   MS_ASSERT(cnode != nullptr && data_info != nullptr);
303   auto value_node = cnode->input(index)->cast<ValueNodePtr>();
304   MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "input node is not value node.");
305 
306   auto value = value_node->value();
307   int ret = RET_OK;
308   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
309   MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "prim is nullptr");
310   if (value->isa<tensor::Tensor>()) {
311     ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info);
312     if (index == kNumWeightIndex && prim->GetAttr(ops::kFormat) != nullptr) {
313       data_info->format_ = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
314     }
315   } else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
316     ret = FetchFromInt32OrInt64ImmValue(value_node, prim, data_info);
317   } else if (value->isa<mindspore::BoolImm>()) {
318     ret = FetchFromBoolImmValue(value_node, prim, data_info);
319   } else if (value->isa<mindspore::ValueSequeue>()) {
320     ret = FetchFromSequenceValue(value_node, prim, data_info);
321   } else if (value->isa<Number>()) {
322     ret = FetchFromNumberValue(value_node, prim, data_info);
323   } else if (value->isa<FuncGraph>()) {
324     MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is func_graph";
325     return RET_NO_CHANGE;
326   } else if (value->isa<Monad>()) {
327     MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is Monad";
328     return RET_NO_CHANGE;
329   } else {
330     MS_LOG(ERROR) << "Not support value type , need add support.";
331     return RET_ERROR;
332   }
333   data_info->node_type_ = NodeType_ValueNode;
334   return ret;
335 }
336 
SetFormatForCnode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)337 int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
338                       DataInfo *data_info) {
339   data_info->format_ = mindspore::NHWC;
340   MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, RET_ERROR, "input is nullptr");
341   auto input_node_prim = GetValueNode<PrimitivePtr>((cnode->input(index)->cast<CNodePtr>()->input(0)));
342   MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "GetValueNode failed");
343   if (input_node_prim->GetAttr(ops::kFormat) != nullptr) {
344     auto value = input_node_prim->GetAttr(ops::kFormat);
345     if (value->isa<mindspore::Int64Imm>()) {
346       data_info->format_ = GetValue<int64_t>(value);
347     }
348   }
349   if (opt::CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
350     std::vector<int> perm;
351     if (opt::GetTransposePerm(cnode->input(index)->cast<CNodePtr>(), &perm) != RET_OK) {
352       return RET_ERROR;
353     }
354     if (perm.size() < kNumTransposePermSize) {
355       return RET_OK;
356     }
357     // NHWC to NCHW: perm is {0, 3, 1, 2}
358     // NCHW to NHWC: perm is {0, 2, 3, 1}
359     if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 &&
360         (data_info->format_ == NHWC || data_info->format_ == KHWC)) {
361       data_info->format_ = NCHW;
362     } else if (perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1 && data_info->format_ == NCHW) {
363       data_info->format_ = NHWC;
364     }
365   }
366   return RET_OK;
367 }
368 
FetchDataFromCNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)369 int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
370                        DataInfo *data_info) {
371   MS_ASSERT(cnode != nullptr && data_info != nullptr);
372   auto abstract = opt::GetCNodeInputAbstract(cnode, index);
373   if (abstract == nullptr) {
374     MS_LOG(ERROR) << "Abstract cnode is nullptr.";
375     return RET_ERROR;
376   }
377   if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
378     MS_LOG(ERROR) << "Abstract should be anstract tensor.";
379     return RET_ERROR;
380   }
381   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
382   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
383   auto type_ptr = abstract_tensor->element()->GetTypeTrack();
384   MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
385   if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
386     MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr.";
387     return RET_ERROR;
388   }
389   auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
390   std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
391   auto ret = SetFormatForCnode(cnode, index, fmk_type, train_flag, data_info);
392   if (ret != RET_OK) {
393     MS_LOG(ERROR) << "set format for cnode failed";
394     return RET_ERROR;
395   }
396   data_info->data_type_ = type_ptr->type_id();
397   data_info->shape_ = dims;
398   data_info->node_type_ = NodeType_CNode;
399   if (type_ptr->type_id() == kObjectTypeTensorType) {
400     auto tensor_info = abstract_tensor->GetValueTrack();
401     if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) {
402       MS_LOG(ERROR) << "tensor info is invalid.";
403       return RET_ERROR;
404     }
405     auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
406     MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed");
407     if (tensor_value->Size() >= kTensorListMinSize) {
408       data_info->data_.resize(tensor_value->Size());
409       if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) !=
410           EOK) {
411         MS_LOG(ERROR) << "memcpy data failed.";
412         return RET_ERROR;
413       }
414     }
415   }
416   return RET_OK;
417 }
418 
RemoveIfDepend(const CNodePtr & cnode)419 int RemoveIfDepend(const CNodePtr &cnode) {
420   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
421   bool has_depend = false;
422   std::vector<AnfNodePtr> inputs;
423   inputs.clear();
424 
425   inputs.emplace_back(cnode->input(0));
426   for (size_t i = 1; i < cnode->inputs().size(); ++i) {
427     AnfNodePtr inputNode = cnode->input(i);
428     MS_CHECK_TRUE_MSG(inputNode != nullptr, RET_NULL_PTR, "inputNode is nullptr");
429     if (!inputNode->isa<CNode>()) {
430       inputs.emplace_back(cnode->input(i));
431       continue;
432     }
433     auto depend_node = utils::cast<CNodePtr>(inputNode);
434     MS_CHECK_TRUE_MSG(depend_node != nullptr, RET_NULL_PTR, "depend_node is nullptr");
435     auto value_node = depend_node->input(0)->cast<ValueNodePtr>();
436     MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid.");
437     if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) {
438       has_depend = true;
439       bool mask_out = (depend_node->inputs().size() == opt::kInputSizeThree);
440       for (size_t j = 1; j < depend_node->inputs().size(); ++j) {
441         AnfNodePtr depend_input_node = depend_node->input(j);
442         MS_CHECK_TRUE_MSG(depend_input_node != nullptr, RET_NULL_PTR, "depend_input_node is nullptr");
443         if (depend_input_node->isa<CNode>()) {
444           inputs.emplace_back(depend_input_node);
445           if (mask_out) {
446             break;
447           }
448         }
449       }
450     } else {
451       inputs.emplace_back(cnode->input(i));
452     }
453   }
454   if (has_depend) {
455     cnode->set_inputs(inputs);
456   }
457   return RET_OK;
458 }
459 
RemoveIfMakeTuple(const CNodePtr & cnode)460 int RemoveIfMakeTuple(const CNodePtr &cnode) {
461   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
462   bool has_make_tuple = false;
463   std::vector<AnfNodePtr> inputs;
464   inputs.clear();
465 
466   inputs.emplace_back(cnode->input(0));
467   for (size_t i = 1; i < cnode->inputs().size(); ++i) {
468     AnfNodePtr input_node = cnode->input(i);
469     MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "input_node is nullptr");
470     if (!input_node->isa<CNode>()) {
471       inputs.emplace_back(cnode->input(i));
472       continue;
473     }
474     auto make_tuple_node = utils::cast<CNodePtr>(input_node);
475     MS_CHECK_TRUE_MSG(make_tuple_node != nullptr, RET_NULL_PTR, "make_tuple_node is nullptr");
476     auto value_node = make_tuple_node->input(0)->cast<ValueNodePtr>();
477     MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid.");
478     if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) ||
479                                            opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) {
480       has_make_tuple = true;
481       for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) {
482         inputs.emplace_back(make_tuple_node->input(j));
483       }
484     } else {
485       inputs.emplace_back(cnode->input(i));
486     }
487   }
488   if (has_make_tuple) {
489     cnode->set_inputs(inputs);
490   }
491   return RET_OK;
492 }
493 }  // namespace lite
494 }  // namespace mindspore
495