1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <vector> 18 #include <algorithm> 19 #include <memory> 20 #include "src/common/log.h" 21 #include "src/train/graph_dropout.h" 22 #include "tools/converter/optimizer.h" 23 #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" 24 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" 25 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" 26 27 namespace mindspore { 28 namespace lite { GetGraphNodes(const schema::MetaGraphT & graph_defT)29std::vector<schema::CNodeT *> GetGraphNodes(const schema::MetaGraphT &graph_defT) { 30 std::vector<schema::CNodeT *> old_nodes{}; 31 old_nodes.resize(graph_defT.nodes.size()); 32 std::transform(graph_defT.nodes.begin(), graph_defT.nodes.end(), old_nodes.begin(), 33 [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); }); 34 return old_nodes; 35 } 36 Run(schema::MetaGraphT * graph)37STATUS GraphDropout::Run(schema::MetaGraphT *graph) { 38 if (graph == nullptr) { 39 MS_LOG(ERROR) << "graph is nullptr."; 40 return RET_ERROR; 41 } 42 Optimizer dropout_optimizer; 43 auto old_nodes = GetGraphNodes(*graph); 44 dropout_optimizer.AddPass(new (std::nothrow) DropoutNodeRemovePass()); 45 dropout_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); 46 dropout_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 47 auto status = dropout_optimizer.Run(graph); 48 if (status != RET_OK && status != RET_NO_CHANGE) { 49 MS_LOG(ERROR) << "graph fusion failed."; 50 return RET_ERROR; 51 } 52 return RET_OK; 53 } 54 } // namespace lite 55 } // namespace mindspore 56