1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/opt.h"
18
19 #include <deque>
20 #include <memory>
21 #include <algorithm>
22 #include <utility>
23
24 #include "mindspore/core/ops/structure_ops.h"
25 #include "utils/hash_map.h"
26 #include "ir/anf.h"
27 #include "ir/manager.h"
28 #include "frontend/optimizer/optimizer.h"
29 #include "utils/log_adapter.h"
30 #include "utils/compile_config.h"
31
32 namespace mindspore {
33 /* namespace to support opt */
34 namespace opt {
MakeSubstitution(const OptimizerCallerPtr & transform,const std::string & name,const PrimitivePtr & prim,const RenormAction & renorm_action,bool has_priority_pattern)35 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
36 const RenormAction &renorm_action, bool has_priority_pattern) {
37 auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
38 return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
39 }
40
MakeSubstitution(const OptimizerCallerPtr & transform,const std::string & name,const std::vector<PrimitivePtr> & prims,const RenormAction & renorm_action,bool has_priority_pattern)41 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
42 const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action,
43 bool has_priority_pattern) {
44 auto fn = [prims](const AnfNodePtr &node) -> bool {
45 auto cnode = dyn_cast_ptr<CNode>(node);
46 if (cnode == nullptr) {
47 return false;
48 }
49 auto cnode_prim = GetValuePtr<Primitive>(cnode->input(0));
50 if (cnode_prim == nullptr) {
51 return false;
52 }
53 auto hash = cnode_prim->Hash();
54 const auto &name = cnode_prim->name();
55 return std::any_of(prims.begin(), prims.end(), [&hash, &name](const PrimitivePtr &prim) {
56 return (prim->Hash() == hash) && (prim->name() == name);
57 });
58 };
59 return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
60 }
61
MakeSubstitution(const OptimizerCallerPtr & transform,const std::string & name,const PredicateFuncType & predicate,const RenormAction & renorm_action,bool has_priority_pattern)62 SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
63 const PredicateFuncType &predicate, const RenormAction &renorm_action,
64 bool has_priority_pattern) {
65 return std::make_shared<Substitution>(transform, name, predicate, renorm_action, has_priority_pattern);
66 }
67
operator ()(const OptimizerPtr & optimizer,const AnfNodePtr & node)68 AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
69 AnfNodePtr result;
70 if (optimizer != nullptr) {
71 MsProfileStatGuard stat_subs_guard("substitution." + name_);
72 MsProfileStatGuard stat_match_guard("match." + name_);
73 result = (*transform_)(optimizer, node);
74 if (result == nullptr) {
75 stat_match_guard.Interrupt();
76 }
77 } else {
78 result = (*transform_)(optimizer, node);
79 }
80
81 if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) {
82 if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) {
83 optimizer->set_is_untyped_generated();
84 }
85 }
86
87 return result;
88 }
89
isTraversable(const AnfNodePtr & node)90 static inline bool isTraversable(const AnfNodePtr &node) {
91 if (node->isa<CNode>() || node->isa<Parameter>()) {
92 return true;
93 }
94 // FuncGraph or RefKey value node is traversable.
95 auto value_node = dyn_cast_ptr<ValueNode>(node);
96 MS_EXCEPTION_IF_NULL(value_node);
97 const auto &value = value_node->value();
98 return (value != nullptr) && (value->isa<FuncGraph>() || value->isa<RefKey>() || value->isa<MindIRClassType>() ||
99 value->isa<MindIRMetaFuncGraph>() || value->isa<parse::ClassType>() ||
100 value->isa<prim::DoSignaturePrimitive>() || value->isa<ValueSequence>() ||
101 value->isa<parse::NameSpace>() || value->isa<ValueDictionary>());
102 }
103
DoTransform(const OptimizerPtr & optimizer,const AnfNodePtr & node,const SubstitutionPtr & substitution)104 static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
105 const SubstitutionPtr &substitution) {
106 auto manager = optimizer->manager();
107 MS_EXCEPTION_IF_NULL(manager);
108 bool is_match;
109 {
110 MsProfileStatGuard stat_predicate_guard("predicate." + substitution->name_);
111 is_match = substitution->predicate_(node);
112 }
113 if (is_match) {
114 TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
115 ScopeGuard scope_guard(node->scope());
116 auto res = (*substitution)(optimizer, node);
117 if (res != nullptr && res != node) {
118 MsProfileStatGuard stat_guard("replace." + substitution->name_);
119 MS_LOG(DEBUG) << "Replace " << node->DebugString() << " with " << res->DebugString() << ", by "
120 << substitution->name_;
121 (void)manager->Replace(node, res);
122 return res;
123 }
124 }
125 return nullptr;
126 }
127
UpdateTransformingListForSubstitutions(const AnfNodePtr & node,std::deque<AnfNodePtr> * todo,bool change)128 static void UpdateTransformingListForSubstitutions(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change) {
129 auto fg = GetValuePtr<FuncGraph>(node);
130 if (fg != nullptr) {
131 (void)todo->emplace_back(fg->return_node());
132 }
133
134 if (change) {
135 (void)todo->emplace_back(node);
136 } else {
137 auto cnode = dyn_cast_ptr<CNode>(node);
138 if (cnode != nullptr) {
139 const auto &inputs = cnode->inputs();
140 (void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend());
141 }
142 }
143 }
144
UpdateTransformingListForIR(const AnfNodePtr & node,std::deque<AnfNodePtr> * todo,bool change,const SubstitutionPtr & substitution)145 static void UpdateTransformingListForIR(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change,
146 const SubstitutionPtr &substitution) {
147 auto fg = GetValuePtr<FuncGraph>(node);
148 if (fg != nullptr) {
149 (void)todo->emplace_back(fg->return_node());
150 }
151
152 // If there is a priority pattern in substitution, don't transform the new node,
153 // otherwise some nodes may match the wrong patterns.
154 if (change && substitution != nullptr && !substitution->has_priority_pattern_) {
155 (void)todo->emplace_back(node);
156 } else {
157 auto cnode = dyn_cast_ptr<CNode>(node);
158 if (cnode != nullptr) {
159 const auto &inputs = cnode->inputs();
160 (void)todo->insert(todo->end(), inputs.cbegin(), inputs.cend());
161 }
162 }
163 }
164
UpdateTransformingListWithUserNodes(const FuncGraphManagerPtr & manager,const AnfNodePtr & node,std::deque<AnfNodePtr> * todo,bool change,SeenNum seen)165 static void UpdateTransformingListWithUserNodes(const FuncGraphManagerPtr &manager, const AnfNodePtr &node,
166 std::deque<AnfNodePtr> *todo, bool change, SeenNum seen) {
167 if (!change) {
168 return;
169 }
170 MS_EXCEPTION_IF_NULL(manager);
171 auto &node_users = manager->node_users();
172 auto users_iterator = node_users.find(node);
173 if (users_iterator == node_users.end()) {
174 return;
175 }
176 auto users = users_iterator->second;
177 for (auto &use : users) {
178 auto use_node = use.first;
179 if (use_node == nullptr) {
180 continue;
181 }
182 (*todo).emplace_back(use_node);
183 if (use_node->seen_ == seen) {
184 use_node->seen_--;
185 }
186 }
187 }
188
ApplyIRToSubstitutions(const OptimizerPtr & optimizer,const FuncGraphPtr & func_graph) const189 bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
190 MsProfileStatGuard stat_guard("opt.transform." + optimizer->name());
191 FuncGraphManagerPtr manager = optimizer->manager();
192 auto seen = NewSeenGeneration();
193 std::deque<AnfNodePtr> todo;
194 (void)todo.emplace_back(func_graph->return_node());
195 bool changes = false;
196 auto &all_nodes = manager->all_nodes();
197 while (!todo.empty()) {
198 AnfNodePtr node = std::move(todo.front());
199 todo.pop_front();
200
201 if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
202 continue;
203 }
204 node->seen_ = seen;
205
206 bool change = false;
207 for (auto &substitution : list_) {
208 auto res = DoTransform(optimizer, node, substitution);
209 if (res != nullptr && res != node) {
210 change = true;
211 changes = true;
212 node = res;
213 break;
214 }
215 }
216 UpdateTransformingListForSubstitutions(node, &todo, change);
217 UpdateTransformingListWithUserNodes(manager, node, &todo, change, seen);
218 }
219 return changes;
220 }
221
ApplySubstitutionToIR(const OptimizerPtr & optimizer,const FuncGraphPtr & func_graph,const SubstitutionPtr & substitution) const222 bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph,
223 const SubstitutionPtr &substitution) const {
224 MsProfileStatGuard stat_guard("opt.transform." + optimizer->name());
225 FuncGraphManagerPtr manager = optimizer->manager();
226 MS_EXCEPTION_IF_NULL(manager);
227 auto seen = NewSeenGeneration();
228 std::deque<AnfNodePtr> todo;
229 (void)todo.emplace_back(func_graph->return_node());
230 bool changes = false;
231
232 auto &all_nodes = manager->all_nodes();
233 while (!todo.empty()) {
234 AnfNodePtr node = todo.front();
235 todo.pop_front();
236
237 if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
238 continue;
239 }
240 node->seen_ = seen;
241
242 bool change = false;
243 auto res = DoTransform(optimizer, node, substitution);
244 if (res != nullptr && res != node) {
245 change = true;
246 changes = true;
247 node = res;
248 }
249 UpdateTransformingListForIR(node, &todo, change, substitution);
250 UpdateTransformingListWithUserNodes(manager, node, &todo, change, seen);
251 }
252 return changes;
253 }
254
DisplayStatusOfSubstitution(const mindspore::HashMap<std::string,std::vector<bool>> & status,const OptimizerPtr & optimizer,size_t space) const255 void SubstitutionList::DisplayStatusOfSubstitution(const mindspore::HashMap<std::string, std::vector<bool>> &status,
256 const OptimizerPtr &optimizer, size_t space) const {
257 constexpr int pad_width = 4;
258 std::stringstream ss;
259 ss << std::endl
260 << "Pass: " << optimizer->name() << "(" << optimizer->current_pass_.counter << ")_"
261 << optimizer->current_pass_.name << std::endl;
262 for (size_t i = 0; i < list_.size(); i++) {
263 auto name = list_[i]->name_;
264 ss << std::left << std::setw(SizeToInt(space) + pad_width) << name << "\t";
265 auto iter = status.find(name + std::to_string(i));
266 if (iter == status.cend()) {
267 continue;
268 }
269 for (auto change : iter->second) {
270 ss << change << " ";
271 }
272 ss << std::endl;
273 }
274 MS_LOG(DEBUG) << ss.str();
275 }
276
ApplySubstitutionsToIR(const OptimizerPtr & optimizer,const FuncGraphPtr & func_graph) const277 bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
278 // Add for substitution status counting
279 size_t space = 0;
280 mindspore::HashMap<std::string, std::vector<bool>> status;
281 if (optimizer->is_on_debug_) {
282 for (size_t i = 0; i < list_.size(); i++) {
283 status[list_[i]->name_ + std::to_string(i)] = {};
284 }
285 }
286
287 bool changes = false;
288 bool loop = true;
289 while (loop) {
290 loop = false;
291 for (size_t i = 0; i < list_.size(); i++) {
292 const auto &substitution = list_[i];
293 MS_LOG(INFO) << "Start substitution: " << substitution->name_;
294 bool change = ApplySubstitutionToIR(optimizer, func_graph, substitution);
295 MS_LOG(INFO) << "End substitution: " << substitution->name_ << ", change: " << change;
296 changes = changes || change;
297 loop = loop || change;
298 #ifdef ENABLE_DUMP_IR
299 static const auto enable_dump_pass = GetDumpConfig().enable_dump_pass_ir;
300 static const auto input_name = common::GetEnv("MS_DEV_DUMP_IR_PASSES");
301 auto enable_dump_pass_ir = (input_name.size() != 0) || enable_dump_pass;
302 auto context = MsContext::GetInstance();
303 if ((enable_dump_pass_ir && context->CanDump(kIntroductory)) || context->CanDump(kFully)) {
304 auto fg_name = optimizer->name() + "_r" + std::to_string(optimizer->current_pass_.counter) + "_" +
305 optimizer->current_pass_.name + "_" + substitution->name_;
306 static const auto switch_order = (common::GetEnv("MS_DEV_SAVE_GRAPHS_SORT_MODE") == "1");
307 if (switch_order) {
308 ExportIR(fg_name + ".ir", func_graph);
309 } else {
310 DumpIR(fg_name + ".ir", func_graph);
311 }
312 if (context->CanDump(kFully)) {
313 draw::Draw(fg_name + ".dot", func_graph);
314 }
315 }
316 #endif
317
318 // Record the status of each substitution
319 if (optimizer->is_on_debug_) {
320 status[substitution->name_ + std::to_string(i)].push_back(change);
321 space = std::max(substitution->name_.size(), space);
322 }
323 }
324 if (is_once_) {
325 break;
326 }
327 }
328
329 // Display the status of each substitution
330 if (optimizer->is_on_debug_) {
331 DisplayStatusOfSubstitution(status, optimizer, space);
332 }
333 return changes;
334 }
335
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer) const336 bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
337 MS_EXCEPTION_IF_NULL(optimizer);
338 MS_EXCEPTION_IF_NULL(func_graph);
339 FuncGraphManagerPtr manager = optimizer->manager();
340 MS_EXCEPTION_IF_NULL(manager);
341 manager->AddFuncGraph(func_graph);
342 bool changes = false;
343 static const auto traverse_mode =
344 (common::GetCompileConfig("TRAVERSE_SUBSTITUTIONS_MODE") != "1" ? kOptTraverseFromIRToSubstitutions
345 : kOptTraverseFromSubstitutionsToIR);
346 if (traverse_mode == kOptTraverseFromIRToSubstitutions &&
347 MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
348 optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) {
349 MS_LOG(INFO) << "IR >> SUB, *, " << optimizer->name() << "_r" << optimizer->current_pass_.counter << "_"
350 << optimizer->current_pass_.name;
351 changes = ApplyIRToSubstitutions(optimizer, func_graph);
352 } else {
353 MS_LOG(INFO) << "SUB >> IR, " << optimizer->name() << "_r" << optimizer->current_pass_.counter << "_"
354 << optimizer->current_pass_.name;
355 changes = ApplySubstitutionsToIR(optimizer, func_graph);
356 }
357 return changes;
358 }
359
Run()360 bool SimpleRewriter::Run() {
361 bool changed = false;
362 auto seen = NewSeenGeneration();
363 std::deque<AnfNodePtr> todo;
364 auto add_todo = [&seen, &todo](const AnfNodePtr &node) {
365 if (node != nullptr && node->seen_ != seen) {
366 (void)todo.emplace_back(node);
367 }
368 };
369 (void)todo.emplace_back(root_graph_->return_node());
370 auto &all_nodes = manager_->all_nodes();
371 while (!todo.empty()) {
372 AnfNodePtr node = std::move(todo.front());
373 todo.pop_front();
374 if (node == nullptr || node->seen_ == seen || !all_nodes.contains(node)) {
375 continue;
376 }
377 node->seen_ = seen;
378 auto cnode = node->cast_ptr<CNode>();
379 if (cnode != nullptr) {
380 for (auto &input : cnode->weak_inputs()) {
381 add_todo(input.lock());
382 }
383 } else {
384 auto fg = GetValuePtr<FuncGraph>(node);
385 if (fg != nullptr) {
386 add_todo(fg->return_node());
387 }
388 }
389 TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
390 ScopeGuard scope_guard(node->scope());
391 auto new_node = NodeRewrite(node);
392 if (new_node != nullptr) {
393 (void)manager_->Replace(node, new_node);
394 changed = true;
395 // Need push the users of new_node to the deque.
396 UpdateTransformingListWithUserNodes(manager_, new_node, &todo, changed, seen);
397 }
398 }
399 return changed;
400 }
401 } // namespace opt
402 } // namespace mindspore
403