• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <algorithm>
19 #include <memory>
20 #include "src/common/log.h"
21 #include "src/train/graph_dropout.h"
22 #include "tools/converter/optimizer.h"
23 #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
24 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
25 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
26 
27 namespace mindspore {
28 namespace lite {
GetGraphNodes(const schema::MetaGraphT & graph_defT)29 std::vector<schema::CNodeT *> GetGraphNodes(const schema::MetaGraphT &graph_defT) {
30   std::vector<schema::CNodeT *> old_nodes{};
31   old_nodes.resize(graph_defT.nodes.size());
32   std::transform(graph_defT.nodes.begin(), graph_defT.nodes.end(), old_nodes.begin(),
33                  [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
34   return old_nodes;
35 }
36 
Run(schema::MetaGraphT * graph)37 STATUS GraphDropout::Run(schema::MetaGraphT *graph) {
38   if (graph == nullptr) {
39     MS_LOG(ERROR) << "graph is nullptr.";
40     return RET_ERROR;
41   }
42   Optimizer dropout_optimizer;
43   auto old_nodes = GetGraphNodes(*graph);
44   dropout_optimizer.AddPass(new (std::nothrow) DropoutNodeRemovePass());
45   dropout_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
46   dropout_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
47   auto status = dropout_optimizer.Run(graph);
48   if (status != RET_OK && status != RET_NO_CHANGE) {
49     MS_LOG(ERROR) << "graph fusion failed.";
50     return RET_ERROR;
51   }
52   return RET_OK;
53 }
54 }  // namespace lite
55 }  // namespace mindspore
56