1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/anf_exporter/fetch_content.h"
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 #include <unordered_map>
22 #include "tools/converter/quant_param_holder.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "tools/optimizer/common/format_utils.h"
26 #include "nnacl/op_base.h"
27
28 namespace mindspore {
29 namespace lite {
30 namespace {
31 constexpr int kNumWeightIndex = 2;
32 constexpr int kNumTransposePermSize = 4;
33 constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t);
34 static const std::unordered_map<int, int> TypeToTypeMap = {
35 {kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}};
GetShapeVectorFromStringTensor(const tensor::TensorPtr & tensor_info,ShapeVector * shape_vector,size_t * offset)36 STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) {
37 MS_ASSERT(tensor_info != nullptr && shape_vector != nullptr && offset != nullptr);
38 auto data_type = tensor_info->data_type();
39 if (data_type != kObjectTypeString) {
40 MS_LOG(ERROR) << "This function only used for string tensor.";
41 return RET_ERROR;
42 }
43 shape_vector->clear();
44 MS_CHECK_TRUE_MSG(tensor_info->data_c() != nullptr, RET_ERROR, "tensor_info->data_c() is nullptr");
45 auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
46 std::string shape_str;
47 std::string shape_size_str;
48 *offset = 0;
49 size_t cnt = 0;
50 for (; *offset < tensor_info->Size(); (*offset)++) {
51 if (tensor_data[*offset] == ',') {
52 (*offset)++;
53 break;
54 }
55 shape_size_str.push_back(tensor_data[*offset]);
56 }
57 if (*offset == 0) {
58 MS_LOG(ERROR) << "string tensor's dim size not found.";
59 return RET_ERROR;
60 }
61 size_t shape_size = std::atoi(shape_size_str.c_str());
62 MS_CHECK_TRUE_RET(shape_size != 0, RET_ERROR);
63 for (; *offset < tensor_info->Size(); (*offset)++) {
64 if (tensor_data[*offset] == ',') {
65 cnt++;
66 int64_t shape = 0;
67 try {
68 shape = std::stoi(shape_str);
69 } catch (const std::exception &e) {
70 MS_LOG(ERROR) << "Get shape failed: " << e.what();
71 return RET_ERROR;
72 }
73 shape_vector->push_back(shape);
74 shape_str.clear();
75 } else {
76 shape_str.push_back(tensor_data[*offset]);
77 }
78 if (cnt == shape_size) {
79 (*offset)++;
80 break;
81 }
82 }
83 if (shape_vector->empty()) {
84 MS_LOG(ERROR) << "string tensor's shape shouldn't be empty.";
85 return RET_ERROR;
86 }
87 return RET_OK;
88 }
89
GetDataTypeAndShape(const ParameterPtr & param_node,TypeId * data_type,ShapeVector * shape_vector)90 STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector) {
91 MS_ASSERT(param_node != nullptr && data_type != nullptr && shape_vector != nullptr);
92 auto abstract_base = param_node->abstract();
93 if (abstract_base == nullptr) {
94 MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
95 return RET_PARAM_INVALID;
96 }
97 if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
98 MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
99 return RET_INPUT_TENSOR_ERROR;
100 }
101 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
102 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
103 auto typePtr = abstract_tensor->element()->GetTypeTrack();
104 MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
105 *data_type = typePtr->type_id();
106 if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
107 MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
108 return RET_PARAM_INVALID;
109 }
110 *shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
111 return RET_OK;
112 }
113
FetchFromTensorValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)114 int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
115 bool train_flag, DataInfo *data_info) {
116 MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
117 auto valueAbstract = value_node->abstract();
118 MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
119 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
120 if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
121 MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
122 return RET_ERROR;
123 }
124 auto typePtr = abstract_tensor->element()->GetTypeTrack();
125 MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
126 data_info->data_type_ = typePtr->type_id();
127 auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
128 std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
129 data_info->shape_ = dims;
130 if (train_flag && dims.empty()) {
131 data_info->shape_ = {1};
132 }
133 auto value = value_node->value();
134 MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
135 auto data = value->cast<tensor::TensorPtr>();
136 MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is invalid");
137 data_info->data_.resize(data->Size());
138 if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
139 MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
140 return RET_ERROR;
141 }
142
143 // process weight tensor
144 if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
145 MS_LOG(ERROR) << "memcpy_s error.";
146 return RET_ERROR;
147 }
148 return RET_OK;
149 }
150
FetchFromInt32OrInt64ImmValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,DataInfo * data_info)151 int FetchFromInt32OrInt64ImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
152 MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
153 // data of int64 is converted to int32 here.
154 data_info->data_type_ = kNumberTypeInt32;
155 data_info->shape_ = {1};
156 data_info->data_.resize(sizeof(int32_t));
157 auto value = value_node->value();
158 MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
159 int real_data = opt::CastToInt(value).front();
160 if (memcpy_s(data_info->data_.data(), sizeof(int32_t), &real_data, sizeof(int32_t)) != EOK) {
161 MS_LOG(ERROR) << "memcpy_s failed";
162 return RET_MEMORY_FAILED;
163 }
164 return RET_OK;
165 }
166
FetchFromBoolImmValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,DataInfo * data_info)167 int FetchFromBoolImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
168 MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
169 data_info->data_type_ = kNumberTypeBool;
170 data_info->shape_ = {1};
171 data_info->data_.resize(sizeof(bool));
172 auto value = value_node->value();
173 MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
174 auto data = value->cast<mindspore::BoolImmPtr>();
175 MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is nullptr");
176 auto data_value = data->value();
177 if (memcpy_s(data_info->data_.data(), sizeof(bool), &data_value, sizeof(bool)) != EOK) {
178 MS_LOG(ERROR) << "memcpy_s failed";
179 return RET_MEMORY_FAILED;
180 }
181 return RET_OK;
182 }
183
FetchFromNumberValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,DataInfo * data_info)184 int FetchFromNumberValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
185 MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
186 data_info->data_type_ = kNumberTypeInt32;
187 data_info->shape_ = {1};
188 data_info->data_.resize(sizeof(int));
189 auto data = value_node->value()->cast<NumberPtr>();
190 MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberPtr failed");
191 int number_type = data->number_type();
192 if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) {
193 number_type = TypeToTypeMap.at(number_type);
194 }
195 if (memcpy_s(data_info->data_.data(), sizeof(int), &number_type, sizeof(int)) != EOK) {
196 MS_LOG(ERROR) << "memcpy_s failed";
197 return RET_MEMORY_FAILED;
198 }
199 return RET_OK;
200 }
201
FetchFromSequenceValue(const ValueNodePtr & value_node,const PrimitivePtr & primitive,DataInfo * data_info)202 int FetchFromSequenceValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
203 MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
204 auto value = value_node->value();
205 MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
206 std::vector<int32_t> shape;
207 auto value_seq = value->cast<ValueSequeuePtr>();
208 MS_CHECK_TRUE_MSG(value_seq != nullptr, RET_ERROR, "value_seq is nullptr");
209 if (!value_seq->value().empty()) {
210 if (value_seq->value().front()->type()->number_type() == kNumberTypeInt32 ||
211 value_seq->value().front()->type()->number_type() == kNumberTypeInt) {
212 shape = GetValue<std::vector<int>>(value);
213 } else if (value_seq->value().front()->type()->number_type() == kNumberTypeInt64) {
214 auto origin_value = GetValue<std::vector<int64_t>>(value);
215 std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(shape),
216 [](int64_t val) { return static_cast<int32_t>(val); });
217 } else {
218 MS_LOG(ERROR) << "Value type is ValueSequence is not integer.";
219 return RET_ERROR;
220 }
221 }
222 data_info->data_type_ = kNumberTypeInt32;
223 data_info->shape_ = {static_cast<int32_t>(shape.size())};
224 data_info->data_.resize(shape.size() * sizeof(int));
225 if (!shape.empty() && memcpy_s(data_info->data_.data(), shape.size() * sizeof(int32_t), shape.data(),
226 shape.size() * sizeof(int32_t)) != EOK) {
227 MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed.";
228 return RET_ERROR;
229 }
230 return RET_OK;
231 }
232 } // namespace
233
FetchFromDefaultParam(const ParameterPtr & param_node,const converter::FmkType & fmk_type,DataInfo * data_info)234 int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkType &fmk_type, DataInfo *data_info) {
235 MS_ASSERT(param_node != nullptr && data_info != nullptr);
236 ShapeVector shape_vector;
237 TypeId data_type = kTypeUnknown;
238 auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
239 if (status != RET_OK) {
240 MS_LOG(ERROR) << "get data type and shape from param node failed.";
241 return RET_ERROR;
242 }
243 data_info->data_type_ = data_type;
244 auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
245 size_t offset = 0;
246 if (!shape_vector.empty() && data_type == kObjectTypeString) {
247 status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
248 if (status != RET_OK) {
249 MS_LOG(ERROR) << "get shape vector from string tensor failed.";
250 return RET_ERROR;
251 }
252 }
253 std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
254 data_info->shape_ = dims;
255 if (tensor_info != nullptr && tensor_info->Size() != 0) {
256 if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
257 data_info->data_.resize(tensor_info->Size() - offset);
258 if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
259 static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
260 MS_LOG(ERROR) << "memcpy_s failed.";
261 return RET_ERROR;
262 }
263 }
264 }
265
266 data_info->format_ = NHWC;
267 return RET_OK;
268 }
269
FetchDataFromParameterNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)270 int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
271 DataInfo *data_info) {
272 MS_ASSERT(cnode != nullptr && data_info != nullptr);
273 auto param_node = cnode->input(index)->cast<ParameterPtr>();
274 MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "input node is not parameter node.");
275 if (FetchFromDefaultParam(param_node, fmk_type, data_info) != RET_OK) {
276 MS_LOG(ERROR) << "fetch information from default param failed.";
277 return RET_ERROR;
278 }
279 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
280 MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
281 if (prim->GetAttr(ops::kFormat) == nullptr && !param_node->has_default()) {
282 data_info->format_ = mindspore::NHWC;
283 }
284 if (prim->GetAttr(ops::kFormat) != nullptr) {
285 auto value = prim->GetAttr(ops::kFormat);
286 if (value->isa<mindspore::Int64Imm>()) {
287 data_info->format_ = GetValue<int64_t>(value);
288 }
289 }
290 QuantParamHolderPtr quant_param_holder =
291 prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
292 if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() &&
293 data_info->data_type_ == kNumberTypeInt8) {
294 data_info->enable_huffman_code_ = true;
295 }
296 data_info->node_type_ = NodeType_ValueNode;
297 return RET_OK;
298 }
299
FetchDataFromValueNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)300 int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
301 DataInfo *data_info) {
302 MS_ASSERT(cnode != nullptr && data_info != nullptr);
303 auto value_node = cnode->input(index)->cast<ValueNodePtr>();
304 MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "input node is not value node.");
305
306 auto value = value_node->value();
307 int ret = RET_OK;
308 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
309 MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "prim is nullptr");
310 if (value->isa<tensor::Tensor>()) {
311 ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info);
312 if (index == kNumWeightIndex && prim->GetAttr(ops::kFormat) != nullptr) {
313 data_info->format_ = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
314 }
315 } else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
316 ret = FetchFromInt32OrInt64ImmValue(value_node, prim, data_info);
317 } else if (value->isa<mindspore::BoolImm>()) {
318 ret = FetchFromBoolImmValue(value_node, prim, data_info);
319 } else if (value->isa<mindspore::ValueSequeue>()) {
320 ret = FetchFromSequenceValue(value_node, prim, data_info);
321 } else if (value->isa<Number>()) {
322 ret = FetchFromNumberValue(value_node, prim, data_info);
323 } else if (value->isa<FuncGraph>()) {
324 MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is func_graph";
325 return RET_NO_CHANGE;
326 } else if (value->isa<Monad>()) {
327 MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is Monad";
328 return RET_NO_CHANGE;
329 } else {
330 MS_LOG(ERROR) << "Not support value type , need add support.";
331 return RET_ERROR;
332 }
333 data_info->node_type_ = NodeType_ValueNode;
334 return ret;
335 }
336
SetFormatForCnode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)337 int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
338 DataInfo *data_info) {
339 data_info->format_ = mindspore::NHWC;
340 MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, RET_ERROR, "input is nullptr");
341 auto input_node_prim = GetValueNode<PrimitivePtr>((cnode->input(index)->cast<CNodePtr>()->input(0)));
342 MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "GetValueNode failed");
343 if (input_node_prim->GetAttr(ops::kFormat) != nullptr) {
344 auto value = input_node_prim->GetAttr(ops::kFormat);
345 if (value->isa<mindspore::Int64Imm>()) {
346 data_info->format_ = GetValue<int64_t>(value);
347 }
348 }
349 if (opt::CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
350 std::vector<int> perm;
351 if (opt::GetTransposePerm(cnode->input(index)->cast<CNodePtr>(), &perm) != RET_OK) {
352 return RET_ERROR;
353 }
354 if (perm.size() < kNumTransposePermSize) {
355 return RET_OK;
356 }
357 // NHWC to NCHW: perm is {0, 3, 1, 2}
358 // NCHW to NHWC: perm is {0, 2, 3, 1}
359 if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 &&
360 (data_info->format_ == NHWC || data_info->format_ == KHWC)) {
361 data_info->format_ = NCHW;
362 } else if (perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1 && data_info->format_ == NCHW) {
363 data_info->format_ = NHWC;
364 }
365 }
366 return RET_OK;
367 }
368
FetchDataFromCNode(const CNodePtr & cnode,size_t index,converter::FmkType fmk_type,bool train_flag,DataInfo * data_info)369 int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
370 DataInfo *data_info) {
371 MS_ASSERT(cnode != nullptr && data_info != nullptr);
372 auto abstract = opt::GetCNodeInputAbstract(cnode, index);
373 if (abstract == nullptr) {
374 MS_LOG(ERROR) << "Abstract cnode is nullptr.";
375 return RET_ERROR;
376 }
377 if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
378 MS_LOG(ERROR) << "Abstract should be anstract tensor.";
379 return RET_ERROR;
380 }
381 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
382 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
383 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
384 MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
385 if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
386 MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr.";
387 return RET_ERROR;
388 }
389 auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
390 std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
391 auto ret = SetFormatForCnode(cnode, index, fmk_type, train_flag, data_info);
392 if (ret != RET_OK) {
393 MS_LOG(ERROR) << "set format for cnode failed";
394 return RET_ERROR;
395 }
396 data_info->data_type_ = type_ptr->type_id();
397 data_info->shape_ = dims;
398 data_info->node_type_ = NodeType_CNode;
399 if (type_ptr->type_id() == kObjectTypeTensorType) {
400 auto tensor_info = abstract_tensor->GetValueTrack();
401 if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) {
402 MS_LOG(ERROR) << "tensor info is invalid.";
403 return RET_ERROR;
404 }
405 auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
406 MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed");
407 if (tensor_value->Size() >= kTensorListMinSize) {
408 data_info->data_.resize(tensor_value->Size());
409 if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) !=
410 EOK) {
411 MS_LOG(ERROR) << "memcpy data failed.";
412 return RET_ERROR;
413 }
414 }
415 }
416 return RET_OK;
417 }
418
RemoveIfDepend(const CNodePtr & cnode)419 int RemoveIfDepend(const CNodePtr &cnode) {
420 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
421 bool has_depend = false;
422 std::vector<AnfNodePtr> inputs;
423 inputs.clear();
424
425 inputs.emplace_back(cnode->input(0));
426 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
427 AnfNodePtr inputNode = cnode->input(i);
428 MS_CHECK_TRUE_MSG(inputNode != nullptr, RET_NULL_PTR, "inputNode is nullptr");
429 if (!inputNode->isa<CNode>()) {
430 inputs.emplace_back(cnode->input(i));
431 continue;
432 }
433 auto depend_node = utils::cast<CNodePtr>(inputNode);
434 MS_CHECK_TRUE_MSG(depend_node != nullptr, RET_NULL_PTR, "depend_node is nullptr");
435 auto value_node = depend_node->input(0)->cast<ValueNodePtr>();
436 MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid.");
437 if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) {
438 has_depend = true;
439 bool mask_out = (depend_node->inputs().size() == opt::kInputSizeThree);
440 for (size_t j = 1; j < depend_node->inputs().size(); ++j) {
441 AnfNodePtr depend_input_node = depend_node->input(j);
442 MS_CHECK_TRUE_MSG(depend_input_node != nullptr, RET_NULL_PTR, "depend_input_node is nullptr");
443 if (depend_input_node->isa<CNode>()) {
444 inputs.emplace_back(depend_input_node);
445 if (mask_out) {
446 break;
447 }
448 }
449 }
450 } else {
451 inputs.emplace_back(cnode->input(i));
452 }
453 }
454 if (has_depend) {
455 cnode->set_inputs(inputs);
456 }
457 return RET_OK;
458 }
459
RemoveIfMakeTuple(const CNodePtr & cnode)460 int RemoveIfMakeTuple(const CNodePtr &cnode) {
461 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
462 bool has_make_tuple = false;
463 std::vector<AnfNodePtr> inputs;
464 inputs.clear();
465
466 inputs.emplace_back(cnode->input(0));
467 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
468 AnfNodePtr input_node = cnode->input(i);
469 MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "input_node is nullptr");
470 if (!input_node->isa<CNode>()) {
471 inputs.emplace_back(cnode->input(i));
472 continue;
473 }
474 auto make_tuple_node = utils::cast<CNodePtr>(input_node);
475 MS_CHECK_TRUE_MSG(make_tuple_node != nullptr, RET_NULL_PTR, "make_tuple_node is nullptr");
476 auto value_node = make_tuple_node->input(0)->cast<ValueNodePtr>();
477 MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid.");
478 if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) ||
479 opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) {
480 has_make_tuple = true;
481 for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) {
482 inputs.emplace_back(make_tuple_node->input(j));
483 }
484 } else {
485 inputs.emplace_back(cnode->input(i));
486 }
487 }
488 if (has_make_tuple) {
489 cnode->set_inputs(inputs);
490 }
491 return RET_OK;
492 }
493 } // namespace lite
494 } // namespace mindspore
495