1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/pi/graph_compiler/func_graph_builder.h"
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "frontend/operator/composite/unpack_call.h"
26 #include "frontend/operator/ops.h"
27 #include "include/common/utils/convert_utils_py.h"
28 #include "include/common/utils/python_adapter.h"
29 #include "ops/framework_ops.h"
30 #include "ops/math_ops.h"
31 #include "ops/sequence_ops.h"
32 #include "ops/structure_ops.h"
33 #include "pipeline/jit/pi/graph_compiler/func_wrapper.h"
34 #include "pipeline/jit/pi/graph_compiler/pi_ir/ir_mutator.h"
35 #include "pipeline/jit/pi/graph_compiler/utils.h"
36 #include "pipeline/jit/ps/parse/parse.h"
37 #include "utils/log_adapter.h"
38
39 namespace mindspore {
40 namespace pijit {
41 namespace ir {
__anonf3d3890b0102(const NodePtr &node, IRMutator *m) 42 STATIC_IR_FUNCTOR(IRMutator, vtable).set_dispatch<MindNode>([](const NodePtr &node, IRMutator *m) { return node; });
43 } // namespace ir
44
BuildFuncGraph(const ir::FunctionNodePtr & func,const py::tuple & args,const py::dict & kwargs)45 FuncGraphPtr FuncGraphBuilder::BuildFuncGraph(const ir::FunctionNodePtr &func, const py::tuple &args,
46 const py::dict &kwargs) {
47 AnfNodePtrList anf_args;
48 bool broaden = func->GetAttr("enable_tuple_broaden");
49 std::transform(args.begin(), args.end(), std::back_inserter(anf_args), [broaden](const py::handle &arg) {
50 auto node = GraphUtils::ConvertPythonObjectToAnfNode(py::cast<py::object>(arg));
51 node->set_abstract(GraphUtils::ArgsToAbstract(py::cast<py::object>(arg), GetValueNode<ValuePtr>(node), broaden));
52 return node;
53 });
54 AnfNodePtr anf_kwargs = GraphUtils::ConvertPythonObjectToAnfNode(kwargs);
55 anf_kwargs->set_abstract(GraphUtils::ArgsToAbstract(kwargs, GetValueNode<ValuePtr>(anf_kwargs), broaden));
56 return BuildFuncGraph(func, anf_args, anf_kwargs);
57 }
58
BuildFuncGraph(const ir::FunctionNodePtr & func,const AnfNodePtrList & args,const AnfNodePtr & kwargs)59 FuncGraphPtr FuncGraphBuilder::BuildFuncGraph(const ir::FunctionNodePtr &func, const AnfNodePtrList &args,
60 const AnfNodePtr &kwargs) {
61 auto builder = std::make_shared<FuncGraphBuilder>(func, args, kwargs);
62 parse::Parser::UpdateTopFuncGraph(builder->func_graph_);
63 auto func_graph = GetValueNode<FuncGraphPtr>(builder->Mutate(func)->cast<MindNodePtr>()->GetAnfNode());
64 return func_graph;
65 }
66
67 #define DEFINE_UN_NODE_MUTATE_(OP) \
68 ir::NodePtr FuncGraphBuilder::Mutate_(const OP &node) { \
69 auto op = GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode()); \
70 auto n = func_graph_->NewCNodeInOrder({op, GetAnfNode(node->GetArg())}); \
71 UpdateLocation(n, node); \
72 return std::make_shared<MindNode>(n); \
73 }
74
75 DEFINE_UN_NODE_MUTATE_(ir::NegativeNodePtr)
DEFINE_UN_NODE_MUTATE_(ir::NotNodePtr)76 DEFINE_UN_NODE_MUTATE_(ir::NotNodePtr)
77
78 #define DEFINE_BIN_NODE_MUTATE_(OP) \
79 ir::NodePtr FuncGraphBuilder::Mutate_(const OP &node) { \
80 auto op = GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode()); \
81 auto left = GetAnfNode(node->GetLeftArg()); \
82 auto right = GetAnfNode(node->GetRightArg()); \
83 CNodePtr n = func_graph_->NewCNodeInOrder({op, left, right}); \
84 UpdateLocation(n, node); \
85 return std::make_shared<MindNode>(n); \
86 }
87
88 DEFINE_BIN_NODE_MUTATE_(ir::AddNodePtr)
89 DEFINE_BIN_NODE_MUTATE_(ir::SubNodePtr)
90 DEFINE_BIN_NODE_MUTATE_(ir::MulNodePtr)
91 DEFINE_BIN_NODE_MUTATE_(ir::DivNodePtr)
92 DEFINE_BIN_NODE_MUTATE_(ir::BitwiseNodePtr)
93 DEFINE_BIN_NODE_MUTATE_(ir::BinaryOperationPtr)
94
95 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::RefNodePtr &node) {
96 node->SetRealNode(Mutate(node->GetRealNode()));
97 return node;
98 }
99
Mutate_(const ir::ParameterPtr & node)100 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::ParameterPtr &node) {
101 auto name = node->GetName();
102 auto index = node->GetIndex();
103 if (!func_->NeedGenParameters()) {
104 MS_EXCEPTION_IF_CHECK_FAIL(static_cast<size_t>(index) < args_.size(), "Invalid paramete[" + name + "].");
105 assigned_vars_[name] = args_[index];
106 return node;
107 }
108 auto param = std::make_shared<Parameter>(func_graph_);
109 MS_EXCEPTION_IF_NULL(param);
110 param->set_name(name);
111 MS_EXCEPTION_IF_NULL(param->debug_info());
112 param->debug_info()->set_name(name);
113 UpdateLocation(param, node);
114 auto category = node->GetCategory();
115 // kwargs
116 if (category == ir::Parameter::KEYWORD) {
117 param->set_abstract(kwargs_->abstract());
118 } else {
119 MS_EXCEPTION_IF_CHECK_FAIL(node->GetIndex() < args_.size(), "Parameter " + name + " has no arguments");
120 param->set_abstract(args_[node->GetIndex()]->abstract());
121 }
122 if (param->abstract()->isa<abstract::AbstractRefTensor>()) {
123 auto abs_ref = param->abstract()->cast<abstract::AbstractRefPtr>();
124 auto new_name = name + "_" + abs_ref->ref_key_value()->ToString();
125 param->set_name(new_name);
126 param->debug_info()->set_name(new_name);
127 }
128 auto defalut_value = node->GetDefaultValue();
129 if (defalut_value != nullptr) {
130 AnfNodePtr value_node = GetAnfNode(defalut_value);
131 func_graph_->set_param_default_value(param->name(), value_node);
132 }
133 assigned_vars_[name] = param;
134 func_graph_->add_parameter(param);
135 return std::make_shared<MindNode>(param);
136 }
137
Mutate_(const ir::FunctionNodePtr & node)138 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::FunctionNodePtr &node) {
139 func_graph_->set_has_vararg(node->HasVarArg());
140 func_graph_->set_kwonlyargs_count(node->GetKwOnlyArgsCnt());
141 func_graph_->set_has_kwarg(node->HasKwArg());
142 func_graph_->debug_info()->set_name(node->GetName());
143 // used for create sub function
144 node->Sort();
145 auto first_if_node = std::find_if(node->GetNodes().begin(), node->GetNodes().end(),
146 [](const ir::NodePtr &n) { return n->isa<ir::IfNode>(); });
147 if (first_if_node != node->GetNodes().end()) {
148 size_t index = std::distance(node->GetNodes().begin(), first_if_node);
149 std::vector<ir::NodePtr> nodes(node->GetNodes().begin() + index + 1, node->GetNodes().end());
150 node->GetNodes().resize(index + 1);
151 auto if_node = node->GetNodes().back()->cast<ir::IfNodePtr>();
152 if (if_node->GetThen().empty() || !if_node->GetThen().back()->isa<ir::ReturnNode>()) {
153 if (!if_node->GetThen().empty() && if_node->GetThen().back()->isa<ir::JumpNode>()) {
154 if_node->GetThen().pop_back();
155 }
156 std::for_each(nodes.begin(), nodes.end(), [&if_node](const ir::NodePtr &n) { if_node->AddThen(n); });
157 }
158 if (if_node->GetElse().empty() || !if_node->GetElse().back()->isa<ir::ReturnNode>()) {
159 if (!if_node->GetElse().empty() && if_node->GetElse().back()->isa<ir::JumpNode>()) {
160 if_node->GetElse().pop_back();
161 }
162 std::for_each(nodes.begin(), nodes.end(), [&if_node](const ir::NodePtr &n) { if_node->AddElse(n); });
163 }
164 node->AddNode(std::make_shared<ir::ReturnNode>(if_node));
165 }
166 MUTATE_NODE_LIST(node->GetParameters())
167 MUTATE_NODE_LIST(node->GetNodes())
168 return std::make_shared<MindNode>(NewValueNode(func_graph_));
169 }
170
UpdateLocation(const AnfNodePtr & anf_node,const ir::NodePtr & node)171 void FuncGraphBuilder::UpdateLocation(const AnfNodePtr &anf_node, const ir::NodePtr &node) {
172 if (!enable_debug_info_) {
173 return;
174 }
175 // Refer to Location::Location() for each node: line, column, line_end, column_end, expr_src.
176 auto debug_info = node->GetDebugInfo();
177 auto line_no = debug_info->GetLineNo();
178 line_no = (line_no == 0) ? last_line_no_ : line_no;
179 last_line_no_ = line_no;
180 auto loc = anf_node->debug_info()->location();
181 if (loc == nullptr) {
182 std::vector<std::string> comments;
183 anf_node->debug_info()->set_location(std::make_shared<Location>(debug_info->GetFileName(), line_no, 0, line_no, 0,
184 debug_info->GetDesc(), std::move(comments)));
185 } else {
186 loc->set_file_name(debug_info->GetFileName());
187 loc->set_line(line_no);
188 loc->set_line_end(line_no);
189 loc->set_expr_src(debug_info->GetDesc());
190 }
191 }
192
ConvertListOrTupleToCNode(const py::object & obj)193 AnfNodePtr FuncGraphBuilder::ConvertListOrTupleToCNode(const py::object &obj) {
194 MS_EXCEPTION_IF_CHECK_FAIL((py::isinstance<py::list>(obj) || py::isinstance<py::tuple>(obj)),
195 "Should be a list or tuple.");
196 auto tuple = py::cast<py::tuple>(obj);
197 auto parameter = python_adapter::GetPyObjAttr(python_adapter::GetPyModule("mindspore"), "Parameter");
198 auto prim = py::isinstance<py::list>(obj) ? prim::kPrimMakeList : prim::kPrimMakeTuple;
199 CNodePtr cnode = func_graph_->NewCNodeInOrder(prim, {});
200 for (size_t idx = 0; idx < tuple.size(); idx++) {
201 if (py::isinstance(tuple[idx], parameter)) {
202 cnode->add_input(parse::ResolveParameterObj(func_graph_, tuple[idx]));
203 } else if (py::isinstance<py::list>(tuple[idx]) || py::isinstance<py::tuple>(tuple[idx])) {
204 cnode->add_input(ConvertListOrTupleToCNode(tuple[idx]));
205 } else {
206 cnode->add_input(GraphUtils::ConvertPythonObjectToAnfNode(tuple[idx]));
207 }
208 }
209 return cnode;
210 }
211
Mutate_(const ir::ValuePtr & node)212 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::ValuePtr &node) {
213 auto obj = node->GetValue();
214 auto dict = python_adapter::GetPyObjAttr(python_adapter::GetPyModule("mindspore._extends.parse.resources"),
215 "convert_object_map");
216 auto special_obj = PyDict_GetItem(dict.ptr(), obj.ptr());
217 if (special_obj != nullptr) {
218 obj = py::cast<py::object>(special_obj);
219 }
220 bool is_list_or_tuple = (py::isinstance<py::list>(obj) || py::isinstance<py::tuple>(obj));
221 auto value = is_list_or_tuple ? ConvertListOrTupleToCNode(obj) : GraphUtils::ConvertPythonObjectToAnfNode(obj);
222 UpdateLocation(value, node);
223 return std::make_shared<MindNode>(value);
224 }
225
Mutate_(const ir::IfNodePtr & node)226 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::IfNodePtr &node) {
227 auto cond = node->GetCondition();
228 MS_EXCEPTION_IF_CHECK_FAIL(cond->isa<ir::JumpNode>(), cond->ToString() + " can't be a condition.");
229 auto jump = cond->cast<ir::JumpNodePtr>();
230 auto condition = Mutate(jump->GetCondition())->cast<MindNodePtr>()->GetAnfNode();
231 if (jump->GetOpCode() == POP_JUMP_IF_TRUE) {
232 condition = func_graph_->NewCNodeInOrder({GraphUtils::GetMetaFuncGraph(UNARY_NOT), condition});
233 }
234 auto _then = node->GetThen();
235 if (!_then.empty() && _then.back()->isa<ir::JumpNode>()) {
236 _then.pop_back();
237 }
238 const std::string prefix = func_->GetName() + "_sub_func_" + std::to_string(cond->GetOffset());
239 FuncWrapperPtr wrapper_then = std::make_shared<FuncWrapper>(prefix + "_true", _then);
240 auto outputs_then = wrapper_then->GetOutputs();
241 auto _else = node->GetElse();
242 if (!_else.empty() && _else.back()->isa<ir::JumpNode>()) {
243 _else.pop_back();
244 }
245 FuncWrapperPtr wrapper_else = std::make_shared<FuncWrapper>(prefix + "_false", _else);
246 auto outputs_else = wrapper_else->GetOutputs();
247 std::set<std::string> var_names;
248 std::vector<ir::ValuePtr> outputs;
249 std::for_each(outputs_then.begin(), outputs_then.end(), [&outputs, &var_names](const ir::ValuePtr &var) {
250 auto var_name = var->GetValue().cast<std::string>();
251 if (var_names.find(var_name) == var_names.end()) {
252 outputs.push_back(var);
253 }
254 var_names.insert(var_name);
255 });
256 std::for_each(outputs_else.begin(), outputs_else.end(), [&outputs, &var_names](const ir::ValuePtr &var) {
257 auto var_name = var->GetValue().cast<std::string>();
258 if (var_names.find(var_name) == var_names.end()) {
259 outputs.push_back(var);
260 }
261 var_names.insert(var_name);
262 });
263 wrapper_then->SpecifyOutputs(outputs);
264 wrapper_else->SpecifyOutputs(outputs);
265
266 ir::FunctionNodePtr func = wrapper_then->Wrapper();
267 func->MarkNoNeedGenParameters();
268 AnfNodePtrList args;
269 std::transform(func->GetParameters().begin(), func->GetParameters().end(), std::back_inserter(args),
270 [this](const ir::NodePtr ¶m) {
271 auto name = param->cast<ir::ParameterPtr>()->GetName();
272 MS_EXCEPTION_IF_CHECK_FAIL(assigned_vars_.find(name) != assigned_vars_.end(),
273 "Local var " + name + " is not defined.");
274 return assigned_vars_.at(name);
275 });
276 FuncGraphBuilderPtr builder = std::make_shared<FuncGraphBuilder>(func, args, NewValueNode(kNone));
277 auto graph_true = GetValueNode<FuncGraphPtr>(builder->Mutate(func)->cast<MindNodePtr>()->GetAnfNode());
278
279 func = wrapper_else->Wrapper();
280 func->MarkNoNeedGenParameters();
281 args.clear();
282 std::transform(func->GetParameters().begin(), func->GetParameters().end(), std::back_inserter(args),
283 [this](const ir::NodePtr ¶m) {
284 auto name = param->cast<ir::ParameterPtr>()->GetName();
285 MS_EXCEPTION_IF_CHECK_FAIL(assigned_vars_.find(name) != assigned_vars_.end(),
286 "Local var " + name + " is not defined.");
287 return assigned_vars_.at(name);
288 });
289 builder = std::make_shared<FuncGraphBuilder>(func, args, NewValueNode(kNone));
290 auto graph_false = GetValueNode<FuncGraphPtr>(builder->Mutate(func)->cast<MindNodePtr>()->GetAnfNode());
291 CNodePtr switch_node =
292 func_graph_->NewCNodeInOrder(prim::kPrimSwitch, {condition, NewValueNode(graph_true), NewValueNode(graph_false)});
293 CNodePtr call_switch = func_graph_->NewCNodeInOrder({switch_node});
294 return std::make_shared<MindNode>(call_switch);
295 }
296
Mutate_(const ir::InvertNodePtr & node)297 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::InvertNodePtr &node) {
298 auto arg = GetAnfNode(node->GetArg());
299 auto op = IsValueNode<Scalar>(arg) ? GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode())
300 : NewValueNode(prim::GetPythonOps("logical_not", "mindspore.ops.functional"));
301 return std::make_shared<MindNode>(func_graph_->NewCNodeInOrder({op, arg}));
302 }
303
Mutate_(const ir::ReturnNodePtr & node)304 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::ReturnNodePtr &node) {
305 auto op = GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode());
306 auto arg = GetAnfNode(node->GetArg());
307 auto ret = func_graph_->NewCNodeInOrder({op, arg});
308 func_graph_->set_return(ret);
309 return std::make_shared<MindNode>(ret);
310 }
311
Mutate_(const ir::CastNodePtr & node)312 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::CastNodePtr &node) {
313 auto op = GraphUtils::GetPrimOrMetaFuncGraph(node->GetOpCode());
314 auto arg = GetAnfNode(node->GetArg());
315 if (!IsValueNode<ValueList>(arg)) {
316 MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(arg, prim::kPrimMakeList),
317 arg->DebugString() + " is invalid for list_to_tuple.");
318 arg->cast<CNodePtr>()->set_input(0, op);
319 } else {
320 auto value_list = GetValueNode<ValueListPtr>(arg);
321 AnfNodePtrList values = {op};
322 std::transform(value_list->value().begin(), value_list->value().end(), std::back_inserter(values),
323 [](const ValuePtr &arg) { return NewValueNode(arg); });
324 arg = func_graph_->NewCNodeInOrder(values);
325 }
326 UpdateLocation(arg, node);
327 return std::make_shared<MindNode>(arg);
328 }
329
Mutate_(const ir::FormatNodePtr & node)330 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::FormatNodePtr &node) {
331 auto arg = node->GetArg(0);
332 MS_EXCEPTION_IF_CHECK_FAIL(arg->isa<ir::Value>(), "The arg of format must be object.");
333 py::object top = arg->cast<ir::ValuePtr>()->GetValue();
334 py::object format;
335 auto fmt_flag = node->GetFormatType();
336 switch (fmt_flag & 0x03) {
337 case 0x00: {
338 break;
339 }
340 case 0x01: {
341 top = py::reinterpret_steal<py::object>(PyObject_Str(top.ptr()));
342 break;
343 }
344 case 0x02: {
345 top = py::reinterpret_steal<py::object>(PyObject_Repr(top.ptr()));
346 break;
347 }
348 case 0x03: {
349 top = py::reinterpret_steal<py::object>(PyObject_ASCII(top.ptr()));
350 break;
351 }
352 default: {
353 if ((fmt_flag & 0x04) == 0x04) {
354 arg = node->GetArg(1);
355 MS_EXCEPTION_IF_CHECK_FAIL(arg->isa<ir::Value>(), "The fmt must be object.");
356 format = arg->cast<ir::ValuePtr>()->GetValue();
357 }
358 break;
359 }
360 }
361 py::str obj = py::cast<py::str>(PyObject_Format(top.ptr(), format.ptr()));
362 AnfNodePtr value = GraphUtils::ConvertPythonObjectToAnfNode(obj);
363 UpdateLocation(value, node);
364 return std::make_shared<MindNode>(value);
365 }
366
Mutate_(const ir::IsNodePtr & node)367 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::IsNodePtr &node) {
368 auto left = GetAnfNode(node->GetLeftArg());
369 auto right = GetAnfNode(node->GetRightArg());
370 PrimitivePtr prim = node->IsInvert() ? prim::kPrimIsNot : prim::kPrimIs_;
371 AnfNodePtr n = func_graph_->NewCNodeInOrder(prim, {left, right});
372 UpdateLocation(n, node);
373 return std::make_shared<MindNode>(n);
374 }
375
Mutate_(const ir::ContainsNodePtr & node)376 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::ContainsNodePtr &node) {
377 auto left = GetAnfNode(node->GetLeftArg());
378 auto right = GetAnfNode(node->GetRightArg());
379 auto name = node->IsInvert() ? "not_in_" : "in_";
380 CNodePtr n = func_graph_->NewCNodeInOrder({GraphUtils::GetMetaFuncGraph(name), left, right});
381 UpdateLocation(n, node);
382 return std::make_shared<MindNode>(n);
383 }
384
Mutate_(const ir::StoreNodePtr & node)385 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::StoreNodePtr &node) {
386 auto left = GetAnfNode(node->GetLeftArg());
387 auto right = GetAnfNode(node->GetRightArg());
388 MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<StringImm>(right), "Excepted var name.");
389 assigned_vars_[GetValue<std::string>(GetValueNode(right))] = left;
390 return nullptr;
391 }
392
Mutate_(const ir::CompareNodePtr & node)393 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::CompareNodePtr &node) {
394 std::vector<std::string> ops = {
395 "less",
396 "less_equal",
397 "equal",
398 "not_equal",
399 "greater",
400 "greater_equal"
401 #if (PY_MAJOR_VERSION == 3 && (PY_MINOR_VERSION == 7 || PY_MINOR_VERSION == 8))
402 ,
403 "in_",
404 "not_in_",
405 "is",
406 "is_not"
407 #endif // #if (PY_MAJOR_VERSION == 3 && (PY_MINOR_VERSION == 7 || PY_MINOR_VERSION == 8))
408 };
409 auto left = GetAnfNode(node->GetLeftArg());
410 auto right = GetAnfNode(node->GetRightArg());
411 const std::string &op = ops[node->GetInstrArg()];
412 #if (PY_MAJOR_VERSION == 3 && (PY_MINOR_VERSION == 7 || PY_MINOR_VERSION == 8))
413 if (op == "is" || op == "is_not") {
414 auto prim = (op == "is") ? prim::kPrimIs_ : prim::kPrimIsNot;
415 return std::make_shared<MindNode>(func_graph_->NewCNodeInOrder(prim, {left, right}));
416 }
417 #endif // #if (PY_MAJOR_VERSION == 3 && (PY_MINOR_VERSION == 7 || PY_MINOR_VERSION == 8))
418 CNodePtr n = func_graph_->NewCNodeInOrder({GraphUtils::GetMetaFuncGraph(op), left, right});
419 UpdateLocation(n, node);
420 return std::make_shared<MindNode>(n);
421 }
422
Mutate_(const ir::UpdateNodePtr & node)423 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::UpdateNodePtr &node) {
424 auto left = GetAnfNode(node->GetLeftArg());
425 auto right = GetAnfNode(node->GetRightArg());
426 #if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 7)
427 if (node->GetOpCode() == MAP_ADD) {
428 left.swap(right);
429 }
430 #endif // #if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 7)
431 AnfNodePtr n = nullptr;
432 if (node->GetOpCode() == LIST_EXTEND) {
433 n = MergeList(left, right);
434 } else {
435 n = MergeDict(left, right);
436 }
437 UpdateLocation(n, node);
438 return std::make_shared<MindNode>(n);
439 }
440
Mutate_(const ir::LoadValueNodePtr & node)441 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::LoadValueNodePtr &node) {
442 AnfNodePtr n = GetAnfNode(node->GetArg(0));
443 ir::OpCode op_code = node->GetOpCode();
444 if (op_code == LOAD_FAST) {
445 MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<StringImm>(n), "Invalid var name.");
446 std::string key = GetValue<std::string>(GetValueNode(n));
447 auto found = assigned_vars_.find(key);
448 MS_EXCEPTION_IF_CHECK_FAIL(found != assigned_vars_.end(), "Found not define var " + key + ".");
449 n = assigned_vars_[key];
450 } else if (op_code == LOAD_DEREF || op_code == LOAD_CLASSDEREF) {
451 auto arg = node->GetArg();
452 MS_EXCEPTION_IF_CHECK_FAIL(arg->isa<ir::Value>(), "Excepted a python object as arg of load_closure.");
453 auto cell = arg->cast<ir::ValuePtr>()->GetValue();
454 n = GraphUtils::ConvertPythonObjectToAnfNode(py::cast<py::object>(PyCell_Get(cell.ptr())));
455 } else {
456 // no need to do anything
457 MS_EXCEPTION_IF_CHECK_FAIL((op_code == LOAD_CONST || op_code == LOAD_GLOBAL || op_code == LOAD_CLOSURE),
458 "Not Expected bytecode.");
459 }
460 UpdateLocation(n, node);
461 return std::make_shared<MindNode>(n);
462 }
463
Mutate_(const ir::LoadFieldNodePtr & node)464 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::LoadFieldNodePtr &node) {
465 auto instance = GetAnfNode(node->GetArg(0));
466 auto field = GetAnfNode(node->GetArg(1));
467 MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<StringImm>(field), "Excepted attr/name.");
468 return std::make_shared<MindNode>(func_graph_->NewCNodeInOrder(prim::kPrimGetAttr, {instance, field}));
469 }
470
GetConstKeys(const AnfNodePtr & node)471 AnfNodePtrList GetConstKeys(const AnfNodePtr &node) {
472 AnfNodePtrList keys;
473 if (node->isa<CNode>()) {
474 auto inputs = node->cast<CNodePtr>()->inputs();
475 keys.insert(keys.begin(), inputs.begin() + 1, inputs.end());
476 } else {
477 MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<ValueTuple>(node), "The keys must be a ValueTuple.");
478 auto tuple = GetValueNode<ValueTuplePtr>(node);
479 std::transform(tuple->value().begin(), tuple->value().end(), std::back_inserter(keys),
480 [](const ValuePtr &value) { return NewValueNode(value); });
481 }
482 return keys;
483 }
484
Mutate_(const ir::BuildNodePtr & node)485 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::BuildNodePtr &node) {
486 AnfNodePtrList array;
487 std::transform(node->GetArgs().begin(), node->GetArgs().end(), std::back_inserter(array),
488 [&](const ir::NodePtr &arg) { return GetAnfNode(arg); });
489 auto prim = GraphUtils::GetPrimitive(node->GetOpCode());
490 if (prim == prim::kPrimStringConcat) {
491 while (array.size() > 2) {
492 size_t array_new_size = array.size() - 2;
493 CNodePtr string_concat =
494 func_graph_->NewCNodeInOrder(prim::kPrimStringConcat, {array.back(), array[array_new_size]});
495 array.resize(array_new_size);
496 array.push_back(string_concat);
497 }
498 }
499 if (prim == prim::kPrimMakeSlice) {
500 while (array.size() < 3) {
501 array.push_back(NewValueNode(kNone));
502 }
503 }
504 if (prim == prim::kPrimMakeDict) {
505 AnfNodePtrList keys;
506 AnfNodePtrList values;
507 if (node->GetOpCode() == BUILD_MAP) {
508 for (size_t index = 0; index < array.size(); index += 2) {
509 values.push_back(array[index]);
510 keys.push_back(array[index + 1]);
511 }
512 } else {
513 auto key_list = GetConstKeys(array.back());
514 keys.insert(keys.begin(), key_list.begin(), key_list.end());
515 values.insert(values.begin(), array.begin(), array.end() - 1);
516 MS_EXCEPTION_IF_CHECK_FAIL((keys.size() == values.size()), "The keys and values of Dict are not match.");
517 }
518 CNodePtr cnode_keys = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, keys);
519 CNodePtr cnode_values = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, values);
520 array.clear();
521 array.push_back(cnode_keys);
522 array.push_back(cnode_values);
523 }
524 CNodePtr n = func_graph_->NewCNodeInOrder(prim, array);
525 UpdateLocation(n, node);
526 return std::make_shared<MindNode>(n);
527 }
528
Mutate_(const ir::CallNodePtr & node)529 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::CallNodePtr &node) {
530 AnfNodePtrList nodes;
531 if (node->GetOpCode() == CALL_FUNCTION_KW || node->GetOpCode() == CALL_FUNCTION_EX) {
532 nodes.push_back(NewValueNode(std::make_shared<prim::UnpackCall>("unpack_call")));
533 }
534 std::transform(node->GetArgs().begin(), node->GetArgs().end(), std::back_inserter(nodes),
535 [&](const ir::NodePtr &arg) { return GetAnfNode(arg); });
536
537 if (node->GetOpCode() == CALL_FUNCTION_KW) {
538 MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(nodes.back(), prim::kPrimMakeTuple), "Expected tuple node.");
539 CNodePtr keys_cnode = nodes.back()->cast<CNodePtr>();
540 size_t args_cnt = nodes.size() - keys_cnode->size();
541 nodes.pop_back();
542 AnfNodePtrList values(nodes.begin() + args_cnt, nodes.end());
543 nodes.resize(args_cnt);
544 if (args_cnt > 2) {
545 AnfNodePtrList pos_args(nodes.begin() + 2, nodes.end());
546 nodes.resize(2);
547 CNodePtr pos_args_cnode = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, std::move(pos_args));
548 nodes.push_back(pos_args_cnode);
549 }
550 CNodePtr values_cnode = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, std::move(values));
551 CNodePtr kwargs_node = func_graph_->NewCNodeInOrder(prim::kPrimMakeDict, {keys_cnode, values_cnode});
552 nodes.push_back(kwargs_node);
553 }
554 auto n = func_graph_->NewCNodeInOrder(std::move(nodes));
555 UpdateLocation(n, node);
556 return std::make_shared<MindNode>(n);
557 }
558
Mutate_(const ir::SubscrNodePtr & node)559 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::SubscrNodePtr &node) {
560 auto object = GetAnfNode(node->GetObject());
561 auto subscr = GetAnfNode(node->GetSubscr());
562 CNodePtr n = func_graph_->NewCNodeInOrder({GraphUtils::GetMetaFuncGraph(BINARY_SUBSCR), object, subscr});
563 UpdateLocation(n, node);
564 return std::make_shared<MindNode>(n);
565 }
566
Mutate_(const ir::AttrNodePtr & node)567 ir::NodePtr FuncGraphBuilder::Mutate_(const ir::AttrNodePtr &node) {
568 auto object = GetAnfNode(node->GetObject());
569 auto attr = GetAnfNode(node->GetAttr());
570 CNodePtr n = func_graph_->NewCNodeInOrder(prim::kPrimGetAttr, {object, attr});
571 UpdateLocation(n, node);
572 return std::make_shared<MindNode>(n);
573 }
574
GetAnfNode(const ir::NodePtr & node)575 AnfNodePtr FuncGraphBuilder::GetAnfNode(const ir::NodePtr &node) {
576 if (node->isa<MindNode>()) {
577 return node->cast<MindNodePtr>()->GetAnfNode();
578 }
579 if (node->isa<ir::RefNode>()) {
580 return GetAnfNode(node->cast<ir::RefNodePtr>()->GetRealNode());
581 }
582 return GetAnfNode(Mutate(node));
583 }
584
MergeList(const AnfNodePtr & left,const AnfNodePtr & right)585 AnfNodePtr FuncGraphBuilder::MergeList(const AnfNodePtr &left, const AnfNodePtr &right) {
586 MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(left, prim::kPrimMakeList), "Invalid args of list extend target.");
587 MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<ValueTuple>(right), "Invalid args of list extend.");
588 auto inputs = left->cast<CNodePtr>()->inputs();
589 AnfNodePtrList values(inputs.begin() + 1, inputs.end());
590 auto valueTuple = GetValuePtr<ValueTuple>(right);
591 std::for_each(valueTuple->value().begin(), valueTuple->value().end(),
592 [&](const ValuePtr &value) { return values.push_back(NewValueNode(value)); });
593 return func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, values);
594 }
595
GetKeysAndValueOfDict(const AnfNodePtr & node)596 std::pair<AnfNodePtrList, AnfNodePtrList> FuncGraphBuilder::GetKeysAndValueOfDict(const AnfNodePtr &node) {
597 AnfNodePtrList keys;
598 AnfNodePtrList values;
599 if (node->isa<Parameter>()) {
600 auto param = node->cast<ParameterPtr>();
601 return GetKeysAndValueOfDict(assigned_vars_.at(param->name()));
602 } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
603 auto key_tuple = node->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->inputs();
604 keys.assign(key_tuple.begin() + 1, key_tuple.end());
605 auto value_tuple = node->cast<CNodePtr>()->input(2)->cast<CNodePtr>()->inputs();
606 values.assign(value_tuple.begin() + 1, value_tuple.end());
607 } else {
608 MS_EXCEPTION_IF_CHECK_FAIL(IsValueNode<ValueDictionary>(node), "Can't convert non-dictionary to dict node.");
609 auto dict = GetValueNode<ValueDictionaryPtr>(node);
610 std::for_each(dict->value().begin(), dict->value().end(), [&](const auto &kv) {
611 keys.push_back(NewValueNode(kv.first));
612 values.push_back(NewValueNode(kv.second));
613 });
614 }
615 return std::make_pair(keys, values);
616 }
617
IsEmptyTuple(const AnfNodePtr & node)618 bool IsEmptyTuple(const AnfNodePtr &node) {
619 return (IsValueNode<ValueTuple>(node) && GetValueNode<ValueTuplePtr>(node)->size() == 0) ||
620 (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && node->cast<CNodePtr>()->size() == 1);
621 }
622
IsEmptyDict(const AnfNodePtr & node)623 bool IsEmptyDict(const AnfNodePtr &node) {
624 return (IsValueNode<ValueDictionary>(node) && GetValueNode<ValueDictionaryPtr>(node)->size() == 0) ||
625 (IsPrimitiveCNode(node, prim::kPrimMakeDict) &&
626 (node->cast<CNodePtr>()->size() == 1 || IsEmptyTuple(node->cast<CNodePtr>()->input(1))));
627 }
628
MergeDict(const AnfNodePtr & left,const AnfNodePtr & right)629 AnfNodePtr FuncGraphBuilder::MergeDict(const AnfNodePtr &left, const AnfNodePtr &right) {
630 MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(left, prim::kPrimMakeDict), "Invalid args of dict merge target.");
631 if (IsEmptyDict(left)) {
632 return right;
633 }
634 if (IsEmptyDict(right)) {
635 return left;
636 }
637 auto kv = GetKeysAndValueOfDict(left);
638 AnfNodePtrList keys(kv.first.begin(), kv.first.end());
639 AnfNodePtrList values(kv.second.begin(), kv.second.end());
640 kv = GetKeysAndValueOfDict(right);
641 keys.insert(keys.end(), kv.first.begin(), kv.first.end());
642 values.insert(values.end(), kv.second.begin(), kv.second.end());
643 CNodePtr keys_cnode = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, keys);
644 CNodePtr values_cnode = func_graph_->NewCNodeInOrder(prim::kPrimMakeTuple, values);
645 return func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), keys_cnode, values_cnode});
646 }
647 } // namespace pijit
648 } // namespace mindspore
649