• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/common/optimizer.h"
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include <algorithm>
22 #include <utility>
23 
24 #include "backend/optimizer/common/pass_manager.h"
25 #include "backend/session/anf_runtime_algorithm.h"
26 #include "ir/manager.h"
27 
28 namespace mindspore {
29 namespace opt {
PatternProcessPass(const std::string & name,bool multigraph)30 PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
31     : NodePass(name),
32       multigraph_(multigraph),
33       pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
34       primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
35 
DefinePattern() const36 const BaseRef PatternProcessPass::DefinePattern() const {
37   VarPtr X = std::make_shared<Var>();
38   return BaseRef({X});
39 }
40 
Build()41 void PatternProcessPass::Build() {
42   VarPtr fg = std::make_shared<Var>("RootG");
43   BaseRef pattern = std::move(DefinePattern());
44   pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_);
45 }
46 
Run(const FuncGraphPtr & func_graph,const AnfNodePtr & node)47 AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
48   if (pattern_ == nullptr) {
49     Build();
50   }
51 
52   auto primitive = GetCNodePrimitive(pattern_);
53   if (IsPrimitiveCNode(node, primitive)) {
54     auto empty_equiv = std::make_shared<Equiv>();
55     MS_EXCEPTION_IF_NULL(primitive_vars_);
56     EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
57     if (equiv != nullptr && !equiv->empty()) {
58       return Process(func_graph, node, equiv);
59     }
60   }
61   return nullptr;
62 }
63 
MatchAnotherPattern(const AnfNodePtr & node,const EquivPtr & equiv) const64 bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
65   MS_EXCEPTION_IF_NULL(node);
66   MS_EXCEPTION_IF_NULL(equiv);
67   VarPtr fg = std::make_shared<Var>("RootG");
68   auto empty_equiv = std::make_shared<Equiv>();
69   MS_EXCEPTION_IF_NULL(child_primitive_vars_);
70   EquivPtr another_equiv =
71     child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
72                                 *child_primitive_vars_, empty_equiv);
73   if (another_equiv != nullptr && !another_equiv->empty()) {
74     return IsShareNodes(equiv, another_equiv);
75   }
76   return false;
77 }
78 
AddPassManager(const PassManagerPtr & pass_manager)79 void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
80   if (pass_manager != nullptr) {
81     pass_managers_.push_back(pass_manager);
82   }
83 }
84 
Optimize(const FuncGraphPtr & func_graph,bool run_only_once)85 FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) {
86   MS_EXCEPTION_IF_NULL(func_graph);
87   run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once;
88   // cppcheck-suppress *
89   auto manager = Manage(func_graph, true);
90 
91   bool changed = true;
92   while (changed) {
93     changed = false;
94     for (size_t i = 0; i < pass_managers_.size(); ++i) {
95       const PassManagerPtr &pm = pass_managers_[i];
96       if (pm != nullptr && pm->Run(func_graph)) {
97         changed = true;
98       }
99     }
100     if (run_only_once_) {
101       break;
102     }
103   }
104 
105   std::vector<FuncGraphPtr> func_graphs;
106   func_graphs.push_back(func_graph);
107   (void)TopoSort(func_graph->get_return());
108   return func_graph;
109 }
110 }  // namespace opt
111 }  // namespace mindspore
112