• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/pi/graph_build/func_graph_builder.h"
18 #include <algorithm>
19 #include <utility>
20 #include <set>
21 #include <queue>
22 #include "frontend/operator/composite/do_signature.h"
23 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
24 #include "pipeline/jit/ps/action.h"
25 #include "pipeline/jit/ps/parse/parse_base.h"
26 #include "pipeline/jit/ps/parse/data_converter.h"
27 #include "pipeline/jit/pi/pi_jit_config.h"
28 #include "ops/arithmetic_ops.h"
29 #include "ops/structure_ops.h"
30 #include "include/common/utils/convert_utils_py.h"
31 #include "ir/tensor.h"
32 #include "ir/anf.h"
33 
34 namespace mindspore {
35 namespace {
36 constexpr auto kPiJitPyObjKey = "pi_jit_py_obj";
37 constexpr auto kTensorModule = "mindspore.common";
38 constexpr auto kAdapterFlag = "adapter_flag";
39 constexpr auto kInnerOpsModule = "mindspore.ops.operations._inner_ops";
40 
ShouldFallBackInRuntime(const PrimitivePtr & prim)41 bool ShouldFallBackInRuntime(const PrimitivePtr &prim) {
42   static HashSet<std::string> prims_should_fallback_in_runtime = {kListInplaceExtendOpName,
43                                                                   kListInplaceInsertOpName,
44                                                                   kListInplacePopOpName,
45                                                                   kListInplaceReverseOpName,
46                                                                   kListInplaceClearOpName,
47                                                                   kDictInplaceSetItemOpName,
48                                                                   kRaiseOpName,
49                                                                   kJoinedStrOpName,
50                                                                   kFormatOpName};
51   return prims_should_fallback_in_runtime.find(prim->name()) != prims_should_fallback_in_runtime.end();
52 }
53 
IsValidScalar(const AbstractBasePtr & abs)54 bool IsValidScalar(const AbstractBasePtr &abs) {
55   auto build_type = abs->BuildType();
56   return build_type->isa<String>() || build_type->isa<Number>();
57 }
58 
Mutable(const py::object & obj,const ValuePtr & value=nullptr)59 bool Mutable(const py::object &obj, const ValuePtr &value = nullptr) {
60   // If a tensor has been set const arg, it should not be mutable.
61   if (value != nullptr && value->isa<tensor::MetaTensor>()) {
62     constexpr char const_arg_attr[] = "const_arg";
63     if (py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr))) {
64       return false;
65     }
66   }
67   constexpr char mutable_attr[] = "__ms_mutable__";
68   return py::hasattr(obj, mutable_attr) && py::cast<bool>(py::getattr(obj, mutable_attr));
69 }
70 
IsConstant(const py::object & obj)71 bool IsConstant(const py::object &obj) {
72   if (obj.ptr() == nullptr || Mutable(obj)) {
73     return false;
74   }
75   if (py::isinstance<py::tuple>(obj)) {
76     auto list_obj = py::cast<py::tuple>(obj);
77     return std::all_of(list_obj.begin(), list_obj.end(),
78                        [](const auto &obj) { return IsConstant(py::cast<py::object>(obj)); });
79   }
80   if (py::isinstance<py::list>(obj)) {
81     auto list_obj = py::cast<py::list>(obj);
82     return std::all_of(list_obj.begin(), list_obj.end(),
83                        [](const auto &obj) { return IsConstant(py::cast<py::object>(obj)); });
84   }
85   if (py::isinstance<py::dict>(obj)) {
86     auto dict_obj = py::cast<py::dict>(obj);
87     return std::all_of(dict_obj.begin(), dict_obj.end(), [](const auto &pair) {
88       return IsConstant(py::cast<py::object>(pair.first)) && IsConstant(py::cast<py::object>(pair.second));
89     });
90   }
91   // Attention: should exclude BaseTensor in the future (when the BaseTensor PR is merged)
92   return !py::isinstance<tensor::Tensor>(obj) && !IsStubTensor(obj);
93 }
94 
TensorArgMutable(const py::object & obj,const ValuePtr & value)95 bool TensorArgMutable(const py::object &obj, const ValuePtr &value) {
96   if (!value->isa<tensor::MetaTensor>()) {
97     return false;
98   }
99   constexpr char const_arg_attr[] = "const_arg";
100   return !py::hasattr(obj, const_arg_attr) || !py::cast<bool>(py::getattr(obj, const_arg_attr));
101 }
102 
NeedBroaden(const py::object & obj,const ValuePtr & value)103 bool NeedBroaden(const py::object &obj, const ValuePtr &value) {
104   return TensorArgMutable(obj, value) || Mutable(obj, value) || value->isa<tensor::MetaSparseTensor>();
105 }
106 
GetTypeIdFromClassName(const std::string & class_name)107 TypeId GetTypeIdFromClassName(const std::string &class_name) {
108   static HashMap<std::string, TypeId> class_name_to_type_ids = {
109     {"Tensor", kObjectTypeTensorType},  {"list", kObjectTypeList},
110     {"tuple", kObjectTypeTuple},        {"int", kNumberTypeInt},
111     {"float", kNumberTypeFloat},        {"CellList", kObjectTypeList},
112     {"CellDict", kObjectTypeDictionary}};
113   auto iter = class_name_to_type_ids.find(class_name);
114   if (iter == class_name_to_type_ids.end()) {
115     return kTypeUnknown;
116   }
117   return iter->second;
118 }
119 
MaybeMakeEmptyTensor(const AbstractBasePtr & abs)120 ValuePtr MaybeMakeEmptyTensor(const AbstractBasePtr &abs) {
121   auto build_value = abs->BuildValue();
122   if (abs->isa<abstract::AbstractSequence>()) {
123     auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
124     std::vector<ValuePtr> value_vec;
125     for (auto &elem : abs_seq->elements()) {
126       (void)value_vec.emplace_back(MaybeMakeEmptyTensor(elem));
127     }
128     if (abs->isa<abstract::AbstractTuple>()) {
129       return std::make_shared<ValueTuple>(value_vec);
130     } else {
131       return std::make_shared<ValueList>(value_vec);
132     }
133   }
134   if (abs->isa<abstract::AbstractDictionary>()) {
135     auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
136     const auto &elements = abs_dict->elements();
137     std::vector<std::pair<ValuePtr, ValuePtr>> val_dict;
138     for (auto &element : elements) {
139       auto key_value = MaybeMakeEmptyTensor(element.first);
140       auto val_value = MaybeMakeEmptyTensor(element.second);
141       (void)val_dict.emplace_back(std::pair<ValuePtr, ValuePtr>{key_value, val_value});
142     }
143     return std::make_shared<ValueDictionary>(val_dict);
144   }
145   if (build_value == kValueAny && abs->isa<abstract::AbstractTensor>()) {
146     auto abs_tensor = abs->cast<abstract::AbstractTensorPtr>();
147     TypePtr tensor_type_ptr = abs_tensor->element()->BuildType();
148     ShapeVector tensor_shape = abs_tensor->shape()->shape();
149     auto tensor = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
150     if (abs->isa<abstract::AbstractRefTensor>()) {
151       auto abs_ref_tensor = abs->cast<abstract::AbstractRefPtr>();
152       // We only need the parameter name, it was used to find the python Parameter object later
153       auto param_info = std::make_shared<ParamInfo>();
154       param_info->set_name(abs_ref_tensor->ref_key_value()->ToString());
155       tensor->set_param_info(param_info);
156     }
157     return tensor;
158   }
159   return build_value;
160 }
161 
FunctionShouldBeParseInAst(const py::object & obj)162 bool FunctionShouldBeParseInAst(const py::object &obj) {
163   static mindspore::HashSet<std::string> func_names{"cast_to_adapter_tensor", "cast_to_ms_tensor"};
164   if (!py::hasattr(obj, "__name__")) {
165     return false;
166   }
167   return func_names.find(py::cast<std::string>(obj.attr("__name__"))) != func_names.end();
168 }
169 
ConvertToPythonTensor(const py::object & obj,const FuncGraphBuilder::PyTensorConverter & tensor_convert_func)170 py::object ConvertToPythonTensor(const py::object &obj,
171                                  const FuncGraphBuilder::PyTensorConverter &tensor_convert_func) {
172   constexpr auto ms_class_attr = "__ms_class__";
173   if (py::hasattr(obj, ms_class_attr) && py::cast<bool>(py::getattr(obj, ms_class_attr))) {
174     return obj;
175   }
176   if (py::isinstance<tensor::Tensor>(obj)) {
177     return tensor_convert_func(obj);
178   }
179   if (py::isinstance<py::list>(obj) || py::isinstance<py::tuple>(obj)) {
180     auto obj_tuple = py::cast<py::tuple>(obj);
181     py::tuple ret(obj_tuple.size());
182     for (size_t i = 0; i < obj_tuple.size(); ++i) {
183       ret[i] = ConvertToPythonTensor(obj_tuple[i], tensor_convert_func);
184     }
185     if (py::isinstance<py::list>(obj)) {
186       return ret.cast<py::list>();
187     }
188     return ret;
189   }
190   if (py::isinstance<py::dict>(obj)) {
191     auto obj_dict = py::cast<py::dict>(obj);
192     for (auto item : obj_dict) {
193       obj_dict[item.first] = ConvertToPythonTensor(py::cast<py::object>(item.second), tensor_convert_func);
194     }
195     return obj_dict;
196   }
197   return obj;
198 }
199 
ConvertCppTensorToPyTensor(const py::object & cpp_tensor)200 py::object ConvertCppTensorToPyTensor(const py::object &cpp_tensor) {
201   if (cpp_tensor.ptr() == nullptr || !py::isinstance<tensor::Tensor>(cpp_tensor)) {
202     return py::object();
203   }
204   bool is_adapter_tensor =
205     py::hasattr(cpp_tensor, kAdapterFlag) && py::cast<bool>(py::getattr(cpp_tensor, kAdapterFlag));
206   py::module mod = python_adapter::GetPyModule(kTensorModule);
207   auto py_tensor = python_adapter::CallPyModFn(mod, "Tensor", cpp_tensor, py::none(), py::none(), py::none(), true);
208   if (is_adapter_tensor) {
209     mod = python_adapter::GetPyModule(kInnerOpsModule);
210     py_tensor = python_adapter::CallPyModFn(mod, "convert_to_adapter_tensor", py_tensor);
211   }
212   return py_tensor;
213 }
214 }  // namespace
215 
ConvertPyObjToValue(const py::object & obj)216 ValuePtr FuncGraphBuilder::ConvertPyObjToValue(const py::object &obj) {
217   if (obj.ptr() == nullptr) {
218     return nullptr;
219   }
220   ValuePtr ret = nullptr;
221   try {
222     MS_LOG_TRY_CATCH_SCOPE;
223     if (!parse::ConvertData(obj, &ret)) {
224       return nullptr;
225     }
226   } catch (const std::exception &e) {
227     MS_LOG(DEBUG) << "Failed to convert python object << " << py::str(obj) << " to value. The exception:\n" << e.what();
228     return nullptr;
229   }
230   return ret;
231 }
232 
ConvertToPyObj(const AbstractBasePtr & abs)233 py::object FuncGraphBuilder::ConvertToPyObj(const AbstractBasePtr &abs) {
234   static auto convert_func = [](const py::object &tensor) { return ConvertCppTensorToPyTensor(tensor); };
235   return FuncGraphBuilder::ConvertToPyObj(abs, convert_func);
236 }
237 
ConvertToPyObj(const AbstractBasePtr & abs,const PyTensorConverter & tensor_convert_func)238 py::object FuncGraphBuilder::ConvertToPyObj(const AbstractBasePtr &abs, const PyTensorConverter &tensor_convert_func) {
239   if (abs->isa<abstract::AbstractNone>()) {
240     return py::none();
241   }
242 
243   auto build_value = MaybeMakeEmptyTensor(abs);
244   auto py_obj = ValueToPyData(build_value, abs);
245   // Return none means failed converting.
246   if (py::isinstance<py::none>(py_obj)) {
247     return py::object();
248   }
249 
250   if (pijit::kPIJitConfigDefault.GetBoolConfig(pijit::GraphJitConfig::kTraceFlag)) {
251     return ConvertToPythonTensor(py_obj, tensor_convert_func);
252   }
253 
254   return py_obj;
255 }
256 
ConvertObjToNode(const py::object & input_obj)257 AnfNodePtr FuncGraphBuilder::ConvertObjToNode(const py::object &input_obj) {
258   if (py::hasattr(input_obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(input_obj)) {
259     // Add the fv parameter and set its abstract.
260     return parse::ResolveParameterObj(graph_, input_obj);
261   }
262   auto val = ConvertPyObjToValue(input_obj);
263   if (val == nullptr) {
264     MS_LOG(INFO) << "The input object " << py::str(input_obj) << " convert to value failed.";
265     return nullptr;
266   }
267   // Constant value input scene, the object should be converted to value node.
268   auto node = NewValueNode(val);
269   node->set_abstract(val->ToAbstract());
270   return node;
271 }
272 
EvalValue(const ValuePtr & value,const AbstractBasePtrList & inputs_abs_list)273 AbstractBasePtr FuncGraphBuilder::EvalValue(const ValuePtr &value, const AbstractBasePtrList &inputs_abs_list) {
274   if (value == nullptr) {
275     return nullptr;
276   }
277   try {
278     MS_LOG_TRY_CATCH_SCOPE;
279     if (value->isa<Primitive>()) {
280       auto prim = value->cast<PrimitivePtr>();
281       auto eval_res = abstract::EvalOnePrim(prim, inputs_abs_list);
282       if (eval_res != nullptr) {
283         return eval_res->abstract();
284       }
285     } else if (value->ToAbstract()->isa<abstract::AbstractFunction>()) {
286       auto analyze_res = pipeline::AbstractAnalyze(value, inputs_abs_list);
287       if (analyze_res.eval_result != nullptr) {
288         return analyze_res.eval_result->abstract();
289       }
290     }
291     return nullptr;
292   } catch (const std::exception &e) {
293     MS_LOG(INFO) << "Failed to EvalValue for value: " << value->ToString();
294     return nullptr;
295   }
296 }
297 
CheckCallable(const ValuePtr & value,const AbstractBasePtr & abs)298 bool FuncGraphBuilder::CheckCallable(const ValuePtr &value, const AbstractBasePtr &abs) {
299   if (value == nullptr || abs == nullptr || abs->isa<abstract::AbstractAny>()) {
300     return false;
301   }
302   if (value->isa<Primitive>() && ShouldFallBackInRuntime(value->cast<PrimitivePtr>())) {
303     return false;
304   }
305   return true;
306 }
307 
CheckGraphOutput(const AbstractBasePtr & abs)308 bool FuncGraphBuilder::CheckGraphOutput(const AbstractBasePtr &abs) {
309   if (abs == nullptr) {
310     return false;
311   }
312   if (abs->isa<abstract::AbstractSequence>()) {
313     const auto elements = abs->cast<abstract::AbstractSequencePtr>()->elements();
314     return std::all_of(elements.begin(), elements.end(),
315                        [](const AbstractBasePtr &elem) { return CheckGraphOutput(elem); });
316   }
317   if (abs->isa<abstract::AbstractScalar>()) {
318     return IsValidScalar(abs);
319   }
320   return abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractRowTensor>() ||
321          abs->isa<abstract::AbstractMapTensor>();
322 }
323 
AddLocalVariable(const py::object & obj)324 bool FuncGraphBuilder::AddLocalVariable(const py::object &obj) {
325   if (obj.ptr() == nullptr) {
326     MS_LOG(INFO) << "Failed to add local variable, py object is null";
327     return false;
328   }
329 
330   auto iter = py_obj_to_node_.find(obj.ptr());
331   if (iter != py_obj_to_node_.end()) {
332     MS_LOG(INFO) << "Py object already in map, no need to add. Associated node: "
333                  << ((iter->second != nullptr) ? iter->second->DebugString() : "NULL");
334     return true;
335   }
336 
337   auto node = ConvertObjToNode(obj);
338   if (node == nullptr) {
339     MS_LOG(INFO) << "Failed to add local variable, convert python object to anf node failed";
340     return false;
341   }
342 
343   node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(obj));
344   (void)py_obj_to_node_.emplace(obj.ptr(), node);
345   return true;
346 }
347 
ReadLocalVariable(const py::object & obj)348 AnfNodePtr FuncGraphBuilder::ReadLocalVariable(const py::object &obj) {
349   auto iter = py_obj_to_node_.find(obj.ptr());
350   if (iter == py_obj_to_node_.end()) {
351     return nullptr;
352   }
353   return iter->second;
354 }
355 
GetNodeByObject(const py::object & obj)356 AnfNodePtr FuncGraphBuilder::GetNodeByObject(const py::object &obj) {
357   // Search the predecessors of the current builder for the local parameter with BFS.
358   mindspore::HashSet<FuncGraphBuilder *> visited_builders;
359   std::queue<FuncGraphBuilder *> builder_queue;
360   builder_queue.push(this);
361   while (!builder_queue.empty()) {
362     const auto cur_builder = builder_queue.front();
363     MS_EXCEPTION_IF_NULL(cur_builder);
364     builder_queue.pop();
365     (void)visited_builders.insert(cur_builder);
366     auto node = cur_builder->ReadLocalVariable(obj);
367     if (node != nullptr) {
368       MS_LOG(INFO) << "Found node: " << node->DebugString() << " for python object: " << std::string(py::str(obj))
369                    << "  " << obj.ptr();
370       return node;
371     }
372     for (const auto &cur_pred_builder : cur_builder->prev_builders()) {
373       if (visited_builders.count(cur_pred_builder) == 0) {
374         builder_queue.push(cur_pred_builder);
375       }
376     }
377   }
378   return nullptr;
379 }
380 
AddTopGraphArgsInputs(const py::object & object)381 bool FuncGraphBuilder::AddTopGraphArgsInputs(const py::object &object) {
382   // args object should always be list object.
383   if (object.ptr() == nullptr || !py::isinstance<py::list>(object)) {
384     MS_LOG(INFO) << "Get top graph args failed.";
385     return false;
386   }
387   auto args = object.cast<py::list>();
388   for (size_t i = 0; i < args.size(); ++i) {
389     auto arg = args[i].cast<py::object>();
390     if (arg.ptr() == nullptr) {
391       return false;
392     }
393     auto value = ConvertPyObjToValue(arg);
394     if (value == nullptr) {
395       return false;
396     }
397     bool broaden = NeedBroaden(arg, value);
398     AbstractBasePtr abs = abstract::ToAbstract(value, nullptr, nullptr);
399     if (broaden) {
400       abs = AbstractBroaden(abs);
401     }
402     if (abs == nullptr) {
403       MS_LOG(INFO) << "Failed to add input for python object: " << std::string(py::str(arg)) << "  " << arg.ptr();
404       return false;
405     }
406     auto para = graph_->add_parameter();
407     para->set_abstract(abs);
408     para->set_is_top_graph_param(true);
409     MS_LOG(INFO) << "Add top arg input success, python object: " << py::str(arg) << ", node: " << para->DebugString()
410                  << ", abstract: " << abs->ToString();
411     (void)py_obj_to_node_.emplace(arg.ptr(), para);
412     para->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(arg));
413   }
414   return true;
415 }
416 
AddTopGraphVargsInputs(const py::object & vargs)417 bool FuncGraphBuilder::AddTopGraphVargsInputs(const py::object &vargs) {
418   if (vargs.ptr() == nullptr) {
419     MS_LOG(INFO) << "Top graph has no vargs input.";
420     return true;
421   }
422   auto vargs_tuple = vargs.cast<py::tuple>();
423   if (vargs_tuple.ptr() == nullptr) {
424     MS_LOG(INFO) << "Vargs object should be tuple but got: " << py::str(vargs) << ", add top graph vargs failed.";
425     return false;
426   }
427   auto value = ConvertPyObjToValue(vargs);
428   if (value == nullptr || !value->isa<ValueTuple>()) {
429     MS_LOG(INFO) << "Convert vargs to value failed, vargs: " << py::str(vargs);
430     return false;
431   }
432   auto value_tuple = value->cast<ValueTuplePtr>();
433   const auto &elements = value_tuple->value();
434   if (elements.size() != vargs_tuple.size()) {
435     MS_LOG(INFO) << "For top graph vargs, converted value element size is " << elements.size()
436                  << ", python tuple element size is " << vargs_tuple.size() << ". Size not matched.";
437     return false;
438   }
439   std::vector<AbstractBasePtr> new_elements;
440   for (size_t i = 0; i < elements.size(); ++i) {
441     auto cur_obj = vargs_tuple[i].cast<py::object>();
442     auto cur_val = elements[i];
443     bool broaden = NeedBroaden(cur_obj, cur_val);
444     auto cur_abs = abstract::ToAbstract(cur_val, nullptr, nullptr);
445     if (broaden) {
446       cur_abs = AbstractBroaden(cur_abs);
447     }
448     if (cur_abs == nullptr) {
449       MS_LOG(INFO) << "Fail to convert args element " << cur_val->ToString();
450       return false;
451     }
452     new_elements.push_back(cur_abs);
453   }
454   auto new_vargs_abs = std::make_shared<abstract::AbstractTuple>(new_elements);
455   auto para = graph_->add_parameter();
456   para->set_abstract(new_vargs_abs);
457   para->set_is_top_graph_param(true);
458   MS_LOG(INFO) << "Add top vargs input success, python object: " << py::str(vargs) << ", node: " << para->DebugString()
459                << ", abstract: " << new_vargs_abs->ToString();
460   (void)py_obj_to_node_.emplace(vargs.ptr(), para);
461   para->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(vargs));
462   return true;
463 }
464 
AddTopGraphKwargsInputs(const py::object & kwargs)465 bool FuncGraphBuilder::AddTopGraphKwargsInputs(const py::object &kwargs) {
466   if (kwargs.ptr() == nullptr) {
467     MS_LOG(INFO) << "Top graph has no kwargs input.";
468     return true;
469   }
470   auto kwargs_dict = kwargs.cast<py::dict>();
471   if (kwargs_dict.ptr() == nullptr) {
472     MS_LOG(INFO) << "Kwargs object should be tuple but got: " << py::str(kwargs) << ", add top graph kwargs failed.";
473     return false;
474   }
475   auto value = ConvertPyObjToValue(kwargs);
476   if (value == nullptr || !value->isa<ValueDictionary>()) {
477     MS_LOG(INFO) << "Convert kwargs to value failed, kwargs: " << py::str(kwargs);
478     return false;
479   }
480   auto value_dict = value->cast<ValueDictionaryPtr>();
481   const auto &elements = value_dict->value();
482   if (elements.size() != kwargs_dict.size()) {
483     MS_LOG(INFO) << "Kwargs dict size is " << kwargs_dict.size() << " and corresponding value dict size is "
484                  << elements.size() << ". Size not matched.";
485   }
486   std::vector<abstract::AbstractElementPair> new_key_values;
487   for (size_t i = 0; i < elements.size(); ++i) {
488     auto cur_key_val = elements[i].first;
489     auto cur_val = elements[i].second;
490     auto cur_key_obj = ValueToPyData(cur_key_val);
491     if (!kwargs_dict.contains(cur_key_obj)) {
492       return false;
493     }
494     auto cur_val_obj = kwargs_dict[cur_key_obj];
495     auto cur_value_abs = abstract::ToAbstract(cur_val, nullptr, nullptr);
496     bool broaden = NeedBroaden(cur_val_obj, cur_val);
497     if (broaden) {
498       cur_value_abs = AbstractBroaden(cur_value_abs);
499     }
500     if (cur_value_abs == nullptr) {
501       MS_LOG(INFO) << "Fail to convert kwargs value element " << cur_val->ToString();
502       return false;
503     }
504     auto cur_key_abs = abstract::ToAbstract(cur_key_val, nullptr, nullptr);
505     new_key_values.push_back(abstract::AbstractElementPair{cur_key_abs, cur_value_abs});
506   }
507   auto new_kwargs_abs = std::make_shared<abstract::AbstractDictionary>(new_key_values);
508   auto para = graph_->add_parameter();
509   para->set_abstract(new_kwargs_abs);
510   para->set_is_top_graph_param(true);
511   MS_LOG(INFO) << "Add top kwargs input success, python object: " << py::str(kwargs)
512                << ", node: " << para->DebugString() << ", abstract: " << new_kwargs_abs->ToString();
513   (void)py_obj_to_node_.emplace(kwargs.ptr(), para);
514   para->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(kwargs));
515   return true;
516 }
517 
AddTopGraphInputs(std::vector<py::object> packed_inputs)518 bool FuncGraphBuilder::AddTopGraphInputs(std::vector<py::object> packed_inputs) {
519   constexpr size_t args_index = 0;
520   constexpr size_t vargs_index = 1;
521   constexpr size_t kwargs_index = 2;
522   constexpr size_t packed_inputs_size = 3;
523   if (!prev_builders_.empty()) {
524     MS_LOG(INFO) << "Current builder has prev builder, add top graph parameter failed.";
525     return false;
526   }
527   if (packed_inputs.size() != packed_inputs_size) {
528     MS_LOG(INFO) << "Top graph packed inputs size is not three but " << packed_inputs.size()
529                  << ", add top graph parameter failed.";
530     return false;
531   }
532   if (!AddTopGraphArgsInputs(packed_inputs[args_index])) {
533     MS_LOG(INFO) << "Add top graph args inputs failed.";
534     return false;
535   }
536   if (!AddTopGraphVargsInputs(packed_inputs[vargs_index])) {
537     MS_LOG(INFO) << "Add top graph vargs inputs failed";
538     return false;
539   }
540   if (!AddTopGraphKwargsInputs(packed_inputs[kwargs_index])) {
541     MS_LOG(INFO) << "Add top graph kwargs inputs failed";
542     return false;
543   }
544   MS_LOG(INFO) << "Add top graph inputs success.";
545   return true;
546 }
547 
AddSubGraphInput(const py::object & obj)548 py::object FuncGraphBuilder::AddSubGraphInput(const py::object &obj) {
549   MS_LOG(INFO) << "Try add sub graph parameter for object: " << std::string(py::str(obj)) << "  " << obj.ptr();
550   AbstractBasePtr abs = nullptr;
551   auto node = GetNodeByObject(obj);
552   if (node != nullptr) {
553     abs = node->abstract();
554   }
555   // Handle constant subgraph input.
556   if (abs == nullptr && IsConstant(obj)) {
557     auto value = ConvertPyObjToValue(obj);
558     if (value != nullptr) {
559       abs = abstract::ToAbstract(value, nullptr, nullptr);
560     }
561   }
562   if (abs == nullptr) {
563     MS_LOG(INFO) << "Failed to add input for python object: " << std::string(py::str(obj)) << "  " << obj.ptr();
564     return py::object();
565   }
566   auto para = graph_->add_parameter();
567   para->set_abstract(abs);
568   para->set_is_top_graph_param(false);
569   MS_LOG(INFO) << "Add input success, node: " << para->DebugString() << " obj: " << py::str(obj) << "  " << obj.ptr();
570   (void)py_obj_to_node_.emplace(obj.ptr(), para);
571   para->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(obj));
572   return obj;
573 }
574 
AddNode(const py::object & callable_obj,const std::vector<py::object> & inputs_obj)575 py::object FuncGraphBuilder::AddNode(const py::object &callable_obj, const std::vector<py::object> &inputs_obj) {
576   if (!CheckCallable(callable_obj)) {
577     MS_LOG(INFO) << "The python obj " << py::str(callable_obj) << " is not callable.";
578     return py::object();
579   }
580   auto callable_value = ConvertPyObjToValue(callable_obj);
581   if (callable_value == nullptr) {
582     MS_LOG(INFO) << "Convert python object " << py::str(callable_obj) << " to value failed.";
583     return py::object();
584   }
585   if (FunctionShouldBeParseInAst(callable_obj)) {
586     return TryToAddNode(callable_value, inputs_obj);
587   }
588   return AddNode(callable_value, inputs_obj);
589 }
590 
AddAttrPythonObject(const py::object & object)591 bool FuncGraphBuilder::AddAttrPythonObject(const py::object &object) {
592   if (object.ptr() == nullptr) {
593     MS_LOG(INFO) << "Convert python object with empty object, convert failed.";
594     return false;
595   }
596   // Attribute object is constant or Parameter, do not need to check constant.
597   auto node = ConvertObjToNode(object);
598   if (node == nullptr) {
599     MS_LOG(INFO) << "Convert python object " << py::str(object) << " to anf node failed.";
600     return false;
601   }
602   node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(object));
603   (void)py_obj_to_node_.emplace(object.ptr(), node);
604   return true;
605 }
606 
GetInputNodesAndAbstracts(const ValuePtr & callable_value,const vector<py::object> & inputs_obj,std::vector<AnfNodePtr> * input_node_list,std::vector<AbstractBasePtr> * input_abs_list)607 bool FuncGraphBuilder::GetInputNodesAndAbstracts(const ValuePtr &callable_value, const vector<py::object> &inputs_obj,
608                                                  std::vector<AnfNodePtr> *input_node_list,
609                                                  std::vector<AbstractBasePtr> *input_abs_list) {
610   input_node_list->reserve(inputs_obj.size() + 1);
611   input_abs_list->reserve(inputs_obj.size());
612 
613   (void)input_node_list->emplace_back(NewValueNode(callable_value));
614   for (const auto &input_obj : inputs_obj) {
615     if (input_obj.ptr() == nullptr) {
616       MS_LOG(INFO) << "The input python object of " << callable_value->ToString() << ", is NULL";
617       return false;
618     }
619     // Node with input of generator may cause change of generator, skip it in build node now.
620     if (PyGen_CheckExact(input_obj.ptr())) {
621       MS_LOG(INFO) << "The input python object is generator " << std::string(py::str(input_obj))
622                    << ", do not build graph.";
623       return false;
624     }
625     auto node = GetNodeByObject(input_obj);
626     if (node == nullptr) {
627       if (!IsConstant(input_obj)) {
628         MS_LOG(INFO) << "Can not convert non-constant value to value node for obj: " << py::str(input_obj);
629         return false;
630       }
631       auto new_node = ConvertObjToNode(input_obj);
632       if (new_node == nullptr) {
633         MS_LOG(INFO) << "Convert input python object " << py::str(input_obj) << " to anf node failed.";
634         return false;
635       }
636       new_node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(input_obj));
637       (void)py_obj_to_node_.emplace(input_obj.ptr(), new_node);
638       (void)input_node_list->emplace_back(new_node);
639       (void)input_abs_list->emplace_back(new_node->abstract());
640       MS_LOG(INFO) << "Add python input " << py::str(input_obj) << " with new node " << new_node->DebugString();
641     } else {
642       (void)input_node_list->emplace_back(node);
643       (void)input_abs_list->emplace_back(node->abstract());
644     }
645   }
646   return true;
647 }
648 
DoPrimitiveInferAndCheck(const PrimitivePtr & primitive,const AnfNodePtrList & input_node_list,const AbstractBasePtrList & args_abs_list)649 CNodePtr FuncGraphBuilder::DoPrimitiveInferAndCheck(const PrimitivePtr &primitive,
650                                                     const AnfNodePtrList &input_node_list,
651                                                     const AbstractBasePtrList &args_abs_list) {
652   try {
653     MS_LOG_TRY_CATCH_SCOPE;
654     const CNodePtr &new_node = AddPrimitiveCNode(primitive, input_node_list, args_abs_list);
655     if (new_node == nullptr) {
656       MS_LOG(INFO) << "Failed to add CNode for Primitive: " << primitive->name();
657       return nullptr;
658     }
659 
660     const AbstractBasePtr &abs = GetAbstractOf(new_node);
661 
662     if (!CheckCallable(primitive, abs)) {
663       MS_LOG(INFO) << "Check callable failed for Primitive: " << primitive->name();
664       return nullptr;
665     }
666     new_node->set_abstract(abs);
667     return new_node;
668   } catch (const std::exception &e) {
669     MS_LOG(INFO) << "Failed to infer Primitive: " << primitive->name() << ". The exception:\n" << e.what();
670     return nullptr;
671   }
672 }
673 
AddPrimitiveCNode(const PrimitivePtr & primitive,const AnfNodePtrList & input_node_list,const AbstractBasePtrList & args_abs_list)674 CNodePtr FuncGraphBuilder::AddPrimitiveCNode(const PrimitivePtr &primitive, const AnfNodePtrList &input_node_list,
675                                              const AbstractBasePtrList &args_abs_list) {
676   auto op_def = mindspore::ops::GetOpDef(primitive->name());
677 
678   if (op_def == nullptr) {
679     if (primitive->has_signature()) {
680       // Follow the implementations in DoSignatureEvaluator
681       AnfNodePtrList args_node_list(input_node_list.cbegin() + 1, input_node_list.cend());
682       AnfNodePtrList new_node_list =
683         prim::GetNewInputsBySignatures(graph_, primitive->ToString(), primitive, args_abs_list, args_node_list);
684 
685       new_node_list.insert(new_node_list.begin(), input_node_list[0]);
686       return graph_->NewCNodeInOrder(new_node_list);
687     }
688   } else if (primitive->isa<PrimitivePy>()) {
689     // Follow the implementations in PrimitiveArgsToInputsEvaluator and DoTransPrimitiveFunctionEvaluator
690     auto arg_signatures = op_def->signatures_;
691     primitive->set_signatures(arg_signatures);
692     primitive->set_has_signature(!arg_signatures.empty());
693 
694     const AnfNodePtrList &init_args = abstract::GetPrimitiveInitArgs(primitive->cast<PrimitivePyPtr>(), op_def);
695 
696     AnfNodePtrList call_args(input_node_list.cbegin() + 1, input_node_list.cend());
697     AbstractBasePtrList call_abs_list;
698     (void)std::transform(call_args.cbegin(), call_args.cend(), std::back_inserter(call_abs_list),
699                          [](const AnfNodePtr &node) { return FuncGraphBuilder::GetAbstractOf(node); });
700     const AnfNodePtrList &new_call_args =
701       prim::GetNewInputsBySignatures(graph_, primitive->name(), primitive, call_abs_list, call_args);
702 
703     return abstract::GeneratePrimitiveCNode(
704       primitive, op_def, graph_, init_args, new_call_args,
705       [](const AnfNodePtr &node) { return FuncGraphBuilder::GetAbstractOf(node); });
706   }
707   MS_LOG(DEBUG) << "Primitive " << primitive->name() << " no need to process signatures and OpDef";
708   return graph_->NewCNodeInOrder(input_node_list);
709 }
710 
GetAbstractOf(const AnfNodePtr & node)711 AbstractBasePtr FuncGraphBuilder::GetAbstractOf(const AnfNodePtr &node) {
712   if (node == nullptr) {
713     return nullptr;
714   }
715   if (node->abstract() != nullptr) {
716     return node->abstract();
717   }
718   if (node->isa<ValueNode>()) {
719     return node->cast<ValueNodePtr>()->value()->ToAbstract();
720   } else if (node->isa<CNode>()) {
721     auto cnode = node->cast<CNodePtr>();
722     if (cnode->empty() || !cnode->input(0)->isa<ValueNode>()) {
723       return nullptr;
724     }
725     ValuePtr value = cnode->input(0)->cast<ValueNodePtr>()->value();
726     std::vector<AbstractBasePtr> abs_list;
727     std::transform(cnode->inputs().begin() + 1, cnode->inputs().end(), std::back_inserter(abs_list),
728                    [](const AnfNodePtr &node) {
729                      if (node->abstract() == nullptr) {
730                        node->set_abstract(FuncGraphBuilder::GetAbstractOf(node));
731                      }
732                      return node->abstract();
733                    });
734     return EvalValue(value, abs_list);
735   }
736   MS_LOG(INFO) << "Unsupported Node type for GetAbstractOf() method, node: " << node->DebugString();
737   return nullptr;
738 }
739 
DoInferAndCheck(const ValuePtr & callable_value,const vector<AbstractBasePtr> & input_abs_list)740 AbstractBasePtr FuncGraphBuilder::DoInferAndCheck(const ValuePtr &callable_value,
741                                                   const vector<AbstractBasePtr> &input_abs_list) {
742   auto abs = EvalValue(callable_value, input_abs_list);
743   if (abs == nullptr) {
744     MS_LOG(DEBUG) << "Eval failed for value: " << callable_value->ToString();
745     return nullptr;
746   }
747   if (!CheckCallable(callable_value, abs)) {
748     MS_LOG(DEBUG) << "Check callable failed for value: " << callable_value->ToString() << ", abs: " << abs->ToString();
749     return nullptr;
750   }
751   return abs;
752 }
753 
TryToAddNode(const ValuePtr & callable_value,const std::vector<py::object> & inputs_obj)754 py::object FuncGraphBuilder::TryToAddNode(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj) {
755   // Collect the input nodes and input abstracts.
756   std::vector<AnfNodePtr> input_node_list;
757   std::vector<AbstractBasePtr> input_abs_list;
758   if (!GetInputNodesAndAbstracts(callable_value, inputs_obj, &input_node_list, &input_abs_list)) {
759     return py::object();
760   }
761 
762   CNodePtr new_node;
763   AbstractBasePtr abs;
764   if (callable_value->isa<Primitive>()) {
765     new_node = DoPrimitiveInferAndCheck(callable_value->cast<PrimitivePtr>(), input_node_list, input_abs_list);
766     if (new_node != nullptr) {
767       abs = new_node->abstract();
768     }
769   } else {
770     // Do infer and check callable.
771     abs = DoInferAndCheck(callable_value, input_abs_list);
772     if (abs != nullptr) {
773       new_node = graph_->NewCNodeInOrder(input_node_list);
774     }
775   }
776   if (new_node == nullptr || abs == nullptr) {
777     return py::object();
778   }
779 
780   // Return the converted python object.
781   py::object output_py_obj;
782   if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
783     auto abs_func = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
784     auto fg = abs_func->func_graph();
785     if (fg == nullptr) {
786       return py::object();
787     }
788     auto obj = fg->python_obj();
789     if (obj == nullptr || !obj->isa<parse::PyObjectWrapper>()) {
790       return py::object();
791     }
792     output_py_obj = obj->cast_ptr<parse::PyObjectWrapper>()->obj();
793   } else {
794     auto convert_func = [this](const py::object &tensor) { return ConvertToPyTensorOrParameter(tensor); };
795     output_py_obj = ConvertToPyObj(abs, convert_func);
796     if (output_py_obj.ptr() == nullptr) {
797       MS_LOG(INFO) << "Convert abs " << abs->ToString() << " to python object failed.";
798       return py::object();
799     }
800   }
801 
802   new_node->set_abstract(abs);
803   MS_LOG(INFO) << "Add node: " << new_node->DebugString() << " for python object: " << py::str(output_py_obj);
804   (void)py_obj_to_node_.emplace(output_py_obj.ptr(), new_node);
805   new_node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(output_py_obj));
806   return output_py_obj;
807 }
808 
ConvertToPyTensorOrParameter(const py::object & cpp_tensor)809 py::object FuncGraphBuilder::ConvertToPyTensorOrParameter(const py::object &cpp_tensor) {
810   if (cpp_tensor.ptr() == nullptr || !py::isinstance<tensor::Tensor>(cpp_tensor)) {
811     return py::object();
812   }
813   auto tensor = py::cast<tensor::TensorPtr>(cpp_tensor);
814   if (tensor->is_parameter()) {
815     const std::string &name = tensor->param_info()->name();
816     for (auto &it : py_obj_to_node_) {
817       if (it.second == nullptr) {
818         continue;
819       }
820       const AbstractBasePtr &abs = it.second->abstract();
821       if (abs != nullptr && abs->isa<abstract::AbstractRefTensor>()) {
822         auto abs_ref_tensor = abs->cast<abstract::AbstractRefPtr>();
823         if (abs_ref_tensor->ref_key_value()->ToString() == name) {
824           return py::reinterpret_borrow<py::object>(it.first);
825         }
826       }
827     }
828     MS_LOG(INFO) << "Python Parameter not found: " << name;
829     return py::object();
830   }
831 
832   return ConvertCppTensorToPyTensor(cpp_tensor);
833 }
834 
AddNode(const ValuePtr & callable_value,const std::vector<py::object> & inputs_obj)835 py::object FuncGraphBuilder::AddNode(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj) {
836   if (!callable_value->ToAbstract()->isa<abstract::AbstractFunction>()) {
837     MS_LOG(INFO) << "The value " << callable_value->ToString() << " is not callable.";
838     return py::object();
839   }
840   if (callable_value->isa<FuncGraph>()) {
841     return AddFgCallNode(callable_value->cast<FuncGraphPtr>(), inputs_obj);
842   }
843   return TryToAddNode(callable_value, inputs_obj);
844 }
845 
AddMultiNode(const std::string & name,const std::vector<py::object> & inputs_obj)846 py::object FuncGraphBuilder::AddMultiNode(const std::string &name, const std::vector<py::object> &inputs_obj) {
847   const std::string mod_str = "mindspore.ops.composite.multitype_ops";
848   py::module mod = py::module::import(mod_str.c_str());
849   if (!py::hasattr(mod, name.c_str())) {
850     MS_LOG(INFO) << "Fail to find multitype function graph for name " << name;
851     return py::object();
852   }
853   py::object fn = mod.attr(name.c_str());
854   return AddNode(fn, inputs_obj);
855 }
856 
AddOutput(const py::object & output_obj,bool is_top_graph)857 bool FuncGraphBuilder::AddOutput(const py::object &output_obj, bool is_top_graph) {
858   auto iter = py_obj_to_node_.find(output_obj.ptr());
859   if (iter == py_obj_to_node_.end()) {
860     MS_LOG(INFO) << "The output python object " << py::str(output_obj) << " should have been added to the graph.";
861     return false;
862   }
863   auto node = iter->second;
864   MS_EXCEPTION_IF_NULL(node);
865   auto abs = node->abstract();
866   // Only top graph has restriction on return value type.
867   if (is_top_graph && !CheckGraphOutput(abs)) {
868     MS_LOG(INFO) << "The output python object " << py::str(output_obj)
869                  << " should not be the graph output, abstract: " << (abs == nullptr ? "null" : abs->ToString());
870     return false;
871   }
872   (void)output_nodes_.emplace_back(node);
873   return true;
874 }
875 
graph()876 FuncGraphPtr FuncGraphBuilder::graph() {
877   if (has_set_output_) {
878     return graph_;
879   }
880   if (output_nodes_.empty()) {
881     MS_LOG(DEBUG) << "The graph " << graph_->ToString() << " has not been set output.";
882     return nullptr;
883   }
884   bool all_value_node = std::any_of(output_nodes_.begin(), output_nodes_.end(),
885                                     [](const AnfNodePtr &node) { return node->isa<ValueNode>(); });
886   if (all_value_node) {
887     MS_LOG(INFO) << "All graph output is value node, no need to run graph.";
888     return nullptr;
889   }
890   // Single output case.
891   if (output_nodes_.size() == 1) {
892     // Use the python obj of the output node as the python obj of the func_graph output.
893     auto node_output_py_obj = output_nodes_[0]->user_data<py::object>(kPiJitPyObjKey);
894     if (node_output_py_obj == nullptr) {
895       MS_LOG(DEBUG) << "Can not find the python object of the node " << output_nodes_[0]->DebugString();
896       return nullptr;
897     }
898     graph_->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(*node_output_py_obj));
899     graph_->set_output(output_nodes_[0]);
900     has_set_output_ = true;
901     return graph_;
902   }
903   // multiple output case.
904   // Make the python tuple obj of the output nodes as the python obj of the func_graph output.
905   py::tuple output_py_obj(output_nodes_.size());
906   for (size_t i = 0; i < output_nodes_.size(); ++i) {
907     auto node_output_py_obj = output_nodes_[i]->user_data<py::object>(kPiJitPyObjKey);
908     if (node_output_py_obj == nullptr) {
909       MS_LOG(DEBUG) << "Can not find the python object of the node " << output_nodes_[i]->DebugString();
910       return nullptr;
911     }
912     output_py_obj[i] = *node_output_py_obj;
913   }
914   // Create make_tuple node and set its abstract.
915   graph_->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(output_py_obj));
916   output_nodes_.insert(output_nodes_.begin(), NewValueNode(prim::kPrimMakeTuple));
917   AbstractBasePtrList abstract_list;
918   (void)std::transform(output_nodes_.begin() + 1, output_nodes_.end(), std::back_inserter(abstract_list),
919                        [](const AnfNodePtr &node) -> AbstractBasePtr { return node->abstract(); });
920   auto output_node = graph_->NewCNodeInOrder(output_nodes_);
921   auto fg_output_abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
922   output_node->set_abstract(fg_output_abs);
923 
924   graph_->set_output(output_node);
925   has_set_output_ = true;
926   return graph_;
927 }
928 
ClearNodeAbstract()929 void FuncGraphBuilder::ClearNodeAbstract() {
930   if (!has_set_output_) {
931     MS_LOG(INTERNAL_EXCEPTION) << "Graph not generated, can not clear abstract.";
932   }
933   // Clear all node abstract.
934   auto mng = Manage(graph_, false);
935   MS_EXCEPTION_IF_NULL(mng);
936   static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
937   for (const auto &node : mng->all_nodes()) {
938     MS_EXCEPTION_IF_NULL(node);
939     const AbstractBasePtr &prev_inferred = node->abstract();
940     auto is_func =
941       node->isa<mindspore::ValueNode>() && prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>();
942     // Keep previous inferred value for parameter and ValueNode if the inferred value is not AbstractFunction.
943     if (!node->isa<Parameter>() && !is_func) {
944       // Reset tuple/list abstract use flags.
945       if (enable_eliminate_unused_element && prev_inferred != nullptr &&
946           prev_inferred->isa<abstract::AbstractSequence>()) {
947         SetSequenceNodeElementsUseFlags(node, nullptr);
948       }
949       node->set_abstract(nullptr);
950       MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
951     }
952   }
953 }
954 
AddFgCallNode(const FuncGraphPtr & fg,const vector<py::object> & inputs_obj)955 py::object FuncGraphBuilder::AddFgCallNode(const FuncGraphPtr &fg, const vector<py::object> &inputs_obj) {
956   std::vector<AnfNodePtr> input_node_list;
957   input_node_list.reserve(inputs_obj.size() + 1);
958 
959   (void)input_node_list.emplace_back(NewValueNode(fg));
960   for (const auto &input_obj : inputs_obj) {
961     auto node = GetNodeByObject(input_obj);
962     if (node == nullptr) {
963       if (!IsConstant(input_obj)) {
964         MS_LOG(INFO) << "Can not convert non-constant value to value node for obj: " << py::str(input_obj);
965         return py::object();
966       }
967       auto new_node = ConvertObjToNode(input_obj);
968       if (new_node == nullptr) {
969         MS_LOG(INFO) << "Convert input python object " << py::str(input_obj) << " to anf node failed.";
970         return py::object();
971       }
972       new_node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(input_obj));
973       (void)py_obj_to_node_.emplace(input_obj.ptr(), new_node);
974       (void)input_node_list.emplace_back(new_node);
975       MS_LOG(DEBUG) << "Add constant python input " << py::str(input_obj) << " with node " << new_node->DebugString();
976     } else {
977       (void)input_node_list.emplace_back(node);
978     }
979   }
980 
981   auto new_node = graph_->NewCNodeInOrder(input_node_list);
982   auto fg_output = fg->output();
983   MS_EXCEPTION_IF_NULL(fg_output);
984   auto fg_output_abs = fg_output->abstract();
985   MS_EXCEPTION_IF_NULL(fg_output_abs);
986   new_node->set_abstract(fg_output_abs);
987 
988   // Use the python obj of the func_graph output as the python obj of the output node.
989   auto fg_output_obj_ptr = fg->user_data<py::object>(kPiJitPyObjKey);
990   if (fg_output_obj_ptr == nullptr) {
991     MS_LOG(DEBUG) << "Can not find the output python object of func_graph " << fg->ToString();
992     return py::object();
993   }
994   auto fg_output_obj = *fg_output_obj_ptr;
995   (void)py_obj_to_node_.emplace(fg_output_obj.ptr(), new_node);
996   new_node->set_user_data(kPiJitPyObjKey, std::make_shared<py::object>(fg_output_obj));
997   return fg_output_obj;
998 }
999 
CheckCallable(const py::object & obj)1000 bool FuncGraphBuilder::CheckCallable(const py::object &obj) {
1001   constexpr auto ms_class_attr = "__ms_class__";
1002   return py::isinstance<MetaFuncGraph>(obj) ||
1003          (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG) &&
1004           parse::data_converter::GetObjType(obj) != parse::RESOLVE_TYPE_CLASS_TYPE) ||
1005          FunctionShouldBeParseInAst(obj) ||
1006          (py::hasattr(obj, ms_class_attr) && py::cast<bool>(py::getattr(obj, ms_class_attr)));
1007 }
1008 
ConvertMethod(const py::object & obj)1009 py::object FuncGraphBuilder::ConvertMethod(const py::object &obj) {
1010   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1011   py::tuple method_info = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_METHOD_INFO, obj);
1012   py::object class_name_obj = method_info[0];
1013   if (py::isinstance<py::none>(class_name_obj)) {
1014     MS_LOG(INFO) << "Can not get the method info of " << py::str(obj);
1015     return py::object();
1016   }
1017   auto class_name = class_name_obj.cast<std::string>();
1018   if (class_name == "Tensor" &&
1019       !py::cast<bool>(python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_IS_MS_TENSOR_METHOD, obj))) {
1020     return py::object();
1021   }
1022   auto type_id = GetTypeIdFromClassName(class_name);
1023   auto method_name = method_info[1].cast<std::string>();
1024   MS_LOG(DEBUG) << "type_id: " << type_id << ", method_name: " << method_name;
1025   Any require = pipeline::Resource::GetMethodPtr(type_id, method_name);
1026   if (require.empty()) {
1027     require = pipeline::Resource::GetAttrPtr(type_id, method_name);
1028   }
1029 
1030   if (require.empty()) {
1031     MS_LOG(DEBUG) << "Can not find the method registered.";
1032     return py::object();
1033   }
1034 
1035   if (require.is<std::string>()) {
1036     py::function fn = mindspore::python_adapter::GetPyFn(parse::kStandardMethodModelName, require.cast<std::string>());
1037     if (py::isinstance<py::none>(fn)) {
1038       MS_LOG(DEBUG) << "Can not find the method '" << require.cast<std::string>() << "' defined in standard_method.";
1039       return py::object();
1040     }
1041     return fn;
1042   } else if (require.is<PrimitivePtr>()) {
1043     auto ops_mod = python_adapter::GetPyModule("mindspore.ops");
1044     auto primitive_class = python_adapter::GetPyObjAttr(ops_mod, "Primitive");
1045     return primitive_class(require.cast<PrimitivePtr>()->name());
1046   }
1047   MS_LOG(DEBUG) << "The method or attr should be a string or a Primitive, but got " << require.ToString();
1048   return py::object();
1049 }
1050 
RemoveOutput(const py::object & output_obj)1051 void FuncGraphBuilder::RemoveOutput(const py::object &output_obj) {
1052   auto iter = py_obj_to_node_.find(output_obj.ptr());
1053   if (iter == py_obj_to_node_.end()) {
1054     MS_LOG(WARNING) << "The output python object " << py::str(output_obj) << " should have been added to the graph.";
1055     return;
1056   }
1057   auto output_nodes_iter = std::find(output_nodes_.begin(), output_nodes_.end(), iter->second);
1058   if (output_nodes_iter == output_nodes_.end()) {
1059     MS_LOG(WARNING) << "The node " << iter->second->DebugString() << " has not been added to the graph outputs.";
1060     return;
1061   }
1062   output_nodes_.erase(output_nodes_iter);
1063 }
1064 
ConvertFunction(const py::object & obj)1065 py::object FuncGraphBuilder::ConvertFunction(const py::object &obj) {
1066   auto dict = python_adapter::GetPyObjAttr(python_adapter::GetPyModule("mindspore._extends.parse.resources"),
1067                                            "convert_object_map");
1068   auto callable_obj_ptr = PyDict_GetItem(dict.ptr(), obj.ptr());
1069   return callable_obj_ptr == nullptr ? py::object() : py::cast<py::object>(callable_obj_ptr);
1070 }
1071 
CanConstantFoldFunc(const py::object & obj)1072 bool FuncGraphBuilder::CanConstantFoldFunc(const py::object &obj) {
1073   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1074   py::object can_constant_fold = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CAN_CONSTANT_FOLD, obj);
1075   return can_constant_fold.cast<bool>();
1076 }
1077 
SetGraphName(const std::string & name)1078 void FuncGraphBuilder::SetGraphName(const std::string &name) {
1079   if (name.empty()) {
1080     return;
1081   }
1082   MS_EXCEPTION_IF_NULL(graph_->debug_info());
1083   graph_->debug_info()->set_name(name);
1084 }
1085 
AddPrevBuilder(const FuncGraphBuilderPtr & builder)1086 void FuncGraphBuilder::AddPrevBuilder(const FuncGraphBuilderPtr &builder) { prev_builders_.push_back(builder.get()); }
1087 
ValidateCallableObject(const py::object & obj)1088 bool FuncGraphBuilder::ValidateCallableObject(const py::object &obj) {
1089   if (obj.ptr() == nullptr) {
1090     return false;
1091   }
1092   // Check if object is invalid method for CellList/CellDict, which should not be converted to graph.
1093   if (CheckInvalidCellListDictMethod(obj)) {
1094     MS_LOG(INFO) << "The object " << py::str(obj) << " is a invalid CellList/CellDict method, "
1095                  << "can not convert to graph";
1096     return false;
1097   }
1098   return true;
1099 }
1100 
CheckInvalidCellListDictMethod(const py::object & obj)1101 bool FuncGraphBuilder::CheckInvalidCellListDictMethod(const py::object &obj) {
1102   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1103   py::tuple method_info = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_METHOD_INFO, obj);
1104   constexpr size_t class_index = 0;
1105   constexpr size_t method_index = 1;
1106   py::object class_name_obj = method_info[class_index];
1107   if (class_name_obj.ptr() == nullptr || py::isinstance<py::none>(class_name_obj)) {
1108     return false;
1109   }
1110   auto class_name = class_name_obj.cast<std::string>();
1111   MS_LOG(INFO) << "class name: " << class_name;
1112   if (class_name != "CellList" && class_name != "CellDict") {
1113     return false;
1114   }
1115   auto method_name_obj = method_info[method_index];
1116   if (method_name_obj.ptr() == nullptr || py::isinstance<py::none>(method_name_obj)) {
1117     return false;
1118   }
1119   auto method_name = method_name_obj.cast<std::string>();
1120   static std::vector<std::string> inplace_method_name = {"clear", "update"};
1121   if (std::any_of(inplace_method_name.begin(), inplace_method_name.end(),
1122                   [&method_name](const std::string &name) { return name == method_name; })) {
1123     MS_LOG(INFO) << "CellDict/CellList inplace function " << method_name << " found";
1124     return true;
1125   }
1126   auto type_id = GetTypeIdFromClassName(class_name);
1127   Any require = pipeline::Resource::GetMethodPtr(type_id, method_name);
1128   return require.empty();
1129 }
1130 }  // namespace mindspore
1131