1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/pi/graph_build/func_graph_builder.h"
18 #include <algorithm>
19 #include <utility>
20 #include <set>
21 #include <queue>
22 #include "frontend/operator/composite/do_signature.h"
23 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
24 #include "pipeline/jit/ps/action.h"
25 #include "pipeline/jit/ps/parse/parse_base.h"
26 #include "pipeline/jit/ps/parse/data_converter.h"
27 #include "pipeline/jit/pi/pi_jit_config.h"
28 #include "ops/arithmetic_ops.h"
29 #include "ops/structure_ops.h"
30 #include "include/common/utils/convert_utils_py.h"
31 #include "ir/tensor.h"
32 #include "ir/anf.h"
33
34 namespace mindspore {
35 namespace {
36 constexpr auto kPiJitPyObjKey = "pi_jit_py_obj";
37 constexpr auto kTensorModule = "mindspore.common";
38 constexpr auto kAdapterFlag = "adapter_flag";
39 constexpr auto kInnerOpsModule = "mindspore.ops.operations._inner_ops";
40
ShouldFallBackInRuntime(const PrimitivePtr & prim)41 bool ShouldFallBackInRuntime(const PrimitivePtr &prim) {
42 static HashSet<std::string> prims_should_fallback_in_runtime = {kListInplaceExtendOpName,
43 kListInplaceInsertOpName,
44 kListInplacePopOpName,
45 kListInplaceReverseOpName,
46 kListInplaceClearOpName,
47 kDictInplaceSetItemOpName,
48 kRaiseOpName,
49 kJoinedStrOpName,
50 kFormatOpName};
51 return prims_should_fallback_in_runtime.find(prim->name()) != prims_should_fallback_in_runtime.end();
52 }
53
IsValidScalar(const AbstractBasePtr & abs)54 bool IsValidScalar(const AbstractBasePtr &abs) {
55 auto build_type = abs->BuildType();
56 return build_type->isa<String>() || build_type->isa<Number>();
57 }
58
Mutable(const py::object & obj,const ValuePtr & value=nullptr)59 bool Mutable(const py::object &obj, const ValuePtr &value = nullptr) {
60 // If a tensor has been set const arg, it should not be mutable.
61 if (value != nullptr && value->isa<tensor::MetaTensor>()) {
62 constexpr char const_arg_attr[] = "const_arg";
63 if (py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr))) {
64 return false;
65 }
66 }
67 constexpr char mutable_attr[] = "__ms_mutable__";
68 return py::hasattr(obj, mutable_attr) && py::cast<bool>(py::getattr(obj, mutable_attr));
69 }
70
IsConstant(const py::object & obj)71 bool IsConstant(const py::object &obj) {
72 if (obj.ptr() == nullptr || Mutable(obj)) {
73 return false;
74 }
75 if (py::isinstance<py::tuple>(obj)) {
76 auto list_obj = py::cast<py::tuple>(obj);
77 return std::all_of(list_obj.begin(), list_obj.end(),
78 [](const auto &obj) { return IsConstant(py::cast<py::object>(obj)); });
79 }
80 if (py::isinstance<py::list>(obj)) {
81 auto list_obj = py::cast<py::list>(obj);
82 return std::all_of(list_obj.begin(), list_obj.end(),
83 [](const auto &obj) { return IsConstant(py::cast<py::object>(obj)); });
84 }
85 if (py::isinstance<py::dict>(obj)) {
86 auto dict_obj = py::cast<py::dict>(obj);
87 return std::all_of(dict_obj.begin(), dict_obj.end(), [](const auto &pair) {
88 return IsConstant(py::cast<py::object>(pair.first)) && IsConstant(py::cast<py::object>(pair.second));
89 });
90 }
91 // Attention: should exclude BaseTensor in the future (when the BaseTensor PR is merged)
92 return !py::isinstance<tensor::Tensor>(obj) && !IsStubTensor(obj);
93 }
94
TensorArgMutable(const py::object & obj,const ValuePtr & value)95 bool TensorArgMutable(const py::object &obj, const ValuePtr &value) {
96 if (!value->isa<tensor::MetaTensor>()) {
97 return false;
98 }
99 constexpr char const_arg_attr[] = "const_arg";
100 return !py::hasattr(obj, const_arg_attr) || !py::cast<bool>(py::getattr(obj, const_arg_attr));
101 }
102
NeedBroaden(const py::object & obj,const ValuePtr & value)103 bool NeedBroaden(const py::object &obj, const ValuePtr &value) {
104 return TensorArgMutable(obj, value) || Mutable(obj, value) || value->isa<tensor::MetaSparseTensor>();
105 }
106
GetTypeIdFromClassName(const std::string & class_name)107 TypeId GetTypeIdFromClassName(const std::string &class_name) {
108 static HashMap<std::string, TypeId> class_name_to_type_ids = {
109 {"Tensor", kObjectTypeTensorType}, {"list", kObjectTypeList},
110 {"tuple", kObjectTypeTuple}, {"int", kNumberTypeInt},
111 {"float", kNumberTypeFloat}, {"CellList", kObjectTypeList},
112 {"CellDict", kObjectTypeDictionary}};
113 auto iter = class_name_to_type_ids.find(class_name);
114 if (iter == class_name_to_type_ids.end()) {
115 return kTypeUnknown;
116 }
117 return iter->second;
118 }
119
MaybeMakeEmptyTensor(const AbstractBasePtr & abs)120 ValuePtr MaybeMakeEmptyTensor(const AbstractBasePtr &abs) {
121 auto build_value = abs->BuildValue();
122 if (abs->isa<abstract::AbstractSequence>()) {
123 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
124 std::vector<ValuePtr> value_vec;
125 for (auto &elem : abs_seq->elements()) {
126 (void)value_vec.emplace_back(MaybeMakeEmptyTensor(elem));
127 }
128 if (abs->isa<abstract::AbstractTuple>()) {
129 return std::make_shared<ValueTuple>(value_vec);
130 } else {
131 return std::make_shared<ValueList>(value_vec);
132 }
133 }
134 if (abs->isa<abstract::AbstractDictionary>()) {
135 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
136 const auto &elements = abs_dict->elements();
137 std::vector<std::pair<ValuePtr, ValuePtr>> val_dict;
138 for (auto &element : elements) {
139 auto key_value = MaybeMakeEmptyTensor(element.first);
140 auto val_value = MaybeMakeEmptyTensor(element.second);
141 (void)val_dict.emplace_back(std::pair<ValuePtr, ValuePtr>{key_value, val_value});
142 }
143 return std::make_shared<ValueDictionary>(val_dict);
144 }
145 if (build_value == kValueAny && abs->isa<abstract::AbstractTensor>()) {
146 auto abs_tensor = abs->cast<abstract::AbstractTensorPtr>();
147 TypePtr tensor_type_ptr = abs_tensor->element()->BuildType();
148 ShapeVector tensor_shape = abs_tensor->shape()->shape();
149 auto tensor = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
150 if (abs->isa<abstract::AbstractRefTensor>()) {
151 auto abs_ref_tensor = abs->cast<abstract::AbstractRefPtr>();
152 // We only need the parameter name, it was used to find the python Parameter object later
153 auto param_info = std::make_shared<ParamInfo>();
154 param_info->set_name(abs_ref_tensor->ref_key_value()->ToString());
155 tensor->set_param_info(param_info);
156 }
157 return tensor;
158 }
159 return build_value;
160 }
161
FunctionShouldBeParseInAst(const py::object & obj)162 bool FunctionShouldBeParseInAst(const py::object &obj) {
163 static mindspore::HashSet<std::string> func_names{"cast_to_adapter_tensor", "cast_to_ms_tensor"};
164 if (!py::hasattr(obj, "__name__")) {
165 return false;
166 }
167 return func_names.find(py::cast<std::string>(obj.attr("__name__"))) != func_names.end();
168 }
169
ConvertToPythonTensor(const py::object & obj,const FuncGraphBuilder::PyTensorConverter & tensor_convert_func)170 py::object ConvertToPythonTensor(const py::object &obj,
171 const FuncGraphBuilder::PyTensorConverter &tensor_convert_func) {
172 constexpr auto ms_class_attr = "__ms_class__";
173 if (py::hasattr(obj, ms_class_attr) && py::cast<bool>(py::getattr(obj, ms_class_attr))) {
174 return obj;
175 }
176 if (py::isinstance<tensor::Tensor>(obj)) {
177 return tensor_convert_func(obj);
178 }
179 if (py::isinstance<py::list>(obj) || py::isinstance<py::tuple>(obj)) {
180 auto obj_tuple = py::cast<py::tuple>(obj);
181 py::tuple ret(obj_tuple.size());
182 for (size_t i = 0; i < obj_tuple.size(); ++i) {
183 ret[i] = ConvertToPythonTensor(obj_tuple[i], tensor_convert_func);
184 }
185 if (py::isinstance<py::list>(obj)) {
186 return ret.cast<py::list>();
187 }
188 return ret;
189 }
190 if (py::isinstance<py::dict>(obj)) {
191 auto obj_dict = py::cast<py::dict>(obj);
192 for (auto item : obj_dict) {
193 obj_dict[item.first] = ConvertToPythonTensor(py::cast<py::object>(item.second), tensor_convert_func);
194 }
195 return obj_dict;
196 }
197 return obj;
198 }
199
ConvertCppTensorToPyTensor(const py::object & cpp_tensor)200 py::object ConvertCppTensorToPyTensor(const py::object &cpp_tensor) {
201 if (cpp_tensor.ptr() == nullptr || !py::isinstance<tensor::Tensor>(cpp_tensor)) {
202 return py::object();
203 }
204 bool is_adapter_tensor =
205 py::hasattr(cpp_tensor, kAdapterFlag) && py::cast<bool>(py::getattr(cpp_tensor, kAdapterFlag));
206 py::module mod = python_adapter::GetPyModule(kTensorModule);
207 auto py_tensor = python_adapter::CallPyModFn(mod, "Tensor", cpp_tensor, py::none(), py::none(), py::none(), true);
208 if (is_adapter_tensor) {
209 mod = python_adapter::GetPyModule(kInnerOpsModule);
210 py_tensor = python_adapter::CallPyModFn(mod, "convert_to_adapter_tensor", py_tensor);
211 }
212 return py_tensor;
213 }
214 } // namespace
215
ConvertPyObjToValue(const py::object & obj)216 ValuePtr FuncGraphBuilder::ConvertPyObjToValue(const py::object &obj) {
217 if (obj.ptr() == nullptr) {
218 return nullptr;
219 }
220 ValuePtr ret = nullptr;
221 try {
222 MS_LOG_TRY_CATCH_SCOPE;
223 if (!parse::ConvertData(obj, &ret)) {
224 return nullptr;
225 }
226 } catch (const std::exception &e) {
227 MS_LOG(DEBUG) << "Failed to convert python object << " << py::str(obj) << " to value. The exception:\n" << e.what();
228 return nullptr;
229 }
230 return ret;
231 }
232
ConvertToPyObj(const AbstractBasePtr & abs)233 py::object FuncGraphBuilder::ConvertToPyObj(const AbstractBasePtr &abs) {
234 static auto convert_func = [](const py::object &tensor) { return ConvertCppTensorToPyTensor(tensor); };
235 return FuncGraphBuilder::ConvertToPyObj(abs, convert_func);
236 }
237
ConvertToPyObj(const AbstractBasePtr & abs,const PyTensorConverter & tensor_convert_func)238 py::object FuncGraphBuilder::ConvertToPyObj(const AbstractBasePtr &abs, const PyTensorConverter &tensor_convert_func) {
239 if (abs->isa<abstract::AbstractNone>()) {
240 return py::none();
241 }
242
243 auto build_value = MaybeMakeEmptyTensor(abs);
244 auto py_obj = ValueToPyData(build_value, abs);
245 // Return none means failed converting.
246 if (py::isinstance<py::none>(py_obj)) {
247 return py::object();
248 }
249
250 if (pijit::kPIJitConfigDefault.GetBoolConfig(pijit::GraphJitConfig::kTraceFlag)) {
251 return ConvertToPythonTensor(py_obj, tensor_convert_func);
252 }
253
254 return py_obj;
255 }
256
ConvertObjToNode(const py::object & input_obj)257 AnfNodePtr FuncGraphBuilder::ConvertObjToNode(const py::object &input_obj) {
258 if (py::hasattr(input_obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(input_obj)) {
259 // Add the fv parameter and set its abstract.
260 return parse::ResolveParameterObj(graph_, input_obj);
261 }
262 auto val = ConvertPyObjToValue(input_obj);
263 if (val == nullptr) {
264 MS_LOG(INFO) << "The input object " << py::str(input_obj) << " convert to value failed.";
265 return nullptr;
266 }
267 // Constant value input scene, the object should be converted to value node.
268 auto node = NewValueNode(val);
269 node->set_abstract(val->ToAbstract());
270 return node;
271 }
272
EvalValue(const ValuePtr & value,const AbstractBasePtrList & inputs_abs_list)273 AbstractBasePtr FuncGraphBuilder::EvalValue(const ValuePtr &value, const AbstractBasePtrList &inputs_abs_list) {
274 if (value == nullptr) {
275 return nullptr;
276 }
277 try {
278 MS_LOG_TRY_CATCH_SCOPE;
279 if (value->isa<Primitive>()) {
280 auto prim = value->cast<PrimitivePtr>();
281 auto eval_res = abstract::EvalOnePrim(prim, inputs_abs_list);
282 if (eval_res != nullptr) {
283 return eval_res->abstract();
284 }
285 } else if (value->ToAbstract()->isa<abstract::AbstractFunction>()) {
286 auto analyze_res = pipeline::AbstractAnalyze(value, inputs_abs_list);
287 if (analyze_res.eval_result != nullptr) {
288 return analyze_res.eval_result->abstract();
289 }
290 }
291 return nullptr;
292 } catch (const std::exception &e) {
293 MS_LOG(INFO) << "Failed to EvalValue for value: " << value->ToString();
294 return nullptr;
295 }
296 }
297
CheckCallable(const ValuePtr & value,const AbstractBasePtr & abs)298 bool FuncGraphBuilder::CheckCallable(const ValuePtr &value, const AbstractBasePtr &abs) {
299 if (value == nullptr || abs == nullptr || abs->isa<abstract::AbstractAny>()) {
300 return false;
301 }
302 if (value->isa<Primitive>() && ShouldFallBackInRuntime(value->cast<PrimitivePtr>())) {
303 return false;
304 }
305 return true;
306 }
307
CheckGraphOutput(const AbstractBasePtr & abs)308 bool FuncGraphBuilder::CheckGraphOutput(const AbstractBasePtr &abs) {
309 if (abs == nullptr) {
310 return false;
311 }
312 if (abs->isa<abstract::AbstractSequence>()) {
313 const auto elements = abs->cast<abstract::AbstractSequencePtr>()->elements();
314 return std::all_of(elements.begin(), elements.end(),
315 [](const AbstractBasePtr &elem) { return CheckGraphOutput(elem); });
316 }
317 if (abs->isa<abstract::AbstractScalar>()) {
318 return IsValidScalar(abs);
319 }
320 return abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractRowTensor>() ||
321 abs->isa<abstract::AbstractMapTensor>();
322 }
323
AddLocalVariable(const py::object & obj)324 bool FuncGraphBuilder::AddLocalVariable(const py::object &obj) {
325 if (obj.ptr() == nullptr) {
326 MS_LOG(INFO) << "Failed to add local variable, py object is null";
327 return false;
328 }
329
330 auto iter = py_obj_to_node_.find(obj.ptr());
331 if (iter != py_obj_to_node_.end()) {
332 MS_LOG(INFO) << "Py object already in map, no need to add. Associated node: "
333 << ((iter->second != nullptr) ? iter->second->DebugString() : "NULL");
334 return true;
335 }
336
337 auto node = ConvertObjToNode(obj);
338 if (node == nullptr) {
339 MS_LOG(INFO) << "Failed to add local variable, convert python object to anf node failed";
340 return false;
341 }
342
343 node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(obj));
344 (void)py_obj_to_node_.emplace(obj.ptr(), node);
345 return true;
346 }
347
ReadLocalVariable(const py::object & obj)348 AnfNodePtr FuncGraphBuilder::ReadLocalVariable(const py::object &obj) {
349 auto iter = py_obj_to_node_.find(obj.ptr());
350 if (iter == py_obj_to_node_.end()) {
351 return nullptr;
352 }
353 return iter->second;
354 }
355
GetNodeByObject(const py::object & obj)356 AnfNodePtr FuncGraphBuilder::GetNodeByObject(const py::object &obj) {
357 // Search the predecessors of the current builder for the local parameter with BFS.
358 mindspore::HashSet<FuncGraphBuilder *> visited_builders;
359 std::queue<FuncGraphBuilder *> builder_queue;
360 builder_queue.push(this);
361 while (!builder_queue.empty()) {
362 const auto cur_builder = builder_queue.front();
363 MS_EXCEPTION_IF_NULL(cur_builder);
364 builder_queue.pop();
365 (void)visited_builders.insert(cur_builder);
366 auto node = cur_builder->ReadLocalVariable(obj);
367 if (node != nullptr) {
368 MS_LOG(INFO) << "Found node: " << node->DebugString() << " for python object: " << std::string(py::str(obj))
369 << " " << obj.ptr();
370 return node;
371 }
372 for (const auto &cur_pred_builder : cur_builder->prev_builders()) {
373 if (visited_builders.count(cur_pred_builder) == 0) {
374 builder_queue.push(cur_pred_builder);
375 }
376 }
377 }
378 return nullptr;
379 }
380
AddTopGraphArgsInputs(const py::object & object)381 bool FuncGraphBuilder::AddTopGraphArgsInputs(const py::object &object) {
382 // args object should always be list object.
383 if (object.ptr() == nullptr || !py::isinstance<py::list>(object)) {
384 MS_LOG(INFO) << "Get top graph args failed.";
385 return false;
386 }
387 auto args = object.cast<py::list>();
388 for (size_t i = 0; i < args.size(); ++i) {
389 auto arg = args[i].cast<py::object>();
390 if (arg.ptr() == nullptr) {
391 return false;
392 }
393 auto value = ConvertPyObjToValue(arg);
394 if (value == nullptr) {
395 return false;
396 }
397 bool broaden = NeedBroaden(arg, value);
398 AbstractBasePtr abs = abstract::ToAbstract(value, nullptr, nullptr);
399 if (broaden) {
400 abs = AbstractBroaden(abs);
401 }
402 if (abs == nullptr) {
403 MS_LOG(INFO) << "Failed to add input for python object: " << std::string(py::str(arg)) << " " << arg.ptr();
404 return false;
405 }
406 auto para = graph_->add_parameter();
407 para->set_abstract(abs);
408 para->set_is_top_graph_param(true);
409 MS_LOG(INFO) << "Add top arg input success, python object: " << py::str(arg) << ", node: " << para->DebugString()
410 << ", abstract: " << abs->ToString();
411 (void)py_obj_to_node_.emplace(arg.ptr(), para);
412 para->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(arg));
413 }
414 return true;
415 }
416
AddTopGraphVargsInputs(const py::object & vargs)417 bool FuncGraphBuilder::AddTopGraphVargsInputs(const py::object &vargs) {
418 if (vargs.ptr() == nullptr) {
419 MS_LOG(INFO) << "Top graph has no vargs input.";
420 return true;
421 }
422 auto vargs_tuple = vargs.cast<py::tuple>();
423 if (vargs_tuple.ptr() == nullptr) {
424 MS_LOG(INFO) << "Vargs object should be tuple but got: " << py::str(vargs) << ", add top graph vargs failed.";
425 return false;
426 }
427 auto value = ConvertPyObjToValue(vargs);
428 if (value == nullptr || !value->isa<ValueTuple>()) {
429 MS_LOG(INFO) << "Convert vargs to value failed, vargs: " << py::str(vargs);
430 return false;
431 }
432 auto value_tuple = value->cast<ValueTuplePtr>();
433 const auto &elements = value_tuple->value();
434 if (elements.size() != vargs_tuple.size()) {
435 MS_LOG(INFO) << "For top graph vargs, converted value element size is " << elements.size()
436 << ", python tuple element size is " << vargs_tuple.size() << ". Size not matched.";
437 return false;
438 }
439 std::vector<AbstractBasePtr> new_elements;
440 for (size_t i = 0; i < elements.size(); ++i) {
441 auto cur_obj = vargs_tuple[i].cast<py::object>();
442 auto cur_val = elements[i];
443 bool broaden = NeedBroaden(cur_obj, cur_val);
444 auto cur_abs = abstract::ToAbstract(cur_val, nullptr, nullptr);
445 if (broaden) {
446 cur_abs = AbstractBroaden(cur_abs);
447 }
448 if (cur_abs == nullptr) {
449 MS_LOG(INFO) << "Fail to convert args element " << cur_val->ToString();
450 return false;
451 }
452 new_elements.push_back(cur_abs);
453 }
454 auto new_vargs_abs = std::make_shared<abstract::AbstractTuple>(new_elements);
455 auto para = graph_->add_parameter();
456 para->set_abstract(new_vargs_abs);
457 para->set_is_top_graph_param(true);
458 MS_LOG(INFO) << "Add top vargs input success, python object: " << py::str(vargs) << ", node: " << para->DebugString()
459 << ", abstract: " << new_vargs_abs->ToString();
460 (void)py_obj_to_node_.emplace(vargs.ptr(), para);
461 para->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(vargs));
462 return true;
463 }
464
AddTopGraphKwargsInputs(const py::object & kwargs)465 bool FuncGraphBuilder::AddTopGraphKwargsInputs(const py::object &kwargs) {
466 if (kwargs.ptr() == nullptr) {
467 MS_LOG(INFO) << "Top graph has no kwargs input.";
468 return true;
469 }
470 auto kwargs_dict = kwargs.cast<py::dict>();
471 if (kwargs_dict.ptr() == nullptr) {
472 MS_LOG(INFO) << "Kwargs object should be tuple but got: " << py::str(kwargs) << ", add top graph kwargs failed.";
473 return false;
474 }
475 auto value = ConvertPyObjToValue(kwargs);
476 if (value == nullptr || !value->isa<ValueDictionary>()) {
477 MS_LOG(INFO) << "Convert kwargs to value failed, kwargs: " << py::str(kwargs);
478 return false;
479 }
480 auto value_dict = value->cast<ValueDictionaryPtr>();
481 const auto &elements = value_dict->value();
482 if (elements.size() != kwargs_dict.size()) {
483 MS_LOG(INFO) << "Kwargs dict size is " << kwargs_dict.size() << " and corresponding value dict size is "
484 << elements.size() << ". Size not matched.";
485 }
486 std::vector<abstract::AbstractElementPair> new_key_values;
487 for (size_t i = 0; i < elements.size(); ++i) {
488 auto cur_key_val = elements[i].first;
489 auto cur_val = elements[i].second;
490 auto cur_key_obj = ValueToPyData(cur_key_val);
491 if (!kwargs_dict.contains(cur_key_obj)) {
492 return false;
493 }
494 auto cur_val_obj = kwargs_dict[cur_key_obj];
495 auto cur_value_abs = abstract::ToAbstract(cur_val, nullptr, nullptr);
496 bool broaden = NeedBroaden(cur_val_obj, cur_val);
497 if (broaden) {
498 cur_value_abs = AbstractBroaden(cur_value_abs);
499 }
500 if (cur_value_abs == nullptr) {
501 MS_LOG(INFO) << "Fail to convert kwargs value element " << cur_val->ToString();
502 return false;
503 }
504 auto cur_key_abs = abstract::ToAbstract(cur_key_val, nullptr, nullptr);
505 new_key_values.push_back(abstract::AbstractElementPair{cur_key_abs, cur_value_abs});
506 }
507 auto new_kwargs_abs = std::make_shared<abstract::AbstractDictionary>(new_key_values);
508 auto para = graph_->add_parameter();
509 para->set_abstract(new_kwargs_abs);
510 para->set_is_top_graph_param(true);
511 MS_LOG(INFO) << "Add top kwargs input success, python object: " << py::str(kwargs)
512 << ", node: " << para->DebugString() << ", abstract: " << new_kwargs_abs->ToString();
513 (void)py_obj_to_node_.emplace(kwargs.ptr(), para);
514 para->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(kwargs));
515 return true;
516 }
517
AddTopGraphInputs(std::vector<py::object> packed_inputs)518 bool FuncGraphBuilder::AddTopGraphInputs(std::vector<py::object> packed_inputs) {
519 constexpr size_t args_index = 0;
520 constexpr size_t vargs_index = 1;
521 constexpr size_t kwargs_index = 2;
522 constexpr size_t packed_inputs_size = 3;
523 if (!prev_builders_.empty()) {
524 MS_LOG(INFO) << "Current builder has prev builder, add top graph parameter failed.";
525 return false;
526 }
527 if (packed_inputs.size() != packed_inputs_size) {
528 MS_LOG(INFO) << "Top graph packed inputs size is not three but " << packed_inputs.size()
529 << ", add top graph parameter failed.";
530 return false;
531 }
532 if (!AddTopGraphArgsInputs(packed_inputs[args_index])) {
533 MS_LOG(INFO) << "Add top graph args inputs failed.";
534 return false;
535 }
536 if (!AddTopGraphVargsInputs(packed_inputs[vargs_index])) {
537 MS_LOG(INFO) << "Add top graph vargs inputs failed";
538 return false;
539 }
540 if (!AddTopGraphKwargsInputs(packed_inputs[kwargs_index])) {
541 MS_LOG(INFO) << "Add top graph kwargs inputs failed";
542 return false;
543 }
544 MS_LOG(INFO) << "Add top graph inputs success.";
545 return true;
546 }
547
AddSubGraphInput(const py::object & obj)548 py::object FuncGraphBuilder::AddSubGraphInput(const py::object &obj) {
549 MS_LOG(INFO) << "Try add sub graph parameter for object: " << std::string(py::str(obj)) << " " << obj.ptr();
550 AbstractBasePtr abs = nullptr;
551 auto node = GetNodeByObject(obj);
552 if (node != nullptr) {
553 abs = node->abstract();
554 }
555 // Handle constant subgraph input.
556 if (abs == nullptr && IsConstant(obj)) {
557 auto value = ConvertPyObjToValue(obj);
558 if (value != nullptr) {
559 abs = abstract::ToAbstract(value, nullptr, nullptr);
560 }
561 }
562 if (abs == nullptr) {
563 MS_LOG(INFO) << "Failed to add input for python object: " << std::string(py::str(obj)) << " " << obj.ptr();
564 return py::object();
565 }
566 auto para = graph_->add_parameter();
567 para->set_abstract(abs);
568 para->set_is_top_graph_param(false);
569 MS_LOG(INFO) << "Add input success, node: " << para->DebugString() << " obj: " << py::str(obj) << " " << obj.ptr();
570 (void)py_obj_to_node_.emplace(obj.ptr(), para);
571 para->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(obj));
572 return obj;
573 }
574
AddNode(const py::object & callable_obj,const std::vector<py::object> & inputs_obj)575 py::object FuncGraphBuilder::AddNode(const py::object &callable_obj, const std::vector<py::object> &inputs_obj) {
576 if (!CheckCallable(callable_obj)) {
577 MS_LOG(INFO) << "The python obj " << py::str(callable_obj) << " is not callable.";
578 return py::object();
579 }
580 auto callable_value = ConvertPyObjToValue(callable_obj);
581 if (callable_value == nullptr) {
582 MS_LOG(INFO) << "Convert python object " << py::str(callable_obj) << " to value failed.";
583 return py::object();
584 }
585 if (FunctionShouldBeParseInAst(callable_obj)) {
586 return TryToAddNode(callable_value, inputs_obj);
587 }
588 return AddNode(callable_value, inputs_obj);
589 }
590
AddAttrPythonObject(const py::object & object)591 bool FuncGraphBuilder::AddAttrPythonObject(const py::object &object) {
592 if (object.ptr() == nullptr) {
593 MS_LOG(INFO) << "Convert python object with empty object, convert failed.";
594 return false;
595 }
596 // Attribute object is constant or Parameter, do not need to check constant.
597 auto node = ConvertObjToNode(object);
598 if (node == nullptr) {
599 MS_LOG(INFO) << "Convert python object " << py::str(object) << " to anf node failed.";
600 return false;
601 }
602 node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(object));
603 (void)py_obj_to_node_.emplace(object.ptr(), node);
604 return true;
605 }
606
GetInputNodesAndAbstracts(const ValuePtr & callable_value,const vector<py::object> & inputs_obj,std::vector<AnfNodePtr> * input_node_list,std::vector<AbstractBasePtr> * input_abs_list)607 bool FuncGraphBuilder::GetInputNodesAndAbstracts(const ValuePtr &callable_value, const vector<py::object> &inputs_obj,
608 std::vector<AnfNodePtr> *input_node_list,
609 std::vector<AbstractBasePtr> *input_abs_list) {
610 input_node_list->reserve(inputs_obj.size() + 1);
611 input_abs_list->reserve(inputs_obj.size());
612
613 (void)input_node_list->emplace_back(NewValueNode(callable_value));
614 for (const auto &input_obj : inputs_obj) {
615 if (input_obj.ptr() == nullptr) {
616 MS_LOG(INFO) << "The input python object of " << callable_value->ToString() << ", is NULL";
617 return false;
618 }
619 // Node with input of generator may cause change of generator, skip it in build node now.
620 if (PyGen_CheckExact(input_obj.ptr())) {
621 MS_LOG(INFO) << "The input python object is generator " << std::string(py::str(input_obj))
622 << ", do not build graph.";
623 return false;
624 }
625 auto node = GetNodeByObject(input_obj);
626 if (node == nullptr) {
627 if (!IsConstant(input_obj)) {
628 MS_LOG(INFO) << "Can not convert non-constant value to value node for obj: " << py::str(input_obj);
629 return false;
630 }
631 auto new_node = ConvertObjToNode(input_obj);
632 if (new_node == nullptr) {
633 MS_LOG(INFO) << "Convert input python object " << py::str(input_obj) << " to anf node failed.";
634 return false;
635 }
636 new_node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(input_obj));
637 (void)py_obj_to_node_.emplace(input_obj.ptr(), new_node);
638 (void)input_node_list->emplace_back(new_node);
639 (void)input_abs_list->emplace_back(new_node->abstract());
640 MS_LOG(INFO) << "Add python input " << py::str(input_obj) << " with new node " << new_node->DebugString();
641 } else {
642 (void)input_node_list->emplace_back(node);
643 (void)input_abs_list->emplace_back(node->abstract());
644 }
645 }
646 return true;
647 }
648
DoPrimitiveInferAndCheck(const PrimitivePtr & primitive,const AnfNodePtrList & input_node_list,const AbstractBasePtrList & args_abs_list)649 CNodePtr FuncGraphBuilder::DoPrimitiveInferAndCheck(const PrimitivePtr &primitive,
650 const AnfNodePtrList &input_node_list,
651 const AbstractBasePtrList &args_abs_list) {
652 try {
653 MS_LOG_TRY_CATCH_SCOPE;
654 const CNodePtr &new_node = AddPrimitiveCNode(primitive, input_node_list, args_abs_list);
655 if (new_node == nullptr) {
656 MS_LOG(INFO) << "Failed to add CNode for Primitive: " << primitive->name();
657 return nullptr;
658 }
659
660 const AbstractBasePtr &abs = GetAbstractOf(new_node);
661
662 if (!CheckCallable(primitive, abs)) {
663 MS_LOG(INFO) << "Check callable failed for Primitive: " << primitive->name();
664 return nullptr;
665 }
666 new_node->set_abstract(abs);
667 return new_node;
668 } catch (const std::exception &e) {
669 MS_LOG(INFO) << "Failed to infer Primitive: " << primitive->name() << ". The exception:\n" << e.what();
670 return nullptr;
671 }
672 }
673
AddPrimitiveCNode(const PrimitivePtr & primitive,const AnfNodePtrList & input_node_list,const AbstractBasePtrList & args_abs_list)674 CNodePtr FuncGraphBuilder::AddPrimitiveCNode(const PrimitivePtr &primitive, const AnfNodePtrList &input_node_list,
675 const AbstractBasePtrList &args_abs_list) {
676 auto op_def = mindspore::ops::GetOpDef(primitive->name());
677
678 if (op_def == nullptr) {
679 if (primitive->has_signature()) {
680 // Follow the implementations in DoSignatureEvaluator
681 AnfNodePtrList args_node_list(input_node_list.cbegin() + 1, input_node_list.cend());
682 AnfNodePtrList new_node_list =
683 prim::GetNewInputsBySignatures(graph_, primitive->ToString(), primitive, args_abs_list, args_node_list);
684
685 new_node_list.insert(new_node_list.begin(), input_node_list[0]);
686 return graph_->NewCNodeInOrder(new_node_list);
687 }
688 } else if (primitive->isa<PrimitivePy>()) {
689 // Follow the implementations in PrimitiveArgsToInputsEvaluator and DoTransPrimitiveFunctionEvaluator
690 auto arg_signatures = op_def->signatures_;
691 primitive->set_signatures(arg_signatures);
692 primitive->set_has_signature(!arg_signatures.empty());
693
694 const AnfNodePtrList &init_args = abstract::GetPrimitiveInitArgs(primitive->cast<PrimitivePyPtr>(), op_def);
695
696 AnfNodePtrList call_args(input_node_list.cbegin() + 1, input_node_list.cend());
697 AbstractBasePtrList call_abs_list;
698 (void)std::transform(call_args.cbegin(), call_args.cend(), std::back_inserter(call_abs_list),
699 [](const AnfNodePtr &node) { return FuncGraphBuilder::GetAbstractOf(node); });
700 const AnfNodePtrList &new_call_args =
701 prim::GetNewInputsBySignatures(graph_, primitive->name(), primitive, call_abs_list, call_args);
702
703 return abstract::GeneratePrimitiveCNode(
704 primitive, op_def, graph_, init_args, new_call_args,
705 [](const AnfNodePtr &node) { return FuncGraphBuilder::GetAbstractOf(node); });
706 }
707 MS_LOG(DEBUG) << "Primitive " << primitive->name() << " no need to process signatures and OpDef";
708 return graph_->NewCNodeInOrder(input_node_list);
709 }
710
GetAbstractOf(const AnfNodePtr & node)711 AbstractBasePtr FuncGraphBuilder::GetAbstractOf(const AnfNodePtr &node) {
712 if (node == nullptr) {
713 return nullptr;
714 }
715 if (node->abstract() != nullptr) {
716 return node->abstract();
717 }
718 if (node->isa<ValueNode>()) {
719 return node->cast<ValueNodePtr>()->value()->ToAbstract();
720 } else if (node->isa<CNode>()) {
721 auto cnode = node->cast<CNodePtr>();
722 if (cnode->empty() || !cnode->input(0)->isa<ValueNode>()) {
723 return nullptr;
724 }
725 ValuePtr value = cnode->input(0)->cast<ValueNodePtr>()->value();
726 std::vector<AbstractBasePtr> abs_list;
727 std::transform(cnode->inputs().begin() + 1, cnode->inputs().end(), std::back_inserter(abs_list),
728 [](const AnfNodePtr &node) {
729 if (node->abstract() == nullptr) {
730 node->set_abstract(FuncGraphBuilder::GetAbstractOf(node));
731 }
732 return node->abstract();
733 });
734 return EvalValue(value, abs_list);
735 }
736 MS_LOG(INFO) << "Unsupported Node type for GetAbstractOf() method, node: " << node->DebugString();
737 return nullptr;
738 }
739
DoInferAndCheck(const ValuePtr & callable_value,const vector<AbstractBasePtr> & input_abs_list)740 AbstractBasePtr FuncGraphBuilder::DoInferAndCheck(const ValuePtr &callable_value,
741 const vector<AbstractBasePtr> &input_abs_list) {
742 auto abs = EvalValue(callable_value, input_abs_list);
743 if (abs == nullptr) {
744 MS_LOG(DEBUG) << "Eval failed for value: " << callable_value->ToString();
745 return nullptr;
746 }
747 if (!CheckCallable(callable_value, abs)) {
748 MS_LOG(DEBUG) << "Check callable failed for value: " << callable_value->ToString() << ", abs: " << abs->ToString();
749 return nullptr;
750 }
751 return abs;
752 }
753
TryToAddNode(const ValuePtr & callable_value,const std::vector<py::object> & inputs_obj)754 py::object FuncGraphBuilder::TryToAddNode(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj) {
755 // Collect the input nodes and input abstracts.
756 std::vector<AnfNodePtr> input_node_list;
757 std::vector<AbstractBasePtr> input_abs_list;
758 if (!GetInputNodesAndAbstracts(callable_value, inputs_obj, &input_node_list, &input_abs_list)) {
759 return py::object();
760 }
761
762 CNodePtr new_node;
763 AbstractBasePtr abs;
764 if (callable_value->isa<Primitive>()) {
765 new_node = DoPrimitiveInferAndCheck(callable_value->cast<PrimitivePtr>(), input_node_list, input_abs_list);
766 if (new_node != nullptr) {
767 abs = new_node->abstract();
768 }
769 } else {
770 // Do infer and check callable.
771 abs = DoInferAndCheck(callable_value, input_abs_list);
772 if (abs != nullptr) {
773 new_node = graph_->NewCNodeInOrder(input_node_list);
774 }
775 }
776 if (new_node == nullptr || abs == nullptr) {
777 return py::object();
778 }
779
780 // Return the converted python object.
781 py::object output_py_obj;
782 if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
783 auto abs_func = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
784 auto fg = abs_func->func_graph();
785 if (fg == nullptr) {
786 return py::object();
787 }
788 auto obj = fg->python_obj();
789 if (obj == nullptr || !obj->isa<parse::PyObjectWrapper>()) {
790 return py::object();
791 }
792 output_py_obj = obj->cast_ptr<parse::PyObjectWrapper>()->obj();
793 } else {
794 auto convert_func = [this](const py::object &tensor) { return ConvertToPyTensorOrParameter(tensor); };
795 output_py_obj = ConvertToPyObj(abs, convert_func);
796 if (output_py_obj.ptr() == nullptr) {
797 MS_LOG(INFO) << "Convert abs " << abs->ToString() << " to python object failed.";
798 return py::object();
799 }
800 }
801
802 new_node->set_abstract(abs);
803 MS_LOG(INFO) << "Add node: " << new_node->DebugString() << " for python object: " << py::str(output_py_obj);
804 (void)py_obj_to_node_.emplace(output_py_obj.ptr(), new_node);
805 new_node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(output_py_obj));
806 return output_py_obj;
807 }
808
ConvertToPyTensorOrParameter(const py::object & cpp_tensor)809 py::object FuncGraphBuilder::ConvertToPyTensorOrParameter(const py::object &cpp_tensor) {
810 if (cpp_tensor.ptr() == nullptr || !py::isinstance<tensor::Tensor>(cpp_tensor)) {
811 return py::object();
812 }
813 auto tensor = py::cast<tensor::TensorPtr>(cpp_tensor);
814 if (tensor->is_parameter()) {
815 const std::string &name = tensor->param_info()->name();
816 for (auto &it : py_obj_to_node_) {
817 if (it.second == nullptr) {
818 continue;
819 }
820 const AbstractBasePtr &abs = it.second->abstract();
821 if (abs != nullptr && abs->isa<abstract::AbstractRefTensor>()) {
822 auto abs_ref_tensor = abs->cast<abstract::AbstractRefPtr>();
823 if (abs_ref_tensor->ref_key_value()->ToString() == name) {
824 return py::reinterpret_borrow<py::object>(it.first);
825 }
826 }
827 }
828 MS_LOG(INFO) << "Python Parameter not found: " << name;
829 return py::object();
830 }
831
832 return ConvertCppTensorToPyTensor(cpp_tensor);
833 }
834
AddNode(const ValuePtr & callable_value,const std::vector<py::object> & inputs_obj)835 py::object FuncGraphBuilder::AddNode(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj) {
836 if (!callable_value->ToAbstract()->isa<abstract::AbstractFunction>()) {
837 MS_LOG(INFO) << "The value " << callable_value->ToString() << " is not callable.";
838 return py::object();
839 }
840 if (callable_value->isa<FuncGraph>()) {
841 return AddFgCallNode(callable_value->cast<FuncGraphPtr>(), inputs_obj);
842 }
843 return TryToAddNode(callable_value, inputs_obj);
844 }
845
AddMultiNode(const std::string & name,const std::vector<py::object> & inputs_obj)846 py::object FuncGraphBuilder::AddMultiNode(const std::string &name, const std::vector<py::object> &inputs_obj) {
847 const std::string mod_str = "mindspore.ops.composite.multitype_ops";
848 py::module mod = py::module::import(mod_str.c_str());
849 if (!py::hasattr(mod, name.c_str())) {
850 MS_LOG(INFO) << "Fail to find multitype function graph for name " << name;
851 return py::object();
852 }
853 py::object fn = mod.attr(name.c_str());
854 return AddNode(fn, inputs_obj);
855 }
856
AddOutput(const py::object & output_obj,bool is_top_graph)857 bool FuncGraphBuilder::AddOutput(const py::object &output_obj, bool is_top_graph) {
858 auto iter = py_obj_to_node_.find(output_obj.ptr());
859 if (iter == py_obj_to_node_.end()) {
860 MS_LOG(INFO) << "The output python object " << py::str(output_obj) << " should have been added to the graph.";
861 return false;
862 }
863 auto node = iter->second;
864 MS_EXCEPTION_IF_NULL(node);
865 auto abs = node->abstract();
866 // Only top graph has restriction on return value type.
867 if (is_top_graph && !CheckGraphOutput(abs)) {
868 MS_LOG(INFO) << "The output python object " << py::str(output_obj)
869 << " should not be the graph output, abstract: " << (abs == nullptr ? "null" : abs->ToString());
870 return false;
871 }
872 (void)output_nodes_.emplace_back(node);
873 return true;
874 }
875
graph()876 FuncGraphPtr FuncGraphBuilder::graph() {
877 if (has_set_output_) {
878 return graph_;
879 }
880 if (output_nodes_.empty()) {
881 MS_LOG(DEBUG) << "The graph " << graph_->ToString() << " has not been set output.";
882 return nullptr;
883 }
884 bool all_value_node = std::any_of(output_nodes_.begin(), output_nodes_.end(),
885 [](const AnfNodePtr &node) { return node->isa<ValueNode>(); });
886 if (all_value_node) {
887 MS_LOG(INFO) << "All graph output is value node, no need to run graph.";
888 return nullptr;
889 }
890 // Single output case.
891 if (output_nodes_.size() == 1) {
892 // Use the python obj of the output node as the python obj of the func_graph output.
893 auto node_output_py_obj = output_nodes_[0]->user_data<py::object>(kPiJitPyObjKey);
894 if (node_output_py_obj == nullptr) {
895 MS_LOG(DEBUG) << "Can not find the python object of the node " << output_nodes_[0]->DebugString();
896 return nullptr;
897 }
898 graph_->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(*node_output_py_obj));
899 graph_->set_output(output_nodes_[0]);
900 has_set_output_ = true;
901 return graph_;
902 }
903 // multiple output case.
904 // Make the python tuple obj of the output nodes as the python obj of the func_graph output.
905 py::tuple output_py_obj(output_nodes_.size());
906 for (size_t i = 0; i < output_nodes_.size(); ++i) {
907 auto node_output_py_obj = output_nodes_[i]->user_data<py::object>(kPiJitPyObjKey);
908 if (node_output_py_obj == nullptr) {
909 MS_LOG(DEBUG) << "Can not find the python object of the node " << output_nodes_[i]->DebugString();
910 return nullptr;
911 }
912 output_py_obj[i] = *node_output_py_obj;
913 }
914 // Create make_tuple node and set its abstract.
915 graph_->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(output_py_obj));
916 output_nodes_.insert(output_nodes_.begin(), NewValueNode(prim::kPrimMakeTuple));
917 AbstractBasePtrList abstract_list;
918 (void)std::transform(output_nodes_.begin() + 1, output_nodes_.end(), std::back_inserter(abstract_list),
919 [](const AnfNodePtr &node) -> AbstractBasePtr { return node->abstract(); });
920 auto output_node = graph_->NewCNodeInOrder(output_nodes_);
921 auto fg_output_abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
922 output_node->set_abstract(fg_output_abs);
923
924 graph_->set_output(output_node);
925 has_set_output_ = true;
926 return graph_;
927 }
928
ClearNodeAbstract()929 void FuncGraphBuilder::ClearNodeAbstract() {
930 if (!has_set_output_) {
931 MS_LOG(INTERNAL_EXCEPTION) << "Graph not generated, can not clear abstract.";
932 }
933 // Clear all node abstract.
934 auto mng = Manage(graph_, false);
935 MS_EXCEPTION_IF_NULL(mng);
936 static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
937 for (const auto &node : mng->all_nodes()) {
938 MS_EXCEPTION_IF_NULL(node);
939 const AbstractBasePtr &prev_inferred = node->abstract();
940 auto is_func =
941 node->isa<mindspore::ValueNode>() && prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>();
942 // Keep previous inferred value for parameter and ValueNode if the inferred value is not AbstractFunction.
943 if (!node->isa<Parameter>() && !is_func) {
944 // Reset tuple/list abstract use flags.
945 if (enable_eliminate_unused_element && prev_inferred != nullptr &&
946 prev_inferred->isa<abstract::AbstractSequence>()) {
947 SetSequenceNodeElementsUseFlags(node, nullptr);
948 }
949 node->set_abstract(nullptr);
950 MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
951 }
952 }
953 }
954
AddFgCallNode(const FuncGraphPtr & fg,const vector<py::object> & inputs_obj)955 py::object FuncGraphBuilder::AddFgCallNode(const FuncGraphPtr &fg, const vector<py::object> &inputs_obj) {
956 std::vector<AnfNodePtr> input_node_list;
957 input_node_list.reserve(inputs_obj.size() + 1);
958
959 (void)input_node_list.emplace_back(NewValueNode(fg));
960 for (const auto &input_obj : inputs_obj) {
961 auto node = GetNodeByObject(input_obj);
962 if (node == nullptr) {
963 if (!IsConstant(input_obj)) {
964 MS_LOG(INFO) << "Can not convert non-constant value to value node for obj: " << py::str(input_obj);
965 return py::object();
966 }
967 auto new_node = ConvertObjToNode(input_obj);
968 if (new_node == nullptr) {
969 MS_LOG(INFO) << "Convert input python object " << py::str(input_obj) << " to anf node failed.";
970 return py::object();
971 }
972 new_node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(input_obj));
973 (void)py_obj_to_node_.emplace(input_obj.ptr(), new_node);
974 (void)input_node_list.emplace_back(new_node);
975 MS_LOG(DEBUG) << "Add constant python input " << py::str(input_obj) << " with node " << new_node->DebugString();
976 } else {
977 (void)input_node_list.emplace_back(node);
978 }
979 }
980
981 auto new_node = graph_->NewCNodeInOrder(input_node_list);
982 auto fg_output = fg->output();
983 MS_EXCEPTION_IF_NULL(fg_output);
984 auto fg_output_abs = fg_output->abstract();
985 MS_EXCEPTION_IF_NULL(fg_output_abs);
986 new_node->set_abstract(fg_output_abs);
987
988 // Use the python obj of the func_graph output as the python obj of the output node.
989 auto fg_output_obj_ptr = fg->user_data<py::object>(kPiJitPyObjKey);
990 if (fg_output_obj_ptr == nullptr) {
991 MS_LOG(DEBUG) << "Can not find the output python object of func_graph " << fg->ToString();
992 return py::object();
993 }
994 auto fg_output_obj = *fg_output_obj_ptr;
995 (void)py_obj_to_node_.emplace(fg_output_obj.ptr(), new_node);
996 new_node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(fg_output_obj));
997 return fg_output_obj;
998 }
999
CheckCallable(const py::object & obj)1000 bool FuncGraphBuilder::CheckCallable(const py::object &obj) {
1001 constexpr auto ms_class_attr = "__ms_class__";
1002 return py::isinstance<MetaFuncGraph>(obj) ||
1003 (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG) &&
1004 parse::data_converter::GetObjType(obj) != parse::RESOLVE_TYPE_CLASS_TYPE) ||
1005 FunctionShouldBeParseInAst(obj) ||
1006 (py::hasattr(obj, ms_class_attr) && py::cast<bool>(py::getattr(obj, ms_class_attr)));
1007 }
1008
ConvertMethod(const py::object & obj)1009 py::object FuncGraphBuilder::ConvertMethod(const py::object &obj) {
1010 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1011 py::tuple method_info = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_METHOD_INFO, obj);
1012 py::object class_name_obj = method_info[0];
1013 if (py::isinstance<py::none>(class_name_obj)) {
1014 MS_LOG(INFO) << "Can not get the method info of " << py::str(obj);
1015 return py::object();
1016 }
1017 auto class_name = class_name_obj.cast<std::string>();
1018 if (class_name == "Tensor" &&
1019 !py::cast<bool>(python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_IS_MS_TENSOR_METHOD, obj))) {
1020 return py::object();
1021 }
1022 auto type_id = GetTypeIdFromClassName(class_name);
1023 auto method_name = method_info[1].cast<std::string>();
1024 MS_LOG(DEBUG) << "type_id: " << type_id << ", method_name: " << method_name;
1025 Any require = pipeline::Resource::GetMethodPtr(type_id, method_name);
1026 if (require.empty()) {
1027 require = pipeline::Resource::GetAttrPtr(type_id, method_name);
1028 }
1029
1030 if (require.empty()) {
1031 MS_LOG(DEBUG) << "Can not find the method registered.";
1032 return py::object();
1033 }
1034
1035 if (require.is<std::string>()) {
1036 py::function fn = mindspore::python_adapter::GetPyFn(parse::kStandardMethodModelName, require.cast<std::string>());
1037 if (py::isinstance<py::none>(fn)) {
1038 MS_LOG(DEBUG) << "Can not find the method '" << require.cast<std::string>() << "' defined in standard_method.";
1039 return py::object();
1040 }
1041 return fn;
1042 } else if (require.is<PrimitivePtr>()) {
1043 auto ops_mod = python_adapter::GetPyModule("mindspore.ops");
1044 auto primitive_class = python_adapter::GetPyObjAttr(ops_mod, "Primitive");
1045 return primitive_class(require.cast<PrimitivePtr>()->name());
1046 }
1047 MS_LOG(DEBUG) << "The method or attr should be a string or a Primitive, but got " << require.ToString();
1048 return py::object();
1049 }
1050
RemoveOutput(const py::object & output_obj)1051 void FuncGraphBuilder::RemoveOutput(const py::object &output_obj) {
1052 auto iter = py_obj_to_node_.find(output_obj.ptr());
1053 if (iter == py_obj_to_node_.end()) {
1054 MS_LOG(WARNING) << "The output python object " << py::str(output_obj) << " should have been added to the graph.";
1055 return;
1056 }
1057 auto output_nodes_iter = std::find(output_nodes_.begin(), output_nodes_.end(), iter->second);
1058 if (output_nodes_iter == output_nodes_.end()) {
1059 MS_LOG(WARNING) << "The node " << iter->second->DebugString() << " has not been added to the graph outputs.";
1060 return;
1061 }
1062 output_nodes_.erase(output_nodes_iter);
1063 }
1064
ConvertFunction(const py::object & obj)1065 py::object FuncGraphBuilder::ConvertFunction(const py::object &obj) {
1066 auto dict = python_adapter::GetPyObjAttr(python_adapter::GetPyModule("mindspore._extends.parse.resources"),
1067 "convert_object_map");
1068 auto callable_obj_ptr = PyDict_GetItem(dict.ptr(), obj.ptr());
1069 return callable_obj_ptr == nullptr ? py::object() : py::cast<py::object>(callable_obj_ptr);
1070 }
1071
CanConstantFoldFunc(const py::object & obj)1072 bool FuncGraphBuilder::CanConstantFoldFunc(const py::object &obj) {
1073 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1074 py::object can_constant_fold = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CAN_CONSTANT_FOLD, obj);
1075 return can_constant_fold.cast<bool>();
1076 }
1077
SetGraphName(const std::string & name)1078 void FuncGraphBuilder::SetGraphName(const std::string &name) {
1079 if (name.empty()) {
1080 return;
1081 }
1082 MS_EXCEPTION_IF_NULL(graph_->debug_info());
1083 graph_->debug_info()->set_name(name);
1084 }
1085
AddPrevBuilder(const FuncGraphBuilderPtr & builder)1086 void FuncGraphBuilder::AddPrevBuilder(const FuncGraphBuilderPtr &builder) { prev_builders_.push_back(builder.get()); }
1087
ValidateCallableObject(const py::object & obj)1088 bool FuncGraphBuilder::ValidateCallableObject(const py::object &obj) {
1089 if (obj.ptr() == nullptr) {
1090 return false;
1091 }
1092 // Check if object is invalid method for CellList/CellDict, which should not be converted to graph.
1093 if (CheckInvalidCellListDictMethod(obj)) {
1094 MS_LOG(INFO) << "The object " << py::str(obj) << " is a invalid CellList/CellDict method, "
1095 << "can not convert to graph";
1096 return false;
1097 }
1098 return true;
1099 }
1100
CheckInvalidCellListDictMethod(const py::object & obj)1101 bool FuncGraphBuilder::CheckInvalidCellListDictMethod(const py::object &obj) {
1102 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1103 py::tuple method_info = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_METHOD_INFO, obj);
1104 constexpr size_t class_index = 0;
1105 constexpr size_t method_index = 1;
1106 py::object class_name_obj = method_info[class_index];
1107 if (class_name_obj.ptr() == nullptr || py::isinstance<py::none>(class_name_obj)) {
1108 return false;
1109 }
1110 auto class_name = class_name_obj.cast<std::string>();
1111 MS_LOG(INFO) << "class name: " << class_name;
1112 if (class_name != "CellList" && class_name != "CellDict") {
1113 return false;
1114 }
1115 auto method_name_obj = method_info[method_index];
1116 if (method_name_obj.ptr() == nullptr || py::isinstance<py::none>(method_name_obj)) {
1117 return false;
1118 }
1119 auto method_name = method_name_obj.cast<std::string>();
1120 static std::vector<std::string> inplace_method_name = {"clear", "update"};
1121 if (std::any_of(inplace_method_name.begin(), inplace_method_name.end(),
1122 [&method_name](const std::string &name) { return name == method_name; })) {
1123 MS_LOG(INFO) << "CellDict/CellList inplace function " << method_name << " found";
1124 return true;
1125 }
1126 auto type_id = GetTypeIdFromClassName(class_name);
1127 Any require = pipeline::Resource::GetMethodPtr(type_id, method_name);
1128 return require.empty();
1129 }
1130 } // namespace mindspore
1131