• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/parse/function_block.h"
20 
21 #include <string>
22 #include <memory>
23 #include <algorithm>
24 
25 #include "pybind11/pybind11.h"
26 #include "pipeline/jit/parse/resolve.h"
27 #include "pipeline/jit/parse/parse.h"
28 #include "pipeline/jit/parse/data_converter.h"
29 #include "frontend/operator/ops.h"
30 #include "utils/info.h"
31 #include "debug/trace.h"
32 #include "utils/utils.h"
33 
34 namespace mindspore {
35 namespace py = pybind11;
36 
37 namespace parse {
FunctionBlock(const Parser & parser)38 FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) {
39   func_graph_ = std::make_shared<FuncGraph>();
40   matured_ = false;
41 }
42 
AddPrevBlock(const FunctionBlockPtr & block)43 void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); }
44 
CanBeIsolatedNode(const std::string & var_name,const AnfNodePtr & node)45 static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &node) {
46   auto cnode = dyn_cast<CNode>(node);
47   if (cnode == nullptr || cnode->inputs().empty()) {
48     // Not a valid cnode, can not be isolate node.
49     return false;
50   }
51   auto prim = GetValueNode<PrimitivePtr>(cnode->inputs().at(0));
52   if (prim == nullptr) {
53     // Not a primitive cnode, it may have side effects or not,
54     // We add it as an isolate node if its name is not '_' or empty.
55     // this means that code like:
56     //    _ = func_call()
57     // will be ignored even if func_call() has side effects.
58     return !var_name.empty() && var_name != "_";
59   }
60   // Primitive cnode with side effects can be isolate nodes.
61   auto effect_info = GetPrimEffectInfo(prim);
62   bool has_effects = (effect_info.memory || effect_info.io);
63   if (has_effects) {
64     return true;
65   }
66   // Primitive cnode with 'no_eliminate' flag can be isolate nodes.
67   return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
68 }
69 
70 // Write variable records the variable name to corresponding node
WriteVariable(const std::string & var_name,const AnfNodePtr & node)71 void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
72   MS_EXCEPTION_IF_NULL(node);
73   MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node "
74                 << node->DebugString();
75   auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
76   if (!is_new_name) {
77     // If a cnode variable with same name already existed but not used,
78     // add it as an isolate node. for example:
79     //   a = print(x)
80     //   a = print(y)
81     // When we write variable 'a = print(y)',
82     // the cnode 'print(x)' should added as an isolate node.
83     auto is_used = iter->second.second;
84     auto hidden_node = iter->second.first;
85     auto is_isolated = CanBeIsolatedNode(var_name, hidden_node);
86     if (!is_used && is_isolated) {
87       MS_EXCEPTION_IF_NULL(hidden_node);
88       MS_LOG(INFO) << "Isolated node found(Hidden), hidden_node: " << hidden_node->DebugString(2) << " is hidden by "
89                    << node->DebugString(2) << " with the same name, var_name: " << var_name << ", block: " << this
90                    << "/" << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
91                    << ", Line: " << trace::GetDebugInfo(hidden_node->debug_info(), "", kSourceLineTipDiscard);
92       AddIsolatedNode(hidden_node);
93     }
94     iter->second = std::make_pair(node, false);
95   }
96 }
97 
98 // Read variable from predecessors
ReadVariable(const std::string & var)99 AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
100   MS_LOG(DEBUG) << "Read begin, var: " << var << ", block: " << ToString();
101   // Get var node if it is found
102   auto found = assigned_vars_.find(var);
103   if (found != assigned_vars_.end()) {
104     auto &node = found->second.first;
105     MS_EXCEPTION_IF_NULL(node);
106     // Mark the variable as used.
107     found->second.second = true;
108     auto iter = resolve_to_removable_phis_.find(node);
109     if (iter != resolve_to_removable_phis_.end()) {
110       return iter->second;
111     }
112     return node;
113   }
114   // Get var from predecessor block, if can't get then make a resolve node to it
115   if (matured_) {
116     // If only one predecessor block, read the definition of var from it.
117     if (prev_blocks_.size() == 1) {
118       auto block = prev_blocks_[0];
119       MS_EXCEPTION_IF_NULL(block);
120       auto res = block->ReadVariable(var);
121       MS_LOG(INFO) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString()
122                    << ",\nCurrent: " << py::str(global_py_params())
123                    << "\nInsert: " << py::str(block->global_py_params());
124       CopyGlobalPyParam(block->global_py_params());
125       return res;
126     } else if (prev_blocks_.empty()) {
127       // Get namespace and make Resolve
128       auto it = var_to_resolve_.find(var);
129       if (it != var_to_resolve_.end()) {
130         return it->second;
131       }
132       MS_LOG(DEBUG) << "var: " << var;
133       auto tmp_node = MakeResolveSymbol(var);
134       var_to_resolve_[var] = tmp_node;
135       return tmp_node;
136     }
137   }
138   // If have more than one predecessor blocks then build a phi node.
139   auto debug_info = std::make_shared<NodeDebugInfo>();
140   debug_info->set_name(var);
141   TraceGuard guard(std::make_shared<TracePhi>(debug_info));
142   ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
143   MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
144                 << phi_param->ToString() << " for " << var;
145   func_graph()->add_parameter(phi_param);
146   phi_nodes_[phi_param] = var;
147   WriteVariable(var, phi_param);
148   if (matured_) {
149     SetPhiArgument(phi_param);
150   }
151   return phi_param;
152 }
153 
154 // Resolve Ast operator node
MakeResolveAstOp(const py::object & op)155 AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
156   auto ast = parser_.ast();
157   MS_EXCEPTION_IF_NULL(ast);
158   TraceGuard trace_guard(parser_.GetLocation(op));
159   py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op);
160   if (namespace_var.size() != 2) {
161     MS_LOG(EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size();
162   }
163   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]);
164   SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
165   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
166   return MakeResolve(name_space, symbol);
167 }
168 
169 // Resolve class member, two possible: method, member variable
MakeResolveClassMember(const std::string & attr)170 AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
171   auto ast = parser_.ast();
172   MS_EXCEPTION_IF_NULL(ast);
173   py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
174   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
175   SymbolPtr symbol = std::make_shared<Symbol>(attr);
176   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
177   return MakeResolve(name_space, symbol);
178 }
179 
HandleNamespaceInfo(const py::tuple & namespace_info)180 AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &namespace_info) {
181   const size_t namespace_info_size = 2;
182   const size_t namespace_more_info_size = 3;
183   if (namespace_info.size() != namespace_info_size && namespace_info.size() != namespace_more_info_size) {
184     MS_EXCEPTION(NameError) << "namespace info size should be 2 or 3, but got " << namespace_info.size();
185   }
186   bool unsupported = false;
187   py::object py_obj;
188   if (namespace_info.size() == namespace_more_info_size) {
189     if (namespace_info[0].is_none()) {  // If namespace is None, the symbol is an undefined name.
190       MS_EXCEPTION(NameError) << namespace_info[namespace_more_info_size - 1].cast<std::string>();
191     } else {  // Or, the symbol is an unsupported builtin symbol in Graph mode.
192       unsupported = true;
193       py_obj = namespace_info[namespace_more_info_size - 1];
194     }
195   }
196   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
197   SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
198   MS_LOG(DEBUG) << "[" << func_graph()->ToString() << "] name_space: " << name_space->ToString()
199                 << ", symbol: " << symbol->ToString() << ", unsupported: " << unsupported;
200   auto resolved_node = MakeResolve(name_space, symbol);
201   if (unsupported) {
202     resolved_node->set_interpret(true);
203     AddGlobalPyParam(symbol->name(), py_obj);
204     MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symblol: {" << symbol->name() << " : "
205                  << py::str(py_obj) << "}";
206   }
207   return resolved_node;
208 }
209 
210 // Make a resolve node for symbol string
MakeResolveSymbol(const std::string & value)211 AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
212   MS_LOG(DEBUG) << "value: " << value;
213   if (value.compare(0, strlen("self"), "self") == 0) {
214     auto start = value.find_first_of('.') + 1;
215     if (start >= value.size()) {
216       MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value;
217       return nullptr;
218     }
219     auto bits_str = value.substr(start);
220     return MakeResolveClassMember(bits_str);
221   }
222   auto ast = parser_.ast();
223   MS_EXCEPTION_IF_NULL(ast);
224 
225   // The fallback feature is enabled in default.
226   // Not support change the flag during the process is alive.
227   static const auto use_fallback = (parser_.support_fallback() == "1");
228   if (!use_fallback) {
229     py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
230     return HandleNamespaceInfo(namespace_info);
231   } else {
232     py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL, value);
233     return HandleNamespaceInfo(namespace_info);
234   }
235 }
236 
MakeResolveOperation(const std::string & value)237 AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
238   auto ast = parser_.ast();
239   MS_EXCEPTION_IF_NULL(ast);
240   py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
241   const size_t namespace_var_size = 2;
242   if (namespace_var.size() < namespace_var_size) {
243     MS_EXCEPTION(NameError) << "namespace_var is less than 2";
244   }
245   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]);
246   SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
247   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
248   return MakeResolve(name_space, symbol);
249 }
250 
MakeResolve(const NameSpacePtr & name_space,const SymbolPtr & resolve_symbol)251 AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
252   MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
253                 << " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resoleve symbol.");
254   ValueNodePtr module_node = NewValueNode(name_space);
255   ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
256   auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
257   return node;
258 }
259 
MakeInterpret(const std::string & script_text,const AnfNodePtr & global_dict_node,const AnfNodePtr & local_dict_node,const AnfNodePtr & orig_node)260 AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
261                                         const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) {
262   MS_LOG(DEBUG) << "MakeInterpret for " << script_text;
263   ScriptPtr script = std::make_shared<Script>(script_text);
264   auto script_node = NewValueNode(script);
265   auto node = func_graph_->NewCNodeInOrder(
266     {NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
267   node->set_interpreted_node(orig_node);
268   return node;
269 }
270 
271 // Add input for the block's phi parameter
SetPhiArgument(const ParameterPtr & phi)272 void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
273   MS_EXCEPTION_IF_NULL(phi);
274   TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
275   std::string var = phi_nodes_[phi];
276   MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
277                 << " for var `" << var << "`";
278   auto removable = CollectRemovablePhi(phi);
279   // If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
280   if (removable) {
281     MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
282                   << " var `" << var << "`";
283     return;
284   }
285   for (auto &pred : prev_blocks_) {
286     MS_EXCEPTION_IF_NULL(pred);
287     MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " pred_blocks_ "
288                   << (pred->func_graph_ ? pred->func_graph_->ToString() : "FG(Null)");
289     AnfNodePtr arg_node = pred->ReadVariable(var);
290     CNodePtr jump = pred->jumps_[this];
291     MS_EXCEPTION_IF_NULL(jump);
292     jump->add_input(arg_node);
293   }
294 }
295 
SearchReplaceNode(const std::string & var,const ParameterPtr & phi)296 AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) {
297   AnfNodePtr arg_node = nullptr;
298   MS_LOG(DEBUG) << "Prev_blocks size: " << prev_blocks_.size();
299   for (auto &prev : prev_blocks_) {
300     MS_EXCEPTION_IF_NULL(prev);
301     AnfNodePtr temp_node = prev->ReadVariable(var);
302     MS_EXCEPTION_IF_NULL(temp_node);
303     if (temp_node != phi) {
304       if (arg_node == nullptr) {
305         arg_node = temp_node;
306         MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
307                       << (phi ? phi->ToString() : "null") << " may be replaced by node " << arg_node->DebugString();
308       } else if (temp_node == arg_node) {
309         MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
310                       << (phi ? phi->ToString() : "null") << " is same as node " << arg_node->DebugString();
311       } else {
312         MS_LOG(DEBUG) << "phi " << (phi ? phi->ToString() : "null")
313                       << " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString()
314                       << ", node2: " << temp_node->DebugString();
315         return nullptr;
316       }
317     }
318   }
319   return arg_node;
320 }
321 
322 // Check if there is removable unnecessary phi node in this graph.
323 // As per the FIRM TR 3.2, a phi node can be remove if:
324 // <Quote>
325 //    If all arguments of a φ-function are the same value s or the φfunction itself,
326 //    then we remove the φ-function and let all users directly uses. We call such a
327 //    φ-function obviously unnecessary.
328 //    When we removed a φ-function p, then we recursively try to apply this simplification
329 //    rule with all (former) users of p, because they may have become obviously unnecessary
330 //    due to the removal of p
331 // <Quote>
332 // phi node in graph will be removed after the whole function is parsed in a DFS visit
333 // of that graph.The reason is :
334 // 1. when this function is called, not all usage of this phi node had bound to the
335 // graph of this function block, some may stay in vars_ in other blocks.
336 // 2. it's costly to iterate the graph to replace the phi for each phi.
337 // Args: phi: This parameter node is functioning as a phi node.
CollectRemovablePhi(const ParameterPtr & phi)338 bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
339   MS_EXCEPTION_IF_NULL(phi);
340   std::string var = phi_nodes_[phi];
341   MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var;
342   if (prev_blocks_.empty()) {
343     MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var;
344     return false;
345   }
346   AnfNodePtr arg_node = SearchReplaceNode(var, phi);
347   if (arg_node != nullptr) {
348     arg_node->set_debug_info(phi->debug_info());
349     MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi " << phi->ToString()
350                   << " can be replaced with " << arg_node->DebugString();
351     // Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
352     WriteVariable(var, arg_node);
353     removable_phis_[phi] = arg_node;
354     resolve_to_removable_phis_[arg_node] = phi;
355     // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized
356     // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node.
357     for (auto &prev : prev_blocks_) {
358       MS_EXCEPTION_IF_NULL(prev);
359       if (!prev->matured_) {
360         continue;
361       }
362       for (auto &phi_iter : prev->removable_phis_) {
363         MS_EXCEPTION_IF_NULL(phi_iter.second);
364         if (phi_iter.second->isa<Parameter>()) {
365           const auto &param = phi_iter.second->cast<ParameterPtr>();
366           if (param == phi) {
367             MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " var "
368                           << phi_iter.first->DebugString() << " can be replaced from " << param->DebugString()
369                           << " with " << arg_node->DebugString() << " in graph "
370                           << (arg_node->func_graph() ? arg_node->func_graph()->ToString() : "FG(Null)");
371             prev->removable_phis_[phi_iter.first] = arg_node;
372           }
373         }
374       }
375     }
376     return true;
377   }
378   return false;
379 }
380 
381 // A block should be marked matured if its predecessor blocks have been processed
Mature()382 void FunctionBlock::Mature() {
383   const auto &graph_params = func_graph_->parameters();
384   for (auto &param_itr : graph_params) {
385     MS_EXCEPTION_IF_NULL(param_itr);
386     auto param = param_itr->cast<ParameterPtr>();
387     if (phi_nodes_.find(param) != phi_nodes_.cend()) {
388       SetPhiArgument(param);
389     }
390   }
391   matured_ = true;
392 }
393 
394 // Force the condition node to bool using bool operation
ForceToBoolNode(const AnfNodePtr & cond)395 CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
396   MS_EXCEPTION_IF_NULL(cond);
397   TraceGuard trace_guard(std::make_shared<TraceForceBool>(cond->debug_info()));
398   CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond});
399   return op_apply_node;
400 }
401 
ForceToWhileCond(const AnfNodePtr & cond)402 CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
403   MS_EXCEPTION_IF_NULL(cond);
404   TraceGuard trace_guard(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
405   CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation("while_cond"), cond});
406   return op_apply_node;
407 }
408 
409 // Perform a jump from this block to target block
Jump(const FunctionBlockPtr & target_block,const std::vector<AnfNodePtr> & args)410 void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
411   MS_LOG(DEBUG) << "Jump from block: " << ToString() << " to block: " << target_block->ToString();
412   MS_EXCEPTION_IF_NULL(target_block);
413   if (is_dead_block_) {
414     MS_LOG(DEBUG) << "Dead code block should not jump to other block! block: " << ToString();
415     return;
416   }
417   if (func_graph_->get_return() != nullptr) {
418     MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
419                       << trace::GetDebugInfo(func_graph_->get_return()->debug_info());
420   }
421   std::vector<AnfNodePtr> input_nodes;
422   input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
423   (void)std::copy(args.begin(), args.end(), std::back_inserter(input_nodes));
424 
425   CNodePtr jump = func_graph_->NewCNodeInOrder(input_nodes);
426   jumps_[target_block.get()] = jump;
427   target_block->AddPrevBlock(shared_from_this());
428   func_graph_->set_output(jump);
429 }
430 
431 // Perform a conditional jump using switch operation.
432 // The first CNode select graph with condition, and than execute this graph
ConditionalJump(AnfNodePtr condNode,const FunctionBlockPtr & true_block,const FunctionBlockPtr & false_block,bool)433 void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block,
434                                     const FunctionBlockPtr &false_block, bool) {
435   MS_EXCEPTION_IF_NULL(true_block);
436   MS_EXCEPTION_IF_NULL(false_block);
437   if (func_graph_->get_return() != nullptr) {
438     MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
439                       << trace::GetDebugInfo(func_graph_->get_return()->debug_info());
440   }
441   CNodePtr switch_app =
442     func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
443                                   NewValueNode(false_block->func_graph())});
444   CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app});
445   func_graph_->set_output(switch_app_new);
446 }
447 
448 // Create cnode for the assign statement like 'self.target = source'.
449 // convert it to 'P.Assign(self.target, source)' and then add the cnode as isolate node.
SetStateAssign(const AnfNodePtr & target,const AnfNodePtr & source)450 void FunctionBlock::SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source) {
451   const std::string primitive_name("assign");
452   const std::string module_name("mindspore.ops.functional");
453   ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
454   auto assign_node = func_graph_->NewCNodeInOrder({assign_op, target, source});
455   MS_LOG(DEBUG) << "Isolated node found(Assign), assign_node: " << assign_node->DebugString(2) << ", block: " << this
456                 << "/" << func_graph_->ToString()
457                 << ", Line: " << trace::GetDebugInfo(assign_node->debug_info(), "", kSourceLineTipDiscard);
458   AddIsolatedNode(assign_node);
459 }
460 
FindIsolatedNodes()461 void FunctionBlock::FindIsolatedNodes() {
462   //
463   // Search isolate nodes from variables, for example,
464   // variable 'a' is an isolate node in below code:
465   //
466   //    def construct(self, x, y):
467   //        a = print(x) # isolate node
468   //        return x + y
469   //
470   std::set<AnfNodePtr> used;
471   // Find used variables.
472   for (const auto &var : assigned_vars_) {
473     auto &node = var.second.first;
474     if (node == nullptr) {
475       continue;
476     }
477     bool is_used = var.second.second;
478     if (is_used) {
479       used.emplace(node);
480     }
481   }
482   // Add isolated nodes which is unused var but not found in used set.
483   for (const auto &var : assigned_vars_) {
484     auto &node = var.second.first;
485     bool is_used = var.second.second;
486     if (node == nullptr || is_used) {
487       continue;
488     }
489     auto &var_name = var.first;
490     if (used.find(node) == used.end() && CanBeIsolatedNode(var_name, node)) {
491       MS_LOG(INFO) << "Isolated node found(NoUse), node: " << node->DebugString(2) << ", var_name: " << var_name
492                    << ", block: " << this << "/" << (func_graph() ? func_graph()->ToString() : "FG(Null)")
493                    << ", Line: " << trace::GetDebugInfo(node->debug_info(), "", kSourceLineTipDiscard);
494       AddIsolatedNode(node);
495     }
496   }
497 }
498 
AddIsolatedNode(const AnfNodePtr & target)499 void FunctionBlock::AddIsolatedNode(const AnfNodePtr &target) { isolated_nodes_.add(target); }
500 
AttachIsolatedNodesBeforeReturn()501 void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
502   if (isolated_nodes_.empty()) {
503     return;
504   }
505   std::vector<AnfNodePtr> states;
506   states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
507   constexpr int recursive_level = 2;
508   for (auto &node : isolated_nodes_) {
509     MS_EXCEPTION_IF_NULL(node);
510     MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(recursive_level) << " in "
511                   << func_graph_->ToString();
512     if (node->func_graph() == func_graph_) {
513       states.emplace_back(node);
514     } else {
515       MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(recursive_level) << " in "
516                    << func_graph_->ToString();
517     }
518   }
519   isolated_nodes_.clear();
520 
521   AnfNodePtr state = nullptr;
522   if (states.size() == 1) {
523     // Only MakeTuple, no state left.
524     return;
525   } else if (states.size() == 2) {
526     // If there are only MakeTuple and another node in states(the states size is 2),
527     // do not need to MakeTuple, just use the node.
528     state = states[1];
529   } else {
530     state = func_graph_->NewCNode(states);
531   }
532 
533   AnfNodePtr old_output = nullptr;
534   auto return_node = func_graph_->get_return();
535   if (return_node) {
536     const size_t return_input_size = 2;
537     if (return_node->inputs().size() < return_input_size) {
538       MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2";
539     }
540     old_output = return_node->input(1);
541   } else {
542     old_output = NewValueNode(kNone);
543   }
544   AnfNodePtr stop_grad_node = func_graph_->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
545   CNodePtr depend_node = func_graph_->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
546   // We add this attribute for @constexpr use scene, since we must infer them before other nodes.
547   // That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
548   depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
549   MS_EXCEPTION_IF_NULL(state);
550   MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
551                << ", state: " << state->DebugString(2);
552   func_graph_->set_output(depend_node, true);
553 }
554 
SetAsDeadBlock()555 void FunctionBlock::SetAsDeadBlock() { is_dead_block_ = true; }
556 }  // namespace parse
557 }  // namespace mindspore
558