• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/opt.h"
18 
19 #include <deque>
20 #include <memory>
21 #include <algorithm>
22 #include <utility>
23 
24 #include "mindspore/core/ops/structure_ops.h"
25 #include "utils/hash_map.h"
26 #include "ir/anf.h"
27 #include "ir/manager.h"
28 #include "frontend/optimizer/optimizer.h"
29 #include "utils/log_adapter.h"
30 #include "utils/compile_config.h"
31 
32 namespace mindspore {
33 /* namespace to support opt */
34 namespace opt {
MakeSubstitution(const OptimizerCallerPtr & transform,const std::string & name,const PrimitivePtr & prim,const RenormAction & renorm_action,bool has_priority_pattern)35 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
36                                  const RenormAction &renorm_action, bool has_priority_pattern) {
37   auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
38   return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
39 }
40 
MakeSubstitution(const OptimizerCallerPtr & transform,const std::string & name,const std::vector<PrimitivePtr> & prims,const RenormAction & renorm_action,bool has_priority_pattern)41 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
42                                  const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action,
43                                  bool has_priority_pattern) {
44   auto fn = [prims](const AnfNodePtr &node) -> bool {
45     auto cnode = dyn_cast_ptr<CNode>(node);
46     if (cnode == nullptr) {
47       return false;
48     }
49     auto cnode_prim = GetValuePtr<Primitive>(cnode->input(0));
50     if (cnode_prim == nullptr) {
51       return false;
52     }
53     auto hash = cnode_prim->Hash();
54     const auto &name = cnode_prim->name();
55     return std::any_of(prims.begin(), prims.end(), [&hash, &name](const PrimitivePtr &prim) {
56       return (prim->Hash() == hash) && (prim->name() == name);
57     });
58   };
59   return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
60 }
61 
MakeSubstitution(const OptimizerCallerPtr & transform,const std::string & name,const PredicateFuncType & predicate,const RenormAction & renorm_action,bool has_priority_pattern)62 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
63                                  const PredicateFuncType &predicate, const RenormAction &renorm_action,
64                                  bool has_priority_pattern) {
65   return std::make_shared<Substitution>(transform, name, predicate, renorm_action, has_priority_pattern);
66 }
67 
operator ()(const OptimizerPtr & optimizer,const AnfNodePtr & node)68 AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
69   AnfNodePtr result;
70   if (optimizer != nullptr) {
71     MsProfileStatGuard stat_subs_guard("substitution." + name_);
72     MsProfileStatGuard stat_match_guard("match." + name_);
73     result = (*transform_)(optimizer, node);
74     if (result == nullptr) {
75       stat_match_guard.Interrupt();
76     }
77   } else {
78     result = (*transform_)(optimizer, node);
79   }
80 
81   if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) {
82     if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) {
83       optimizer->set_is_untyped_generated();
84     }
85   }
86 
87   return result;
88 }
89 
isTraversable(const AnfNodePtr & node)90 static inline bool isTraversable(const AnfNodePtr &node) {
91   if (node->isa<CNode>() || node->isa<Parameter>()) {
92     return true;
93   }
94   // FuncGraph or RefKey value node is traversable.
95   auto value_node = dyn_cast_ptr<ValueNode>(node);
96   MS_EXCEPTION_IF_NULL(value_node);
97   const auto &value = value_node->value();
98   return (value != nullptr) && (value->isa<FuncGraph>() || value->isa<RefKey>() || value->isa<MindIRClassType>() ||
99                                 value->isa<MindIRMetaFuncGraph>() || value->isa<parse::ClassType>() ||
100                                 value->isa<prim::DoSignaturePrimitive>() || value->isa<ValueSequence>() ||
101                                 value->isa<parse::NameSpace>() || value->isa<ValueDictionary>());
102 }
103 
DoTransform(const OptimizerPtr & optimizer,const AnfNodePtr & node,const SubstitutionPtr & substitution)104 static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
105                               const SubstitutionPtr &substitution) {
106   auto manager = optimizer->manager();
107   MS_EXCEPTION_IF_NULL(manager);
108   bool is_match;
109   {
110     MsProfileStatGuard stat_predicate_guard("predicate." + substitution->name_);
111     is_match = substitution->predicate_(node);
112   }
113   if (is_match) {
114     TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
115     ScopeGuard scope_guard(node->scope());
116     auto res = (*substitution)(optimizer, node);
117     if (res != nullptr && res != node) {
118       MsProfileStatGuard stat_guard("replace." + substitution->name_);
119       MS_LOG(DEBUG) << "Replace " << node->DebugString() << " with " << res->DebugString() << ", by "
120                     << substitution->name_;
121       (void)manager->Replace(node, res);
122       return res;
123     }
124   }
125   return nullptr;
126 }
127 
UpdateTransformingListForSubstitutions(const AnfNodePtr & node,std::deque<AnfNodePtr> * todo,bool change)128 static void UpdateTransformingListForSubstitutions(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change) {
129   auto fg = GetValuePtr<FuncGraph>(node);
130   if (fg != nullptr) {
131     (void)todo->emplace_back(fg->return_node());
132   }
133 
134   if (change) {
135     (void)todo->emplace_back(node);
136   } else {
137     auto cnode = dyn_cast_ptr<CNode>(node);
138     if (cnode != nullptr) {
139       const auto &inputs = cnode->inputs();
140       (void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend());
141     }
142   }
143 }
144 
UpdateTransformingListForIR(const AnfNodePtr & node,std::deque<AnfNodePtr> * todo,bool change,const SubstitutionPtr & substitution)145 static void UpdateTransformingListForIR(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change,
146                                         const SubstitutionPtr &substitution) {
147   auto fg = GetValuePtr<FuncGraph>(node);
148   if (fg != nullptr) {
149     (void)todo->emplace_back(fg->return_node());
150   }
151 
152   // If there is a priority pattern in substitution, don't transform the new node,
153   // otherwise some nodes may match the wrong patterns.
154   if (change && substitution != nullptr && !substitution->has_priority_pattern_) {
155     (void)todo->emplace_back(node);
156   } else {
157     auto cnode = dyn_cast_ptr<CNode>(node);
158     if (cnode != nullptr) {
159       const auto &inputs = cnode->inputs();
160       (void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend());
161     }
162   }
163 }
164 
UpdateTransformingListWithUserNodes(const FuncGraphManagerPtr & manager,const AnfNodePtr & node,std::deque<AnfNodePtr> * todo,bool change,SeenNum seen)165 static void UpdateTransformingListWithUserNodes(const FuncGraphManagerPtr &manager, const AnfNodePtr &node,
166                                                 std::deque<AnfNodePtr> *todo, bool change, SeenNum seen) {
167   if (!change) {
168     return;
169   }
170   MS_EXCEPTION_IF_NULL(manager);
171   auto &node_users = manager->node_users();
172   auto users_iterator = node_users.find(node);
173   if (users_iterator == node_users.end()) {
174     return;
175   }
176   auto users = users_iterator->second;
177   for (auto &use : users) {
178     auto use_node = use.first;
179     if (use_node == nullptr) {
180       continue;
181     }
182     (*todo).emplace_back(use_node);
183     if (use_node->seen_ == seen) {
184       use_node->seen_--;
185     }
186   }
187 }
188 
ApplyIRToSubstitutions(const OptimizerPtr & optimizer,const FuncGraphPtr & func_graph) const189 bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
190   MsProfileStatGuard stat_guard("opt.transform." + optimizer->name());
191   FuncGraphManagerPtr manager = optimizer->manager();
192   auto seen = NewSeenGeneration();
193   std::deque<AnfNodePtr> todo;
194   (void)todo.emplace_back(func_graph->return_node());
195   bool changes = false;
196   auto &all_nodes = manager->all_nodes();
197   while (!todo.empty()) {
198     AnfNodePtr node = std::move(todo.front());
199     todo.pop_front();
200 
201     if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
202       continue;
203     }
204     node->seen_ = seen;
205 
206     bool change = false;
207     for (auto &substitution : list_) {
208       auto res = DoTransform(optimizer, node, substitution);
209       if (res != nullptr && res != node) {
210         change = true;
211         changes = true;
212         node = res;
213         break;
214       }
215     }
216     UpdateTransformingListForSubstitutions(node, &todo, change);
217     UpdateTransformingListWithUserNodes(manager, node, &todo, change, seen);
218   }
219   return changes;
220 }
221 
ApplySubstitutionToIR(const OptimizerPtr & optimizer,const FuncGraphPtr & func_graph,const SubstitutionPtr & substitution) const222 bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph,
223                                              const SubstitutionPtr &substitution) const {
224   MsProfileStatGuard stat_guard("opt.transform." + optimizer->name());
225   FuncGraphManagerPtr manager = optimizer->manager();
226   MS_EXCEPTION_IF_NULL(manager);
227   auto seen = NewSeenGeneration();
228   std::deque<AnfNodePtr> todo;
229   (void)todo.emplace_back(func_graph->return_node());
230   bool changes = false;
231 
232   auto &all_nodes = manager->all_nodes();
233   while (!todo.empty()) {
234     AnfNodePtr node = todo.front();
235     todo.pop_front();
236 
237     if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
238       continue;
239     }
240     node->seen_ = seen;
241 
242     bool change = false;
243     auto res = DoTransform(optimizer, node, substitution);
244     if (res != nullptr && res != node) {
245       change = true;
246       changes = true;
247       node = res;
248     }
249     UpdateTransformingListForIR(node, &todo, change, substitution);
250     UpdateTransformingListWithUserNodes(manager, node, &todo, change, seen);
251   }
252   return changes;
253 }
254 
DisplayStatusOfSubstitution(const mindspore::HashMap<std::string,std::vector<bool>> & status,const OptimizerPtr & optimizer,size_t space) const255 void SubstitutionList::DisplayStatusOfSubstitution(const mindspore::HashMap<std::string, std::vector<bool>> &status,
256                                                    const OptimizerPtr &optimizer, size_t space) const {
257   constexpr int pad_width = 4;
258   std::stringstream ss;
259   ss << std::endl
260      << "Pass: " << optimizer->name() << "(" << optimizer->current_pass_.counter << ")_"
261      << optimizer->current_pass_.name << std::endl;
262   for (size_t i = 0; i < list_.size(); i++) {
263     auto name = list_[i]->name_;
264     ss << std::left << std::setw(SizeToInt(space) + pad_width) << name << "\t";
265     auto iter = status.find(name + std::to_string(i));
266     if (iter == status.cend()) {
267       continue;
268     }
269     for (auto change : iter->second) {
270       ss << change << " ";
271     }
272     ss << std::endl;
273   }
274   MS_LOG(DEBUG) << ss.str();
275 }
276 
ApplySubstitutionsToIR(const OptimizerPtr & optimizer,const FuncGraphPtr & func_graph) const277 bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
278   // Add for substitution status counting
279   size_t space = 0;
280   mindspore::HashMap<std::string, std::vector<bool>> status;
281   if (optimizer->is_on_debug_) {
282     for (size_t i = 0; i < list_.size(); i++) {
283       status[list_[i]->name_ + std::to_string(i)] = {};
284     }
285   }
286 
287   bool changes = false;
288   bool loop = true;
289   while (loop) {
290     loop = false;
291     for (size_t i = 0; i < list_.size(); i++) {
292       const auto &substitution = list_[i];
293       MS_LOG(INFO) << "Start substitution: " << substitution->name_;
294       bool change = ApplySubstitutionToIR(optimizer, func_graph, substitution);
295       MS_LOG(INFO) << "End substitution: " << substitution->name_ << ", change: " << change;
296       changes = changes || change;
297       loop = loop || change;
298 #ifdef ENABLE_DUMP_IR
299       static const auto enable_dump_pass = GetDumpConfig().enable_dump_pass_ir;
300       static const auto input_name = common::GetEnv("MS_DEV_DUMP_IR_PASSES");
301       auto enable_dump_pass_ir = (input_name.size() != 0) || enable_dump_pass;
302       auto context = MsContext::GetInstance();
303       if ((enable_dump_pass_ir && context->CanDump(kIntroductory)) || context->CanDump(kFully)) {
304         auto fg_name = optimizer->name() + "_r" + std::to_string(optimizer->current_pass_.counter) + "_" +
305                        optimizer->current_pass_.name + "_" + substitution->name_;
306         static const auto switch_order = (common::GetEnv("MS_DEV_SAVE_GRAPHS_SORT_MODE") == "1");
307         if (switch_order) {
308           ExportIR(fg_name + ".ir", func_graph);
309         } else {
310           DumpIR(fg_name + ".ir", func_graph);
311         }
312         if (context->CanDump(kFully)) {
313           draw::Draw(fg_name + ".dot", func_graph);
314         }
315       }
316 #endif
317 
318       // Record the status of each substitution
319       if (optimizer->is_on_debug_) {
320         status[substitution->name_ + std::to_string(i)].push_back(change);
321         space = std::max(substitution->name_.size(), space);
322       }
323     }
324     if (is_once_) {
325       break;
326     }
327   }
328 
329   // Display the status of each substitution
330   if (optimizer->is_on_debug_) {
331     DisplayStatusOfSubstitution(status, optimizer, space);
332   }
333   return changes;
334 }
335 
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer) const336 bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
337   MS_EXCEPTION_IF_NULL(optimizer);
338   MS_EXCEPTION_IF_NULL(func_graph);
339   FuncGraphManagerPtr manager = optimizer->manager();
340   MS_EXCEPTION_IF_NULL(manager);
341   manager->AddFuncGraph(func_graph);
342   bool changes = false;
343   static const auto traverse_mode =
344     (common::GetCompileConfig("TRAVERSE_SUBSTITUTIONS_MODE") != "1" ? kOptTraverseFromIRToSubstitutions
345                                                                     : kOptTraverseFromSubstitutionsToIR);
346   if (traverse_mode == kOptTraverseFromIRToSubstitutions &&
347       MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
348       optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) {
349     MS_LOG(INFO) << "IR >> SUB, *, " << optimizer->name() << "_r" << optimizer->current_pass_.counter << "_"
350                  << optimizer->current_pass_.name;
351     changes = ApplyIRToSubstitutions(optimizer, func_graph);
352   } else {
353     MS_LOG(INFO) << "SUB >> IR, " << optimizer->name() << "_r" << optimizer->current_pass_.counter << "_"
354                  << optimizer->current_pass_.name;
355     changes = ApplySubstitutionsToIR(optimizer, func_graph);
356   }
357   return changes;
358 }
359 
Run()360 bool SimpleRewriter::Run() {
361   bool changed = false;
362   auto seen = NewSeenGeneration();
363   std::deque<AnfNodePtr> todo;
364   auto add_todo = [&seen, &todo](const AnfNodePtr &node) {
365     if (node != nullptr && node->seen_ != seen) {
366       (void)todo.emplace_back(node);
367     }
368   };
369   (void)todo.emplace_back(root_graph_->return_node());
370   auto &all_nodes = manager_->all_nodes();
371   while (!todo.empty()) {
372     AnfNodePtr node = std::move(todo.front());
373     todo.pop_front();
374     if (node == nullptr || node->seen_ == seen || !all_nodes.contains(node)) {
375       continue;
376     }
377     node->seen_ = seen;
378     auto cnode = node->cast_ptr<CNode>();
379     if (cnode != nullptr) {
380       for (auto &input : cnode->weak_inputs()) {
381         add_todo(input.lock());
382       }
383     } else {
384       auto fg = GetValuePtr<FuncGraph>(node);
385       if (fg != nullptr) {
386         add_todo(fg->return_node());
387       }
388     }
389     TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
390     ScopeGuard scope_guard(node->scope());
391     auto new_node = NodeRewrite(node);
392     if (new_node != nullptr) {
393       (void)manager_->Replace(node, new_node);
394       changed = true;
395       // Need push the users of new_node to the deque.
396       UpdateTransformingListWithUserNodes(manager_, new_node, &todo, changed, seen);
397     }
398   }
399   return changed;
400 }
401 }  // namespace opt
402 }  // namespace mindspore
403