1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/ps/parse/function_block.h"
20
21 #include <algorithm>
22 #include <queue>
23
24 #include "frontend/operator/ops.h"
25 #include "include/common/utils/python_adapter.h"
26 #include "include/common/utils/utils.h"
27 #include "ir/cell.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "mindspore/core/ops/sequence_ops.h"
30 #include "mindspore/core/ops/structure_ops.h"
31 #include "pipeline/jit/ps/debug/trace.h"
32 #include "pipeline/jit/ps/fallback.h"
33 #include "pipeline/jit/ps/parse/data_converter.h"
34 #include "pipeline/jit/ps/parse/parse.h"
35 #include "pipeline/jit/ps/parse/parse_base.h"
36 #include "pipeline/jit/ps/parse/resolve.h"
37 #include "utils/hash_set.h"
38 #include "utils/info.h"
39 #include "utils/compile_config.h"
40
41 namespace mindspore {
42 namespace py = pybind11;
43
44 namespace parse {
FunctionBlock(const Parser & parser)45 FunctionBlock::FunctionBlock(const Parser &parser)
46 : func_graph_(std::make_shared<FuncGraph>()), parser_(parser), matured_(false) {}
47
AddPrevBlock(const FunctionBlockPtr & block)48 void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); }
49
CanBeIsolatedNode(const std::string & var_name,const AnfNodePtr & node)50 static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &node) {
51 auto cnode = dyn_cast<CNode>(node);
52 if (cnode == nullptr || cnode->inputs().empty()) {
53 // Not a valid cnode, can not be isolate node.
54 return false;
55 }
56 auto prim = GetValueNode<PrimitivePtr>(cnode->inputs().at(0));
57 if (prim == nullptr) {
58 // Not a primitive cnode, it may have side effects or not,
59 // We add it as an isolate node if its name is not '_' or empty.
60 // this means that code like:
61 // _ = func_call()
62 // will be ignored even if func_call() has side effects.
63 return !var_name.empty() && var_name != "_";
64 }
65 // Primitive cnode with side effects can be isolate nodes.
66 auto effect_info = GetPrimEffectInfo(prim);
67 bool has_effects = (effect_info.memory || effect_info.io);
68 if (has_effects) {
69 return true;
70 }
71 // Primitive cnode with 'no_eliminate' flag can be isolate nodes.
72 return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
73 }
74
75 // Write variable records the variable name to corresponding node
WriteVariable(const std::string & var_name,const AnfNodePtr & node)76 void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
77 MS_EXCEPTION_IF_NULL(node);
78 MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node "
79 << node->DebugString();
80 constexpr auto kRecursiveLevel = 2;
81 // a[::][::] = b will be translated to c = a[::] c[::] = b and the c is a no named variable.
82 if (var_name.empty()) {
83 MS_LOG(DEBUG) << "The node is " << node->DebugString(kRecursiveLevel)
84 << "added in the isolated list.\nBlock: " << this << "/"
85 << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
86 << ", Line: " << trace::GetDebugInfoStr(node->debug_info(), "", kSourceLineTipDiscard);
87 AddIsolatedNode(node);
88 return;
89 }
90 auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
91 if (!is_new_name) {
92 // If a cnode variable with same name already existed but not used,
93 // add it as an isolate node. for example:
94 // a = print(x)
95 // a = print(y)
96 // When we write variable 'a = print(y)',
97 // the cnode 'print(x)' should added as an isolate node.
98 auto is_used = iter->second.second;
99 auto hidden_node = iter->second.first;
100 auto is_isolated = CanBeIsolatedNode(var_name, hidden_node);
101 if (!is_used && is_isolated) {
102 MS_EXCEPTION_IF_NULL(hidden_node);
103 MS_LOG(INFO) << "Isolated node found(Hidden), hidden_node: " << hidden_node->DebugString(kRecursiveLevel)
104 << " is hidden by " << node->DebugString(kRecursiveLevel)
105 << " with the same name, var_name: " << var_name << ", block: " << this << "/"
106 << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
107 << ", Line: " << trace::GetDebugInfoStr(hidden_node->debug_info(), "", kSourceLineTipDiscard);
108 AddIsolatedNode(hidden_node);
109 }
110 MS_LOG(INFO) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " update var `" << var_name
111 << "` with node " << node->DebugString();
112 iter->second = std::make_pair(node, false);
113 }
114 if (!HasGlobalPyParam(var_name)) {
115 UpdateLocalPyParam(var_name, node);
116 }
117 }
118
ReadLocalVariable(const std::string & var_name)119 AnfNodePtr FunctionBlock::ReadLocalVariable(const std::string &var_name) {
120 auto found = assigned_vars_.find(var_name);
121 if (found != assigned_vars_.end()) {
122 auto &node = found->second.first;
123 MS_EXCEPTION_IF_NULL(node);
124 // Mark the variable as used.
125 found->second.second = true;
126 MS_LOG(DEBUG) << "Found var: " << var_name << ", as: " << node->DebugString();
127 return node;
128 }
129 return nullptr;
130 }
131
CheckHasVariable(const std::string & var_name)132 bool FunctionBlock::CheckHasVariable(const std::string &var_name) {
133 auto node = ReadLocalVariable(var_name);
134 if (node != nullptr) {
135 return true;
136 }
137 if (!prev_blocks_.empty()) {
138 auto block = prev_blocks_[0];
139 MS_EXCEPTION_IF_NULL(block);
140 return block->CheckHasVariable(var_name);
141 }
142 return false;
143 }
144
FindPredInterpretNode(const std::string & var_name)145 std::pair<AnfNodePtr, bool> FunctionBlock::FindPredInterpretNode(const std::string &var_name) {
146 // Search the predecessors of the current block for the local parameter. If one of the local parameter of the
147 // predecessors is interpret node, the phi_param needs to set the interpret true.
148 mindspore::HashSet<FunctionBlock *> visited_block;
149 std::queue<FunctionBlock *> block_queue;
150 block_queue.push(this);
151 bool has_found = false;
152 while (!block_queue.empty()) {
153 const auto cur_block = block_queue.front();
154 MS_EXCEPTION_IF_NULL(cur_block);
155 block_queue.pop();
156 (void)visited_block.insert(cur_block);
157 auto pred_node = cur_block->ReadLocalVariable(var_name);
158 if (pred_node != nullptr) {
159 has_found = true;
160 bool interpret_without_internal =
161 IsPrimitiveCNode(pred_node, prim::kPrimPyInterpret) && !pred_node->interpret_internal_type();
162 if (pred_node->interpret() || interpret_without_internal) {
163 return std::make_pair(pred_node, has_found);
164 }
165 } else {
166 for (const auto &cur_pred_block : cur_block->prev_blocks()) {
167 if (visited_block.count(cur_pred_block) == 0) {
168 block_queue.push(cur_pred_block);
169 }
170 }
171 }
172 }
173 return std::make_pair(nullptr, has_found);
174 }
175
176 // Read variable from predecessors
ReadVariable(const std::string & var_name)177 AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
178 MS_LOG(DEBUG) << "Read begin, var_name: " << var_name << ", block: " << ToString();
179 // Get var node if it is found
180 auto node = ReadLocalVariable(var_name);
181 if (node != nullptr) {
182 if (!HasGlobalPyParam(var_name)) {
183 UpdateLocalPyParam(var_name, node);
184 }
185 return node;
186 }
187
188 MS_LOG(DEBUG) << "matured_: " << matured_ << ", prev_blocks_.size: " << prev_blocks_.size();
189 // Get var from predecessor block, if can't get then make a resolve node to it
190 if (matured_) {
191 // If only one predecessor block, read the definition of var from it.
192 if (prev_blocks_.size() == 1) {
193 auto block = prev_blocks_[0];
194 MS_EXCEPTION_IF_NULL(block);
195 auto res = block->ReadVariable(var_name);
196 MS_LOG(DEBUG) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString()
197 << ",\nCurrent: " << py::str(const_cast<py::dict &>(global_py_params()))
198 << "\nInsert: " << py::str(const_cast<py::dict &>(block->global_py_params()));
199 UpdateGlobalPyParam(block->global_py_params());
200 if (!HasGlobalPyParam(var_name)) {
201 UpdateLocalPyParam(var_name, res);
202 }
203 return res;
204 } else if (prev_blocks_.empty()) {
205 // Get namespace and make Resolve
206 auto it = var_to_resolve_.find(var_name);
207 if (it != var_to_resolve_.end()) {
208 return it->second;
209 }
210 MS_LOG(DEBUG) << "var_name: " << var_name;
211 auto tmp_node = MakeResolveSymbol(var_name);
212 var_to_resolve_[var_name] = tmp_node;
213 return tmp_node;
214 }
215 }
216 // If have more than one predecessor blocks then build a phi node.
217 auto debug_info = std::make_shared<NodeDebugInfo>();
218 debug_info->set_name(var_name);
219 TraceGuard guard(std::make_shared<TracePhi>(debug_info));
220 ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
221 MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
222 << phi_param->ToString() << " for " << var_name;
223
224 auto [pred_node, has_found] = FindPredInterpretNode(var_name);
225 if (pred_node != nullptr) {
226 phi_param->set_interpret(true);
227 } else if (!has_found) {
228 // If the current node is created as a phi node at the first time.(the var_name has not be found in pre blocks)
229 // need resolve to determine whether it needs to be marked with interpret.
230 auto resolve_node = MakeResolveSymbol(var_name);
231 MS_EXCEPTION_IF_NULL(resolve_node);
232 // Avoid to build phi node if current block is not matured and resolve_node is an undefined symbol.
233 if (!matured_ && resolve_node->isa<ValueNode>()) {
234 auto value = GetValuePtr<ValueProblem>(resolve_node->cast<ValueNodePtr>());
235 if (!is_dead_block() && value != nullptr && value->IsUndefined()) {
236 MS_LOG(DEBUG) << "Avoid to build phi node and return undefined node, var_name: " << var_name
237 << ", block: " << ToString() << ", matured_: " << matured_
238 << ", prev_blocks_.size: " << prev_blocks_.size();
239 return resolve_node;
240 }
241 }
242 phi_param->set_interpret(resolve_node->interpret());
243 phi_param->set_interpret_internal_type(resolve_node->interpret_internal_type());
244 if (resolve_node->isa<Parameter>()) {
245 phi_param->set_debug_info(resolve_node->debug_info());
246 }
247 }
248
249 func_graph()->add_parameter(phi_param);
250 phi_nodes_[phi_param] = var_name;
251 WriteVariable(var_name, phi_param);
252 if (matured_) {
253 SetPhiArgument(phi_param);
254 }
255 // In SetPhiArgument/CollectRemovablePhi, this phi may be set as removable and set it as
256 // real node, so check it again.
257 MS_LOG(DEBUG) << "Read again, var_name: " << var_name << ", block: " << ToString();
258 node = ReadLocalVariable(var_name);
259 if (node != nullptr) {
260 return node;
261 }
262 return phi_param;
263 }
264
265 // Resolve Ast operator node
GetAstOpNameSpace(const py::object & op)266 py::tuple FunctionBlock::GetAstOpNameSpace(const py::object &op) {
267 auto ast = parser_.ast();
268 MS_EXCEPTION_IF_NULL(ast);
269 TraceGuard trace_guard(parser_.GetLocation(op));
270 py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op);
271 constexpr size_t namespace_size = 3;
272 if (namespace_var.size() != namespace_size) {
273 MS_LOG(INTERNAL_EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size();
274 }
275 return namespace_var;
276 }
277
278 // Resolve Ast operator node
MakeResolveAstOpNameSpace(const py::tuple & namespace_var)279 AnfNodePtr FunctionBlock::MakeResolveAstOpNameSpace(const py::tuple &namespace_var) {
280 constexpr size_t namespace_index = 0;
281 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_AST, namespace_var[namespace_index]);
282 constexpr size_t symbol_index = 1;
283 SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[symbol_index].cast<std::string>());
284 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
285 return MakeResolve(name_space, symbol);
286 }
287
288 // Resolve class object self.
MakeResolveClassObject()289 AnfNodePtr FunctionBlock::MakeResolveClassObject() {
290 auto ast = parser_.ast();
291 MS_EXCEPTION_IF_NULL(ast);
292 py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
293 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_OBJECT, namespace_var, ast->obj());
294 constexpr auto self_name = "self";
295 SymbolPtr symbol = std::make_shared<Symbol>(self_name); // Must be 'self'.
296 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
297 return MakeResolve(name_space, symbol);
298 }
299
300 // Resolve class member: method, member variable.
MakeResolveClassMember(const std::string & attr_or_self)301 AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr_or_self) {
302 auto ast = parser_.ast();
303 MS_EXCEPTION_IF_NULL(ast);
304 py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
305 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var, ast->obj());
306 SymbolPtr symbol = std::make_shared<Symbol>(attr_or_self);
307 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
308 return MakeResolve(name_space, symbol);
309 }
310
CheckUndefinedSymbol(const std::string & var,const AnfNodePtr & node) const311 void FunctionBlock::CheckUndefinedSymbol(const std::string &var, const AnfNodePtr &node) const {
312 if (node->isa<ValueNode>()) {
313 auto value = GetValuePtr<ValueProblem>(node->cast<ValueNodePtr>());
314 if (!is_dead_block() && value != nullptr && value->IsUndefined()) {
315 MS_EXCEPTION(NameError) << "The name '" << var << "' is not defined, or not supported in graph mode.";
316 }
317 }
318 }
319
HandleNamespaceSymbol(const std::string & var_name)320 AnfNodePtr FunctionBlock::HandleNamespaceSymbol(const std::string &var_name) {
321 auto ast = parser_.ast();
322 MS_EXCEPTION_IF_NULL(ast);
323 const py::tuple &info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, var_name);
324
325 constexpr size_t closure_info_size = 3;
326 constexpr size_t global_info_size = 4;
327 constexpr size_t namespace_index = 0;
328 constexpr size_t symbol_index = 1;
329 constexpr size_t value_index = 2;
330 constexpr size_t flag_index = 3;
331 if (info.size() != closure_info_size && info.size() != global_info_size) {
332 MS_INTERNAL_EXCEPTION(NameError) << "The namespace info size should be 3 or 4, but got " << info.size();
333 }
334 // If namespace is None, the symbol is an undefined name.
335 if (info[namespace_index].is_none()) {
336 const auto undefined_symbol = std::make_shared<ValueProblem>(ValueProblemType::kUndefined);
337 MS_LOG(WARNING) << "Undefined symbol: " << var_name << ", during parsing " << py::str(ast->function()) << " of "
338 << py::str(ast->obj());
339 return NewValueNode(undefined_symbol);
340 }
341
342 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, info[namespace_index]);
343 SymbolPtr symbol = std::make_shared<Symbol>(info[symbol_index].cast<std::string>());
344 auto resolved_node = MakeResolve(name_space, symbol);
345
346 // Handle closure namespace info.
347 if (info.size() == closure_info_size) {
348 return resolved_node;
349 }
350
351 // Handle global namespace info.
352 auto syntax_support = info[flag_index].cast<int32_t>();
353 py::object py_obj = info[value_index];
354 if (syntax_support != SYNTAX_SUPPORTED && syntax_support != SYNTAX_HYBRID_TYPE) {
355 resolved_node->set_interpret(true);
356 if (syntax_support == SYNTAX_UNSUPPORTED_INTERNAL_TYPE) {
357 resolved_node->set_interpret_internal_type(true);
358 resolved_node->set_user_data<py::object>(kClassTensorObject, std::make_shared<py::object>(py_obj));
359 }
360 }
361
362 auto symbol_name = info[symbol_index].cast<std::string>();
363 if (symbol_name == "Shard") {
364 MS_LOG(EXCEPTION) << "Cell.shard or ms.shard not supported in jit syntax, please use shard out of jit"
365 << " or construct scope.";
366 }
367 AddGlobalPyParam(symbol_name, py_obj);
368 MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symbol: {" << symbol_name << " : "
369 << py::str(py_obj) << "}";
370 fallback::SetPyObjectToNode(resolved_node, py_obj);
371 return resolved_node;
372 }
373
374 // Make a resolve node for symbol string
MakeResolveSymbol(const std::string & var_name)375 AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &var_name) {
376 MS_LOG(DEBUG) << "var_name: " << var_name << ", ast object type: " << parser_.ast()->target_type();
377
378 // Handle self. The prefix of var_name is "self".
379 constexpr auto self_name = "self";
380 const auto self_name_len = strlen(self_name);
381 // For PARSE_TARGET_METHOD or PARSE_TARGET_OBJECT_INSTANCE, should deal with self here, exclude PARSE_TARGET_FUNCTION.
382 if ((parser_.ast()->target_type() == PARSE_TARGET_METHOD ||
383 parser_.ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) &&
384 var_name.compare(0, self_name_len, self_name) == 0) {
385 auto start = var_name.find_first_of('.');
386 if (start != std::string::npos) { // 'self.xxx'
387 ++start;
388 if (start >= var_name.size()) {
389 MS_LOG(ERROR) << "Find invalid resolve symbol str: " << var_name;
390 return nullptr;
391 }
392 auto bits_str = var_name.substr(start);
393 auto resolve_node = MakeResolveClassMember(bits_str);
394 if (!HasGlobalPyParam(var_name)) {
395 UpdateLocalPyParam(var_name, resolve_node);
396 }
397 return resolve_node;
398 } else if (var_name.size() == self_name_len) { // 'self'
399 auto resolve_node = MakeResolveClassObject();
400 if (!HasGlobalPyParam(var_name)) {
401 UpdateLocalPyParam(var_name, resolve_node);
402 }
403 return resolve_node;
404 }
405 }
406
407 // Handle non-self.
408 return HandleNamespaceSymbol(var_name);
409 }
410
MakeResolveOperation(const std::string & value)411 AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
412 auto ast = parser_.ast();
413 MS_EXCEPTION_IF_NULL(ast);
414 py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
415 const size_t namespace_var_size = 2;
416 if (namespace_var.size() < namespace_var_size) {
417 MS_INTERNAL_EXCEPTION(NameError) << "namespace_var is less than 2";
418 }
419 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]);
420 SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
421 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
422 return MakeResolve(name_space, symbol);
423 }
424
425 namespace {
426 // The same as TransformVectorFuncValueNode() in mindspore/ccsrc/pipeline/jit/parse/resolve.cc, but not add to manager.
TransformVectorFuncValueNode(const FuncGraphPtr & func_graph,const ValuePtr & value,AnfNodePtr * const transformed)427 bool TransformVectorFuncValueNode(const FuncGraphPtr &func_graph, const ValuePtr &value,
428 AnfNodePtr *const transformed) {
429 MS_EXCEPTION_IF_NULL(value);
430 const auto &value_vec = GetValue<ValuePtrList>(value);
431 if (value_vec.empty()) {
432 return false;
433 }
434 std::vector<AnfNodePtr> nodes;
435 (void)nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
436 bool is_all_func = true;
437 for (auto &elem : value_vec) {
438 MS_EXCEPTION_IF_NULL(elem);
439 AnfNodePtr node = nullptr;
440 if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
441 is_all_func = is_all_func && TransformVectorFuncValueNode(func_graph, elem, &node);
442 } else if (elem->isa<FuncGraph>()) {
443 FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>();
444 node = NewValueNode(new_fg);
445 } else if (elem->isa<Primitive>()) {
446 node = NewValueNode(elem);
447 } else {
448 is_all_func = false;
449 }
450 (void)nodes.emplace_back(node);
451 }
452 if (is_all_func) {
453 // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
454 // So if has graph in list, try to replace the node with make tuple of graph value node.
455 // We do this because the graph manager won't investigate the graph inside valuetuple,
456 // change the vector of graph to be make_tuple of graph value node.
457 // (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
458 // independent nodes.
459 *transformed = func_graph->NewCNode(std::move(nodes));
460 }
461 return is_all_func;
462 }
463 } // namespace
464
MakeResolve(const NameSpacePtr & name_space,const SymbolPtr & resolve_symbol)465 AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
466 MS_LOG(DEBUG) << "MakeResolve for "
467 << (name_space ? (std::string)py::str(name_space->namespace_obj()) : "null namespace") << " , "
468 << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
469 ValueNodePtr module_node = NewValueNode(name_space);
470 ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
471 auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
472
473 // Directly resolve the symbol.
474 return DoResolve(node, name_space, resolve_symbol);
475 }
476
DoResolve(const AnfNodePtr & node,const std::shared_ptr<NameSpace> & name_space,const std::shared_ptr<Symbol> & resolve_symbol)477 AnfNodePtr FunctionBlock::DoResolve(const AnfNodePtr &node, const std::shared_ptr<NameSpace> &name_space,
478 const std::shared_ptr<Symbol> &resolve_symbol) {
479 static const auto boost_parse = common::GetCompileConfig("GREED_PARSE");
480 if (Parser::defer_resolve() || boost_parse != "1") {
481 return node;
482 }
483 // Directly resolve the symbol.
484 const auto &obj = GetSymbolObject(name_space, resolve_symbol, node);
485 // Avoid recursively resolving Cell.
486 if (py::isinstance<Cell>(obj) && resolve_symbol->symbol() == "self") {
487 MS_LOG(ERROR) << "Not direct resolve Cell self. node: " << node->DebugString() << ", ns: " << name_space->ToString()
488 << ", sym: " << resolve_symbol->ToString();
489 return node;
490 }
491 AnfNodePtr resolved_node = nullptr;
492 bool success = ResolveObjectToNode(node, obj, &resolved_node);
493 if (!success || resolved_node == nullptr) {
494 MS_LOG(INTERNAL_EXCEPTION) << "Parse Resolve covert failed." << node->DebugString()
495 << ", ns: " << name_space->ToString() << ", sym: " << resolve_symbol->ToString();
496 }
497 // If the constant node is constant of vector of graph, add graph to manager.
498 if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
499 auto value = resolved_node->cast<ValueNodePtr>()->value();
500 if (!TransformVectorFuncValueNode(func_graph_, value, &resolved_node)) {
501 MS_LOG(INFO) << "Fail to convert value tuple/list to CNode, " << resolved_node->DebugString();
502 }
503 }
504 MS_LOG(DEBUG) << "node: " << node->DebugString() << ", ns: " << name_space->ToString()
505 << ", sym: " << resolve_symbol->ToString() << ", resolved_node: " << resolved_node->DebugString();
506 return resolved_node;
507 }
508
MakeInterpret(const std::string & script_text,const AnfNodePtr & global_dict_node,const AnfNodePtr & local_dict_node,const AnfNodePtr & orig_node)509 AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
510 const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) {
511 MS_LOG(DEBUG) << "MakeInterpret for " << script_text;
512 MS_EXCEPTION_IF_NULL(orig_node);
513 auto script = std::make_shared<parse::Script>(script_text);
514 auto script_node = NewValueNode(script);
515 auto node = func_graph_->NewCNodeInOrder(
516 {NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
517 MS_EXCEPTION_IF_NULL(node);
518 node->set_debug_info(orig_node->debug_info());
519 node->set_interpret_internal_type(orig_node->interpret_internal_type());
520 return node;
521 }
522
523 // Add input for the block's phi parameter
SetPhiArgument(const ParameterPtr & phi)524 void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
525 MS_EXCEPTION_IF_NULL(phi);
526 TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
527 std::string var = phi_nodes_[phi];
528 MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
529 << " for var `" << var << "`";
530 CollectRemovablePhi(phi);
531 for (auto &pred : prev_blocks_) {
532 MS_EXCEPTION_IF_NULL(pred);
533 MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " pred_blocks_ "
534 << (pred->func_graph_ ? pred->func_graph_->ToString() : "FG(Null)");
535 AnfNodePtr arg_node = pred->ReadVariable(var);
536 auto jump = pred->GetJumpNode(this);
537 if (jump == nullptr) {
538 // If prev block is a switch call's prev block, no jumps here.
539 continue;
540 }
541 jump->add_input(arg_node);
542 }
543 }
544
545 namespace {
GetVariableDefinedLocation(const FunctionBlock * block,const std::string & var,int start_line)546 std::string GetVariableDefinedLocation(const FunctionBlock *block, const std::string &var, int start_line) {
547 MS_EXCEPTION_IF_NULL(block);
548 HashSet<FunctionBlock *> visited;
549 std::vector<FunctionBlock *> todo_list = {};
550 (void)std::copy(block->prev_blocks().cbegin(), block->prev_blocks().cend(), std::back_inserter(todo_list));
551 while (!todo_list.empty()) {
552 auto cur_block = todo_list.front();
553 (void)todo_list.erase(todo_list.begin());
554 if (visited.find(cur_block) != visited.cend()) {
555 continue;
556 }
557 (void)visited.insert(cur_block);
558 (void)std::copy(cur_block->prev_blocks().cbegin(), cur_block->prev_blocks().cend(), std::back_inserter(todo_list));
559 auto node = cur_block->ReadLocalVariable(var);
560 if (node != nullptr) {
561 const auto &debug_info = trace::GetSourceCodeDebugInfo(node->debug_info());
562 const auto &location = debug_info->location();
563 return location->ToString(kSourceSectionTipNextLineHere, start_line);
564 }
565 }
566 return "";
567 }
568 } // namespace
569
CheckVariableNotDefined(const std::pair<std::string,AnfNodePtr> & not_defined_branch,const std::string & var)570 void FunctionBlock::CheckVariableNotDefined(const std::pair<std::string, AnfNodePtr> ¬_defined_branch,
571 const std::string &var) {
572 std::ostringstream oss;
573 std::string not_defined_branch_name = not_defined_branch.first;
574 const auto &debug_info = trace::GetSourceCodeDebugInfo(this->func_graph()->debug_info());
575 const auto &location = debug_info->location();
576 int start_line = location->line();
577 if ((not_defined_branch_name == "while") || (not_defined_branch_name == "for")) {
578 oss << "The local variable '" << var << "' defined in the '" << not_defined_branch_name
579 << "' loop body cannot be used outside of the loop body. "
580 << "Please define variable '" << var << "' before '" << not_defined_branch_name << "'.\n";
581 }
582 if ((not_defined_branch_name == "true branch") || (not_defined_branch_name == "false branch")) {
583 oss << "The local variable '" << var << "' is not defined in " << not_defined_branch_name << ", but defined in "
584 << (not_defined_branch_name == "true branch" ? "false branch" : "true branch") << ".\n";
585 }
586 std::string location_info = GetVariableDefinedLocation(this, var, start_line);
587 if (location_info.empty()) {
588 return;
589 }
590 oss << location_info;
591 MS_EXCEPTION(UnboundLocalError) << oss.str();
592 }
593
SearchAllArgsOfPhiNode(const std::string & var,const ParameterPtr & phi)594 std::set<AnfNodePtr> FunctionBlock::SearchAllArgsOfPhiNode(const std::string &var, const ParameterPtr &phi) {
595 std::vector<std::pair<std::string, AnfNodePtr>> defined_branch;
596 std::pair<std::string, AnfNodePtr> not_defined_branch;
597 MS_LOG(DEBUG) << "Search block:" << ToString() << "Prev_blocks size: " << prev_blocks_.size();
598 for (auto &prev : prev_blocks_) {
599 MS_EXCEPTION_IF_NULL(prev);
600 AnfNodePtr temp_node = prev->ReadVariable(var);
601 MS_EXCEPTION_IF_NULL(temp_node);
602 MS_LOG(DEBUG) << "Read from prev block:" << prev->ToString() << "Found var: " << var
603 << ", as: " << temp_node->DebugString();
604 bool undefined_symbol_flag = false;
605 if (temp_node->isa<ValueNode>()) {
606 auto value = GetValuePtr<ValueProblem>(temp_node->cast<ValueNodePtr>());
607 if ((value != nullptr) && (value->IsUndefined())) {
608 undefined_symbol_flag = true;
609 }
610 }
611 if (undefined_symbol_flag) {
612 not_defined_branch = std::make_pair(prev->block_name(), temp_node);
613 } else {
614 defined_branch.push_back(std::make_pair(prev->block_name(), temp_node));
615 }
616 }
617 if (defined_branch.size() == 1) {
618 auto arg_node = defined_branch.front().second;
619 MS_EXCEPTION_IF_NULL(arg_node);
620 MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi "
621 << (phi ? phi->ToString() : "null") << " may be replaced by node " << arg_node->DebugString();
622 }
623
624 if (not_defined_branch.second != nullptr) {
625 if (!defined_branch.empty()) {
626 auto locaction = trace::GetSourceCodeDebugInfo(phi->debug_info())->location();
627 MS_EXCEPTION_IF_NULL(locaction);
628 TraceGuard trace_guard(locaction);
629 CheckVariableNotDefined(not_defined_branch, var);
630 }
631 MS_EXCEPTION(NameError) << "The name '" << var << "' is not defined, or not supported in graph mode.";
632 }
633
634 std::set<AnfNodePtr> all_arg_nodes;
635 for (auto &item : defined_branch) {
636 (void)all_arg_nodes.insert(item.second);
637 }
638 return all_arg_nodes;
639 }
640
641 // Check if there is removable unnecessary phi node in this graph.
642 // As per the FIRM TR 3.2, a phi node can be remove if:
643 // <Quote>
644 // If all arguments of a φ-function are the same value s or the φfunction itself,
645 // then we remove the φ-function and let all users directly uses. We call such a
646 // φ-function obviously unnecessary.
647 // When we removed a φ-function p, then we recursively try to apply this simplification
648 // rule with all (former) users of p, because they may have become obviously unnecessary
649 // due to the removal of p
650 // <Quote>
651 // phi node in graph will be removed after the whole function is parsed in a DFS visit
652 // of that graph.The reason is :
653 // 1. when this function is called, not all usage of this phi node had bound to the
654 // graph of this function block, some may stay in vars_ in other blocks.
655 // 2. it's costly to iterate the graph to replace the phi for each phi.
656 // Args: phi: This parameter node is functioning as a phi node.
CollectRemovablePhi(const ParameterPtr & phi)657 void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
658 MS_EXCEPTION_IF_NULL(phi);
659 const auto &var_name = phi_nodes_[phi];
660 MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var_name;
661 if (prev_blocks_.empty()) {
662 MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var_name;
663 return;
664 }
665 auto arg_nodes = SearchAllArgsOfPhiNode(var_name, phi);
666 phi_args_[phi] = arg_nodes;
667 if (arg_nodes.size() == 1) {
668 auto arg_node = *arg_nodes.begin();
669 if (arg_node->debug_info() == nullptr) {
670 arg_node->set_debug_info(phi->debug_info());
671 }
672 MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi " << phi->ToString()
673 << " can be replaced with " << arg_node->DebugString();
674 // Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
675 WriteVariable(var_name, arg_node);
676 bool interpret_without_internal =
677 IsPrimitiveCNode(arg_node, prim::kPrimPyInterpret) && !arg_node->interpret_internal_type();
678 if (arg_node->interpret() || interpret_without_internal) {
679 phi->set_interpret(true);
680 if (arg_node->interpret_internal_type()) {
681 phi->set_interpret_internal_type(true);
682 }
683 }
684 }
685 }
686
687 // A block should be marked matured if its predecessor blocks have been processed
Mature()688 void FunctionBlock::Mature() {
689 const auto &graph_params = func_graph_->parameters();
690 for (auto ¶m_itr : graph_params) {
691 MS_EXCEPTION_IF_NULL(param_itr);
692 auto param = param_itr->cast<ParameterPtr>();
693 if (phi_nodes_.find(param) != phi_nodes_.cend()) {
694 SetPhiArgument(param);
695 }
696 }
697 matured_ = true;
698 }
699
700 // Get the truth value testing for cond node.
ForceToCondNode(const AnfNodePtr & cond,bool is_while_cond)701 CNodePtr FunctionBlock::ForceToCondNode(const AnfNodePtr &cond, bool is_while_cond) {
702 MS_EXCEPTION_IF_NULL(cond);
703 CNodePtr op_apply_node =
704 func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimCond), cond, NewValueNode(MakeValue(is_while_cond))});
705 return op_apply_node;
706 }
707
708 // Perform a jump from this block to target block
Jump(const FunctionBlockPtr & target_block,const std::vector<AnfNodePtr> & args)709 void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
710 MS_EXCEPTION_IF_NULL(target_block);
711 MS_LOG(DEBUG) << "Jump from block: " << ToString() << " to block: " << target_block->ToString();
712 if (is_dead_block_) {
713 MS_LOG(DEBUG) << "Dead code block should not jump to other block! block: " << ToString();
714 return;
715 }
716 if (func_graph_->get_return() != nullptr) {
717 MS_LOG(INTERNAL_EXCEPTION) << "Failure: have return node! NodeInfo: "
718 << trace::GetDebugInfoStr(func_graph_->get_return()->debug_info());
719 }
720 std::vector<AnfNodePtr> input_nodes;
721 input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
722 (void)std::copy(args.begin(), args.end(), std::back_inserter(input_nodes));
723
724 CNodePtr jump = func_graph_->NewCNodeInOrder(std::move(input_nodes));
725 jumps_[target_block.get()] = jump;
726 target_block->AddPrevBlock(shared_from_this());
727 func_graph_->set_output(jump);
728 }
729
730 // Perform a conditional jump using switch operation.
731 // The first CNode select graph with condition, and than execute this graph
ConditionalJump(const AnfNodePtr & cond_node,const AnfNodePtr & true_block_call,const AnfNodePtr & false_block_call)732 CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
733 const AnfNodePtr &false_block_call) {
734 MS_EXCEPTION_IF_NULL(true_block_call);
735 MS_EXCEPTION_IF_NULL(false_block_call);
736 if (func_graph_->get_return() != nullptr) {
737 MS_LOG(INTERNAL_EXCEPTION) << "Failure: have return node! fg: " << func_graph_->ToString()
738 << "\nNodeInfo: " << trace::GetDebugInfoStr(func_graph_->get_return()->debug_info())
739 << "\ncond_node: " << cond_node->DebugString()
740 << "\nNodeInfo: " << trace::GetDebugInfoStr(cond_node->debug_info());
741 }
742 CNodePtr switch_app =
743 func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, true_block_call, false_block_call});
744 CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app});
745 func_graph_->set_output(switch_app_new);
746 return switch_app_new;
747 }
748
ConditionalJump(const AnfNodePtr & cond_node,const FunctionBlockPtr & true_block,const FunctionBlockPtr & false_block)749 CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
750 const FunctionBlockPtr &false_block) {
751 MS_EXCEPTION_IF_NULL(true_block);
752 MS_EXCEPTION_IF_NULL(false_block);
753 return ConditionalJump(cond_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph()));
754 }
755
756 // Create cnode for the assign statement like 'self.target = source'.
757 // convert it to 'P.Assign(self.target, source)' and then add the cnode as isolate node.
SetStateAssign(const AnfNodePtr & target,const AnfNodePtr & source)758 void FunctionBlock::SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source) {
759 const std::string primitive_name("assign");
760 const std::string module_name("mindspore.ops.functional");
761 ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
762 auto assign_node = func_graph_->NewCNodeInOrder({assign_op, target, source});
763 const int recursive_level = 2;
764 MS_LOG(DEBUG) << "Isolated node found(Assign), assign_node: " << assign_node->DebugString(recursive_level)
765 << ", block: " << this << "/" << func_graph_->ToString()
766 << ", Line: " << trace::GetDebugInfoStr(assign_node->debug_info(), "", kSourceLineTipDiscard);
767 AddIsolatedNode(assign_node);
768 }
769
ConvertUnusedNodesToIsolated(const std::pair<std::string,std::pair<AnfNodePtr,bool>> var)770 void FunctionBlock::ConvertUnusedNodesToIsolated(const std::pair<std::string, std::pair<AnfNodePtr, bool>> var) {
771 auto &node = var.second.first;
772 bool is_used = var.second.second;
773 if (node == nullptr || is_used) {
774 return;
775 }
776 auto &var_name = var.first;
777 if (CanBeIsolatedNode(var_name, node)) {
778 const int recursive_level = 2;
779 MS_LOG(INFO) << "Isolated node found(NoUse), node: " << node->DebugString(recursive_level)
780 << ", var_name: " << var_name << ", block: " << this << "/"
781 << (func_graph() ? func_graph()->ToString() : "FG(Null)")
782 << ", Line: " << trace::GetDebugInfoStr(node->debug_info(), "", kSourceLineTipDiscard);
783 AddIsolatedNode(node);
784 }
785 }
786
FindIsolatedNodes()787 void FunctionBlock::FindIsolatedNodes() {
788 //
789 // Search isolate nodes from variables, for example,
790 // variable 'a' is an isolate node in below code:
791 //
792 // def construct(self, x, y):
793 // a = print(x) # isolate node
794 // return x + y
795 //
796 // Add isolated nodes which is unused var but not found in used set.
797 for (const auto &var : assigned_vars_) {
798 ConvertUnusedNodesToIsolated(var);
799 }
800 }
801
AddIsolatedNode(const AnfNodePtr & target)802 void FunctionBlock::AddIsolatedNode(const AnfNodePtr &target) { isolated_nodes_.add(target); }
803
AttachIsolatedNodesBeforeReturn()804 void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
805 if (isolated_nodes_.empty()) {
806 return;
807 }
808 std::vector<AnfNodePtr> states;
809 states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
810 constexpr int recursive_level = 2;
811 for (const auto &node : isolated_nodes_) {
812 MS_EXCEPTION_IF_NULL(node);
813 MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(recursive_level) << " in "
814 << func_graph_->ToString();
815 if (node->func_graph() == func_graph_) {
816 states.emplace_back(node);
817 } else {
818 MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(recursive_level) << " in "
819 << func_graph_->ToString();
820 }
821 }
822 isolated_nodes_.clear();
823
824 AnfNodePtr state = nullptr;
825 constexpr size_t no_state_size = 1;
826 constexpr size_t only_one_state_size = 2;
827 if (states.size() == no_state_size) {
828 // Only MakeTuple, no state left.
829 return;
830 } else if (states.size() == only_one_state_size) {
831 // If there are only MakeTuple and another node in states(the states size is 2),
832 // do not need to MakeTuple, just use the node.
833 state = states[1];
834 } else {
835 state = func_graph_->NewCNode(std::move(states));
836 if (state != nullptr && state->debug_info() != nullptr) {
837 state->debug_info()->set_location(nullptr);
838 }
839 }
840
841 AnfNodePtr old_output = nullptr;
842 auto return_node = func_graph_->get_return();
843 if (return_node != nullptr) {
844 const size_t return_input_size = 2;
845 if (return_node->size() < return_input_size) {
846 MS_LOG(INTERNAL_EXCEPTION) << "Length of inputs of output node is less than 2";
847 }
848 old_output = return_node->input(1);
849 } else {
850 old_output = NewValueNode(kNone);
851 }
852 AnfNodePtr stop_grad_node = func_graph_->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
853 CNodePtr depend_node = func_graph_->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
854 if (stop_grad_node->debug_info()) {
855 stop_grad_node->debug_info()->set_location(nullptr);
856 }
857 if (depend_node->debug_info()) {
858 depend_node->debug_info()->set_location(old_output->debug_info()->location());
859 }
860 // We add this attribute for @constexpr use scene, since we must infer them before other nodes.
861 // That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
862 depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
863 MS_EXCEPTION_IF_NULL(state);
864 MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
865 << ", state: " << state->DebugString(recursive_level);
866 func_graph_->set_output(depend_node, true);
867 // Update new return node's debug_info with old one.
868 if (return_node != nullptr && return_node->debug_info()) {
869 auto new_return = func_graph_->get_return();
870 MS_EXCEPTION_IF_NULL(new_return);
871 new_return->set_debug_info(return_node->debug_info());
872 }
873 }
874
SetAsDeadBlock()875 void FunctionBlock::SetAsDeadBlock() { is_dead_block_ = true; }
876
GetJumpNode(FunctionBlock * target_block)877 CNodePtr FunctionBlock::GetJumpNode(FunctionBlock *target_block) {
878 auto it = jumps_.find(target_block);
879 if (it == jumps_.end()) {
880 MS_LOG(DEBUG) << "Can't find jump node from block:" << ToString() << " to block:" << target_block->ToString();
881 return nullptr;
882 }
883 return it->second;
884 }
885
set_is_return_statement_inside()886 void FunctionBlock::set_is_return_statement_inside() { is_return_statement_inside_ = true; }
set_break_continue_statement_inside()887 void FunctionBlock::set_break_continue_statement_inside() { is_break_continue_statement_inside_ = true; }
888 } // namespace parse
889 } // namespace mindspore
890