• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/pi/graph_compiler/func_graph_builder.h"
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "frontend/operator/composite/unpack_call.h"
26 #include "frontend/operator/ops.h"
27 #include "include/common/utils/convert_utils_py.h"
28 #include "include/common/utils/python_adapter.h"
29 #include "ops/framework_ops.h"
30 #include "ops/math_ops.h"
31 #include "ops/sequence_ops.h"
32 #include "ops/structure_ops.h"
33 #include "pipeline/jit/pi/graph_compiler/func_wrapper.h"
34 #include "pipeline/jit/pi/graph_compiler/pi_ir/ir_mutator.h"
35 #include "pipeline/jit/pi/graph_compiler/utils.h"
36 #include "pipeline/jit/ps/parse/parse.h"
37 #include "utils/log_adapter.h"
38 
39 namespace mindspore {
40 namespace pijit {
41 namespace ir {
__anonf3d3890b0102(const NodePtr &node, IRMutator *m) 42 STATIC_IR_FUNCTOR(IRMutator, vtable).set_dispatch<MindNode>([](const NodePtr &node, IRMutator *m) { return node; });
43 }  // namespace ir
44 
BuildFuncGraph(const ir::FunctionNodePtr & func,const py::tuple & args,const py::dict & kwargs)45 FuncGraphPtr FuncGraphBuilder::BuildFuncGraph(const ir::FunctionNodePtr &func, const py::tuple &args,
46                                               const py::dict &kwargs) {
47   AnfNodePtrList anf_args;
48   bool broaden = func->GetAttr("enable_tuple_broaden");
49   std::transform(args.begin(), args.end(), std::back_inserter(anf_args), [broaden](const py::handle &arg) {
50     auto node = GraphUtils::ConvertPythonObjectToAnfNode(py::cast<py::object>(arg));
51     node->set_abstract(GraphUtils::ArgsToAbstract(py::cast<py::object>(arg), GetValueNode<ValuePtr>(node), broaden));
52     return node;
53   });
54   AnfNodePtr anf_kwargs = GraphUtils::ConvertPythonObjectToAnfNode(kwargs);
55   anf_kwargs->set_abstract(GraphUtils::ArgsToAbstract(kwargs, GetValueNode<ValuePtr>(anf_kwargs), broaden));
56   return BuildFuncGraph(func, anf_args, anf_kwargs);
57 }
58 
BuildFuncGraph(const ir::FunctionNodePtr & func,const AnfNodePtrList & args,const AnfNodePtr & kwargs)59 FuncGraphPtr FuncGraphBuilder::BuildFuncGraph(const ir::FunctionNodePtr &func, const AnfNodePtrList &args,
60                                               const AnfNodePtr &kwargs) {
61   auto builder = std::make_shared<FuncGraphBuilder>(func, args, kwargs);
62   parse::Parser::UpdateTopFuncGraph(builder->func_graph_);
63   auto func_graph = GetValueNode<FuncGraphPtr>(builder->Mutate(func)->cast<MindNodePtr>()->GetAnfNode());
64   return func_graph;
65 }
66 
67 #define DEFINE_UN_NODE_MUTATE_(OP)                                           \
68   ir::NodePtr FuncGraphBuilder::Mutate_(const OP &node) {                    \
69     auto op = GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode());         \
70     auto n = func_graph_->NewCNodeInOrder({op, GetAnfNode(node->GetArg())}); \
71     UpdateLocation(n, node);                                                 \
72     return std::make_shared<MindNode>(n);                                    \
73   }
74 
75 DEFINE_UN_NODE_MUTATE_(ir::NegativeNodePtr)
DEFINE_UN_NODE_MUTATE_(ir::NotNodePtr)76 DEFINE_UN_NODE_MUTATE_(ir::NotNodePtr)
77 
78 #define DEFINE_BIN_NODE_MUTATE_(OP)                                  \
79   ir::NodePtr FuncGraphBuilder::Mutate_(const OP &node) {            \
80     auto op = GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode()); \
81     auto left = GetAnfNode(node->GetLeftArg());                      \
82     auto right = GetAnfNode(node->GetRightArg());                    \
83     CNodePtr n = func_graph_->NewCNodeInOrder({op, left, right});    \
84     UpdateLocation(n, node);                                         \
85     return std::make_shared<MindNode>(n);                            \
86   }
87 
88 DEFINE_BIN_NODE_MUTATE_(ir::AddNodePtr)
89 DEFINE_BIN_NODE_MUTATE_(ir::SubNodePtr)
90 DEFINE_BIN_NODE_MUTATE_(ir::MulNodePtr)
91 DEFINE_BIN_NODE_MUTATE_(ir::DivNodePtr)
92 DEFINE_BIN_NODE_MUTATE_(ir::BitwiseNodePtr)
93 DEFINE_BIN_NODE_MUTATE_(ir::BinaryOperationPtr)
94 
95 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::RefNodePtr &node) {
96   node->SetRealNode(Mutate(node->GetRealNode()));
97   return node;
98 }
99 
Mutate_(const ir::ParameterPtr & node)100 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::ParameterPtr &node) {
101   auto name = node->GetName();
102   auto index = node->GetIndex();
103   if (!func_->NeedGenParameters()) {
104     MS_EXCEPTION_IF_CHECK_FAIL(static_cast<size_t>(index) < args_.size(), "Invalid paramete[" + name + "].");
105     assigned_vars_[name] = args_[index];
106     return node;
107   }
108   auto param = std::make_shared<Parameter>(func_graph_);
109   MS_EXCEPTION_IF_NULL(param);
110   param->set_name(name);
111   MS_EXCEPTION_IF_NULL(param->debug_info());
112   param->debug_info()->set_name(name);
113   UpdateLocation(param, node);
114   auto category = node->GetCategory();
115   // kwargs
116   if (category == ir::Parameter::KEYWORD) {
117     param->set_abstract(kwargs_->abstract());
118   } else {
119     MS_EXCEPTION_IF_CHECK_FAIL(node->GetIndex() < args_.size(), "Parameter " + name + " has no arguments");
120     param->set_abstract(args_[node->GetIndex()]->abstract());
121   }
122   if (param->abstract()->isa<abstract::AbstractRefTensor>()) {
123     auto abs_ref = param->abstract()->cast<abstract::AbstractRefPtr>();
124     auto new_name = name + "_" + abs_ref->ref_key_value()->ToString();
125     param->set_name(new_name);
126     param->debug_info()->set_name(new_name);
127   }
128   auto defalut_value = node->GetDefaultValue();
129   if (defalut_value != nullptr) {
130     AnfNodePtr value_node = GetAnfNode(defalut_value);
131     func_graph_->set_param_default_value(param->name(), value_node);
132   }
133   assigned_vars_[name] = param;
134   func_graph_->add_parameter(param);
135   return std::make_shared<MindNode>(param);
136 }
137 
Mutate_(const ir::FunctionNodePtr & node)138 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::FunctionNodePtr &node) {
139   func_graph_->set_has_vararg(node->HasVarArg());
140   func_graph_->set_kwonlyargs_count(node->GetKwOnlyArgsCnt());
141   func_graph_->set_has_kwarg(node->HasKwArg());
142   func_graph_->debug_info()->set_name(node->GetName());
143   // used for create sub function
144   node->Sort();
145   auto first_if_node = std::find_if(node->GetNodes().begin(), node->GetNodes().end(),
146                                     [](const ir::NodePtr &n) { return n->isa<ir::IfNode>(); });
147   if (first_if_node != node->GetNodes().end()) {
148     size_t index = std::distance(node->GetNodes().begin(), first_if_node);
149     std::vector<ir::NodePtr> nodes(node->GetNodes().begin() + index + 1, node->GetNodes().end());
150     node->GetNodes().resize(index + 1);
151     auto if_node = node->GetNodes().back()->cast<ir::IfNodePtr>();
152     if (if_node->GetThen().empty() || !if_node->GetThen().back()->isa<ir::ReturnNode>()) {
153       if (!if_node->GetThen().empty() && if_node->GetThen().back()->isa<ir::JumpNode>()) {
154         if_node->GetThen().pop_back();
155       }
156       std::for_each(nodes.begin(), nodes.end(), [&if_node](const ir::NodePtr &n) { if_node->AddThen(n); });
157     }
158     if (if_node->GetElse().empty() || !if_node->GetElse().back()->isa<ir::ReturnNode>()) {
159       if (!if_node->GetElse().empty() && if_node->GetElse().back()->isa<ir::JumpNode>()) {
160         if_node->GetElse().pop_back();
161       }
162       std::for_each(nodes.begin(), nodes.end(), [&if_node](const ir::NodePtr &n) { if_node->AddElse(n); });
163     }
164     node->AddNode(std::make_shared<ir::ReturnNode>(if_node));
165   }
166   MUTATE_NODE_LIST(node->GetParameters())
167   MUTATE_NODE_LIST(node->GetNodes())
168   return std::make_shared<MindNode>(NewValueNode(func_graph_));
169 }
170 
UpdateLocation(const AnfNodePtr & anf_node,const ir::NodePtr & node)171 void FuncGraphBuilder::UpdateLocation(const AnfNodePtr &anf_node, const ir::NodePtr &node) {
172   if (!enable_debug_info_) {
173     return;
174   }
175   // Refer to Location::Location() for each node: line, column, line_end, column_end, expr_src.
176   auto debug_info = node->GetDebugInfo();
177   auto line_no = debug_info->GetLineNo();
178   line_no = (line_no == 0) ? last_line_no_ : line_no;
179   last_line_no_ = line_no;
180   auto loc = anf_node->debug_info()->location();
181   if (loc == nullptr) {
182     std::vector<std::string> comments;
183     anf_node->debug_info()->set_location(std::make_shared<Location>(debug_info->GetFileName(), line_no, 0, line_no, 0,
184                                                                     debug_info->GetDesc(), std::move(comments)));
185   } else {
186     loc->set_file_name(debug_info->GetFileName());
187     loc->set_line(line_no);
188     loc->set_line_end(line_no);
189     loc->set_expr_src(debug_info->GetDesc());
190   }
191 }
192 
ConvertListOrTupleToCNode(const py::object & obj)193 AnfNodePtr FuncGraphBuilder::ConvertListOrTupleToCNode(const py::object &obj) {
194   MS_EXCEPTION_IF_CHECK_FAIL((py::isinstance<py::list>(obj) || py::isinstance<py::tuple>(obj)),
195                              "Should be a list or tuple.");
196   auto tuple = py::cast<py::tuple>(obj);
197   auto parameter = python_adapter::GetPyObjAttr(python_adapter::GetPyModule("mindspore"), "Parameter");
198   auto prim = py::isinstance<py::list>(obj) ? prim::kPrimMakeList : prim::kPrimMakeTuple;
199   CNodePtr cnode = func_graph_->NewCNodeInOrder(prim, {});
200   for (size_t idx = 0; idx < tuple.size(); idx++) {
201     if (py::isinstance(tuple[idx], parameter)) {
202       cnode->add_input(parse::ResolveParameterObj(func_graph_, tuple[idx]));
203     } else if (py::isinstance<py::list>(tuple[idx]) || py::isinstance<py::tuple>(tuple[idx])) {
204       cnode->add_input(ConvertListOrTupleToCNode(tuple[idx]));
205     } else {
206       cnode->add_input(GraphUtils::ConvertPythonObjectToAnfNode(tuple[idx]));
207     }
208   }
209   return cnode;
210 }
211 
Mutate_(const ir::ValuePtr & node)212 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::ValuePtr &node) {
213   auto obj = node->GetValue();
214   auto dict = python_adapter::GetPyObjAttr(python_adapter::GetPyModule("mindspore._extends.parse.resources"),
215                                            "convert_object_map");
216   auto special_obj = PyDict_GetItem(dict.ptr(), obj.ptr());
217   if (special_obj != nullptr) {
218     obj = py::cast<py::object>(special_obj);
219   }
220   bool is_list_or_tuple = (py::isinstance<py::list>(obj) || py::isinstance<py::tuple>(obj));
221   auto value = is_list_or_tuple ? ConvertListOrTupleToCNode(obj) : GraphUtils::ConvertPythonObjectToAnfNode(obj);
222   UpdateLocation(value, node);
223   return std::make_shared<MindNode>(value);
224 }
225 
Mutate_(const ir::IfNodePtr & node)226 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::IfNodePtr &node) {
227   auto cond = node->GetCondition();
228   MS_EXCEPTION_IF_CHECK_FAIL(cond->isa<ir::JumpNode>(), cond->ToString() + " can't be a condition.");
229   auto jump = cond->cast<ir::JumpNodePtr>();
230   auto condition = Mutate(jump->GetCondition())->cast<MindNodePtr>()->GetAnfNode();
231   if (jump->GetOpCode() == POP_JUMP_IF_TRUE) {
232     condition = func_graph_->NewCNodeInOrder({GraphUtils::GetMetaFuncGraph(UNARY_NOT), condition});
233   }
234   auto _then = node->GetThen();
235   if (!_then.empty() && _then.back()->isa<ir::JumpNode>()) {
236     _then.pop_back();
237   }
238   const std::string prefix = func_->GetName() + "_sub_func_" + std::to_string(cond->GetOffset());
239   FuncWrapperPtr wrapper_then = std::make_shared<FuncWrapper>(prefix + "_true", _then);
240   auto outputs_then = wrapper_then->GetOutputs();
241   auto _else = node->GetElse();
242   if (!_else.empty() && _else.back()->isa<ir::JumpNode>()) {
243     _else.pop_back();
244   }
245   FuncWrapperPtr wrapper_else = std::make_shared<FuncWrapper>(prefix + "_false", _else);
246   auto outputs_else = wrapper_else->GetOutputs();
247   std::set<std::string> var_names;
248   std::vector<ir::ValuePtr> outputs;
249   std::for_each(outputs_then.begin(), outputs_then.end(), [&outputs, &var_names](const ir::ValuePtr &var) {
250     auto var_name = var->GetValue().cast<std::string>();
251     if (var_names.find(var_name) == var_names.end()) {
252       outputs.push_back(var);
253     }
254     var_names.insert(var_name);
255   });
256   std::for_each(outputs_else.begin(), outputs_else.end(), [&outputs, &var_names](const ir::ValuePtr &var) {
257     auto var_name = var->GetValue().cast<std::string>();
258     if (var_names.find(var_name) == var_names.end()) {
259       outputs.push_back(var);
260     }
261     var_names.insert(var_name);
262   });
263   wrapper_then->SpecifyOutputs(outputs);
264   wrapper_else->SpecifyOutputs(outputs);
265 
266   ir::FunctionNodePtr func = wrapper_then->Wrapper();
267   func->MarkNoNeedGenParameters();
268   AnfNodePtrList args;
269   std::transform(func->GetParameters().begin(), func->GetParameters().end(), std::back_inserter(args),
270                  [this](const ir::NodePtr &param) {
271                    auto name = param->cast<ir::ParameterPtr>()->GetName();
272                    MS_EXCEPTION_IF_CHECK_FAIL(assigned_vars_.find(name) != assigned_vars_.end(),
273                                               "Local var " + name + " is not defined.");
274                    return assigned_vars_.at(name);
275                  });
276   FuncGraphBuilderPtr builder = std::make_shared<FuncGraphBuilder>(func, args, NewValueNode(kNone));
277   auto graph_true = GetValueNode<FuncGraphPtr>(builder->Mutate(func)->cast<MindNodePtr>()->GetAnfNode());
278 
279   func = wrapper_else->Wrapper();
280   func->MarkNoNeedGenParameters();
281   args.clear();
282   std::transform(func->GetParameters().begin(), func->GetParameters().end(), std::back_inserter(args),
283                  [this](const ir::NodePtr &param) {
284                    auto name = param->cast<ir::ParameterPtr>()->GetName();
285                    MS_EXCEPTION_IF_CHECK_FAIL(assigned_vars_.find(name) != assigned_vars_.end(),
286                                               "Local var " + name + " is not defined.");
287                    return assigned_vars_.at(name);
288                  });
289   builder = std::make_shared<FuncGraphBuilder>(func, args, NewValueNode(kNone));
290   auto graph_false = GetValueNode<FuncGraphPtr>(builder->Mutate(func)->cast<MindNodePtr>()->GetAnfNode());
291   CNodePtr switch_node =
292     func_graph_->NewCNodeInOrder(prim::kPrimSwitch, {condition, NewValueNode(graph_true), NewValueNode(graph_false)});
293   CNodePtr call_switch = func_graph_->NewCNodeInOrder({switch_node});
294   return std::make_shared<MindNode>(call_switch);
295 }
296 
Mutate_(const ir::InvertNodePtr & node)297 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::InvertNodePtr &node) {
298   auto arg = GetAnfNode(node->GetArg());
299   auto op = IsValueNode<Scalar>(arg) ? GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode())
300                                      : NewValueNode(prim::GetPythonOps("logical_not", "mindspore.ops.functional"));
301   return std::make_shared<MindNode>(func_graph_->NewCNodeInOrder({op, arg}));
302 }
303 
Mutate_(const ir::ReturnNodePtr & node)304 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::ReturnNodePtr &node) {
305   auto op = GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode());
306   auto arg = GetAnfNode(node->GetArg());
307   auto ret = func_graph_->NewCNodeInOrder({op, arg});
308   func_graph_->set_return(ret);
309   return std::make_shared<MindNode>(ret);
310 }
311 
Mutate_(const ir::CastNodePtr & node)312 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::CastNodePtr &node) {
313   auto op = GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode());
314   auto arg = GetAnfNode(node->GetArg());
315   if (!IsValueNode<ValueList>(arg)) {
316     MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(arg, prim::kPrimMakeList),
317                                arg->DebugString() + " is invalid for list_to_tuple.");
318     arg->cast<CNodePtr>()->set_input(0, op);
319   } else {
320     auto value_list = GetValueNode<ValueListPtr>(arg);
321     AnfNodePtrList values = {op};
322     std::transform(value_list->value().begin(), value_list->value().end(), std::back_inserter(values),
323                    [](const ValuePtr &arg) { return NewValueNode(arg); });
324     arg = func_graph_->NewCNodeInOrder(values);
325   }
326   UpdateLocation(arg, node);
327   return std::make_shared<MindNode>(arg);
328 }
329 
Mutate_(const ir::FormatNodePtr & node)330 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::FormatNodePtr &node) {
331   auto arg = node->GetArg(0);
332   MS_EXCEPTION_IF_CHECK_FAIL(arg->isa<ir::Value>(), "The arg of format must be object.");
333   py::object top = arg->cast<ir::ValuePtr>()->GetValue();
334   py::object format;
335   auto fmt_flag = node->GetFormatType();
336   switch (fmt_flag & 0x03) {
337     case 0x00: {
338       break;
339     }
340     case 0x01: {
341       top = py::reinterpret_steal<py::object>(PyObject_Str(top.ptr()));
342       break;
343     }
344     case 0x02: {
345       top = py::reinterpret_steal<py::object>(PyObject_Repr(top.ptr()));
346       break;
347     }
348     case 0x03: {
349       top = py::reinterpret_steal<py::object>(PyObject_ASCII(top.ptr()));
350       break;
351     }
352     default: {
353       if ((fmt_flag & 0x04) == 0x04) {
354         arg = node->GetArg(1);
355         MS_EXCEPTION_IF_CHECK_FAIL(arg->isa<ir::Value>(), "The fmt must be object.");
356         format = arg->cast<ir::ValuePtr>()->GetValue();
357       }
358       break;
359     }
360   }
361   py::str obj = py::cast<py::str>(PyObject_Format(top.ptr(), format.ptr()));
362   AnfNodePtr value = GraphUtils::ConvertPythonObjectToAnfNode(obj);
363   UpdateLocation(value, node);
364   return std::make_shared<MindNode>(value);
365 }
366 
Mutate_(const ir::IsNodePtr & node)367 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::IsNodePtr &node) {
368   auto left = GetAnfNode(node->GetLeftArg());
369   auto right = GetAnfNode(node->GetRightArg());
370   PrimitivePtr prim = node->IsInvert() ? prim::kPrimIsNot : prim::kPrimIs_;
371   AnfNodePtr n = func_graph_->NewCNodeInOrder(prim, {left, right});
372   UpdateLocation(n, node);
373   return std::make_shared<MindNode>(n);
374 }
375 
Mutate_(const ir::ContainsNodePtr & node)376 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::ContainsNodePtr &node) {
377   auto left = GetAnfNode(node->GetLeftArg());
378   auto right = GetAnfNode(node->GetRightArg());
379   auto name = node->IsInvert() ? "not_in_" : "in_";
380   CNodePtr n = func_graph_->NewCNodeInOrder({GraphUtils::GetMetaFuncGraph(name), left, right});
381   UpdateLocation(n, node);
382   return std::make_shared<MindNode>(n);
383 }
384 
Mutate_(const ir::StoreNodePtr & node)385 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::StoreNodePtr &node) {
386   auto left = GetAnfNode(node->GetLeftArg());
387   auto right = GetAnfNode(node->GetRightArg());
388   MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<StringImm>(right), "Excepted var name.");
389   assigned_vars_[GetValue<std::string>(GetValueNode(right))] = left;
390   return nullptr;
391 }
392 
Mutate_(const ir::CompareNodePtr & node)393 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::CompareNodePtr &node) {
394   std::vector<std::string> ops = {
395     "less",
396     "less_equal",
397     "equal",
398     "not_equal",
399     "greater",
400     "greater_equal"
401 #if (PY_MAJOR_VERSION == 3 && (PY_MINOR_VERSION == 7 || PY_MINOR_VERSION == 8))
402     ,
403     "in_",
404     "not_in_",
405     "is",
406     "is_not"
407 #endif  // #if (PY_MAJOR_VERSION == 3 && (PY_MINOR_VERSION == 7 || PY_MINOR_VERSION == 8))
408   };
409   auto left = GetAnfNode(node->GetLeftArg());
410   auto right = GetAnfNode(node->GetRightArg());
411   const std::string &op = ops[node->GetInstrArg()];
412 #if (PY_MAJOR_VERSION == 3 && (PY_MINOR_VERSION == 7 || PY_MINOR_VERSION == 8))
413   if (op == "is" || op == "is_not") {
414     auto prim = (op == "is") ? prim::kPrimIs_ : prim::kPrimIsNot;
415     return std::make_shared<MindNode>(func_graph_->NewCNodeInOrder(prim, {left, right}));
416   }
417 #endif  // #if (PY_MAJOR_VERSION == 3 && (PY_MINOR_VERSION == 7 || PY_MINOR_VERSION == 8))
418   CNodePtr n = func_graph_->NewCNodeInOrder({GraphUtils::GetMetaFuncGraph(op), left, right});
419   UpdateLocation(n, node);
420   return std::make_shared<MindNode>(n);
421 }
422 
Mutate_(const ir::UpdateNodePtr & node)423 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::UpdateNodePtr &node) {
424   auto left = GetAnfNode(node->GetLeftArg());
425   auto right = GetAnfNode(node->GetRightArg());
426 #if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 7)
427   if (node->GetOpCode() == MAP_ADD) {
428     left.swap(right);
429   }
430 #endif  // #if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 7)
431   AnfNodePtr n = nullptr;
432   if (node->GetOpCode() == LIST_EXTEND) {
433     n = MergeList(left, right);
434   } else {
435     n = MergeDict(left, right);
436   }
437   UpdateLocation(n, node);
438   return std::make_shared<MindNode>(n);
439 }
440 
Mutate_(const ir::LoadValueNodePtr & node)441 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::LoadValueNodePtr &node) {
442   AnfNodePtr n = GetAnfNode(node->GetArg(0));
443   ir::OpCode op_code = node->GetOpCode();
444   if (op_code == LOAD_FAST) {
445     MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<StringImm>(n), "Invalid var name.");
446     std::string key = GetValue<std::string>(GetValueNode(n));
447     auto found = assigned_vars_.find(key);
448     MS_EXCEPTION_IF_CHECK_FAIL(found != assigned_vars_.end(), "Found not define var " + key + ".");
449     n = assigned_vars_[key];
450   } else if (op_code == LOAD_DEREF || op_code == LOAD_CLASSDEREF) {
451     auto arg = node->GetArg();
452     MS_EXCEPTION_IF_CHECK_FAIL(arg->isa<ir::Value>(), "Excepted a python object as arg of load_closure.");
453     auto cell = arg->cast<ir::ValuePtr>()->GetValue();
454     n = GraphUtils::ConvertPythonObjectToAnfNode(py::cast<py::object>(PyCell_Get(cell.ptr())));
455   } else {
456     // no need to do anything
457     MS_EXCEPTION_IF_CHECK_FAIL((op_code == LOAD_CONST || op_code == LOAD_GLOBAL || op_code == LOAD_CLOSURE),
458                                "Not Expected bytecode.");
459   }
460   UpdateLocation(n, node);
461   return std::make_shared<MindNode>(n);
462 }
463 
Mutate_(const ir::LoadFieldNodePtr & node)464 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::LoadFieldNodePtr &node) {
465   auto instance = GetAnfNode(node->GetArg(0));
466   auto field = GetAnfNode(node->GetArg(1));
467   MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<StringImm>(field), "Excepted attr/name.");
468   return std::make_shared<MindNode>(func_graph_->NewCNodeInOrder(prim::kPrimGetAttr, {instance, field}));
469 }
470 
GetConstKeys(const AnfNodePtr & node)471 AnfNodePtrList GetConstKeys(const AnfNodePtr &node) {
472   AnfNodePtrList keys;
473   if (node->isa<CNode>()) {
474     auto inputs = node->cast<CNodePtr>()->inputs();
475     keys.insert(keys.begin(), inputs.begin() + 1, inputs.end());
476   } else {
477     MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<ValueTuple>(node), "The keys must be a ValueTuple.");
478     auto tuple = GetValueNode<ValueTuplePtr>(node);
479     std::transform(tuple->value().begin(), tuple->value().end(), std::back_inserter(keys),
480                    [](const ValuePtr &value) { return NewValueNode(value); });
481   }
482   return keys;
483 }
484 
Mutate_(const ir::BuildNodePtr & node)485 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::BuildNodePtr &node) {
486   AnfNodePtrList array;
487   std::transform(node->GetArgs().begin(), node->GetArgs().end(), std::back_inserter(array),
488                  [&](const ir::NodePtr &arg) { return GetAnfNode(arg); });
489   auto prim = GraphUtils::GetPrimitive(node->GetOpCode());
490   if (prim == prim::kPrimStringConcat) {
491     while (array.size() > 2) {
492       size_t array_new_size = array.size() - 2;
493       CNodePtr string_concat =
494         func_graph_->NewCNodeInOrder(prim::kPrimStringConcat, {array.back(), array[array_new_size]});
495       array.resize(array_new_size);
496       array.push_back(string_concat);
497     }
498   }
499   if (prim == prim::kPrimMakeSlice) {
500     while (array.size() < 3) {
501       array.push_back(NewValueNode(kNone));
502     }
503   }
504   if (prim == prim::kPrimMakeDict) {
505     AnfNodePtrList keys;
506     AnfNodePtrList values;
507     if (node->GetOpCode() == BUILD_MAP) {
508       for (size_t index = 0; index < array.size(); index += 2) {
509         values.push_back(array[index]);
510         keys.push_back(array[index + 1]);
511       }
512     } else {
513       auto key_list = GetConstKeys(array.back());
514       keys.insert(keys.begin(), key_list.begin(), key_list.end());
515       values.insert(values.begin(), array.begin(), array.end() - 1);
516       MS_EXCEPTION_IF_CHECK_FAIL((keys.size() == values.size()), "The keys and values of Dict are not match.");
517     }
518     CNodePtr cnode_keys = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, keys);
519     CNodePtr cnode_values = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, values);
520     array.clear();
521     array.push_back(cnode_keys);
522     array.push_back(cnode_values);
523   }
524   CNodePtr n = func_graph_->NewCNodeInOrder(prim, array);
525   UpdateLocation(n, node);
526   return std::make_shared<MindNode>(n);
527 }
528 
Mutate_(const ir::CallNodePtr & node)529 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::CallNodePtr &node) {
530   AnfNodePtrList nodes;
531   if (node->GetOpCode() == CALL_FUNCTION_KW || node->GetOpCode() == CALL_FUNCTION_EX) {
532     nodes.push_back(NewValueNode(std::make_shared<prim::UnpackCall>("unpack_call")));
533   }
534   std::transform(node->GetArgs().begin(), node->GetArgs().end(), std::back_inserter(nodes),
535                  [&](const ir::NodePtr &arg) { return GetAnfNode(arg); });
536 
537   if (node->GetOpCode() == CALL_FUNCTION_KW) {
538     MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(nodes.back(), prim::kPrimMakeTuple), "Expected tuple node.");
539     CNodePtr keys_cnode = nodes.back()->cast<CNodePtr>();
540     size_t args_cnt = nodes.size() - keys_cnode->size();
541     nodes.pop_back();
542     AnfNodePtrList values(nodes.begin() + args_cnt, nodes.end());
543     nodes.resize(args_cnt);
544     if (args_cnt > 2) {
545       AnfNodePtrList pos_args(nodes.begin() + 2, nodes.end());
546       nodes.resize(2);
547       CNodePtr pos_args_cnode = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, std::move(pos_args));
548       nodes.push_back(pos_args_cnode);
549     }
550     CNodePtr values_cnode = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, std::move(values));
551     CNodePtr kwargs_node = func_graph_->NewCNodeInOrder(prim::kPrimMakeDict, {keys_cnode, values_cnode});
552     nodes.push_back(kwargs_node);
553   }
554   auto n = func_graph_->NewCNodeInOrder(std::move(nodes));
555   UpdateLocation(n, node);
556   return std::make_shared<MindNode>(n);
557 }
558 
Mutate_(const ir::SubscrNodePtr & node)559 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::SubscrNodePtr &node) {
560   auto object = GetAnfNode(node->GetObject());
561   auto subscr = GetAnfNode(node->GetSubscr());
562   CNodePtr n = func_graph_->NewCNodeInOrder({GraphUtils::GetMetaFuncGraph(BINARY_SUBSCR), object, subscr});
563   UpdateLocation(n, node);
564   return std::make_shared<MindNode>(n);
565 }
566 
Mutate_(const ir::AttrNodePtr & node)567 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::AttrNodePtr &node) {
568   auto object = GetAnfNode(node->GetObject());
569   auto attr = GetAnfNode(node->GetAttr());
570   CNodePtr n = func_graph_->NewCNodeInOrder(prim::kPrimGetAttr, {object, attr});
571   UpdateLocation(n, node);
572   return std::make_shared<MindNode>(n);
573 }
574 
GetAnfNode(const ir::NodePtr & node)575 AnfNodePtr FuncGraphBuilder::GetAnfNode(const ir::NodePtr &node) {
576   if (node->isa<MindNode>()) {
577     return node->cast<MindNodePtr>()->GetAnfNode();
578   }
579   if (node->isa<ir::RefNode>()) {
580     return GetAnfNode(node->cast<ir::RefNodePtr>()->GetRealNode());
581   }
582   return GetAnfNode(Mutate(node));
583 }
584 
MergeList(const AnfNodePtr & left,const AnfNodePtr & right)585 AnfNodePtr FuncGraphBuilder::MergeList(const AnfNodePtr &left, const AnfNodePtr &right) {
586   MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(left, prim::kPrimMakeList), "Invalid args of list extend target.");
587   MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<ValueTuple>(right), "Invalid args of list extend.");
588   auto inputs = left->cast<CNodePtr>()->inputs();
589   AnfNodePtrList values(inputs.begin() + 1, inputs.end());
590   auto valueTuple = GetValuePtr<ValueTuple>(right);
591   std::for_each(valueTuple->value().begin(), valueTuple->value().end(),
592                 [&](const ValuePtr &value) { return values.push_back(NewValueNode(value)); });
593   return func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, values);
594 }
595 
GetKeysAndValueOfDict(const AnfNodePtr & node)596 std::pair<AnfNodePtrList, AnfNodePtrList> FuncGraphBuilder::GetKeysAndValueOfDict(const AnfNodePtr &node) {
597   AnfNodePtrList keys;
598   AnfNodePtrList values;
599   if (node->isa<Parameter>()) {
600     auto param = node->cast<ParameterPtr>();
601     return GetKeysAndValueOfDict(assigned_vars_.at(param->name()));
602   } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
603     auto key_tuple = node->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->inputs();
604     keys.assign(key_tuple.begin() + 1, key_tuple.end());
605     auto value_tuple = node->cast<CNodePtr>()->input(2)->cast<CNodePtr>()->inputs();
606     values.assign(value_tuple.begin() + 1, value_tuple.end());
607   } else {
608     MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<ValueDictionary>(node), "Can't convert non-dictionary to dict node.");
609     auto dict = GetValueNode<ValueDictionaryPtr>(node);
610     std::for_each(dict->value().begin(), dict->value().end(), [&](const auto &kv) {
611       keys.push_back(NewValueNode(kv.first));
612       values.push_back(NewValueNode(kv.second));
613     });
614   }
615   return std::make_pair(keys, values);
616 }
617 
IsEmptyTuple(const AnfNodePtr & node)618 bool IsEmptyTuple(const AnfNodePtr &node) {
619   return (IsValueNode<ValueTuple>(node) && GetValueNode<ValueTuplePtr>(node)->size() == 0) ||
620          (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && node->cast<CNodePtr>()->size() == 1);
621 }
622 
IsEmptyDict(const AnfNodePtr & node)623 bool IsEmptyDict(const AnfNodePtr &node) {
624   return (IsValueNode<ValueDictionary>(node) && GetValueNode<ValueDictionaryPtr>(node)->size() == 0) ||
625          (IsPrimitiveCNode(node, prim::kPrimMakeDict) &&
626           (node->cast<CNodePtr>()->size() == 1 || IsEmptyTuple(node->cast<CNodePtr>()->input(1))));
627 }
628 
MergeDict(const AnfNodePtr & left,const AnfNodePtr & right)629 AnfNodePtr FuncGraphBuilder::MergeDict(const AnfNodePtr &left, const AnfNodePtr &right) {
630   MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(left, prim::kPrimMakeDict), "Invalid args of dict merge target.");
631   if (IsEmptyDict(left)) {
632     return right;
633   }
634   if (IsEmptyDict(right)) {
635     return left;
636   }
637   auto kv = GetKeysAndValueOfDict(left);
638   AnfNodePtrList keys(kv.first.begin(), kv.first.end());
639   AnfNodePtrList values(kv.second.begin(), kv.second.end());
640   kv = GetKeysAndValueOfDict(right);
641   keys.insert(keys.end(), kv.first.begin(), kv.first.end());
642   values.insert(values.end(), kv.second.begin(), kv.second.end());
643   CNodePtr keys_cnode = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, keys);
644   CNodePtr values_cnode = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, values);
645   return func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), keys_cnode, values_cnode});
646 }
647 }  // namespace pijit
648 }  // namespace mindspore
649