1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/common/graph_util.h"
18 #include <algorithm>
19 #include <functional>
20 #include <ctime>
21 #include <utility>
22 #include <set>
23 #include "schema/inner/model_generated.h"
24 #include "tools/common/tensor_util.h"
25 #include "tools/converter/quantizer/bitpacking.h"
26 #include "tools/common/node_util.h"
27 #include "src/common/log_adapter.h"
28 #include "src/common/utils.h"
29 #include "tools/converter/ops/ops_def.h"
30 #include "nnacl/op_base.h"
31
32 namespace mindspore {
33 namespace lite {
34 namespace {
35 enum QuantBitNum { QuantBitNum_INT8 = 8, QuantBitNum_INT16 = 16 };
36 const int kZeroPointGap = 128;
37 } // namespace
SetFuncGraphOutput(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & outputs)38 int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs) {
39 if (graph == nullptr || outputs.empty()) {
40 MS_LOG(DEBUG) << "Input graph is nullptr or outputs is empty";
41 return RET_INPUT_PARAM_INVALID;
42 }
43 if (outputs.size() == 1) {
44 graph->set_output(outputs.front(), false);
45 return RET_OK;
46 }
47 auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
48 if (make_tuple_prim_ptr == nullptr) {
49 MS_LOG(DEBUG) << "new MakeTuple failed";
50 return lite::RET_NULL_PTR;
51 }
52 auto make_tuple_cnode = graph->NewCNode(make_tuple_prim_ptr, outputs);
53 if (make_tuple_prim_ptr == nullptr) {
54 MS_LOG(DEBUG) << "new cnode failed";
55 return lite::RET_NULL_PTR;
56 }
57 make_tuple_cnode->set_fullname_with_scope("return tuple");
58 graph->set_output(make_tuple_cnode, false);
59 return RET_OK;
60 }
61
GetSimpleOpCopyer()62 OpDefCopyer GetSimpleOpCopyer() {
63 return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
64 std::unique_ptr<CNodeT> newCNode = std::make_unique<CNodeT>();
65 if (newCNode == nullptr) {
66 return nullptr;
67 }
68
69 newCNode->name = inCNode->name;
70 newCNode->quantType = inCNode->quantType;
71 newCNode->primitive = std::make_unique<schema::PrimitiveT>();
72 newCNode->primitive->value.type = inCNode->primitive->value.type;
73 return newCNode;
74 };
75 }
76
GetInputNodeIdx(const schema::MetaGraphT & graphT,const size_t & nodeIdx,const int inputIndexIdx)77 std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
78 return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
79 }
80
GetInputNodeIdx(const schema::MetaGraphT & graphT,const CNodeT & node,const int inputIndexIdx)81 std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) {
82 std::vector<uint32_t> inputIndexes;
83 if (inputIndexIdx == -1) {
84 inputIndexes = node.inputIndex;
85 } else {
86 MS_ASSERT(node.inputIndex.size() > inputIndexIdx);
87 inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
88 }
89 std::set<size_t> inputNodeIdx;
90 for (uint32_t inputIdx : inputIndexes) {
91 auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
92 inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
93 }
94 std::vector<size_t> ret;
95 ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
96 return ret;
97 }
98
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const size_t & nodeIdx,const int outputIndexIdx)99 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
100 const int outputIndexIdx) {
101 return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
102 }
103
ReplaceOutput(const uint32_t & old_index,const uint32_t & new_index,schema::MetaGraphT * graphT)104 void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
105 std::replace_if(
106 std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
107 [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
108
109 for (auto &subGraph : graphT->subGraph) {
110 std::replace_if(
111 std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
112 [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
113 }
114 }
115
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const CNodeT & node,const int outputIndexIdx)116 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
117 std::vector<uint32_t> outputIndexes;
118 if (outputIndexIdx == -1) {
119 outputIndexes = node.outputIndex;
120 } else {
121 MS_ASSERT(node.outputIndex.size() > outputIndexIdx);
122 outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
123 }
124 std::set<size_t> outputNodeIdx;
125 for (uint32_t outputIdx : outputIndexes) {
126 auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
127 outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
128 }
129 std::vector<size_t> ret;
130 ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
131 return ret;
132 }
133
GetLinkedPreIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)134 std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
135 std::vector<size_t> preNodeIdx;
136 for (size_t i = 0; i < graphT.nodes.size(); i++) {
137 auto &oldNode = graphT.nodes.at(i);
138 if (oldNode == nullptr) {
139 continue;
140 }
141 auto outputIndexes = oldNode->outputIndex;
142 if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
143 preNodeIdx.emplace_back(i);
144 }
145 }
146 return preNodeIdx;
147 }
148
GetLinkedPostIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)149 std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
150 std::vector<size_t> postNodeIdx;
151 for (size_t i = 0; i < graphT.nodes.size(); i++) {
152 auto &oldNode = graphT.nodes.at(i);
153 if (oldNode == nullptr) {
154 continue;
155 }
156 auto inputIndexes = oldNode->inputIndex;
157 if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
158 postNodeIdx.emplace_back(i);
159 }
160 }
161 return postNodeIdx;
162 }
163
IsolateNode(schema::MetaGraphT * graphT,CNodeT * node)164 STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
165 MS_ASSERT(graphT != nullptr);
166 MS_ASSERT(node != nullptr);
167 size_t nodeIdx = 0;
168 for (size_t i = 0; i < graphT->nodes.size(); i++) {
169 auto &inNode = graphT->nodes.at(i);
170 MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
171 if (inNode->name == node->name) {
172 nodeIdx = i;
173 break;
174 }
175 }
176 auto inputTensorIdxes = node->inputIndex;
177 auto outputTensorIdxes = node->outputIndex;
178 if (inputTensorIdxes.empty()) {
179 MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
180 return RET_ERROR;
181 }
182 if (outputTensorIdxes.size() != 1) {
183 MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
184 << "should has 1 output, in fact: " << outputTensorIdxes.size();
185 return RET_ERROR;
186 }
187 auto inDataTensorIdx = inputTensorIdxes.front();
188 auto outDataTensorIdx = outputTensorIdxes.front();
189
190 MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
191 ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
192
193 // find poseNode
194 auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
195 for (auto postNodeIdx : postNodeIdxes) {
196 MS_ASSERT(graphT->nodes.size() > postNodeIdx);
197 auto &postNode = graphT->nodes.at(postNodeIdx);
198 MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
199 for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
200 if (*iter == outDataTensorIdx) {
201 *iter = inDataTensorIdx;
202 break;
203 }
204 }
205 }
206
207 RemoveTensor(graphT, outputTensorIdxes);
208 node->inputIndex.clear();
209 node->outputIndex.clear();
210
211 return RET_OK;
212 }
213
IsolateOneWayNode(schema::MetaGraphT * graph,size_t subGraphIdx,size_t nodeIdx,bool removeTensor)214 STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
215 MS_ASSERT(graph != nullptr);
216 return IsolateOneWayNode(graph, nodeIdx, removeTensor);
217 }
218
IsolateOneWayNode(schema::MetaGraphT * graphT,size_t nodeIdx,bool removeTensor)219 STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
220 MS_ASSERT(graphT != nullptr);
221 if (graphT->nodes.size() <= nodeIdx) {
222 MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
223 return RET_PARAM_INVALID;
224 }
225 CNodeT *node = graphT->nodes.at(nodeIdx).get();
226 if (node == nullptr) {
227 MS_LOG(ERROR) << "node is null";
228 return RET_NULL_PTR;
229 }
230 auto inputTensorIdxes = node->inputIndex;
231 auto outputTensorIdxes = node->outputIndex;
232 auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
233 if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
234 MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
235 return RET_ERROR;
236 }
237 if (inputTensorIdxes.empty()) {
238 MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
239 return RET_ERROR;
240 }
241 auto inDataTensorIdx = inputTensorIdxes.front();
242 if (!outputTensorIdxes.empty()) {
243 auto outDataTensorIdx = outputTensorIdxes.front();
244 MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
245 MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
246 ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
247
248 // find poseNode
249 auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
250 for (auto postNodeIdx : postNodeIdxes) {
251 MS_ASSERT(graphT->nodes.size() > postNodeIdx);
252 auto &postNode = graphT->nodes.at(postNodeIdx);
253 MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
254 for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
255 if (*iter == outDataTensorIdx) {
256 *iter = inDataTensorIdx;
257 break;
258 }
259 }
260 }
261 }
262
263 if (removeTensor) {
264 // now all node's outputTensors are useless
265 // remove all node's outputTensors
266 auto status = RemoveTensor(graphT, outputTensorIdxes);
267 if (status != RET_OK) {
268 MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
269 return RET_ERROR;
270 }
271 }
272 node->inputIndex.clear();
273 node->outputIndex.clear();
274 return RET_OK;
275 }
276
IsolateOneWayNode(schema::MetaGraphT * graphT,CNodeT * node,bool removeTensor)277 STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) {
278 MS_ASSERT(graphT != nullptr);
279 MS_ASSERT(node != nullptr);
280 bool isSubNode = false;
281 size_t nodeIdx = 0;
282 for (size_t i = 0; i < graphT->nodes.size(); i++) {
283 auto &inNode = graphT->nodes.at(i);
284 MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
285 if (inNode->name == node->name) {
286 isSubNode = true;
287 nodeIdx = i;
288 break;
289 }
290 }
291 if (!isSubNode) {
292 MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
293 return RET_PARAM_INVALID;
294 } else {
295 return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
296 }
297 }
298
RemoveTensor(schema::MetaGraphT * graphT,std::vector<uint32_t> toDeleteTensorIdxes,bool forceDelete)299 STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
300 MS_ASSERT(graphT != nullptr);
301 for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
302 uint32_t deleteIdx = *iter;
303 if (!forceDelete) {
304 if (GetRefCount(graphT, deleteIdx) > 1) {
305 iter++;
306 continue;
307 }
308 }
309 // update graph input indices
310 for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
311 if (*gInIdx > deleteIdx) {
312 (*gInIdx)--;
313 }
314 }
315 // update graph output indices
316 for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
317 if (*gOutIdx > deleteIdx) {
318 (*gOutIdx)--;
319 }
320 }
321
322 for (auto &subgraph : graphT->subGraph) {
323 // update subgraph input indices
324 for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
325 if (*gInIdx > deleteIdx) {
326 (*gInIdx)--;
327 }
328 }
329 // update subgraph output indices
330 for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
331 if (*gOutIdx > deleteIdx) {
332 (*gOutIdx)--;
333 }
334 }
335 // update subgraph output indices
336 for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
337 if (*idx > deleteIdx) {
338 (*idx)--;
339 }
340 }
341 }
342
343 // update nodes indexes
344 for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
345 // update nodes input indexes
346 UpdateNodeIndex((*node_iter).get(), deleteIdx);
347 }
348 // update deleteTensorIdx
349 for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
350 if (*selfIt > deleteIdx) {
351 (*selfIt)--;
352 }
353 }
354 graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
355 iter = toDeleteTensorIdxes.erase(iter);
356 }
357 return RET_OK;
358 }
359
UpdateNodeIndex(CNodeT * node,uint32_t deleteIdx)360 STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
361 MS_ASSERT(node != nullptr);
362 for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
363 if (*inIdxIt == deleteIdx) {
364 inIdxIt = node->inputIndex.erase(inIdxIt);
365 } else {
366 if (*inIdxIt > deleteIdx) {
367 (*inIdxIt)--;
368 }
369 inIdxIt++;
370 }
371 }
372 // update nodes output indexes
373 for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
374 if (*outIdxIt == deleteIdx) {
375 outIdxIt = node->outputIndex.erase(outIdxIt);
376 } else {
377 if (*outIdxIt > deleteIdx) {
378 (*outIdxIt)--;
379 }
380 outIdxIt++;
381 }
382 }
383 return RET_OK;
384 }
385
AddTensor2Node(schema::MetaGraphT * graphT,uint32_t nodeIdx,std::unique_ptr<TensorT> tensor,InsertPlace place)386 STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
387 InsertPlace place) {
388 if (nodeIdx >= graphT->nodes.size()) {
389 MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
390 return RET_PARAM_INVALID;
391 }
392 graphT->allTensors.emplace_back(std::move(tensor));
393 uint32_t newTensorIdx = graphT->allTensors.size() - 1;
394 auto node = graphT->nodes.at(nodeIdx).get();
395 MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
396 if (place == kBefore) {
397 node->inputIndex.emplace_back(newTensorIdx);
398 } else {
399 node->outputIndex.emplace_back(newTensorIdx);
400 }
401 return RET_OK;
402 }
403
ReplaceTensorOfNode(schema::MetaGraphT * graphT,uint32_t nodeIdx,uint32_t inTensorIdx,std::unique_ptr<TensorT> tensor)404 STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx,
405 std::unique_ptr<TensorT> tensor) {
406 MS_ASSERT(graphT != nullptr);
407 if (nodeIdx >= graphT->nodes.size()) {
408 MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
409 return RET_PARAM_INVALID;
410 }
411 auto node = graphT->nodes.at(nodeIdx).get();
412 MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
413 if (inTensorIdx >= graphT->allTensors.size()) {
414 MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx;
415 return RET_PARAM_INVALID;
416 }
417 if (!IsContain(node->inputIndex, inTensorIdx)) {
418 MS_LOG(ERROR) << "inTensorIdx(" << inTensorIdx << ") is not a inputIdx of node(" << nodeIdx << ")";
419 return RET_PARAM_INVALID;
420 }
421 graphT->allTensors.at(inTensorIdx).swap(tensor);
422 return RET_OK;
423 }
424
DoBitPack(const int & bit_num,schema::TensorT * tensor_input)425 int DoBitPack(const int &bit_num, schema::TensorT *tensor_input) {
426 if (bit_num > 0 && bit_num < 8) {
427 std::vector<int8_t> origin_data(tensor_input->data.size());
428 auto status = memcpy_s(origin_data.data(), origin_data.size() * sizeof(int8_t), tensor_input->data.data(),
429 tensor_input->data.size() * sizeof(uint8_t));
430 if (status != EOK) {
431 MS_LOG(ERROR) << "memcpy failed. " << status;
432 return RET_ERROR;
433 }
434 std::vector<uint8_t> pack_data{};
435 BitPack::BitPacking<int8_t, uint8_t>(bit_num, origin_data, &pack_data);
436 tensor_input->data.resize(pack_data.size() * sizeof(uint8_t));
437 status = memcpy_s(tensor_input->data.data(), tensor_input->data.size() * sizeof(uint8_t), pack_data.data(),
438 pack_data.size() * sizeof(uint8_t));
439 if (status != EOK) {
440 MS_LOG(ERROR) << "memcpy_s failed. " << status;
441 return RET_ERROR;
442 }
443 } else if (bit_num > QuantBitNum_INT8 && bit_num < QuantBitNum_INT16) {
444 auto shape_size =
445 std::accumulate(tensor_input->dims.begin(), tensor_input->dims.end(), size_t(1), std::multiplies<size_t>());
446 std::vector<int16_t> origin_data(shape_size);
447 auto status = memcpy_s(origin_data.data(), origin_data.size() * sizeof(int16_t), tensor_input->data.data(),
448 tensor_input->data.size() * sizeof(uint8_t));
449 if (status != EOK) {
450 MS_LOG(ERROR) << "memcpy failed. " << status;
451 return RET_ERROR;
452 }
453 std::vector<uint16_t> pack_data{};
454 BitPack::BitPacking<int16_t, uint16_t>(bit_num, origin_data, &pack_data);
455 tensor_input->data.resize(pack_data.size() * sizeof(uint16_t));
456 status = memcpy_s(tensor_input->data.data(), tensor_input->data.size() * sizeof(uint8_t), pack_data.data(),
457 pack_data.size() * sizeof(uint16_t));
458 if (status != EOK) {
459 MS_LOG(ERROR) << "memcpy_s failed. " << status;
460 return RET_ERROR;
461 }
462 }
463 return RET_OK;
464 }
465
InsertNode(schema::MetaGraphT * graphT,uint32_t existNodeIdx,InsertPlace place,size_t inoutIndex,std::unique_ptr<CNodeT> toAddNode,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)466 NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex,
467 std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, int *insert_num,
468 const OpDefCopyer &opDefCopyer) {
469 MS_ASSERT(graphT != nullptr);
470 MS_ASSERT(errorCode != nullptr);
471 if (existNodeIdx >= graphT->nodes.size()) {
472 MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx;
473 return graphT->nodes.end();
474 }
475 auto node_iter = graphT->nodes.begin() + existNodeIdx;
476 MS_ASSERT(node_iter != graphT->nodes.begin());
477 MS_ASSERT((*node_iter) != nullptr);
478 return InsertNode(graphT, node_iter, place, inoutIndex, std::move(toAddNode), errorCode, insert_num);
479 }
480
InsertNode(schema::MetaGraphT * graphT,NodeIter existNodeIter,InsertPlace place,size_t inoutIndexIdx,std::unique_ptr<CNodeT> toAddNode,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)481 NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx,
482 std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, int *insert_num,
483 const OpDefCopyer &opDefCopyer) {
484 MS_ASSERT(graphT != nullptr);
485 MS_ASSERT(errorCode != nullptr);
486 if (place == kBefore) {
487 return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, insert_num,
488 opDefCopyer);
489 } else if (place == kAfter) {
490 return InsertNodeAfter(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, insert_num,
491 opDefCopyer);
492 } else {
493 MS_LOG(ERROR) << "Invalid InsertPlace : " << place;
494 return graphT->nodes.end();
495 }
496 }
497
InsertNodeBefore(schema::MetaGraphT * graphT,NodeIter existNodeIter,size_t inputIndexIdx,std::unique_ptr<CNodeT> toAddNodeIn,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)498 NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx,
499 std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, int *insert_num,
500 const OpDefCopyer &opDefCopyer) {
501 MS_ASSERT(graphT != nullptr);
502 MS_ASSERT(errorCode != nullptr);
503 auto &existNode = *existNodeIter;
504 MS_ASSERT(existNode != nullptr);
505 MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx);
506 MS_ASSERT(toAddNodeIn != nullptr);
507 auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx);
508 MS_ASSERT(graphT->allTensors.size() > preTensorIdx);
509
510 auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode), inputIndexIdx);
511 size_t insert_node_num = preNodeIdxes.empty() ? 1 : preNodeIdxes.size();
512 std::vector<std::unique_ptr<CNodeT>> toAddNodes;
513 for (size_t i = 0; i < insert_node_num; ++i) {
514 auto &preTensor = graphT->allTensors.at(preTensorIdx);
515 MS_ASSERT(preTensor != nullptr);
516 auto toAddTensor = CopyTensorDefT(preTensor);
517 if (toAddTensor == nullptr) {
518 *errorCode = RET_NULL_PTR;
519 MS_LOG(ERROR) << "Copy Tensor failed";
520 return graphT->nodes.end();
521 }
522 toAddTensor->nodeType = NodeType_CNode;
523 toAddTensor->refCount = 0;
524 toAddTensor->data.clear();
525 MS_ASSERT(toAddNodeIn->primitive != nullptr);
526 if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
527 auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
528 MS_ASSERT(prim != nullptr);
529 if (prim->src_t == TypeId::kNumberTypeUInt8) {
530 if (preTensor->dataType == TypeId::kNumberTypeUInt8) {
531 toAddTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
532 } else {
533 preTensor->quantParams.front()->zeroPoint += kZeroPointGap;
534 }
535 } else if (prim->dst_t == TypeId::kNumberTypeUInt8) {
536 if (preTensor->dataType == TypeId::kNumberTypeInt8) {
537 toAddTensor->quantParams.front()->zeroPoint += kZeroPointGap;
538 } else {
539 preTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
540 }
541 }
542 preTensor->dataType = prim->src_t;
543 toAddTensor->dataType = prim->dst_t;
544 }
545 graphT->allTensors.emplace_back(std::move(toAddTensor));
546 size_t toAddTensorIdx = graphT->allTensors.size() - 1;
547 auto toAddNode = opDefCopyer(toAddNodeIn.get());
548 if (toAddNode == nullptr) {
549 MS_LOG(ERROR) << "copy toAddNodeIn failed";
550 *errorCode = RET_NULL_PTR;
551 return graphT->nodes.end();
552 }
553 if (!preNodeIdxes.empty()) {
554 toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
555 }
556 toAddNode->inputIndex.clear();
557 toAddNode->inputIndex.push_back(preTensorIdx);
558 toAddNode->outputIndex.clear();
559 toAddNode->outputIndex.push_back(toAddTensorIdx);
560 for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
561 if (*iter == preTensorIdx) {
562 *iter = toAddTensorIdx;
563 break;
564 }
565 }
566 toAddNodes.emplace_back(std::move(toAddNode));
567 }
568 for (auto &toAddNode : toAddNodes) {
569 existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
570 existNodeIter++;
571 *insert_num += 1;
572 }
573 *errorCode = RET_OK;
574 return existNodeIter;
575 }
576
InsertNodeAfter(schema::MetaGraphT * graphT,NodeIter existNodeIter,size_t outputIndexIdx,std::unique_ptr<schema::CNodeT> toAddNodeIn,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)577 NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx,
578 std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, int *insert_num,
579 const OpDefCopyer &opDefCopyer) {
580 MS_ASSERT(graphT != nullptr);
581 MS_ASSERT(errorCode != nullptr);
582 auto &existNode = *existNodeIter;
583 MS_ASSERT(existNode != nullptr);
584 MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx);
585 MS_ASSERT(toAddNodeIn != nullptr);
586 auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx);
587 MS_ASSERT(graphT->allTensors.size() > postTensorIdx);
588 auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode), outputIndexIdx);
589 bool is_output_index = IsContain(graphT->outputIndex, postTensorIdx);
590 size_t insert_node_num = (postNodeIdxes.empty() || is_output_index) ? postNodeIdxes.size() + 1 : postNodeIdxes.size();
591 bool has_insert_for_graph_out = postNodeIdxes.empty() || is_output_index;
592 std::vector<std::unique_ptr<schema::CNodeT>> toAddNodes;
593 for (size_t i = 0; i < insert_node_num; ++i) {
594 auto &postTensor = graphT->allTensors.at(postTensorIdx);
595 MS_ASSERT(postTensor != nullptr);
596 auto toAddTensor = CopyTensorDefT(postTensor);
597 if (toAddTensor == nullptr) {
598 MS_LOG(ERROR) << "Copy TensorT failed";
599 *errorCode = RET_NULL_PTR;
600 return graphT->nodes.end();
601 }
602 toAddTensor->nodeType = NodeType_CNode;
603 MS_ASSERT(toAddNodeIn->primitive != nullptr);
604 if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
605 auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
606 MS_ASSERT(prim != nullptr);
607 if (prim->dst_t == TypeId::kNumberTypeUInt8) {
608 if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
609 postTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
610 } else {
611 toAddTensor->quantParams.front()->zeroPoint += kZeroPointGap;
612 }
613 } else if (prim->src_t == TypeId::kNumberTypeUInt8) {
614 if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
615 toAddTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
616 } else {
617 postTensor->quantParams.front()->zeroPoint += kZeroPointGap;
618 }
619 }
620 postTensor->dataType = prim->src_t;
621 toAddTensor->dataType = prim->dst_t;
622 }
623 graphT->allTensors.emplace_back(std::move(toAddTensor));
624 size_t toAddTensorIdx = graphT->allTensors.size() - 1;
625 auto toAddNode = opDefCopyer(toAddNodeIn.get());
626 if (toAddNode == nullptr) {
627 MS_LOG(ERROR) << "copy toAddNodeIn failed";
628 *errorCode = RET_NULL_PTR;
629 return graphT->nodes.end();
630 }
631 toAddNode->inputIndex.clear();
632 toAddNode->inputIndex.push_back(postTensorIdx);
633 toAddNode->outputIndex.clear();
634 toAddNode->outputIndex.push_back(toAddTensorIdx);
635 if (!postNodeIdxes.empty()) {
636 toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
637 }
638 if (has_insert_for_graph_out) {
639 ReplaceOutput(postTensorIdx, toAddTensorIdx, graphT);
640 has_insert_for_graph_out = false;
641 } else {
642 auto &postNode = graphT->nodes.at(postNodeIdxes[is_output_index ? i - 1 : i]);
643 for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
644 if (*iter == postTensorIdx) {
645 *iter = toAddTensorIdx;
646 }
647 }
648 }
649 toAddNodes.emplace_back(std::move(toAddNode));
650 }
651 for (auto &toAddNode : toAddNodes) {
652 existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
653 existNodeIter++;
654 *insert_num += 1;
655 }
656 *errorCode = RET_OK;
657 return existNodeIter;
658 }
659
ValidateFileStr(const std::string & modelFile,const std::string & fileType)660 STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType) {
661 if (modelFile.size() > fileType.size() && modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
662 return RET_OK;
663 } else {
664 return RET_ERROR;
665 }
666 }
667
GetModelName(const std::string & modelFile)668 std::string GetModelName(const std::string &modelFile) {
669 std::string modelName = modelFile;
670 modelName = modelName.substr(modelName.find_last_of('/') + 1);
671 modelName = modelName.substr(0, modelName.find_last_of('.'));
672 return modelName;
673 }
674
SetSubgraphTensorIndices(schema::MetaGraphT * meta_graphT)675 int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
676 for (auto &subgraph : meta_graphT->subGraph) {
677 std::vector<uint32_t> subgraph_indices{};
678 subgraph_indices.assign(subgraph->inputIndices.begin(), subgraph->inputIndices.end());
679 subgraph_indices.assign(subgraph->outputIndices.begin(), subgraph->outputIndices.end());
680 for (auto &node_idx : subgraph->nodeIndices) {
681 auto &node = meta_graphT->nodes.at(node_idx);
682 for (auto &input_idx : node->inputIndex) {
683 if (IsContain(subgraph_indices, input_idx)) {
684 continue;
685 } else {
686 subgraph_indices.push_back(input_idx);
687 }
688 }
689 for (auto &output_idx : node->outputIndex) {
690 if (IsContain(subgraph_indices, output_idx)) {
691 continue;
692 } else {
693 subgraph_indices.push_back(output_idx);
694 }
695 }
696 }
697 subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
698 }
699 return RET_OK;
700 }
701
GetTransposePerm(MetaGraphT * graph,const std::unique_ptr<CNodeT> & cnode)702 std::vector<int> GetTransposePerm(MetaGraphT *graph, const std::unique_ptr<CNodeT> &cnode) {
703 MS_ASSERT(graph != nullptr && cnode != nullptr);
704 std::vector<int> perm;
705 if (cnode->primitive->value.type != schema::PrimitiveType_Transpose) {
706 return perm;
707 }
708 if (cnode->inputIndex.size() < 2) {
709 MS_LOG(ERROR) << "transpose node input size is less than 2.";
710 return perm;
711 }
712 MS_ASSERT(cnode->outputIndex.at(1) < graph->allTensors.size());
713 auto &perm_tensor = graph->allTensors.at(cnode->inputIndex.at(1));
714 if (perm_tensor->data.empty()) {
715 return perm;
716 }
717 MS_ASSERT(perm_tensor->dims.size() != 0);
718 perm.resize(perm_tensor->dims[0]);
719 if (memcpy_s(perm.data(), perm_tensor->dims[0] * sizeof(int), perm_tensor->data.data(),
720 perm_tensor->dims[0] * sizeof(int)) != EOK) {
721 MS_LOG(ERROR) << "memcpy data failed.";
722 return {};
723 }
724 return perm;
725 }
726
727 namespace {
728 constexpr size_t kBitNumPerByte = 8;
729 }
730
BoolVectorToString(const std::vector<bool> & bool_vec)731 std::string BoolVectorToString(const std::vector<bool> &bool_vec) {
732 size_t size_in_byte = ceil(bool_vec.size() / kBitNumPerByte);
733 std::string str(size_in_byte, '\0');
734 auto iter = str.begin();
735 size_t shift = kBitNumPerByte;
736 for (bool bit : bool_vec) {
737 *iter |= bit << (shift - 1);
738 if (--shift == 0) {
739 iter++;
740 shift = kBitNumPerByte;
741 }
742 }
743 return str;
744 }
745
GetAbstractTensorDtype(const abstract::AbstractTensorPtr & tensor)746 TypeId GetAbstractTensorDtype(const abstract::AbstractTensorPtr &tensor) {
747 if (tensor == nullptr || tensor->element() == nullptr) {
748 MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
749 return kTypeUnknown;
750 }
751 auto type_ptr = tensor->element()->GetTypeTrack();
752 return type_ptr->type_id();
753 }
754
GetParameterDtype(const ParameterPtr & param_node)755 TypeId GetParameterDtype(const ParameterPtr ¶m_node) {
756 auto abstract_base = param_node->abstract();
757 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
758 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
759 return type_ptr->type_id();
760 }
761
UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr & func_graph)762 STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph) {
763 MS_ASSERT(func_graph != nullptr);
764 // update graph inputs dtype
765 size_t idx = 0;
766 for (auto &input : func_graph->get_inputs()) {
767 TypeId type = GetParameterDtype(input->cast<ParameterPtr>());
768 ConverterContext::GetInstance()->UpdateGraphInputDType(idx, type);
769 idx++;
770 }
771 // update graph outputs dtype
772 auto graph_return = func_graph->get_return();
773 idx = 0;
774 for (auto &input : graph_return->inputs()) {
775 if (input->isa<CNode>()) {
776 if (utils::isa<abstract::AbstractTuple>(input->abstract())) {
777 auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(input->abstract());
778 if (tuple == nullptr) {
779 MS_LOG(ERROR) << "tuple is nullptr";
780 return RET_ERROR;
781 }
782 for (const auto &tuple_item : tuple->elements()) {
783 TypeId type = GetAbstractTensorDtype(tuple_item->cast<abstract::AbstractTensorPtr>());
784 ConverterContext::GetInstance()->UpdateGraphOutputDType(idx, type);
785 idx++;
786 }
787 } else if (utils::isa<abstract::AbstractTensor>(input->abstract())) {
788 TypeId type = GetAbstractTensorDtype(input->abstract()->cast<abstract::AbstractTensorPtr>());
789 ConverterContext::GetInstance()->UpdateGraphOutputDType(idx, type);
790 idx++;
791 } else {
792 ConverterContext::GetInstance()->UpdateGraphOutputDType(idx, kTypeUnknown);
793 idx++;
794 }
795 }
796 }
797 return RET_OK;
798 }
799 } // namespace lite
800 } // namespace mindspore
801