• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/ps/parse/function_block.h"
20 
21 #include <algorithm>
22 #include <queue>
23 
24 #include "frontend/operator/ops.h"
25 #include "include/common/utils/python_adapter.h"
26 #include "include/common/utils/utils.h"
27 #include "ir/cell.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "mindspore/core/ops/sequence_ops.h"
30 #include "mindspore/core/ops/structure_ops.h"
31 #include "pipeline/jit/ps/debug/trace.h"
32 #include "pipeline/jit/ps/fallback.h"
33 #include "pipeline/jit/ps/parse/data_converter.h"
34 #include "pipeline/jit/ps/parse/parse.h"
35 #include "pipeline/jit/ps/parse/parse_base.h"
36 #include "pipeline/jit/ps/parse/resolve.h"
37 #include "utils/hash_set.h"
38 #include "utils/info.h"
39 #include "utils/compile_config.h"
40 
41 namespace mindspore {
42 namespace py = pybind11;
43 
44 namespace parse {
FunctionBlock(const Parser & parser)45 FunctionBlock::FunctionBlock(const Parser &parser)
46     : func_graph_(std::make_shared<FuncGraph>()), parser_(parser), matured_(false) {}
47 
AddPrevBlock(const FunctionBlockPtr & block)48 void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); }
49 
CanBeIsolatedNode(const std::string & var_name,const AnfNodePtr & node)50 static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &node) {
51   auto cnode = dyn_cast<CNode>(node);
52   if (cnode == nullptr || cnode->inputs().empty()) {
53     // Not a valid cnode, can not be isolate node.
54     return false;
55   }
56   auto prim = GetValueNode<PrimitivePtr>(cnode->inputs().at(0));
57   if (prim == nullptr) {
58     // Not a primitive cnode, it may have side effects or not,
59     // We add it as an isolate node if its name is not '_' or empty.
60     // this means that code like:
61     //    _ = func_call()
62     // will be ignored even if func_call() has side effects.
63     return !var_name.empty() && var_name != "_";
64   }
65   // Primitive cnode with side effects can be isolate nodes.
66   auto effect_info = GetPrimEffectInfo(prim);
67   bool has_effects = (effect_info.memory || effect_info.io);
68   if (has_effects) {
69     return true;
70   }
71   // Primitive cnode with 'no_eliminate' flag can be isolate nodes.
72   return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
73 }
74 
75 // Write variable records the variable name to corresponding node
WriteVariable(const std::string & var_name,const AnfNodePtr & node)76 void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
77   MS_EXCEPTION_IF_NULL(node);
78   MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node "
79                 << node->DebugString();
80   constexpr auto kRecursiveLevel = 2;
81   // a[::][::] = b will be translated to c = a[::] c[::] = b and the c is a no named variable.
82   if (var_name.empty()) {
83     MS_LOG(DEBUG) << "The node is " << node->DebugString(kRecursiveLevel)
84                   << "added in the isolated list.\nBlock: " << this << "/"
85                   << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
86                   << ", Line: " << trace::GetDebugInfoStr(node->debug_info(), "", kSourceLineTipDiscard);
87     AddIsolatedNode(node);
88     return;
89   }
90   auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
91   if (!is_new_name) {
92     // If a cnode variable with same name already existed but not used,
93     // add it as an isolate node. for example:
94     //   a = print(x)
95     //   a = print(y)
96     // When we write variable 'a = print(y)',
97     // the cnode 'print(x)' should added as an isolate node.
98     auto is_used = iter->second.second;
99     auto hidden_node = iter->second.first;
100     auto is_isolated = CanBeIsolatedNode(var_name, hidden_node);
101     if (!is_used && is_isolated) {
102       MS_EXCEPTION_IF_NULL(hidden_node);
103       MS_LOG(INFO) << "Isolated node found(Hidden), hidden_node: " << hidden_node->DebugString(kRecursiveLevel)
104                    << " is hidden by " << node->DebugString(kRecursiveLevel)
105                    << " with the same name, var_name: " << var_name << ", block: " << this << "/"
106                    << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
107                    << ", Line: " << trace::GetDebugInfoStr(hidden_node->debug_info(), "", kSourceLineTipDiscard);
108       AddIsolatedNode(hidden_node);
109     }
110     MS_LOG(INFO) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " update var `" << var_name
111                  << "` with node " << node->DebugString();
112     iter->second = std::make_pair(node, false);
113   }
114   if (!HasGlobalPyParam(var_name)) {
115     UpdateLocalPyParam(var_name, node);
116   }
117 }
118 
ReadLocalVariable(const std::string & var_name)119 AnfNodePtr FunctionBlock::ReadLocalVariable(const std::string &var_name) {
120   auto found = assigned_vars_.find(var_name);
121   if (found != assigned_vars_.end()) {
122     auto &node = found->second.first;
123     MS_EXCEPTION_IF_NULL(node);
124     // Mark the variable as used.
125     found->second.second = true;
126     MS_LOG(DEBUG) << "Found var: " << var_name << ", as: " << node->DebugString();
127     return node;
128   }
129   return nullptr;
130 }
131 
CheckHasVariable(const std::string & var_name)132 bool FunctionBlock::CheckHasVariable(const std::string &var_name) {
133   auto node = ReadLocalVariable(var_name);
134   if (node != nullptr) {
135     return true;
136   }
137   if (!prev_blocks_.empty()) {
138     auto block = prev_blocks_[0];
139     MS_EXCEPTION_IF_NULL(block);
140     return block->CheckHasVariable(var_name);
141   }
142   return false;
143 }
144 
FindPredInterpretNode(const std::string & var_name)145 std::pair<AnfNodePtr, bool> FunctionBlock::FindPredInterpretNode(const std::string &var_name) {
146   // Search the predecessors of the current block for the local parameter. If one of the local parameter of the
147   // predecessors is interpret node, the phi_param needs to set the interpret true.
148   mindspore::HashSet<FunctionBlock *> visited_block;
149   std::queue<FunctionBlock *> block_queue;
150   block_queue.push(this);
151   bool has_found = false;
152   while (!block_queue.empty()) {
153     const auto cur_block = block_queue.front();
154     MS_EXCEPTION_IF_NULL(cur_block);
155     block_queue.pop();
156     (void)visited_block.insert(cur_block);
157     auto pred_node = cur_block->ReadLocalVariable(var_name);
158     if (pred_node != nullptr) {
159       has_found = true;
160       bool interpret_without_internal =
161         IsPrimitiveCNode(pred_node, prim::kPrimPyInterpret) && !pred_node->interpret_internal_type();
162       if (pred_node->interpret() || interpret_without_internal) {
163         return std::make_pair(pred_node, has_found);
164       }
165     } else {
166       for (const auto &cur_pred_block : cur_block->prev_blocks()) {
167         if (visited_block.count(cur_pred_block) == 0) {
168           block_queue.push(cur_pred_block);
169         }
170       }
171     }
172   }
173   return std::make_pair(nullptr, has_found);
174 }
175 
176 // Read variable from predecessors
ReadVariable(const std::string & var_name)177 AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
178   MS_LOG(DEBUG) << "Read begin, var_name: " << var_name << ", block: " << ToString();
179   // Get var node if it is found
180   auto node = ReadLocalVariable(var_name);
181   if (node != nullptr) {
182     if (!HasGlobalPyParam(var_name)) {
183       UpdateLocalPyParam(var_name, node);
184     }
185     return node;
186   }
187 
188   MS_LOG(DEBUG) << "matured_: " << matured_ << ", prev_blocks_.size: " << prev_blocks_.size();
189   // Get var from predecessor block, if can't get then make a resolve node to it
190   if (matured_) {
191     // If only one predecessor block, read the definition of var from it.
192     if (prev_blocks_.size() == 1) {
193       auto block = prev_blocks_[0];
194       MS_EXCEPTION_IF_NULL(block);
195       auto res = block->ReadVariable(var_name);
196       MS_LOG(DEBUG) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString()
197                     << ",\nCurrent: " << py::str(const_cast<py::dict &>(global_py_params()))
198                     << "\nInsert: " << py::str(const_cast<py::dict &>(block->global_py_params()));
199       UpdateGlobalPyParam(block->global_py_params());
200       if (!HasGlobalPyParam(var_name)) {
201         UpdateLocalPyParam(var_name, res);
202       }
203       return res;
204     } else if (prev_blocks_.empty()) {
205       // Get namespace and make Resolve
206       auto it = var_to_resolve_.find(var_name);
207       if (it != var_to_resolve_.end()) {
208         return it->second;
209       }
210       MS_LOG(DEBUG) << "var_name: " << var_name;
211       auto tmp_node = MakeResolveSymbol(var_name);
212       var_to_resolve_[var_name] = tmp_node;
213       return tmp_node;
214     }
215   }
216   // If have more than one predecessor blocks then build a phi node.
217   auto debug_info = std::make_shared<NodeDebugInfo>();
218   debug_info->set_name(var_name);
219   TraceGuard guard(std::make_shared<TracePhi>(debug_info));
220   ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
221   MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
222                 << phi_param->ToString() << " for " << var_name;
223 
224   auto [pred_node, has_found] = FindPredInterpretNode(var_name);
225   if (pred_node != nullptr) {
226     phi_param->set_interpret(true);
227   } else if (!has_found) {
228     // If the current node is created as a phi node at the first time.(the var_name has not be found in pre blocks)
229     // need resolve to determine whether it needs to be marked with interpret.
230     auto resolve_node = MakeResolveSymbol(var_name);
231     MS_EXCEPTION_IF_NULL(resolve_node);
232     // Avoid to build phi node if current block is not matured and resolve_node is an undefined symbol.
233     if (!matured_ && resolve_node->isa<ValueNode>()) {
234       auto value = GetValuePtr<ValueProblem>(resolve_node->cast<ValueNodePtr>());
235       if (!is_dead_block() && value != nullptr && value->IsUndefined()) {
236         MS_LOG(DEBUG) << "Avoid to build phi node and return undefined node, var_name: " << var_name
237                       << ", block: " << ToString() << ", matured_: " << matured_
238                       << ", prev_blocks_.size: " << prev_blocks_.size();
239         return resolve_node;
240       }
241     }
242     phi_param->set_interpret(resolve_node->interpret());
243     phi_param->set_interpret_internal_type(resolve_node->interpret_internal_type());
244     if (resolve_node->isa<Parameter>()) {
245       phi_param->set_debug_info(resolve_node->debug_info());
246     }
247   }
248 
249   func_graph()->add_parameter(phi_param);
250   phi_nodes_[phi_param] = var_name;
251   WriteVariable(var_name, phi_param);
252   if (matured_) {
253     SetPhiArgument(phi_param);
254   }
255   // In SetPhiArgument/CollectRemovablePhi, this phi may be set as removable and set it as
256   // real node, so check it again.
257   MS_LOG(DEBUG) << "Read again, var_name: " << var_name << ", block: " << ToString();
258   node = ReadLocalVariable(var_name);
259   if (node != nullptr) {
260     return node;
261   }
262   return phi_param;
263 }
264 
265 // Resolve Ast operator node
GetAstOpNameSpace(const py::object & op)266 py::tuple FunctionBlock::GetAstOpNameSpace(const py::object &op) {
267   auto ast = parser_.ast();
268   MS_EXCEPTION_IF_NULL(ast);
269   TraceGuard trace_guard(parser_.GetLocation(op));
270   py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op);
271   constexpr size_t namespace_size = 3;
272   if (namespace_var.size() != namespace_size) {
273     MS_LOG(INTERNAL_EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size();
274   }
275   return namespace_var;
276 }
277 
278 // Resolve Ast operator node
MakeResolveAstOpNameSpace(const py::tuple & namespace_var)279 AnfNodePtr FunctionBlock::MakeResolveAstOpNameSpace(const py::tuple &namespace_var) {
280   constexpr size_t namespace_index = 0;
281   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_AST, namespace_var[namespace_index]);
282   constexpr size_t symbol_index = 1;
283   SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[symbol_index].cast<std::string>());
284   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
285   return MakeResolve(name_space, symbol);
286 }
287 
288 // Resolve class object self.
MakeResolveClassObject()289 AnfNodePtr FunctionBlock::MakeResolveClassObject() {
290   auto ast = parser_.ast();
291   MS_EXCEPTION_IF_NULL(ast);
292   py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
293   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_OBJECT, namespace_var, ast->obj());
294   constexpr auto self_name = "self";
295   SymbolPtr symbol = std::make_shared<Symbol>(self_name);  // Must be 'self'.
296   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
297   return MakeResolve(name_space, symbol);
298 }
299 
300 // Resolve class member: method, member variable.
MakeResolveClassMember(const std::string & attr_or_self)301 AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr_or_self) {
302   auto ast = parser_.ast();
303   MS_EXCEPTION_IF_NULL(ast);
304   py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
305   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var, ast->obj());
306   SymbolPtr symbol = std::make_shared<Symbol>(attr_or_self);
307   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
308   return MakeResolve(name_space, symbol);
309 }
310 
CheckUndefinedSymbol(const std::string & var,const AnfNodePtr & node) const311 void FunctionBlock::CheckUndefinedSymbol(const std::string &var, const AnfNodePtr &node) const {
312   if (node->isa<ValueNode>()) {
313     auto value = GetValuePtr<ValueProblem>(node->cast<ValueNodePtr>());
314     if (!is_dead_block() && value != nullptr && value->IsUndefined()) {
315       MS_EXCEPTION(NameError) << "The name '" << var << "' is not defined, or not supported in graph mode.";
316     }
317   }
318 }
319 
HandleNamespaceSymbol(const std::string & var_name)320 AnfNodePtr FunctionBlock::HandleNamespaceSymbol(const std::string &var_name) {
321   auto ast = parser_.ast();
322   MS_EXCEPTION_IF_NULL(ast);
323   const py::tuple &info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, var_name);
324 
325   constexpr size_t closure_info_size = 3;
326   constexpr size_t global_info_size = 4;
327   constexpr size_t namespace_index = 0;
328   constexpr size_t symbol_index = 1;
329   constexpr size_t value_index = 2;
330   constexpr size_t flag_index = 3;
331   if (info.size() != closure_info_size && info.size() != global_info_size) {
332     MS_INTERNAL_EXCEPTION(NameError) << "The namespace info size should be 3 or 4, but got " << info.size();
333   }
334   // If namespace is None, the symbol is an undefined name.
335   if (info[namespace_index].is_none()) {
336     const auto undefined_symbol = std::make_shared<ValueProblem>(ValueProblemType::kUndefined);
337     MS_LOG(WARNING) << "Undefined symbol: " << var_name << ", during parsing " << py::str(ast->function()) << " of "
338                     << py::str(ast->obj());
339     return NewValueNode(undefined_symbol);
340   }
341 
342   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, info[namespace_index]);
343   SymbolPtr symbol = std::make_shared<Symbol>(info[symbol_index].cast<std::string>());
344   auto resolved_node = MakeResolve(name_space, symbol);
345 
346   // Handle closure namespace info.
347   if (info.size() == closure_info_size) {
348     return resolved_node;
349   }
350 
351   // Handle global namespace info.
352   auto syntax_support = info[flag_index].cast<int32_t>();
353   py::object py_obj = info[value_index];
354   if (syntax_support != SYNTAX_SUPPORTED && syntax_support != SYNTAX_HYBRID_TYPE) {
355     resolved_node->set_interpret(true);
356     if (syntax_support == SYNTAX_UNSUPPORTED_INTERNAL_TYPE) {
357       resolved_node->set_interpret_internal_type(true);
358       resolved_node->set_user_data<py::object>(kClassTensorObject, std::make_shared<py::object>(py_obj));
359     }
360   }
361 
362   auto symbol_name = info[symbol_index].cast<std::string>();
363   if (symbol_name == "Shard") {
364     MS_LOG(EXCEPTION) << "Cell.shard or ms.shard not supported in jit syntax, please use shard out of jit"
365                       << " or construct scope.";
366   }
367   AddGlobalPyParam(symbol_name, py_obj);
368   MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symbol: {" << symbol_name << " : "
369                << py::str(py_obj) << "}";
370   fallback::SetPyObjectToNode(resolved_node, py_obj);
371   return resolved_node;
372 }
373 
374 // Make a resolve node for symbol string
MakeResolveSymbol(const std::string & var_name)375 AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &var_name) {
376   MS_LOG(DEBUG) << "var_name: " << var_name << ", ast object type: " << parser_.ast()->target_type();
377 
378   // Handle self. The prefix of var_name is "self".
379   constexpr auto self_name = "self";
380   const auto self_name_len = strlen(self_name);
381   // For PARSE_TARGET_METHOD or PARSE_TARGET_OBJECT_INSTANCE, should deal with self here, exclude PARSE_TARGET_FUNCTION.
382   if ((parser_.ast()->target_type() == PARSE_TARGET_METHOD ||
383        parser_.ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) &&
384       var_name.compare(0, self_name_len, self_name) == 0) {
385     auto start = var_name.find_first_of('.');
386     if (start != std::string::npos) {  // 'self.xxx'
387       ++start;
388       if (start >= var_name.size()) {
389         MS_LOG(ERROR) << "Find invalid resolve symbol str: " << var_name;
390         return nullptr;
391       }
392       auto bits_str = var_name.substr(start);
393       auto resolve_node = MakeResolveClassMember(bits_str);
394       if (!HasGlobalPyParam(var_name)) {
395         UpdateLocalPyParam(var_name, resolve_node);
396       }
397       return resolve_node;
398     } else if (var_name.size() == self_name_len) {  // 'self'
399       auto resolve_node = MakeResolveClassObject();
400       if (!HasGlobalPyParam(var_name)) {
401         UpdateLocalPyParam(var_name, resolve_node);
402       }
403       return resolve_node;
404     }
405   }
406 
407   // Handle non-self.
408   return HandleNamespaceSymbol(var_name);
409 }
410 
MakeResolveOperation(const std::string & value)411 AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
412   auto ast = parser_.ast();
413   MS_EXCEPTION_IF_NULL(ast);
414   py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
415   const size_t namespace_var_size = 2;
416   if (namespace_var.size() < namespace_var_size) {
417     MS_INTERNAL_EXCEPTION(NameError) << "namespace_var is less than 2";
418   }
419   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]);
420   SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
421   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
422   return MakeResolve(name_space, symbol);
423 }
424 
425 namespace {
426 // The same as TransformVectorFuncValueNode() in mindspore/ccsrc/pipeline/jit/parse/resolve.cc, but not add to manager.
TransformVectorFuncValueNode(const FuncGraphPtr & func_graph,const ValuePtr & value,AnfNodePtr * const transformed)427 bool TransformVectorFuncValueNode(const FuncGraphPtr &func_graph, const ValuePtr &value,
428                                   AnfNodePtr *const transformed) {
429   MS_EXCEPTION_IF_NULL(value);
430   const auto &value_vec = GetValue<ValuePtrList>(value);
431   if (value_vec.empty()) {
432     return false;
433   }
434   std::vector<AnfNodePtr> nodes;
435   (void)nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
436   bool is_all_func = true;
437   for (auto &elem : value_vec) {
438     MS_EXCEPTION_IF_NULL(elem);
439     AnfNodePtr node = nullptr;
440     if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
441       is_all_func = is_all_func && TransformVectorFuncValueNode(func_graph, elem, &node);
442     } else if (elem->isa<FuncGraph>()) {
443       FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
444       node = NewValueNode(new_fg);
445     } else if (elem->isa<Primitive>()) {
446       node = NewValueNode(elem);
447     } else {
448       is_all_func = false;
449     }
450     (void)nodes.emplace_back(node);
451   }
452   if (is_all_func) {
453     // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
454     // So if has graph in list, try to replace the node with make tuple of graph value node.
455     // We do this because the graph manager won't investigate the graph inside valuetuple,
456     // change the vector of graph to be make_tuple of graph value node.
457     // (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
458     // independent nodes.
459     *transformed = func_graph->NewCNode(std::move(nodes));
460   }
461   return is_all_func;
462 }
463 }  // namespace
464 
MakeResolve(const NameSpacePtr & name_space,const SymbolPtr & resolve_symbol)465 AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
466   MS_LOG(DEBUG) << "MakeResolve for "
467                 << (name_space ? (std::string)py::str(name_space->namespace_obj()) : "null namespace") << " , "
468                 << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
469   ValueNodePtr module_node = NewValueNode(name_space);
470   ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
471   auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
472 
473   // Directly resolve the symbol.
474   return DoResolve(node, name_space, resolve_symbol);
475 }
476 
DoResolve(const AnfNodePtr & node,const std::shared_ptr<NameSpace> & name_space,const std::shared_ptr<Symbol> & resolve_symbol)477 AnfNodePtr FunctionBlock::DoResolve(const AnfNodePtr &node, const std::shared_ptr<NameSpace> &name_space,
478                                     const std::shared_ptr<Symbol> &resolve_symbol) {
479   static const auto boost_parse = common::GetCompileConfig("GREED_PARSE");
480   if (Parser::defer_resolve() || boost_parse != "1") {
481     return node;
482   }
483   // Directly resolve the symbol.
484   const auto &obj = GetSymbolObject(name_space, resolve_symbol, node);
485   // Avoid recursively resolving Cell.
486   if (py::isinstance<Cell>(obj) && resolve_symbol->symbol() == "self") {
487     MS_LOG(ERROR) << "Not direct resolve Cell self. node: " << node->DebugString() << ", ns: " << name_space->ToString()
488                   << ", sym: " << resolve_symbol->ToString();
489     return node;
490   }
491   AnfNodePtr resolved_node = nullptr;
492   bool success = ResolveObjectToNode(node, obj, &resolved_node);
493   if (!success || resolved_node == nullptr) {
494     MS_LOG(INTERNAL_EXCEPTION) << "Parse Resolve covert failed." << node->DebugString()
495                                << ", ns: " << name_space->ToString() << ", sym: " << resolve_symbol->ToString();
496   }
497   // If the constant node is constant of vector of graph, add graph to manager.
498   if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
499     auto value = resolved_node->cast<ValueNodePtr>()->value();
500     if (!TransformVectorFuncValueNode(func_graph_, value, &resolved_node)) {
501       MS_LOG(INFO) << "Fail to convert value tuple/list to CNode, " << resolved_node->DebugString();
502     }
503   }
504   MS_LOG(DEBUG) << "node: " << node->DebugString() << ", ns: " << name_space->ToString()
505                 << ", sym: " << resolve_symbol->ToString() << ", resolved_node: " << resolved_node->DebugString();
506   return resolved_node;
507 }
508 
MakeInterpret(const std::string & script_text,const AnfNodePtr & global_dict_node,const AnfNodePtr & local_dict_node,const AnfNodePtr & orig_node)509 AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
510                                         const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) {
511   MS_LOG(DEBUG) << "MakeInterpret for " << script_text;
512   MS_EXCEPTION_IF_NULL(orig_node);
513   auto script = std::make_shared<parse::Script>(script_text);
514   auto script_node = NewValueNode(script);
515   auto node = func_graph_->NewCNodeInOrder(
516     {NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
517   MS_EXCEPTION_IF_NULL(node);
518   node->set_debug_info(orig_node->debug_info());
519   node->set_interpret_internal_type(orig_node->interpret_internal_type());
520   return node;
521 }
522 
523 // Add input for the block's phi parameter
SetPhiArgument(const ParameterPtr & phi)524 void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
525   MS_EXCEPTION_IF_NULL(phi);
526   TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
527   std::string var = phi_nodes_[phi];
528   MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
529                 << " for var `" << var << "`";
530   CollectRemovablePhi(phi);
531   for (auto &pred : prev_blocks_) {
532     MS_EXCEPTION_IF_NULL(pred);
533     MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " pred_blocks_ "
534                   << (pred->func_graph_ ? pred->func_graph_->ToString() : "FG(Null)");
535     AnfNodePtr arg_node = pred->ReadVariable(var);
536     auto jump = pred->GetJumpNode(this);
537     if (jump == nullptr) {
538       // If prev block is a switch call's prev block, no jumps here.
539       continue;
540     }
541     jump->add_input(arg_node);
542   }
543 }
544 
545 namespace {
GetVariableDefinedLocation(const FunctionBlock * block,const std::string & var,int start_line)546 std::string GetVariableDefinedLocation(const FunctionBlock *block, const std::string &var, int start_line) {
547   MS_EXCEPTION_IF_NULL(block);
548   HashSet<FunctionBlock *> visited;
549   std::vector<FunctionBlock *> todo_list = {};
550   (void)std::copy(block->prev_blocks().cbegin(), block->prev_blocks().cend(), std::back_inserter(todo_list));
551   while (!todo_list.empty()) {
552     auto cur_block = todo_list.front();
553     (void)todo_list.erase(todo_list.begin());
554     if (visited.find(cur_block) != visited.cend()) {
555       continue;
556     }
557     (void)visited.insert(cur_block);
558     (void)std::copy(cur_block->prev_blocks().cbegin(), cur_block->prev_blocks().cend(), std::back_inserter(todo_list));
559     auto node = cur_block->ReadLocalVariable(var);
560     if (node != nullptr) {
561       const auto &debug_info = trace::GetSourceCodeDebugInfo(node->debug_info());
562       const auto &location = debug_info->location();
563       return location->ToString(kSourceSectionTipNextLineHere, start_line);
564     }
565   }
566   return "";
567 }
568 }  // namespace
569 
CheckVariableNotDefined(const std::pair<std::string,AnfNodePtr> & not_defined_branch,const std::string & var)570 void FunctionBlock::CheckVariableNotDefined(const std::pair<std::string, AnfNodePtr> &not_defined_branch,
571                                             const std::string &var) {
572   std::ostringstream oss;
573   std::string not_defined_branch_name = not_defined_branch.first;
574   const auto &debug_info = trace::GetSourceCodeDebugInfo(this->func_graph()->debug_info());
575   const auto &location = debug_info->location();
576   int start_line = location->line();
577   if ((not_defined_branch_name == "while") || (not_defined_branch_name == "for")) {
578     oss << "The local variable '" << var << "' defined in the '" << not_defined_branch_name
579         << "' loop body cannot be used outside of the loop body. "
580         << "Please define variable '" << var << "' before '" << not_defined_branch_name << "'.\n";
581   }
582   if ((not_defined_branch_name == "true branch") || (not_defined_branch_name == "false branch")) {
583     oss << "The local variable '" << var << "' is not defined in " << not_defined_branch_name << ", but defined in "
584         << (not_defined_branch_name == "true branch" ? "false branch" : "true branch") << ".\n";
585   }
586   std::string location_info = GetVariableDefinedLocation(this, var, start_line);
587   if (location_info.empty()) {
588     return;
589   }
590   oss << location_info;
591   MS_EXCEPTION(UnboundLocalError) << oss.str();
592 }
593 
SearchAllArgsOfPhiNode(const std::string & var,const ParameterPtr & phi)594 std::set<AnfNodePtr> FunctionBlock::SearchAllArgsOfPhiNode(const std::string &var, const ParameterPtr &phi) {
595   std::vector<std::pair<std::string, AnfNodePtr>> defined_branch;
596   std::pair<std::string, AnfNodePtr> not_defined_branch;
597   MS_LOG(DEBUG) << "Search block:" << ToString() << "Prev_blocks size: " << prev_blocks_.size();
598   for (auto &prev : prev_blocks_) {
599     MS_EXCEPTION_IF_NULL(prev);
600     AnfNodePtr temp_node = prev->ReadVariable(var);
601     MS_EXCEPTION_IF_NULL(temp_node);
602     MS_LOG(DEBUG) << "Read from prev block:" << prev->ToString() << "Found var: " << var
603                   << ", as: " << temp_node->DebugString();
604     bool undefined_symbol_flag = false;
605     if (temp_node->isa<ValueNode>()) {
606       auto value = GetValuePtr<ValueProblem>(temp_node->cast<ValueNodePtr>());
607       if ((value != nullptr) && (value->IsUndefined())) {
608         undefined_symbol_flag = true;
609       }
610     }
611     if (undefined_symbol_flag) {
612       not_defined_branch = std::make_pair(prev->block_name(), temp_node);
613     } else {
614       defined_branch.push_back(std::make_pair(prev->block_name(), temp_node));
615     }
616   }
617   if (defined_branch.size() == 1) {
618     auto arg_node = defined_branch.front().second;
619     MS_EXCEPTION_IF_NULL(arg_node);
620     MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi "
621                   << (phi ? phi->ToString() : "null") << " may be replaced by node " << arg_node->DebugString();
622   }
623 
624   if (not_defined_branch.second != nullptr) {
625     if (!defined_branch.empty()) {
626       auto locaction = trace::GetSourceCodeDebugInfo(phi->debug_info())->location();
627       MS_EXCEPTION_IF_NULL(locaction);
628       TraceGuard trace_guard(locaction);
629       CheckVariableNotDefined(not_defined_branch, var);
630     }
631     MS_EXCEPTION(NameError) << "The name '" << var << "' is not defined, or not supported in graph mode.";
632   }
633 
634   std::set<AnfNodePtr> all_arg_nodes;
635   for (auto &item : defined_branch) {
636     (void)all_arg_nodes.insert(item.second);
637   }
638   return all_arg_nodes;
639 }
640 
641 // Check if there is removable unnecessary phi node in this graph.
642 // As per the FIRM TR 3.2, a phi node can be remove if:
643 // <Quote>
644 //    If all arguments of a φ-function are the same value s or the φfunction itself,
645 //    then we remove the φ-function and let all users directly uses. We call such a
646 //    φ-function obviously unnecessary.
647 //    When we removed a φ-function p, then we recursively try to apply this simplification
648 //    rule with all (former) users of p, because they may have become obviously unnecessary
649 //    due to the removal of p
650 // <Quote>
651 // phi node in graph will be removed after the whole function is parsed in a DFS visit
652 // of that graph.The reason is :
653 // 1. when this function is called, not all usage of this phi node had bound to the
654 // graph of this function block, some may stay in vars_ in other blocks.
655 // 2. it's costly to iterate the graph to replace the phi for each phi.
656 // Args: phi: This parameter node is functioning as a phi node.
CollectRemovablePhi(const ParameterPtr & phi)657 void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
658   MS_EXCEPTION_IF_NULL(phi);
659   const auto &var_name = phi_nodes_[phi];
660   MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var_name;
661   if (prev_blocks_.empty()) {
662     MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var_name;
663     return;
664   }
665   auto arg_nodes = SearchAllArgsOfPhiNode(var_name, phi);
666   phi_args_[phi] = arg_nodes;
667   if (arg_nodes.size() == 1) {
668     auto arg_node = *arg_nodes.begin();
669     if (arg_node->debug_info() == nullptr) {
670       arg_node->set_debug_info(phi->debug_info());
671     }
672     MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi " << phi->ToString()
673                   << " can be replaced with " << arg_node->DebugString();
674     // Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
675     WriteVariable(var_name, arg_node);
676     bool interpret_without_internal =
677       IsPrimitiveCNode(arg_node, prim::kPrimPyInterpret) && !arg_node->interpret_internal_type();
678     if (arg_node->interpret() || interpret_without_internal) {
679       phi->set_interpret(true);
680       if (arg_node->interpret_internal_type()) {
681         phi->set_interpret_internal_type(true);
682       }
683     }
684   }
685 }
686 
687 // A block should be marked matured if its predecessor blocks have been processed
Mature()688 void FunctionBlock::Mature() {
689   const auto &graph_params = func_graph_->parameters();
690   for (auto &param_itr : graph_params) {
691     MS_EXCEPTION_IF_NULL(param_itr);
692     auto param = param_itr->cast<ParameterPtr>();
693     if (phi_nodes_.find(param) != phi_nodes_.cend()) {
694       SetPhiArgument(param);
695     }
696   }
697   matured_ = true;
698 }
699 
700 // Get the truth value testing for cond node.
ForceToCondNode(const AnfNodePtr & cond,bool is_while_cond)701 CNodePtr FunctionBlock::ForceToCondNode(const AnfNodePtr &cond, bool is_while_cond) {
702   MS_EXCEPTION_IF_NULL(cond);
703   CNodePtr op_apply_node =
704     func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimCond), cond, NewValueNode(MakeValue(is_while_cond))});
705   return op_apply_node;
706 }
707 
708 // Perform a jump from this block to target block
Jump(const FunctionBlockPtr & target_block,const std::vector<AnfNodePtr> & args)709 void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
710   MS_EXCEPTION_IF_NULL(target_block);
711   MS_LOG(DEBUG) << "Jump from block: " << ToString() << " to block: " << target_block->ToString();
712   if (is_dead_block_) {
713     MS_LOG(DEBUG) << "Dead code block should not jump to other block! block: " << ToString();
714     return;
715   }
716   if (func_graph_->get_return() != nullptr) {
717     MS_LOG(INTERNAL_EXCEPTION) << "Failure: have return node! NodeInfo: "
718                                << trace::GetDebugInfoStr(func_graph_->get_return()->debug_info());
719   }
720   std::vector<AnfNodePtr> input_nodes;
721   input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
722   (void)std::copy(args.begin(), args.end(), std::back_inserter(input_nodes));
723 
724   CNodePtr jump = func_graph_->NewCNodeInOrder(std::move(input_nodes));
725   jumps_[target_block.get()] = jump;
726   target_block->AddPrevBlock(shared_from_this());
727   func_graph_->set_output(jump);
728 }
729 
730 // Perform a conditional jump using switch operation.
731 // The first CNode select graph with condition, and than execute this graph
ConditionalJump(const AnfNodePtr & cond_node,const AnfNodePtr & true_block_call,const AnfNodePtr & false_block_call)732 CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
733                                         const AnfNodePtr &false_block_call) {
734   MS_EXCEPTION_IF_NULL(true_block_call);
735   MS_EXCEPTION_IF_NULL(false_block_call);
736   if (func_graph_->get_return() != nullptr) {
737     MS_LOG(INTERNAL_EXCEPTION) << "Failure: have return node! fg: " << func_graph_->ToString()
738                                << "\nNodeInfo: " << trace::GetDebugInfoStr(func_graph_->get_return()->debug_info())
739                                << "\ncond_node: " << cond_node->DebugString()
740                                << "\nNodeInfo: " << trace::GetDebugInfoStr(cond_node->debug_info());
741   }
742   CNodePtr switch_app =
743     func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, true_block_call, false_block_call});
744   CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app});
745   func_graph_->set_output(switch_app_new);
746   return switch_app_new;
747 }
748 
ConditionalJump(const AnfNodePtr & cond_node,const FunctionBlockPtr & true_block,const FunctionBlockPtr & false_block)749 CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
750                                         const FunctionBlockPtr &false_block) {
751   MS_EXCEPTION_IF_NULL(true_block);
752   MS_EXCEPTION_IF_NULL(false_block);
753   return ConditionalJump(cond_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph()));
754 }
755 
756 // Create cnode for the assign statement like 'self.target = source'.
757 // convert it to 'P.Assign(self.target, source)' and then add the cnode as isolate node.
SetStateAssign(const AnfNodePtr & target,const AnfNodePtr & source)758 void FunctionBlock::SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source) {
759   const std::string primitive_name("assign");
760   const std::string module_name("mindspore.ops.functional");
761   ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
762   auto assign_node = func_graph_->NewCNodeInOrder({assign_op, target, source});
763   const int recursive_level = 2;
764   MS_LOG(DEBUG) << "Isolated node found(Assign), assign_node: " << assign_node->DebugString(recursive_level)
765                 << ", block: " << this << "/" << func_graph_->ToString()
766                 << ", Line: " << trace::GetDebugInfoStr(assign_node->debug_info(), "", kSourceLineTipDiscard);
767   AddIsolatedNode(assign_node);
768 }
769 
ConvertUnusedNodesToIsolated(const std::pair<std::string,std::pair<AnfNodePtr,bool>> var)770 void FunctionBlock::ConvertUnusedNodesToIsolated(const std::pair<std::string, std::pair<AnfNodePtr, bool>> var) {
771   auto &node = var.second.first;
772   bool is_used = var.second.second;
773   if (node == nullptr || is_used) {
774     return;
775   }
776   auto &var_name = var.first;
777   if (CanBeIsolatedNode(var_name, node)) {
778     const int recursive_level = 2;
779     MS_LOG(INFO) << "Isolated node found(NoUse), node: " << node->DebugString(recursive_level)
780                  << ", var_name: " << var_name << ", block: " << this << "/"
781                  << (func_graph() ? func_graph()->ToString() : "FG(Null)")
782                  << ", Line: " << trace::GetDebugInfoStr(node->debug_info(), "", kSourceLineTipDiscard);
783     AddIsolatedNode(node);
784   }
785 }
786 
FindIsolatedNodes()787 void FunctionBlock::FindIsolatedNodes() {
788   //
789   // Search isolate nodes from variables, for example,
790   // variable 'a' is an isolate node in below code:
791   //
792   //    def construct(self, x, y):
793   //        a = print(x) # isolate node
794   //        return x + y
795   //
796   // Add isolated nodes which is unused var but not found in used set.
797   for (const auto &var : assigned_vars_) {
798     ConvertUnusedNodesToIsolated(var);
799   }
800 }
801 
AddIsolatedNode(const AnfNodePtr & target)802 void FunctionBlock::AddIsolatedNode(const AnfNodePtr &target) { isolated_nodes_.add(target); }
803 
AttachIsolatedNodesBeforeReturn()804 void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
805   if (isolated_nodes_.empty()) {
806     return;
807   }
808   std::vector<AnfNodePtr> states;
809   states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
810   constexpr int recursive_level = 2;
811   for (const auto &node : isolated_nodes_) {
812     MS_EXCEPTION_IF_NULL(node);
813     MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(recursive_level) << " in "
814                   << func_graph_->ToString();
815     if (node->func_graph() == func_graph_) {
816       states.emplace_back(node);
817     } else {
818       MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(recursive_level) << " in "
819                    << func_graph_->ToString();
820     }
821   }
822   isolated_nodes_.clear();
823 
824   AnfNodePtr state = nullptr;
825   constexpr size_t no_state_size = 1;
826   constexpr size_t only_one_state_size = 2;
827   if (states.size() == no_state_size) {
828     // Only MakeTuple, no state left.
829     return;
830   } else if (states.size() == only_one_state_size) {
831     // If there are only MakeTuple and another node in states(the states size is 2),
832     // do not need to MakeTuple, just use the node.
833     state = states[1];
834   } else {
835     state = func_graph_->NewCNode(std::move(states));
836     if (state != nullptr && state->debug_info() != nullptr) {
837       state->debug_info()->set_location(nullptr);
838     }
839   }
840 
841   AnfNodePtr old_output = nullptr;
842   auto return_node = func_graph_->get_return();
843   if (return_node != nullptr) {
844     const size_t return_input_size = 2;
845     if (return_node->size() < return_input_size) {
846       MS_LOG(INTERNAL_EXCEPTION) << "Length of inputs of output node is less than 2";
847     }
848     old_output = return_node->input(1);
849   } else {
850     old_output = NewValueNode(kNone);
851   }
852   AnfNodePtr stop_grad_node = func_graph_->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
853   CNodePtr depend_node = func_graph_->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
854   if (stop_grad_node->debug_info()) {
855     stop_grad_node->debug_info()->set_location(nullptr);
856   }
857   if (depend_node->debug_info()) {
858     depend_node->debug_info()->set_location(old_output->debug_info()->location());
859   }
860   // We add this attribute for @constexpr use scene, since we must infer them before other nodes.
861   // That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
862   depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
863   MS_EXCEPTION_IF_NULL(state);
864   MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
865                << ", state: " << state->DebugString(recursive_level);
866   func_graph_->set_output(depend_node, true);
867   // Update new return node's debug_info with old one.
868   if (return_node != nullptr && return_node->debug_info()) {
869     auto new_return = func_graph_->get_return();
870     MS_EXCEPTION_IF_NULL(new_return);
871     new_return->set_debug_info(return_node->debug_info());
872   }
873 }
874 
SetAsDeadBlock()875 void FunctionBlock::SetAsDeadBlock() { is_dead_block_ = true; }
876 
GetJumpNode(FunctionBlock * target_block)877 CNodePtr FunctionBlock::GetJumpNode(FunctionBlock *target_block) {
878   auto it = jumps_.find(target_block);
879   if (it == jumps_.end()) {
880     MS_LOG(DEBUG) << "Can't find jump node from block:" << ToString() << " to block:" << target_block->ToString();
881     return nullptr;
882   }
883   return it->second;
884 }
885 
set_is_return_statement_inside()886 void FunctionBlock::set_is_return_statement_inside() { is_return_statement_inside_ = true; }
set_break_continue_statement_inside()887 void FunctionBlock::set_break_continue_statement_inside() { is_break_continue_statement_inside_ = true; }
888 }  // namespace parse
889 }  // namespace mindspore
890