1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "src/train/optimizer/fusion/remove_redundant_tensor.h" 18 #include <map> 19 #include "src/common/log_adapter.h" 20 #include "nnacl/op_base.h" 21 22 namespace mindspore { 23 namespace lite { Run(schema::MetaGraphT * graph)24STATUS RemoveRedundantTensor::Run(schema::MetaGraphT *graph) { 25 if (graph == nullptr) { 26 MS_LOG(ERROR) << "The graph is a nullptr."; 27 return RET_NULL_PTR; 28 } 29 std::map<uint32_t, uint32_t> index_map; 30 uint32_t index = 0; 31 auto graph_input_index = graph->inputIndex; 32 graph->inputIndex.clear(); 33 for (auto input_index : graph_input_index) { 34 if (index_map.find(input_index) == index_map.end()) { 35 index_map[input_index] = index; 36 ++index; 37 } 38 graph->inputIndex.push_back(index_map[input_index]); 39 } 40 for (auto &node : graph->nodes) { 41 auto node_in_index = node->inputIndex; 42 node->inputIndex.clear(); 43 for (auto in_index : node_in_index) { 44 if (index_map.find(in_index) == index_map.end()) { 45 index_map[in_index] = index; 46 ++index; 47 } 48 node->inputIndex.push_back(index_map[in_index]); 49 } 50 auto node_out_index = node->outputIndex; 51 node->outputIndex.clear(); 52 for (auto out_index : node_out_index) { 53 if (index_map.find(out_index) == index_map.end()) { 54 index_map[out_index] = index; 55 ++index; 56 } 57 node->outputIndex.push_back(index_map[out_index]); 58 } 59 } 60 auto graph_output_index = graph->outputIndex; 61 graph->outputIndex.clear(); 62 for (auto output_index : graph_output_index) { 63 if (index_map.find(output_index) == index_map.end()) { 64 index_map[output_index] = index; 65 ++index; 66 } 67 graph->outputIndex.push_back(index_map[output_index]); 68 } 69 std::vector<std::unique_ptr<mindspore::schema::TensorT>> old_tensors; 70 old_tensors.swap(graph->allTensors); 71 graph->allTensors.resize(index_map.size()); 72 for (size_t i = 0; i < old_tensors.size(); ++i) { 73 if (index_map.find(i) == index_map.end()) { 74 continue; 75 } 76 graph->allTensors[index_map[i]].swap(old_tensors[i]); 77 } 78 if (!graph->subGraph.empty()) { 79 graph->subGraph[0]->inputIndices = graph->inputIndex; 80 graph->subGraph[0]->outputIndices = graph->outputIndex; 81 graph->subGraph[0]->tensorIndices = {}; 82 for (uint32_t i = 0; i < index; ++i) { 83 graph->subGraph[0]->tensorIndices.push_back(i); 84 } 85 } 86 return RET_OK; 87 } 88 } // namespace lite 89 } // namespace mindspore 90