1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/common/graph_util.h"
19 #include <algorithm>
20 #include <functional>
21 #include <ctime>
22 #include <utility>
23 #include <set>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "tools/common/meta_graph_utils.h"
27 #include "schema/inner/model_generated.h"
28 #include "tools/common/tensor_util.h"
29 #include "src/common/log_adapter.h"
30 #include "src/common/utils.h"
31 #include "nnacl/op_base.h"
32 #include "ops/make_tuple.h"
33 #include "tools/converter/converter_context.h"
34 #include "tools/optimizer/common/gllo_utils.h"
35 #include "tools/common/string_util.h"
36
37 namespace mindspore {
38 namespace lite {
39 namespace {
40 const int kZeroPointGap = 128;
41 constexpr size_t kTupleGetItemFirstInputIdx = 1;
42 constexpr size_t kDependInputNum = 3;
43 constexpr size_t kDependFirstInputIdx = 1;
44 constexpr size_t kSequenceCodeGetItemInputSize = 3;
45 constexpr size_t kSecondIndex = 1;
46 constexpr size_t kInvalidSize = SIZE_MAX;
47 constexpr auto kMakeTuple = "MakeTuple";
48 constexpr auto kMakeList = "make_list";
49 constexpr size_t kEncMaxLen = 16;
50 } // namespace
51
GetAbstractfromSequenceCodeGetItem(const CNodePtr & cnode,AbstractBasePtr * abstract,size_t * idx)52 static STATUS GetAbstractfromSequenceCodeGetItem(const CNodePtr &cnode, AbstractBasePtr *abstract, size_t *idx) {
53 MS_CHECK_TRUE_MSG(abstract != nullptr, lite::RET_ERROR, "Abstract is nullptr.");
54 MS_CHECK_TRUE_MSG(idx != nullptr, lite::RET_ERROR, "idx is nullptr.");
55 auto SequenceCode_inputs = cnode->inputs();
56 MS_CHECK_TRUE_MSG(SequenceCode_inputs.size() == kSequenceCodeGetItemInputSize, lite::RET_ERROR,
57 "The node must have 3 inputs!");
58 auto get_item_input_cnode = SequenceCode_inputs.at(kSecondIndex);
59 MS_CHECK_TRUE_MSG(get_item_input_cnode != nullptr, lite::RET_ERROR, "input node is nullptr.");
60
61 AbstractBasePtrList abstract_list;
62 if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
63 *idx = opt::GetTupleGetItemOutIndex(cnode);
64 if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
65 MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple, cnode name: "
66 << get_item_input_cnode->fullname_with_scope();
67 return lite::RET_ERROR;
68 }
69 auto input_node_abstract = utils::cast<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
70 abstract_list = input_node_abstract->elements();
71 } else {
72 *idx = opt::GetListGetItemOutIndex(cnode);
73 if (!mindspore::utils::isa<mindspore::abstract::AbstractListPtr>(get_item_input_cnode->abstract())) {
74 MS_LOG(ERROR) << "ListGetItem's abstract is not AbstractTuple, cnode name: "
75 << get_item_input_cnode->fullname_with_scope();
76 return lite::RET_ERROR;
77 }
78 auto input_node_abstract = utils::cast<abstract::AbstractListPtr>(get_item_input_cnode->abstract());
79 abstract_list = input_node_abstract->elements();
80 }
81
82 if (abstract_list.size() <= *idx) {
83 MS_LOG(ERROR) << "Abstract's size is smaller than expect";
84 return lite::RET_ERROR;
85 }
86 *abstract = abstract_list[*idx];
87 return lite::RET_OK;
88 }
89
GetShapeVectorFromParameter(const mindspore::ParameterPtr & param_node,std::vector<int64_t> * shape_vector)90 STATUS GetShapeVectorFromParameter(const mindspore::ParameterPtr ¶m_node, std::vector<int64_t> *shape_vector) {
91 MS_CHECK_TRUE_MSG(shape_vector != nullptr, RET_ERROR, "shape vector is nullptr.");
92 auto abstract_base = param_node->abstract();
93 if (abstract_base == nullptr) {
94 MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
95 return RET_ERROR;
96 }
97
98 if (!abstract_base->isa<abstract::AbstractTensor>()) {
99 MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name();
100 return lite::RET_ERROR;
101 }
102 auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
103 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
104 *shape_vector = abstract_tensor->shape()->shape();
105 return lite::RET_OK;
106 }
107
GetShapeVectorAndIdxFromCNode(const CNodePtr & cnode,std::vector<int64_t> * shape_vector,size_t * idx)108 STATUS GetShapeVectorAndIdxFromCNode(const CNodePtr &cnode, std::vector<int64_t> *shape_vector, size_t *idx) {
109 MS_CHECK_TRUE_MSG(shape_vector != nullptr, lite::RET_ERROR, "shape is nullptr");
110
111 AbstractBasePtr cnode_abstract = nullptr;
112 if ((opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) ||
113 (opt::CheckPrimitiveType(cnode, prim::kPrimListGetItem))) {
114 // idx is only used when cnode is type of kPrimTupleGetItem or kPrimListGetItem.
115 MS_CHECK_TRUE_MSG(idx != nullptr, lite::RET_ERROR, "idx is nullptr");
116 if (GetAbstractfromSequenceCodeGetItem(cnode, &cnode_abstract, idx) != lite::RET_OK) {
117 MS_LOG(ERROR) << "Get abstract from tuple get item failed.";
118 return lite::RET_ERROR;
119 }
120 } else {
121 cnode_abstract = cnode->abstract();
122 }
123 // the control flow model may be nullptr
124 if (cnode_abstract == nullptr) {
125 *shape_vector = std::vector<int64_t>();
126 return lite::RET_OK;
127 }
128 if (cnode_abstract->BuildShape() == mindspore::abstract::kNoShape) {
129 *shape_vector = std::vector<int64_t>();
130 return lite::RET_OK;
131 }
132 if (!utils::isa<mindspore::abstract::AbstractTensorPtr>(cnode_abstract)) {
133 MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope();
134 return lite::RET_ERROR;
135 }
136 auto cnode_abstract_tensor = cnode_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
137 CHECK_NULL_RETURN(cnode_abstract_tensor);
138 if (!utils::isa<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
139 MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope();
140 return lite::RET_ERROR;
141 }
142 auto shape_ptr = utils::cast<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
143 CHECK_NULL_RETURN(shape_ptr);
144 if (shape_ptr->shape().empty()) {
145 MS_LOG(WARNING) << "Shape is empty " << cnode->fullname_with_scope();
146 }
147 *shape_vector = shape_ptr->shape();
148 return lite::RET_OK;
149 }
150
GetCNodeOrParameterShapeVec(const AnfNodePtr & anf_node,std::vector<int> * shape)151 STATUS GetCNodeOrParameterShapeVec(const AnfNodePtr &anf_node, std::vector<int> *shape) {
152 auto int64_t_to_int_func = [](int64_t x) -> int { return static_cast<int>(x); };
153 std::vector<int64_t> in_shape;
154 if (anf_node->isa<CNode>()) {
155 auto status = GetShapeVectorAndIdxFromCNode(anf_node->cast<CNodePtr>(), &in_shape);
156 if (status != RET_OK) {
157 MS_LOG(ERROR) << "Get shape from CNode failed.";
158 return status;
159 }
160 } else if (anf_node->isa<Parameter>()) {
161 auto param_node = anf_node->cast<ParameterPtr>();
162 auto status = GetShapeVectorFromParameter(param_node, &in_shape);
163 if (status != RET_OK) {
164 MS_LOG(ERROR) << "Get shape from Parameter failed.";
165 return status;
166 }
167 } else {
168 MS_LOG(ERROR) << "Node type is not recognized.";
169 return RET_ERROR;
170 }
171 shape->resize(in_shape.size());
172 (void)std::transform(in_shape.begin(), in_shape.end(), shape->begin(), int64_t_to_int_func);
173 return RET_OK;
174 }
175
TraceOutput(const AnfNodePtr & node,std::vector<std::pair<AnfNodePtr,int64_t>> * outputs,std::vector<std::string> * output_names,std::vector<std::vector<int64_t>> * output_dims)176 static STATUS TraceOutput(const AnfNodePtr &node, std::vector<std::pair<AnfNodePtr, int64_t>> *outputs,
177 std::vector<std::string> *output_names, std::vector<std::vector<int64_t>> *output_dims) {
178 static size_t iter = 0;
179 CHECK_NULL_RETURN(node);
180 if (utils::isa<ParameterPtr>(node) || utils::isa<ValueNode>(node)) {
181 MS_LOG(INFO) << "Name of graph output value node is : " << node->fullname_with_scope();
182 outputs->emplace_back(std::pair<AnfNodePtr, int64_t>(node, 0));
183 output_names->push_back(node->fullname_with_scope());
184 output_dims->emplace_back(std::vector<int64_t>());
185 return lite::RET_OK;
186 }
187 AnfNodePtr cur_node = node;
188 CNodePtr pre_node = nullptr;
189 while (cur_node->isa<CNode>() && IsPrimitiveCNode(cur_node, prim::kPrimTupleGetItem)) {
190 auto tmp = cur_node->cast<CNodePtr>();
191 CHECK_NULL_RETURN(tmp);
192 pre_node = tmp;
193 cur_node = tmp->input(kTupleGetItemFirstInputIdx);
194 CHECK_NULL_RETURN(cur_node);
195 }
196 auto cnode = cur_node->cast<CNodePtr>();
197 CHECK_NULL_RETURN(cnode);
198 std::string name = GetCNodeFuncName(cnode);
199 iter++;
200 MS_LOG(INFO) << "Func name of cnode " << name << " ,trace iter: " << iter;
201 if ((name == kMakeTuple) || (name == kMakeList)) {
202 for (size_t i = 1; i < cnode->size(); ++i) {
203 auto make_tuple_input = cnode->input(i);
204 if (opt::CheckPrimitiveType(make_tuple_input, prim::kPrimUpdateState) ||
205 opt::CheckPrimitiveType(make_tuple_input, prim::kPrimLoad)) {
206 continue;
207 }
208 if (TraceOutput(make_tuple_input, outputs, output_names, output_dims) != lite::RET_OK) {
209 MS_LOG(ERROR) << "The input[ " << i << "]"
210 << " trace output failed, name: " << name;
211 return lite::RET_ERROR;
212 }
213 }
214 } else if (name == prim::kPrimDepend->name()) {
215 if (cnode->size() < kDependInputNum) {
216 MS_LOG(ERROR) << "Length of inputs is " << cnode->size() << ", which is less than three.";
217 return lite::RET_ERROR;
218 }
219 if (TraceOutput(cnode->input(kDependFirstInputIdx), outputs, output_names, output_dims) != lite::RET_OK) {
220 MS_LOG(ERROR) << "Depend node trace output failed.";
221 return lite::RET_ERROR;
222 }
223 } else {
224 MS_LOG(INFO) << "Name of graph output node is " << cnode->fullname_with_scope();
225 std::string node_name = cnode->fullname_with_scope();
226 std::vector<int64_t> dims;
227 size_t idx = -1;
228 STATUS ret;
229 if (pre_node != nullptr && IsPrimitiveCNode(pre_node, prim::kPrimTupleGetItem)) {
230 ret = GetShapeVectorAndIdxFromCNode(pre_node, &dims, &idx);
231 node_name = node_name + "_" + std::to_string(idx);
232 } else {
233 ret = GetShapeVectorAndIdxFromCNode(cnode, &dims, &idx);
234 }
235 if (ret != lite::RET_OK) {
236 MS_LOG(ERROR) << "Get node shape failed.";
237 return lite::RET_ERROR;
238 }
239 outputs->emplace_back(std::pair<AnfNodePtr, int64_t>(cnode, idx));
240 output_names->emplace_back(node_name);
241 output_dims->emplace_back(dims);
242 }
243 return lite::RET_OK;
244 }
245
SetFuncGraphOutput(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & outputs)246 int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs) {
247 if (graph == nullptr || outputs.empty()) {
248 MS_LOG(DEBUG) << "Input graph is nullptr or outputs is empty";
249 return RET_INPUT_PARAM_INVALID;
250 }
251 if (outputs.size() == 1) {
252 graph->set_output(outputs.front(), false);
253 return RET_OK;
254 }
255 auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
256 if (make_tuple_prim_ptr == nullptr) {
257 MS_LOG(DEBUG) << "new MakeTuple failed";
258 return lite::RET_NULL_PTR;
259 }
260 auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim();
261 MS_CHECK_TRUE_MSG(make_tuple_prim_c != nullptr, lite::RET_NULL_PTR, "make_tuple_prim_c is nullptr");
262 auto make_tuple_cnode = graph->NewCNode(make_tuple_prim_c, outputs);
263 if (make_tuple_cnode == nullptr) {
264 MS_LOG(DEBUG) << "new cnode failed";
265 return lite::RET_NULL_PTR;
266 }
267 make_tuple_cnode->set_fullname_with_scope("return tuple");
268 graph->set_output(make_tuple_cnode, false);
269 return RET_OK;
270 }
271
GetSimpleOpCopyer()272 OpDefCopyer GetSimpleOpCopyer() {
273 return [](const CNodeT &inCNode) -> std::unique_ptr<CNodeT> {
274 std::unique_ptr<CNodeT> newCNode = std::make_unique<CNodeT>();
275 if (newCNode == nullptr) {
276 return nullptr;
277 }
278
279 newCNode->name = inCNode.name;
280 newCNode->quantType = inCNode.quantType;
281 newCNode->primitive = std::make_unique<schema::PrimitiveT>();
282 newCNode->primitive->value.type = inCNode.primitive->value.type;
283 return newCNode;
284 };
285 }
286
AddTensor2Node(schema::MetaGraphT * graphT,uint32_t nodeIdx,std::unique_ptr<TensorT> tensor,InsertPlace place)287 STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
288 InsertPlace place) {
289 MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
290 if (nodeIdx >= graphT->nodes.size()) {
291 MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
292 return RET_PARAM_INVALID;
293 }
294 graphT->allTensors.emplace_back(std::move(tensor));
295 uint32_t newTensorIdx = graphT->allTensors.size() - 1;
296 auto node = graphT->nodes.at(nodeIdx).get();
297 MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
298 if (place == kBefore) {
299 node->inputIndex.emplace_back(newTensorIdx);
300 } else {
301 node->outputIndex.emplace_back(newTensorIdx);
302 }
303 return RET_OK;
304 }
305
ReplaceTensorOfNode(schema::MetaGraphT * graphT,uint32_t nodeIdx,uint32_t inTensorIdx,std::unique_ptr<TensorT> tensor)306 STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx,
307 std::unique_ptr<TensorT> tensor) {
308 MS_ASSERT(graphT != nullptr);
309 if (nodeIdx >= graphT->nodes.size()) {
310 MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
311 return RET_PARAM_INVALID;
312 }
313 auto node = graphT->nodes.at(nodeIdx).get();
314 MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
315 if (inTensorIdx >= graphT->allTensors.size()) {
316 MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx;
317 return RET_PARAM_INVALID;
318 }
319 if (!IsContain(node->inputIndex, inTensorIdx)) {
320 MS_LOG(ERROR) << "inTensorIdx(" << inTensorIdx << ") is not a inputIdx of node(" << nodeIdx << ")";
321 return RET_PARAM_INVALID;
322 }
323 graphT->allTensors.at(inTensorIdx).swap(tensor);
324 return RET_OK;
325 }
326
InsertNode(schema::MetaGraphT * graphT,uint32_t existNodeIdx,InsertPlace place,size_t inoutIndex,std::unique_ptr<CNodeT> toAddNode,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)327 NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex,
328 std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, int *insert_num,
329 const OpDefCopyer &opDefCopyer) {
330 MS_ASSERT(graphT != nullptr);
331 MS_ASSERT(errorCode != nullptr);
332 if (existNodeIdx >= graphT->nodes.size()) {
333 MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx;
334 return graphT->nodes.end();
335 }
336 auto node_iter = graphT->nodes.begin() + existNodeIdx;
337 MS_ASSERT(node_iter != graphT->nodes.begin());
338 MS_ASSERT((*node_iter) != nullptr);
339 return InsertNode(graphT, node_iter, place, inoutIndex, std::move(toAddNode), errorCode, insert_num);
340 }
341
InsertNode(schema::MetaGraphT * graphT,NodeIter existNodeIter,InsertPlace place,size_t inoutIndexIdx,std::unique_ptr<CNodeT> toAddNode,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)342 NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx,
343 std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, int *insert_num,
344 const OpDefCopyer &opDefCopyer) {
345 MS_ASSERT(graphT != nullptr);
346 MS_ASSERT(errorCode != nullptr);
347 if (place == kBefore) {
348 return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, insert_num,
349 opDefCopyer);
350 } else if (place == kAfter) {
351 return InsertNodeAfter(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, insert_num,
352 opDefCopyer);
353 } else {
354 MS_LOG(ERROR) << "Invalid InsertPlace : " << place;
355 return graphT->nodes.end();
356 }
357 }
358
InsertNodeBefore(schema::MetaGraphT * graphT,NodeIter existNodeIter,size_t inputIndexIdx,std::unique_ptr<CNodeT> toAddNodeIn,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)359 NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx,
360 std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, int *insert_num,
361 const OpDefCopyer &opDefCopyer) {
362 MS_ASSERT(graphT != nullptr);
363 MS_ASSERT(errorCode != nullptr);
364 auto &existNode = *existNodeIter;
365 MS_ASSERT(existNode != nullptr);
366 MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx);
367 MS_ASSERT(toAddNodeIn != nullptr);
368 auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx);
369 MS_ASSERT(graphT->allTensors.size() > preTensorIdx);
370
371 auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode), inputIndexIdx);
372 size_t insert_node_num = preNodeIdxes.empty() ? 1 : preNodeIdxes.size();
373 std::vector<std::unique_ptr<CNodeT>> toAddNodes;
374 for (size_t i = 0; i < insert_node_num; ++i) {
375 auto &preTensor = graphT->allTensors.at(preTensorIdx);
376 MS_ASSERT(preTensor != nullptr);
377 auto toAddTensor = CopyTensorDefT(preTensor);
378 if (toAddTensor == nullptr) {
379 *errorCode = RET_NULL_PTR;
380 MS_LOG(ERROR) << "Copy Tensor failed";
381 return graphT->nodes.end();
382 }
383 toAddTensor->nodeType = NodeType_CNode;
384 toAddTensor->refCount = 0;
385 toAddTensor->data.clear();
386 MS_ASSERT(toAddNodeIn->primitive != nullptr);
387 if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
388 auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
389 MS_ASSERT(prim != nullptr);
390 if (prim->src_t == TypeId::kNumberTypeUInt8) {
391 if (preTensor->dataType == TypeId::kNumberTypeUInt8) {
392 toAddTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
393 } else {
394 preTensor->quantParams.front()->zeroPoint += kZeroPointGap;
395 }
396 } else if (prim->dst_t == TypeId::kNumberTypeUInt8) {
397 if (preTensor->dataType == TypeId::kNumberTypeInt8) {
398 toAddTensor->quantParams.front()->zeroPoint += kZeroPointGap;
399 } else {
400 preTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
401 }
402 }
403 preTensor->dataType = prim->src_t;
404 toAddTensor->dataType = prim->dst_t;
405 }
406 graphT->allTensors.emplace_back(std::move(toAddTensor));
407 size_t toAddTensorIdx = graphT->allTensors.size() - 1;
408 auto toAddNode = opDefCopyer(*toAddNodeIn);
409 if (toAddNode == nullptr) {
410 MS_LOG(ERROR) << "copy toAddNodeIn failed";
411 *errorCode = RET_NULL_PTR;
412 return graphT->nodes.end();
413 }
414 if (!preNodeIdxes.empty()) {
415 toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
416 }
417 toAddNode->inputIndex.clear();
418 toAddNode->inputIndex.push_back(preTensorIdx);
419 toAddNode->outputIndex.clear();
420 toAddNode->outputIndex.push_back(toAddTensorIdx);
421 for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
422 if (*iter == preTensorIdx) {
423 *iter = toAddTensorIdx;
424 break;
425 }
426 }
427 toAddNodes.emplace_back(std::move(toAddNode));
428 }
429 for (auto &toAddNode : toAddNodes) {
430 existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
431 existNodeIter++;
432 *insert_num += 1;
433 }
434 *errorCode = RET_OK;
435 return existNodeIter;
436 }
437
InsertNodeAfter(schema::MetaGraphT * graphT,NodeIter existNodeIter,size_t outputIndexIdx,std::unique_ptr<schema::CNodeT> toAddNodeIn,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)438 NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx,
439 std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, int *insert_num,
440 const OpDefCopyer &opDefCopyer) {
441 MS_ASSERT(graphT != nullptr);
442 MS_ASSERT(errorCode != nullptr);
443 auto &existNode = *existNodeIter;
444 MS_ASSERT(existNode != nullptr);
445 MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx);
446 MS_ASSERT(toAddNodeIn != nullptr);
447 auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx);
448 MS_ASSERT(graphT->allTensors.size() > postTensorIdx);
449 auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode), outputIndexIdx);
450 bool is_output_index = IsContain(graphT->outputIndex, postTensorIdx);
451 size_t insert_node_num = (postNodeIdxes.empty() || is_output_index) ? postNodeIdxes.size() + 1 : postNodeIdxes.size();
452 bool has_insert_for_graph_out = postNodeIdxes.empty() || is_output_index;
453 std::vector<std::unique_ptr<schema::CNodeT>> toAddNodes;
454 for (size_t i = 0; i < insert_node_num; ++i) {
455 auto &postTensor = graphT->allTensors.at(postTensorIdx);
456 MS_ASSERT(postTensor != nullptr);
457 auto toAddTensor = CopyTensorDefT(postTensor);
458 if (toAddTensor == nullptr) {
459 MS_LOG(ERROR) << "Copy TensorT failed";
460 *errorCode = RET_NULL_PTR;
461 return graphT->nodes.end();
462 }
463 toAddTensor->nodeType = NodeType_CNode;
464 MS_ASSERT(toAddNodeIn->primitive != nullptr);
465 if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
466 auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
467 MS_ASSERT(prim != nullptr);
468 if (prim->dst_t == TypeId::kNumberTypeUInt8) {
469 if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
470 postTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
471 } else {
472 toAddTensor->quantParams.front()->zeroPoint += kZeroPointGap;
473 }
474 } else if (prim->src_t == TypeId::kNumberTypeUInt8) {
475 if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
476 toAddTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
477 } else {
478 postTensor->quantParams.front()->zeroPoint += kZeroPointGap;
479 }
480 }
481 postTensor->dataType = prim->src_t;
482 toAddTensor->dataType = prim->dst_t;
483 }
484 graphT->allTensors.emplace_back(std::move(toAddTensor));
485 size_t toAddTensorIdx = graphT->allTensors.size() - 1;
486 auto toAddNode = opDefCopyer(*toAddNodeIn);
487 if (toAddNode == nullptr) {
488 MS_LOG(ERROR) << "copy toAddNodeIn failed";
489 *errorCode = RET_NULL_PTR;
490 return graphT->nodes.end();
491 }
492 toAddNode->inputIndex.clear();
493 toAddNode->inputIndex.push_back(postTensorIdx);
494 toAddNode->outputIndex.clear();
495 toAddNode->outputIndex.push_back(toAddTensorIdx);
496 if (!postNodeIdxes.empty()) {
497 toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
498 }
499 if (has_insert_for_graph_out) {
500 ReplaceOutput(postTensorIdx, toAddTensorIdx, graphT);
501 has_insert_for_graph_out = false;
502 } else {
503 auto &postNode = graphT->nodes.at(postNodeIdxes[is_output_index ? i - 1 : i]);
504 for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
505 if (*iter == postTensorIdx) {
506 *iter = toAddTensorIdx;
507 }
508 }
509 }
510 toAddNodes.emplace_back(std::move(toAddNode));
511 }
512 for (auto &toAddNode : toAddNodes) {
513 existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
514 existNodeIter++;
515 *insert_num += 1;
516 }
517 *errorCode = RET_OK;
518 return existNodeIter;
519 }
520
ValidateFileStr(const std::string & modelFile,const std::string & fileType)521 STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType) {
522 if (modelFile.size() > fileType.size() && modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
523 return RET_OK;
524 } else {
525 return RET_ERROR;
526 }
527 }
528
SetSubgraphTensorIndices(schema::MetaGraphT * meta_graphT)529 void SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
530 if (meta_graphT == nullptr) {
531 MS_LOG(ERROR) << "meta_graphT is nullptr.";
532 return;
533 }
534 for (auto &subgraph : meta_graphT->subGraph) {
535 std::vector<uint32_t> subgraph_indices{};
536 subgraph_indices.assign(subgraph->inputIndices.begin(), subgraph->inputIndices.end());
537 subgraph_indices.assign(subgraph->outputIndices.begin(), subgraph->outputIndices.end());
538 for (auto &node_idx : subgraph->nodeIndices) {
539 auto &node = meta_graphT->nodes.at(node_idx);
540 for (auto &input_idx : node->inputIndex) {
541 if (IsContain(subgraph_indices, input_idx)) {
542 continue;
543 } else {
544 subgraph_indices.push_back(input_idx);
545 }
546 }
547 for (auto &output_idx : node->outputIndex) {
548 if (IsContain(subgraph_indices, output_idx)) {
549 continue;
550 } else {
551 subgraph_indices.push_back(output_idx);
552 }
553 }
554 }
555 subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
556 }
557 }
558
GetModelName(const std::string & modelFile)559 std::string GetModelName(const std::string &modelFile) {
560 std::string modelName = modelFile;
561 modelName = modelName.substr(modelName.find_last_of('/') + 1);
562 modelName = modelName.substr(0, modelName.find_last_of('.'));
563 return modelName;
564 }
565
GetTransposePerm(MetaGraphT * graph,const std::unique_ptr<CNodeT> & cnode)566 std::vector<int> GetTransposePerm(MetaGraphT *graph, const std::unique_ptr<CNodeT> &cnode) {
567 MS_ASSERT(graph != nullptr && cnode != nullptr);
568 std::vector<int> perm;
569 if (cnode->primitive->value.type != schema::PrimitiveType_Transpose) {
570 return perm;
571 }
572 if (cnode->inputIndex.size() < 2) {
573 MS_LOG(ERROR) << "transpose node input size is less than 2.";
574 return perm;
575 }
576 MS_ASSERT(cnode->outputIndex.at(1) < graph->allTensors.size());
577 auto &perm_tensor = graph->allTensors.at(cnode->inputIndex.at(1));
578 if (perm_tensor->data.empty()) {
579 return perm;
580 }
581 MS_ASSERT(perm_tensor->dims.size() != 0);
582 perm.resize(perm_tensor->dims[0]);
583 if (memcpy_s(perm.data(), perm_tensor->dims[0] * sizeof(int), perm_tensor->data.data(),
584 perm_tensor->dims[0] * sizeof(int)) != EOK) {
585 MS_LOG(ERROR) << "memcpy data failed.";
586 return {};
587 }
588 return perm;
589 }
590
GetAbstractTensorDtype(const abstract::AbstractTensorPtr & tensor)591 TypeId GetAbstractTensorDtype(const abstract::AbstractTensorPtr &tensor) {
592 if (tensor == nullptr || tensor->element() == nullptr) {
593 MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
594 return kTypeUnknown;
595 }
596 auto type_ptr = tensor->element()->GetTypeTrack();
597 MS_CHECK_TRUE_MSG(type_ptr != nullptr, kTypeUnknown, "type_ptr is nullptr");
598 return type_ptr->type_id();
599 }
600
GetParameterDtype(const ParameterPtr & param_node)601 TypeId GetParameterDtype(const ParameterPtr ¶m_node) {
602 MS_CHECK_TRUE_MSG(param_node != nullptr, kTypeUnknown, "param_node is nullptr");
603 auto abstract_base = param_node->abstract();
604 MS_CHECK_TRUE_MSG(abstract_base != nullptr, kTypeUnknown, "abstract_base is nullptr");
605 auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
606 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, kTypeUnknown, "Cast to abstract tensor failed!");
607 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
608 MS_CHECK_TRUE_MSG(type_ptr != nullptr, kTypeUnknown, "type_ptr is nullptr");
609 return type_ptr->type_id();
610 }
611
UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr & func_graph)612 STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph) {
613 MS_ASSERT(func_graph != nullptr);
614 // update graph inputs dtype
615 size_t idx = 0;
616 for (auto &input : func_graph->get_inputs()) {
617 TypeId type = GetParameterDtype(input->cast<ParameterPtr>());
618 ConverterInnerContext::GetInstance()->UpdateGraphInputDType(idx, type);
619 idx++;
620 }
621 // update graph outputs dtype
622 auto graph_return = func_graph->get_return();
623 idx = 0;
624 for (auto &input : graph_return->inputs()) {
625 if (input->isa<CNode>()) {
626 if (utils::isa<abstract::AbstractTuple>(input->abstract())) {
627 auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(input->abstract());
628 if (tuple == nullptr) {
629 MS_LOG(ERROR) << "tuple is nullptr";
630 return RET_ERROR;
631 }
632 for (const auto &tuple_item : tuple->elements()) {
633 if (utils::isa<abstract::AbstractTuple>(tuple_item)) {
634 continue;
635 }
636 TypeId type = GetAbstractTensorDtype(tuple_item->cast<abstract::AbstractTensorPtr>());
637 ConverterInnerContext::GetInstance()->UpdateGraphOutputDType(idx, type);
638 idx++;
639 }
640 } else if (utils::isa<abstract::AbstractTensor>(input->abstract())) {
641 TypeId type = GetAbstractTensorDtype(input->abstract()->cast<abstract::AbstractTensorPtr>());
642 ConverterInnerContext::GetInstance()->UpdateGraphOutputDType(idx, type);
643 idx++;
644 } else {
645 ConverterInnerContext::GetInstance()->UpdateGraphOutputDType(idx, kTypeUnknown);
646 idx++;
647 }
648 }
649 }
650 return RET_OK;
651 }
652
GetFuncGraphOutputsInfo(const FuncGraphPtr & func_graph,std::vector<std::pair<AnfNodePtr,int64_t>> * outputs,std::vector<std::string> * output_names,std::vector<std::vector<int64_t>> * output_dims)653 STATUS GetFuncGraphOutputsInfo(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, int64_t>> *outputs,
654 std::vector<std::string> *output_names, std::vector<std::vector<int64_t>> *output_dims) {
655 MS_CHECK_TRUE_MSG(outputs != nullptr, lite::RET_ERROR, "Output is nullptr.");
656 MS_CHECK_TRUE_MSG(output_names != nullptr, lite::RET_ERROR, "Output names is nullptr.");
657 MS_CHECK_TRUE_MSG(output_dims != nullptr, lite::RET_ERROR, "Output dims is nullptr.");
658 AnfNodePtr return_input = func_graph->output();
659 CHECK_NULL_RETURN(return_input);
660 if (TraceOutput(return_input, outputs, output_names, output_dims) != lite::RET_OK) {
661 MS_LOG(ERROR) << "Trace output failed.";
662 return lite::RET_ERROR;
663 }
664 return lite::RET_OK;
665 }
666
UpdateGraphOutputName(schema::MetaGraphT * meta_graph)667 STATUS UpdateGraphOutputName(schema::MetaGraphT *meta_graph) {
668 MS_CHECK_TRUE_MSG(meta_graph != nullptr, RET_NULL_PTR, "meta_graph is nullptr");
669 auto output_names = ConverterInnerContext::GetInstance()->GetGraphOutputTensorNames();
670 if (output_names.size() > meta_graph->outputIndex.size()) {
671 MS_LOG(ERROR) << "the num of setting output_names is greater than actual, " << output_names.size() << " > "
672 << meta_graph->outputIndex.size() << ".";
673 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
674 return RET_ERROR;
675 }
676 for (size_t idx = 0; idx < output_names.size(); idx++) {
677 auto &tensor = meta_graph->allTensors.at(meta_graph->outputIndex.at(idx));
678 tensor->name = output_names.at(idx);
679 }
680 return RET_OK;
681 }
682
TransferMetaGraph(const schema::MetaGraphT & graph,void ** model_buf,size_t * size)683 int TransferMetaGraph(const schema::MetaGraphT &graph, void **model_buf, size_t *size) {
684 if (model_buf == nullptr) {
685 MS_LOG(ERROR) << "input model_buf invalid";
686 return RET_ERROR;
687 }
688 if (size == nullptr) {
689 MS_LOG(ERROR) << "input size invalid";
690 return RET_ERROR;
691 }
692
693 /* model_buf malloc here, free outside */
694 if (*model_buf != nullptr) {
695 MS_LOG(ERROR) << "input model_buf must be nullptr";
696 return RET_ERROR;
697 }
698 flatbuffers::FlatBufferBuilder builder(MAX_GRAPH_SIZE);
699 auto offset = schema::MetaGraph::Pack(builder, &graph);
700 builder.Finish(offset);
701 schema::FinishMetaGraphBuffer(builder, offset);
702 *size = builder.GetSize();
703 auto content = builder.GetBufferPointer();
704 if (content == nullptr) {
705 MS_LOG(ERROR) << "GetBufferPointer nullptr";
706 return RET_ERROR;
707 }
708 *model_buf = new (std::nothrow) char[*size];
709 if (*model_buf == nullptr) {
710 MS_LOG(ERROR) << "malloc model_buf failed";
711 return RET_ERROR;
712 }
713 return memcpy_s(*model_buf, *size, content, *size);
714 }
715
InitEncryptKey(const std::shared_ptr<ConverterPara> & param,unsigned char * encKey,size_t * keyLen)716 int InitEncryptKey(const std::shared_ptr<ConverterPara> ¶m, unsigned char *encKey, size_t *keyLen) {
717 if (!param->enable_encryption) {
718 return RET_OK;
719 }
720 if (param->encrypt_key.empty()) {
721 MS_LOG(ERROR) << "param->encrypt_key is empty.";
722 return RET_INPUT_PARAM_INVALID;
723 }
724 *keyLen = lite::Hex2ByteArray(param->encrypt_key, encKey, kEncMaxLen);
725 if (*keyLen != kEncMaxLen) {
726 MS_LOG(ERROR) << "enc_key must expressed in hexadecimal characters "
727 << " and only support AES-GCM method and the key length is " << kEncMaxLen;
728 return RET_INPUT_PARAM_INVALID;
729 }
730
731 return RET_OK;
732 }
733 } // namespace lite
734 } // namespace mindspore
735