• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/train/optimizer/fusion/remove_redundant_tensor.h"
18 #include <map>
19 #include "src/common/log_adapter.h"
20 #include "nnacl/op_base.h"
21 
22 namespace mindspore {
23 namespace lite {
Run(schema::MetaGraphT * graph)24 STATUS RemoveRedundantTensor::Run(schema::MetaGraphT *graph) {
25   if (graph == nullptr) {
26     MS_LOG(ERROR) << "The graph is a nullptr.";
27     return RET_NULL_PTR;
28   }
29   std::map<uint32_t, uint32_t> index_map;
30   uint32_t index = 0;
31   auto graph_input_index = graph->inputIndex;
32   graph->inputIndex.clear();
33   for (auto input_index : graph_input_index) {
34     if (index_map.find(input_index) == index_map.end()) {
35       index_map[input_index] = index;
36       ++index;
37     }
38     graph->inputIndex.push_back(index_map[input_index]);
39   }
40   for (auto &node : graph->nodes) {
41     auto node_in_index = node->inputIndex;
42     node->inputIndex.clear();
43     for (auto in_index : node_in_index) {
44       if (index_map.find(in_index) == index_map.end()) {
45         index_map[in_index] = index;
46         ++index;
47       }
48       node->inputIndex.push_back(index_map[in_index]);
49     }
50     auto node_out_index = node->outputIndex;
51     node->outputIndex.clear();
52     for (auto out_index : node_out_index) {
53       if (index_map.find(out_index) == index_map.end()) {
54         index_map[out_index] = index;
55         ++index;
56       }
57       node->outputIndex.push_back(index_map[out_index]);
58     }
59   }
60   auto graph_output_index = graph->outputIndex;
61   graph->outputIndex.clear();
62   for (auto output_index : graph_output_index) {
63     if (index_map.find(output_index) == index_map.end()) {
64       index_map[output_index] = index;
65       ++index;
66     }
67     graph->outputIndex.push_back(index_map[output_index]);
68   }
69   std::vector<std::unique_ptr<mindspore::schema::TensorT>> old_tensors;
70   old_tensors.swap(graph->allTensors);
71   graph->allTensors.resize(index_map.size());
72   for (size_t i = 0; i < old_tensors.size(); ++i) {
73     if (index_map.find(i) == index_map.end()) {
74       continue;
75     }
76     graph->allTensors[index_map[i]].swap(old_tensors[i]);
77   }
78   if (!graph->subGraph.empty()) {
79     graph->subGraph[0]->inputIndices = graph->inputIndex;
80     graph->subGraph[0]->outputIndices = graph->outputIndex;
81     graph->subGraph[0]->tensorIndices = {};
82     for (uint32_t i = 0; i < index; ++i) {
83       graph->subGraph[0]->tensorIndices.push_back(i);
84     }
85   }
86   return RET_OK;
87 }
88 }  // namespace lite
89 }  // namespace mindspore
90