1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/parse/function_block.h"
20
21 #include <string>
22 #include <memory>
23 #include <algorithm>
24
25 #include "pybind11/pybind11.h"
26 #include "pipeline/jit/parse/resolve.h"
27 #include "pipeline/jit/parse/parse.h"
28 #include "pipeline/jit/parse/data_converter.h"
29 #include "frontend/operator/ops.h"
30 #include "utils/info.h"
31 #include "debug/trace.h"
32 #include "utils/utils.h"
33
34 namespace mindspore {
35 namespace py = pybind11;
36
37 namespace parse {
FunctionBlock(const Parser & parser)38 FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) {
39 func_graph_ = std::make_shared<FuncGraph>();
40 matured_ = false;
41 }
42
AddPrevBlock(const FunctionBlockPtr & block)43 void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); }
44
CanBeIsolatedNode(const std::string & var_name,const AnfNodePtr & node)45 static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &node) {
46 auto cnode = dyn_cast<CNode>(node);
47 if (cnode == nullptr || cnode->inputs().empty()) {
48 // Not a valid cnode, can not be isolate node.
49 return false;
50 }
51 auto prim = GetValueNode<PrimitivePtr>(cnode->inputs().at(0));
52 if (prim == nullptr) {
53 // Not a primitive cnode, it may have side effects or not,
54 // We add it as an isolate node if its name is not '_' or empty.
55 // this means that code like:
56 // _ = func_call()
57 // will be ignored even if func_call() has side effects.
58 return !var_name.empty() && var_name != "_";
59 }
60 // Primitive cnode with side effects can be isolate nodes.
61 auto effect_info = GetPrimEffectInfo(prim);
62 bool has_effects = (effect_info.memory || effect_info.io);
63 if (has_effects) {
64 return true;
65 }
66 // Primitive cnode with 'no_eliminate' flag can be isolate nodes.
67 return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
68 }
69
70 // Write variable records the variable name to corresponding node
WriteVariable(const std::string & var_name,const AnfNodePtr & node)71 void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
72 MS_EXCEPTION_IF_NULL(node);
73 MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node "
74 << node->DebugString();
75 auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
76 if (!is_new_name) {
77 // If a cnode variable with same name already existed but not used,
78 // add it as an isolate node. for example:
79 // a = print(x)
80 // a = print(y)
81 // When we write variable 'a = print(y)',
82 // the cnode 'print(x)' should added as an isolate node.
83 auto is_used = iter->second.second;
84 auto hidden_node = iter->second.first;
85 auto is_isolated = CanBeIsolatedNode(var_name, hidden_node);
86 if (!is_used && is_isolated) {
87 MS_EXCEPTION_IF_NULL(hidden_node);
88 MS_LOG(INFO) << "Isolated node found(Hidden), hidden_node: " << hidden_node->DebugString(2) << " is hidden by "
89 << node->DebugString(2) << " with the same name, var_name: " << var_name << ", block: " << this
90 << "/" << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
91 << ", Line: " << trace::GetDebugInfo(hidden_node->debug_info(), "", kSourceLineTipDiscard);
92 AddIsolatedNode(hidden_node);
93 }
94 iter->second = std::make_pair(node, false);
95 }
96 }
97
98 // Read variable from predecessors
ReadVariable(const std::string & var)99 AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
100 MS_LOG(DEBUG) << "Read begin, var: " << var << ", block: " << ToString();
101 // Get var node if it is found
102 auto found = assigned_vars_.find(var);
103 if (found != assigned_vars_.end()) {
104 auto &node = found->second.first;
105 MS_EXCEPTION_IF_NULL(node);
106 // Mark the variable as used.
107 found->second.second = true;
108 auto iter = resolve_to_removable_phis_.find(node);
109 if (iter != resolve_to_removable_phis_.end()) {
110 return iter->second;
111 }
112 return node;
113 }
114 // Get var from predecessor block, if can't get then make a resolve node to it
115 if (matured_) {
116 // If only one predecessor block, read the definition of var from it.
117 if (prev_blocks_.size() == 1) {
118 auto block = prev_blocks_[0];
119 MS_EXCEPTION_IF_NULL(block);
120 auto res = block->ReadVariable(var);
121 MS_LOG(INFO) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString()
122 << ",\nCurrent: " << py::str(global_py_params())
123 << "\nInsert: " << py::str(block->global_py_params());
124 CopyGlobalPyParam(block->global_py_params());
125 return res;
126 } else if (prev_blocks_.empty()) {
127 // Get namespace and make Resolve
128 auto it = var_to_resolve_.find(var);
129 if (it != var_to_resolve_.end()) {
130 return it->second;
131 }
132 MS_LOG(DEBUG) << "var: " << var;
133 auto tmp_node = MakeResolveSymbol(var);
134 var_to_resolve_[var] = tmp_node;
135 return tmp_node;
136 }
137 }
138 // If have more than one predecessor blocks then build a phi node.
139 auto debug_info = std::make_shared<NodeDebugInfo>();
140 debug_info->set_name(var);
141 TraceGuard guard(std::make_shared<TracePhi>(debug_info));
142 ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
143 MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
144 << phi_param->ToString() << " for " << var;
145 func_graph()->add_parameter(phi_param);
146 phi_nodes_[phi_param] = var;
147 WriteVariable(var, phi_param);
148 if (matured_) {
149 SetPhiArgument(phi_param);
150 }
151 return phi_param;
152 }
153
154 // Resolve Ast operator node
MakeResolveAstOp(const py::object & op)155 AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
156 auto ast = parser_.ast();
157 MS_EXCEPTION_IF_NULL(ast);
158 TraceGuard trace_guard(parser_.GetLocation(op));
159 py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op);
160 if (namespace_var.size() != 2) {
161 MS_LOG(EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size();
162 }
163 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]);
164 SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
165 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
166 return MakeResolve(name_space, symbol);
167 }
168
169 // Resolve class member, two possible: method, member variable
MakeResolveClassMember(const std::string & attr)170 AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
171 auto ast = parser_.ast();
172 MS_EXCEPTION_IF_NULL(ast);
173 py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
174 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
175 SymbolPtr symbol = std::make_shared<Symbol>(attr);
176 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
177 return MakeResolve(name_space, symbol);
178 }
179
HandleNamespaceInfo(const py::tuple & namespace_info)180 AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &namespace_info) {
181 const size_t namespace_info_size = 2;
182 const size_t namespace_more_info_size = 3;
183 if (namespace_info.size() != namespace_info_size && namespace_info.size() != namespace_more_info_size) {
184 MS_EXCEPTION(NameError) << "namespace info size should be 2 or 3, but got " << namespace_info.size();
185 }
186 bool unsupported = false;
187 py::object py_obj;
188 if (namespace_info.size() == namespace_more_info_size) {
189 if (namespace_info[0].is_none()) { // If namespace is None, the symbol is an undefined name.
190 MS_EXCEPTION(NameError) << namespace_info[namespace_more_info_size - 1].cast<std::string>();
191 } else { // Or, the symbol is an unsupported builtin symbol in Graph mode.
192 unsupported = true;
193 py_obj = namespace_info[namespace_more_info_size - 1];
194 }
195 }
196 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
197 SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
198 MS_LOG(DEBUG) << "[" << func_graph()->ToString() << "] name_space: " << name_space->ToString()
199 << ", symbol: " << symbol->ToString() << ", unsupported: " << unsupported;
200 auto resolved_node = MakeResolve(name_space, symbol);
201 if (unsupported) {
202 resolved_node->set_interpret(true);
203 AddGlobalPyParam(symbol->name(), py_obj);
204 MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symblol: {" << symbol->name() << " : "
205 << py::str(py_obj) << "}";
206 }
207 return resolved_node;
208 }
209
210 // Make a resolve node for symbol string
MakeResolveSymbol(const std::string & value)211 AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
212 MS_LOG(DEBUG) << "value: " << value;
213 if (value.compare(0, strlen("self"), "self") == 0) {
214 auto start = value.find_first_of('.') + 1;
215 if (start >= value.size()) {
216 MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value;
217 return nullptr;
218 }
219 auto bits_str = value.substr(start);
220 return MakeResolveClassMember(bits_str);
221 }
222 auto ast = parser_.ast();
223 MS_EXCEPTION_IF_NULL(ast);
224
225 // The fallback feature is enabled in default.
226 // Not support change the flag during the process is alive.
227 static const auto use_fallback = (parser_.support_fallback() == "1");
228 if (!use_fallback) {
229 py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
230 return HandleNamespaceInfo(namespace_info);
231 } else {
232 py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL, value);
233 return HandleNamespaceInfo(namespace_info);
234 }
235 }
236
MakeResolveOperation(const std::string & value)237 AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
238 auto ast = parser_.ast();
239 MS_EXCEPTION_IF_NULL(ast);
240 py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
241 const size_t namespace_var_size = 2;
242 if (namespace_var.size() < namespace_var_size) {
243 MS_EXCEPTION(NameError) << "namespace_var is less than 2";
244 }
245 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]);
246 SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
247 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
248 return MakeResolve(name_space, symbol);
249 }
250
MakeResolve(const NameSpacePtr & name_space,const SymbolPtr & resolve_symbol)251 AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
252 MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
253 << " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resoleve symbol.");
254 ValueNodePtr module_node = NewValueNode(name_space);
255 ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
256 auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
257 return node;
258 }
259
MakeInterpret(const std::string & script_text,const AnfNodePtr & global_dict_node,const AnfNodePtr & local_dict_node,const AnfNodePtr & orig_node)260 AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
261 const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) {
262 MS_LOG(DEBUG) << "MakeInterpret for " << script_text;
263 ScriptPtr script = std::make_shared<Script>(script_text);
264 auto script_node = NewValueNode(script);
265 auto node = func_graph_->NewCNodeInOrder(
266 {NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
267 node->set_interpreted_node(orig_node);
268 return node;
269 }
270
271 // Add input for the block's phi parameter
SetPhiArgument(const ParameterPtr & phi)272 void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
273 MS_EXCEPTION_IF_NULL(phi);
274 TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
275 std::string var = phi_nodes_[phi];
276 MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
277 << " for var `" << var << "`";
278 auto removable = CollectRemovablePhi(phi);
279 // If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
280 if (removable) {
281 MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
282 << " var `" << var << "`";
283 return;
284 }
285 for (auto &pred : prev_blocks_) {
286 MS_EXCEPTION_IF_NULL(pred);
287 MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " pred_blocks_ "
288 << (pred->func_graph_ ? pred->func_graph_->ToString() : "FG(Null)");
289 AnfNodePtr arg_node = pred->ReadVariable(var);
290 CNodePtr jump = pred->jumps_[this];
291 MS_EXCEPTION_IF_NULL(jump);
292 jump->add_input(arg_node);
293 }
294 }
295
SearchReplaceNode(const std::string & var,const ParameterPtr & phi)296 AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) {
297 AnfNodePtr arg_node = nullptr;
298 MS_LOG(DEBUG) << "Prev_blocks size: " << prev_blocks_.size();
299 for (auto &prev : prev_blocks_) {
300 MS_EXCEPTION_IF_NULL(prev);
301 AnfNodePtr temp_node = prev->ReadVariable(var);
302 MS_EXCEPTION_IF_NULL(temp_node);
303 if (temp_node != phi) {
304 if (arg_node == nullptr) {
305 arg_node = temp_node;
306 MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
307 << (phi ? phi->ToString() : "null") << " may be replaced by node " << arg_node->DebugString();
308 } else if (temp_node == arg_node) {
309 MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
310 << (phi ? phi->ToString() : "null") << " is same as node " << arg_node->DebugString();
311 } else {
312 MS_LOG(DEBUG) << "phi " << (phi ? phi->ToString() : "null")
313 << " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString()
314 << ", node2: " << temp_node->DebugString();
315 return nullptr;
316 }
317 }
318 }
319 return arg_node;
320 }
321
322 // Check if there is removable unnecessary phi node in this graph.
323 // As per the FIRM TR 3.2, a phi node can be remove if:
324 // <Quote>
325 // If all arguments of a φ-function are the same value s or the φfunction itself,
326 // then we remove the φ-function and let all users directly uses. We call such a
327 // φ-function obviously unnecessary.
328 // When we removed a φ-function p, then we recursively try to apply this simplification
329 // rule with all (former) users of p, because they may have become obviously unnecessary
330 // due to the removal of p
331 // <Quote>
332 // phi node in graph will be removed after the whole function is parsed in a DFS visit
333 // of that graph.The reason is :
334 // 1. when this function is called, not all usage of this phi node had bound to the
335 // graph of this function block, some may stay in vars_ in other blocks.
336 // 2. it's costly to iterate the graph to replace the phi for each phi.
337 // Args: phi: This parameter node is functioning as a phi node.
CollectRemovablePhi(const ParameterPtr & phi)338 bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
339 MS_EXCEPTION_IF_NULL(phi);
340 std::string var = phi_nodes_[phi];
341 MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var;
342 if (prev_blocks_.empty()) {
343 MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var;
344 return false;
345 }
346 AnfNodePtr arg_node = SearchReplaceNode(var, phi);
347 if (arg_node != nullptr) {
348 arg_node->set_debug_info(phi->debug_info());
349 MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi " << phi->ToString()
350 << " can be replaced with " << arg_node->DebugString();
351 // Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
352 WriteVariable(var, arg_node);
353 removable_phis_[phi] = arg_node;
354 resolve_to_removable_phis_[arg_node] = phi;
355 // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized
356 // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node.
357 for (auto &prev : prev_blocks_) {
358 MS_EXCEPTION_IF_NULL(prev);
359 if (!prev->matured_) {
360 continue;
361 }
362 for (auto &phi_iter : prev->removable_phis_) {
363 MS_EXCEPTION_IF_NULL(phi_iter.second);
364 if (phi_iter.second->isa<Parameter>()) {
365 const auto ¶m = phi_iter.second->cast<ParameterPtr>();
366 if (param == phi) {
367 MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " var "
368 << phi_iter.first->DebugString() << " can be replaced from " << param->DebugString()
369 << " with " << arg_node->DebugString() << " in graph "
370 << (arg_node->func_graph() ? arg_node->func_graph()->ToString() : "FG(Null)");
371 prev->removable_phis_[phi_iter.first] = arg_node;
372 }
373 }
374 }
375 }
376 return true;
377 }
378 return false;
379 }
380
381 // A block should be marked matured if its predecessor blocks have been processed
Mature()382 void FunctionBlock::Mature() {
383 const auto &graph_params = func_graph_->parameters();
384 for (auto ¶m_itr : graph_params) {
385 MS_EXCEPTION_IF_NULL(param_itr);
386 auto param = param_itr->cast<ParameterPtr>();
387 if (phi_nodes_.find(param) != phi_nodes_.cend()) {
388 SetPhiArgument(param);
389 }
390 }
391 matured_ = true;
392 }
393
394 // Force the condition node to bool using bool operation
ForceToBoolNode(const AnfNodePtr & cond)395 CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
396 MS_EXCEPTION_IF_NULL(cond);
397 TraceGuard trace_guard(std::make_shared<TraceForceBool>(cond->debug_info()));
398 CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond});
399 return op_apply_node;
400 }
401
ForceToWhileCond(const AnfNodePtr & cond)402 CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
403 MS_EXCEPTION_IF_NULL(cond);
404 TraceGuard trace_guard(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
405 CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation("while_cond"), cond});
406 return op_apply_node;
407 }
408
409 // Perform a jump from this block to target block
Jump(const FunctionBlockPtr & target_block,const std::vector<AnfNodePtr> & args)410 void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
411 MS_LOG(DEBUG) << "Jump from block: " << ToString() << " to block: " << target_block->ToString();
412 MS_EXCEPTION_IF_NULL(target_block);
413 if (is_dead_block_) {
414 MS_LOG(DEBUG) << "Dead code block should not jump to other block! block: " << ToString();
415 return;
416 }
417 if (func_graph_->get_return() != nullptr) {
418 MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
419 << trace::GetDebugInfo(func_graph_->get_return()->debug_info());
420 }
421 std::vector<AnfNodePtr> input_nodes;
422 input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
423 (void)std::copy(args.begin(), args.end(), std::back_inserter(input_nodes));
424
425 CNodePtr jump = func_graph_->NewCNodeInOrder(input_nodes);
426 jumps_[target_block.get()] = jump;
427 target_block->AddPrevBlock(shared_from_this());
428 func_graph_->set_output(jump);
429 }
430
431 // Perform a conditional jump using switch operation.
432 // The first CNode select graph with condition, and than execute this graph
ConditionalJump(AnfNodePtr condNode,const FunctionBlockPtr & true_block,const FunctionBlockPtr & false_block,bool)433 void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block,
434 const FunctionBlockPtr &false_block, bool) {
435 MS_EXCEPTION_IF_NULL(true_block);
436 MS_EXCEPTION_IF_NULL(false_block);
437 if (func_graph_->get_return() != nullptr) {
438 MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
439 << trace::GetDebugInfo(func_graph_->get_return()->debug_info());
440 }
441 CNodePtr switch_app =
442 func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
443 NewValueNode(false_block->func_graph())});
444 CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app});
445 func_graph_->set_output(switch_app_new);
446 }
447
448 // Create cnode for the assign statement like 'self.target = source'.
449 // convert it to 'P.Assign(self.target, source)' and then add the cnode as isolate node.
SetStateAssign(const AnfNodePtr & target,const AnfNodePtr & source)450 void FunctionBlock::SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source) {
451 const std::string primitive_name("assign");
452 const std::string module_name("mindspore.ops.functional");
453 ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
454 auto assign_node = func_graph_->NewCNodeInOrder({assign_op, target, source});
455 MS_LOG(DEBUG) << "Isolated node found(Assign), assign_node: " << assign_node->DebugString(2) << ", block: " << this
456 << "/" << func_graph_->ToString()
457 << ", Line: " << trace::GetDebugInfo(assign_node->debug_info(), "", kSourceLineTipDiscard);
458 AddIsolatedNode(assign_node);
459 }
460
FindIsolatedNodes()461 void FunctionBlock::FindIsolatedNodes() {
462 //
463 // Search isolate nodes from variables, for example,
464 // variable 'a' is an isolate node in below code:
465 //
466 // def construct(self, x, y):
467 // a = print(x) # isolate node
468 // return x + y
469 //
470 std::set<AnfNodePtr> used;
471 // Find used variables.
472 for (const auto &var : assigned_vars_) {
473 auto &node = var.second.first;
474 if (node == nullptr) {
475 continue;
476 }
477 bool is_used = var.second.second;
478 if (is_used) {
479 used.emplace(node);
480 }
481 }
482 // Add isolated nodes which is unused var but not found in used set.
483 for (const auto &var : assigned_vars_) {
484 auto &node = var.second.first;
485 bool is_used = var.second.second;
486 if (node == nullptr || is_used) {
487 continue;
488 }
489 auto &var_name = var.first;
490 if (used.find(node) == used.end() && CanBeIsolatedNode(var_name, node)) {
491 MS_LOG(INFO) << "Isolated node found(NoUse), node: " << node->DebugString(2) << ", var_name: " << var_name
492 << ", block: " << this << "/" << (func_graph() ? func_graph()->ToString() : "FG(Null)")
493 << ", Line: " << trace::GetDebugInfo(node->debug_info(), "", kSourceLineTipDiscard);
494 AddIsolatedNode(node);
495 }
496 }
497 }
498
AddIsolatedNode(const AnfNodePtr & target)499 void FunctionBlock::AddIsolatedNode(const AnfNodePtr &target) { isolated_nodes_.add(target); }
500
AttachIsolatedNodesBeforeReturn()501 void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
502 if (isolated_nodes_.empty()) {
503 return;
504 }
505 std::vector<AnfNodePtr> states;
506 states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
507 constexpr int recursive_level = 2;
508 for (auto &node : isolated_nodes_) {
509 MS_EXCEPTION_IF_NULL(node);
510 MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(recursive_level) << " in "
511 << func_graph_->ToString();
512 if (node->func_graph() == func_graph_) {
513 states.emplace_back(node);
514 } else {
515 MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(recursive_level) << " in "
516 << func_graph_->ToString();
517 }
518 }
519 isolated_nodes_.clear();
520
521 AnfNodePtr state = nullptr;
522 if (states.size() == 1) {
523 // Only MakeTuple, no state left.
524 return;
525 } else if (states.size() == 2) {
526 // If there are only MakeTuple and another node in states(the states size is 2),
527 // do not need to MakeTuple, just use the node.
528 state = states[1];
529 } else {
530 state = func_graph_->NewCNode(states);
531 }
532
533 AnfNodePtr old_output = nullptr;
534 auto return_node = func_graph_->get_return();
535 if (return_node) {
536 const size_t return_input_size = 2;
537 if (return_node->inputs().size() < return_input_size) {
538 MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2";
539 }
540 old_output = return_node->input(1);
541 } else {
542 old_output = NewValueNode(kNone);
543 }
544 AnfNodePtr stop_grad_node = func_graph_->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
545 CNodePtr depend_node = func_graph_->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
546 // We add this attribute for @constexpr use scene, since we must infer them before other nodes.
547 // That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
548 depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
549 MS_EXCEPTION_IF_NULL(state);
550 MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
551 << ", state: " << state->DebugString(2);
552 func_graph_->set_output(depend_node, true);
553 }
554
SetAsDeadBlock()555 void FunctionBlock::SetAsDeadBlock() { is_dead_block_ = true; }
556 } // namespace parse
557 } // namespace mindspore
558