1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <algorithm>
19 #include <memory>
20 #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
21 #include "src/common/log_adapter.h"
22 #include "src/common/utils.h"
23 #include "tools/common/graph_util.h"
24 #include "include/errorcode.h"
25 #include "schema/inner/model_generated.h"
26 #include "src/common/log_util.h"
27
28 namespace mindspore {
29 namespace lite {
IsUsing(const schema::MetaGraphT & graph,const uint32_t & tensor_idx)30 bool SubgraphTensorPass::IsUsing(const schema::MetaGraphT &graph, const uint32_t &tensor_idx) {
31 for (const auto &node : graph.nodes) {
32 if (IsContain<uint32_t>(node->inputIndex, tensor_idx)) {
33 return true;
34 }
35 if (IsContain<uint32_t>(node->outputIndex, tensor_idx)) {
36 return true;
37 }
38 }
39 for (const auto &subgraph : graph.subGraph) {
40 if (IsContain<uint32_t>(subgraph->inputIndices, tensor_idx)) {
41 return true;
42 }
43 if (IsContain<uint32_t>(subgraph->outputIndices, tensor_idx)) {
44 return true;
45 }
46 }
47 return false;
48 }
49
UpdateTensorIdx(schema::MetaGraphT * graph,const uint32_t & tensor_idx)50 void SubgraphTensorPass::UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
51 for (const auto &subgraph : graph->subGraph) {
52 UpdateVec<uint32_t>(&(subgraph->inputIndices), tensor_idx);
53 UpdateVec<uint32_t>(&(subgraph->outputIndices), tensor_idx);
54 }
55 for (const auto &node : graph->nodes) {
56 UpdateVec<uint32_t>(&(node->inputIndex), tensor_idx);
57 UpdateVec<uint32_t>(&(node->outputIndex), tensor_idx);
58 }
59 UpdateVec<uint32_t>(&(graph->inputIndex), tensor_idx);
60 UpdateVec<uint32_t>(&(graph->outputIndex), tensor_idx);
61 }
62
RemoveUselessTensors(schema::MetaGraphT * graph)63 void SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) {
64 for (auto it = graph->allTensors.begin(); it != graph->allTensors.end();) {
65 uint32_t idx = static_cast<uint32_t>(it - graph->allTensors.begin());
66 if (IsUsing(*graph, idx)) {
67 it++;
68 } else {
69 it = graph->allTensors.erase(it);
70 UpdateTensorIdx(graph, idx);
71 }
72 }
73 }
74
SyncMainGraphInputAndOutput(const schema::MetaGraphT & graph)75 void SubgraphTensorPass::SyncMainGraphInputAndOutput(const schema::MetaGraphT &graph) {
76 MS_ASSERT(graph.subGraph.size() > 0);
77 graph.subGraph[0]->inputIndices.assign(graph.inputIndex.begin(), graph.inputIndex.end());
78 }
79
Run(schema::MetaGraphT * graph)80 STATUS SubgraphTensorPass::Run(schema::MetaGraphT *graph) {
81 CHECK_NULL_RETURN(graph);
82
83 RemoveUselessTensors(graph);
84
85 SetSubgraphTensorIndices(graph);
86
87 SyncMainGraphInputAndOutput(*graph);
88
89 return RET_OK;
90 }
91 } // namespace lite
92 } // namespace mindspore
93