1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <algorithm>
19 #include <memory>
20 #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
21 #include "src/common/log_adapter.h"
22 #include "src/common/utils.h"
23 #include "tools/common/graph_util.h"
24 #include "include/errorcode.h"
25 #include "schema/inner/model_generated.h"
26 #include "src/common/log_util.h"
27
28 namespace mindspore {
29 namespace lite {
IsUsing(schema::MetaGraphT * graph,const uint32_t & tensor_idx)30 bool SubgraphTensorPass::IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
31 for (const auto &node : graph->nodes) {
32 if (IsContain<uint32_t>(node->inputIndex, tensor_idx)) {
33 return true;
34 }
35 if (IsContain<uint32_t>(node->outputIndex, tensor_idx)) {
36 return true;
37 }
38 }
39 for (const auto &subgraph : graph->subGraph) {
40 if (IsContain<uint32_t>(subgraph->inputIndices, tensor_idx)) {
41 return true;
42 }
43 if (IsContain<uint32_t>(subgraph->outputIndices, tensor_idx)) {
44 return true;
45 }
46 }
47 return false;
48 }
49
UpdateTensorIdx(schema::MetaGraphT * graph,const uint32_t & tensor_idx)50 STATUS SubgraphTensorPass::UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
51 for (const auto &subgraph : graph->subGraph) {
52 UpdateVec<uint32_t>(&(subgraph->inputIndices), tensor_idx);
53 UpdateVec<uint32_t>(&(subgraph->outputIndices), tensor_idx);
54 }
55 for (const auto &node : graph->nodes) {
56 UpdateVec<uint32_t>(&(node->inputIndex), tensor_idx);
57 UpdateVec<uint32_t>(&(node->outputIndex), tensor_idx);
58 }
59 UpdateVec<uint32_t>(&(graph->inputIndex), tensor_idx);
60 UpdateVec<uint32_t>(&(graph->outputIndex), tensor_idx);
61 return RET_OK;
62 }
63
RemoveUselessTensors(schema::MetaGraphT * graph)64 STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) {
65 for (auto it = graph->allTensors.begin(); it != graph->allTensors.end();) {
66 uint32_t idx = it - graph->allTensors.begin();
67 if (IsUsing(graph, idx)) {
68 it++;
69 } else {
70 it = graph->allTensors.erase(it);
71 UpdateTensorIdx(graph, idx);
72 }
73 }
74 return RET_OK;
75 }
76
SyncMainGraphInputAndOutput(schema::MetaGraphT * graph)77 STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) {
78 MS_ASSERT(graph->subGraph.size() > 0);
79 graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end());
80 return RET_OK;
81 }
82
Run(schema::MetaGraphT * graph)83 STATUS SubgraphTensorPass::Run(schema::MetaGraphT *graph) {
84 CHECK_NULL_RETURN(graph);
85
86 int ret = RemoveUselessTensors(graph);
87 if (ret != RET_OK) {
88 MS_LOG(ERROR) << "RemoveUselessTensors failed, ret: " << ret;
89 return ret;
90 }
91
92 ret = SetSubgraphTensorIndices(graph);
93 if (ret != RET_OK) {
94 MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
95 return ret;
96 }
97
98 ret = SyncMainGraphInputAndOutput(graph);
99 if (ret != RET_OK) {
100 MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
101 return ret;
102 }
103
104 return RET_OK;
105 }
106 } // namespace lite
107 } // namespace mindspore
108