• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/ps/fallback.h"
18 
19 #include <algorithm>
20 #include <iostream>
21 #include <memory>
22 #include <regex>
23 #include <string>
24 #include <vector>
25 #include <utility>
26 
27 #include "mindspore/core/ops/structure_ops.h"
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "include/common/fallback.h"
31 #include "include/common/utils/python_adapter.h"
32 #include "include/common/utils/convert_utils_py.h"
33 #include "utils/log_adapter.h"
34 #include "utils/ms_context.h"
35 #include "utils/compile_config.h"
36 #include "utils/interpret_node_recorder.h"
37 #include "pipeline/jit/ps/debug/trace.h"
38 #include "pipeline/jit/ps/parse/resolve.h"
39 #include "abstract/abstract_value.h"
40 #include "ir/func_graph.h"
41 
42 namespace mindspore {
43 namespace fallback {
44 namespace {
45 // Get the type from python type string, defined in Python module 'mindspore.common.dtype'.
GetTypeFromString(const std::string & dtype)46 TypePtr GetTypeFromString(const std::string &dtype) {
47   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
48   constexpr auto get_dtype_python_function = "get_dtype";
49   auto type = python_adapter::CallPyModFn(mod, get_dtype_python_function, py::str(dtype));
50   MS_LOG(DEBUG) << "type: " << type;
51   if (py::isinstance<py::none>(type)) {
52     return nullptr;
53   }
54   auto type_ptr = py::cast<TypePtr>(type);
55   if (type_ptr == nullptr) {
56     return nullptr;
57   }
58   return type_ptr->Clone();
59 }
60 
GetErrorFormatMessage(const AnfNodePtr & node,const std::string & comment)61 std::string GetErrorFormatMessage(const AnfNodePtr &node, const std::string &comment) {
62   std::stringstream err_buf;
63   err_buf << "Wrong comment format for JIT type annotation: '" << comment
64           << "'.\ne.g. '# @jit.typing: () -> tensor_type[int32]' or:"
65           << "\n---\n\tdtype_var = ms.int32\n\t# @jit.typing: () -> tensor_type[{dtype_var}]\n\t...\n---\n\n"
66           << trace::GetDebugInfoStr(node->debug_info());
67   return err_buf.str();
68 }
69 
HandleBaseTypeForAnnotation(const std::string & dtype_str,const std::string & container_type_str,const FormatedVariableTypeFunc & format_type_func,const AnfNodePtr & node,const std::string & comment)70 TypePtr HandleBaseTypeForAnnotation(const std::string &dtype_str, const std::string &container_type_str,
71                                     const FormatedVariableTypeFunc &format_type_func, const AnfNodePtr &node,
72                                     const std::string &comment) {
73   if (!dtype_str.empty()) {
74     return nullptr;
75   }
76   TypePtr base_type = nullptr;
77   // Handle dtype.
78   if (container_type_str.front() == '{' && container_type_str.back() == '}') {  // Handle format variable type.
79     if (!format_type_func) {
80       MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
81     }
82     constexpr auto excluded_size = 2;
83     const auto &variable_base_type = container_type_str.substr(1, container_type_str.size() - excluded_size);
84     // Find variable type.
85     if (!variable_base_type.empty()) {
86       base_type = format_type_func(variable_base_type);
87       if (base_type == nullptr) {  // Not throw exception if not match any variable.
88         return nullptr;
89       }
90     }
91   } else {  // Handle string type.
92     const auto &base_type_str = container_type_str;
93     base_type = GetTypeFromString(base_type_str);
94   }
95   if (base_type == nullptr) {
96     MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
97   }
98   return base_type;
99 }
100 
GetDTypeFromDTypeStr(const std::string & dtype_str,const FormatedVariableTypeFunc & format_type_func,const AnfNodePtr & node,const std::string & comment)101 std::pair<bool, TypePtr> GetDTypeFromDTypeStr(const std::string &dtype_str,
102                                               const FormatedVariableTypeFunc &format_type_func, const AnfNodePtr &node,
103                                               const std::string &comment) {
104   TypePtr dtype = nullptr;
105   if (dtype_str.front() == '{' && dtype_str.back() == '}') {  // Handle format variable dtype.
106     if (!format_type_func) {
107       MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
108     }
109     constexpr auto excluded_size = 2;
110     const auto &variable_dtype = dtype_str.substr(1, dtype_str.size() - excluded_size);
111     // Find variable dtype.
112     if (!variable_dtype.empty()) {
113       dtype = format_type_func(variable_dtype);
114       if (dtype == nullptr) {  // Not throw exception if not match any variable.
115         return std::make_pair(false, nullptr);
116       }
117     }
118   } else {  // Handle string dtype.
119     dtype = GetTypeFromString(dtype_str);
120   }
121   return std::make_pair(true, dtype);
122 }
123 
HandleContainerTypeForAnnotation(const std::string & dtype_str,const std::string & container_type_str,const FormatedVariableTypeFunc & format_type_func,const AnfNodePtr & node,const std::string & comment)124 TypePtr HandleContainerTypeForAnnotation(const std::string &dtype_str, const std::string &container_type_str,
125                                          const FormatedVariableTypeFunc &format_type_func, const AnfNodePtr &node,
126                                          const std::string &comment) {
127   const auto &container_type = GetTypeFromString(container_type_str);
128   if (container_type == nullptr) {
129     MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
130   }
131   if (!container_type->isa<Tuple>() && !container_type->isa<List>() && !container_type->isa<TensorType>()) {
132     MS_LOG(EXCEPTION) << "JIT type annotation only support tensor/list_/tuple_, but got '" << container_type_str;
133   }
134 
135   auto [is_match, dtype] = GetDTypeFromDTypeStr(dtype_str, format_type_func, node, comment);
136   if (!is_match) {
137     return nullptr;
138   }
139   if (dtype == nullptr) {
140     MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
141   }
142   if (container_type->isa<TensorType>()) {  // Handle tensor type.
143     if (!dtype->isa<Number>()) {
144       MS_LOG(EXCEPTION) << "Cannot get dtype for by input string: '" << dtype_str << "', for '" << container_type_str
145                         << "'\n"
146                         << trace::GetDebugInfoStr(node->debug_info());
147     }
148     container_type->cast<TensorTypePtr>()->set_element(dtype);
149   } else if (container_type->isa<Tuple>() || container_type->isa<List>()) {  // Handle list_/tuple_ type.
150     // To handle nested sequence later.
151     if (!dtype->isa<Number>() && !dtype->isa<TensorType>()) {
152       MS_LOG(EXCEPTION) << "Cannot get element type for by input string: '" << dtype_str << "', for '"
153                         << container_type_str << "'\n"
154                         << trace::GetDebugInfoStr(node->debug_info());
155     }
156     if (container_type->isa<Tuple>()) {
157       container_type->cast<TuplePtr>()->set_elements(TypePtrList({dtype}));
158     } else if (container_type->isa<List>()) {
159       container_type->cast<ListPtr>()->set_elements(TypePtrList({dtype}));
160     }
161     return nullptr;  // Supports tuple_[...] / list_[...] later.
162   }
163   return container_type;
164 }
165 }  // namespace
166 
CreatePyExecuteCNode(const FuncGraphPtr & fg,const AnfNodePtr & script,const AnfNodePtr & keys,const AnfNodePtr & values,const NodeDebugInfoPtr & debug_info)167 CNodePtr CreatePyExecuteCNode(const FuncGraphPtr &fg, const AnfNodePtr &script, const AnfNodePtr &keys,
168                               const AnfNodePtr &values, const NodeDebugInfoPtr &debug_info) {
169   const auto interpreted_cnode = fg->NewCNode({NewValueNode(prim::kPrimPyExecute), script, keys, values});
170   if (debug_info != nullptr) {
171     interpreted_cnode->set_debug_info(debug_info);
172   }
173   // Record the PyExecute node.
174   InterpretNodeRecorder::GetInstance().PushPyExecuteNode(interpreted_cnode);
175   return interpreted_cnode;
176 }
177 
CreatePyExecuteCNode(const AnfNodePtr & orig_node,const AnfNodePtr & script,const AnfNodePtr & keys,const AnfNodePtr & values)178 CNodePtr CreatePyExecuteCNode(const AnfNodePtr &orig_node, const AnfNodePtr &script, const AnfNodePtr &keys,
179                               const AnfNodePtr &values) {
180   const FuncGraphPtr &fg = orig_node->func_graph();
181   if (fg == nullptr) {
182     MS_LOG(INTERNAL_EXCEPTION) << "The func graph is null. orig_node: " << orig_node->DebugString();
183   }
184   const auto interpreted_cnode = CreatePyExecuteCNode(fg, script, keys, values, orig_node->debug_info());
185   return interpreted_cnode;
186 }
187 
CreatePyExecuteCNodeInOrder(const FuncGraphPtr & fg,const AnfNodePtr & script,const AnfNodePtr & keys,const AnfNodePtr & values,const NodeDebugInfoPtr & debug_info)188 CNodePtr CreatePyExecuteCNodeInOrder(const FuncGraphPtr &fg, const AnfNodePtr &script, const AnfNodePtr &keys,
189                                      const AnfNodePtr &values, const NodeDebugInfoPtr &debug_info) {
190   const auto interpreted_cnode = fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), script, keys, values});
191   interpreted_cnode->set_debug_info(debug_info);
192   // Record the PyExecute node.
193   InterpretNodeRecorder::GetInstance().PushPyExecuteNode(interpreted_cnode);
194   return interpreted_cnode;
195 }
196 
CreatePyExecuteCNodeInOrder(const AnfNodePtr & orig_node,const AnfNodePtr & script,const AnfNodePtr & keys,const AnfNodePtr & values)197 CNodePtr CreatePyExecuteCNodeInOrder(const AnfNodePtr &orig_node, const AnfNodePtr &script, const AnfNodePtr &keys,
198                                      const AnfNodePtr &values) {
199   const FuncGraphPtr &fg = orig_node->func_graph();
200   if (fg == nullptr) {
201     MS_LOG(INTERNAL_EXCEPTION) << "The func graph is null. orig_node: " << orig_node->DebugString();
202   }
203   const auto interpreted_cnode = CreatePyExecuteCNodeInOrder(fg, script, keys, values, orig_node->debug_info());
204   return interpreted_cnode;
205 }
206 
CreatePyInterpretCNode(const FuncGraphPtr & fg,const std::string & script_text,const py::object & global_dict_obj,const AnfNodePtr & local_dict_node,const NodeDebugInfoPtr & debug_info)207 CNodePtr CreatePyInterpretCNode(const FuncGraphPtr &fg, const std::string &script_text,
208                                 const py::object &global_dict_obj, const AnfNodePtr &local_dict_node,
209                                 const NodeDebugInfoPtr &debug_info) {
210   auto script = std::make_shared<parse::Script>(script_text);
211   auto script_node = NewValueNode(script);
212   parse::PyObjectWrapperPtr global_dict_wrapper = std::make_shared<parse::InterpretedObject>(global_dict_obj);
213   auto global_dict_node = NewValueNode(global_dict_wrapper);
214   auto node = fg->NewCNode({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
215   if (debug_info != nullptr) {
216     node->set_debug_info(debug_info);
217   }
218   InterpretNodeRecorder::GetInstance().PushPyInterpretNode(node);
219   return node;
220 }
221 
CreatePyInterpretCNodeInOrder(const FuncGraphPtr & fg,const std::string & script_text,const py::object & global_dict_obj,const AnfNodePtr & local_dict_node,const NodeDebugInfoPtr & debug_info)222 CNodePtr CreatePyInterpretCNodeInOrder(const FuncGraphPtr &fg, const std::string &script_text,
223                                        const py::object &global_dict_obj, const AnfNodePtr &local_dict_node,
224                                        const NodeDebugInfoPtr &debug_info) {
225   auto script = std::make_shared<parse::Script>(script_text);
226   auto script_node = NewValueNode(script);
227   parse::PyObjectWrapperPtr global_dict_wrapper = std::make_shared<parse::InterpretedObject>(global_dict_obj);
228   auto global_dict_node = NewValueNode(global_dict_wrapper);
229   auto node =
230     fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
231   if (debug_info != nullptr) {
232     node->set_debug_info(debug_info);
233   }
234   InterpretNodeRecorder::GetInstance().PushPyInterpretNode(node);
235   return node;
236 }
237 
SetPyObjectToLocalVariable(const std::string & key,const py::object & value)238 void SetPyObjectToLocalVariable(const std::string &key, const py::object &value) {
239   py::module mod = python_adapter::GetPyModule("mindspore.common._jit_fallback_utils");
240   constexpr auto set_local_variable = "set_local_variable";
241   MS_LOG(DEBUG) << set_local_variable << "([" << key << "]/" << key << ", " << value << ")";
242   (void)python_adapter::CallPyModFn(mod, set_local_variable, key, value);
243 }
244 
ConvertPyObjectToPyExecute(const FuncGraphPtr & fg,const std::string & key,const py::object value,const AnfNodePtr & node,bool replace)245 AnfNodePtr ConvertPyObjectToPyExecute(const FuncGraphPtr &fg, const std::string &key, const py::object value,
246                                       const AnfNodePtr &node, bool replace) {
247   auto value_node_key = ConvertRealStrToUnicodeStr(key, 0);
248   // Set the value node into dict firstly.
249   SetPyObjectToLocalVariable(value_node_key, value);
250 
251   // Get the value node from the dict in IR.
252   std::stringstream script_buffer;
253   script_buffer << "__import__('mindspore').common._jit_fallback_utils.get_local_variable(" << value_node_key << ")";
254   const std::string &script = script_buffer.str();
255   const auto script_str = std::make_shared<StringImm>(script);
256 
257   // Build new CNode for value node.
258   ValuePtrList keys({std::make_shared<StringImm>(value_node_key)});
259   ValuePtrList values({std::make_shared<StringImm>(value_node_key)});
260   const auto interpreted_cnode =
261     CreatePyExecuteCNode(fg, NewValueNode(script_str), NewValueNode(std::make_shared<ValueTuple>(keys)),
262                          NewValueNode(std::make_shared<ValueTuple>(values)), node->debug_info());
263   constexpr auto debug_recursive_level = 2;
264   MS_LOG(DEBUG) << "original node: " << node->DebugString(debug_recursive_level)
265                 << ", interpreted_cnode: " << interpreted_cnode->DebugString(debug_recursive_level);
266   if (replace) {
267     fg->ReplaceInOrder(node, interpreted_cnode);
268   }
269   return interpreted_cnode;
270 }
271 
ConvertPyObjectToPyInterpret(const FuncGraphPtr & fg,const std::string & key,const py::object value,const AnfNodePtr & node,bool replace)272 AnfNodePtr ConvertPyObjectToPyInterpret(const FuncGraphPtr &fg, const std::string &key, const py::object value,
273                                         const AnfNodePtr &node, bool replace) {
274   auto value_node_key = ConvertRealStrToUnicodeStr(key, 0);
275   // Set the value node into dict firstly.
276   SetPyObjectToLocalVariable(value_node_key, value);
277 
278   // Build the script
279   std::stringstream script_buffer;
280   script_buffer << "__import__('mindspore').common._jit_fallback_utils.get_local_variable(" << value_node_key << ")";
281   const std::string &script = script_buffer.str();
282   auto script_str = std::make_shared<parse::Script>(script);
283   auto script_node = NewValueNode(script_str);
284 
285   // Build the global dict.
286   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
287   constexpr auto python_get_dict = "get_global_params";
288   const auto &global_dict = python_adapter::CallPyModFn(mod, python_get_dict);
289   parse::PyObjectWrapperPtr interpreted_global_dict = std::make_shared<parse::InterpretedObject>(global_dict);
290   auto global_dict_node = NewValueNode(interpreted_global_dict);
291 
292   // Build the local dict.
293   ValuePtrList local_keys({std::make_shared<StringImm>(value_node_key)});
294   ValuePtrList local_values({std::make_shared<StringImm>(value_node_key)});
295   auto local_key_tuple = NewValueNode(std::make_shared<ValueTuple>(local_keys));
296   auto local_value_tuple = NewValueNode(std::make_shared<ValueTuple>(local_values));
297   auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_tuple, local_value_tuple});
298   auto prim = NewValueNode(prim::kPrimPyInterpret);
299   auto interpret_node = fg->NewCNode({prim, script_node, global_dict_node, local_dict_node});
300   InterpretNodeRecorder::GetInstance().PushPyInterpretNode(interpret_node);
301   if (replace) {
302     fg->ReplaceInOrder(node, interpret_node);
303   }
304   return interpret_node;
305 }
306 
ConvertMsClassObjectToPyExecute(const FuncGraphPtr & fg,const ValuePtr & value,const AnfNodePtr & node)307 AnfNodePtr ConvertMsClassObjectToPyExecute(const FuncGraphPtr &fg, const ValuePtr &value, const AnfNodePtr &node) {
308   const auto &ms_class_value = dyn_cast<parse::MsClassObject>(value);
309   if (ms_class_value == nullptr) {
310     return nullptr;
311   }
312   return ConvertPyObjectToPyExecute(fg, ms_class_value->name(), ms_class_value->obj(), node, true);
313 }
314 
GetJitAnnotationTypeFromComment(const AnfNodePtr & node,const FormatedVariableTypeFunc & format_type_func)315 TypePtr GetJitAnnotationTypeFromComment(const AnfNodePtr &node, const FormatedVariableTypeFunc &format_type_func) {
316   const auto &debug_info = trace::GetSourceCodeDebugInfo(node->debug_info());
317   const auto &location = debug_info->location();
318   if (location == nullptr) {
319     MS_LOG(INFO) << "Location info is null, node: " << node->DebugString();
320     return nullptr;
321   }
322   const auto &comments = location->comments();
323   if (comments.empty()) {
324     return nullptr;
325   }
326   // Only use the last comment.
327   const auto &comment = comments.back();
328   std::regex regex("^#\\s*@jit.typing\\s*:\\s*\\(\\)\\s*->\\s*([a-zA-Z0-9{}_]+)?\\[?([a-zA-Z0-9{}_]+)?\\]?$");
329   std::smatch matched_results;
330   if (std::regex_match(comment, matched_results, regex)) {
331     constexpr auto container_match_count = 3;
332     // Not match.
333     if (matched_results.size() != container_match_count) {
334       return nullptr;
335     }
336     const auto &container_type_str = matched_results[1].str();
337     const auto &dtype_str = matched_results[container_match_count - 1].str();
338     MS_LOG(DEBUG) << "matched_results: " << matched_results[0] << ", " << container_type_str << ", " << dtype_str;
339     // Match nothing.
340     if (container_type_str.empty()) {
341       MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
342     }
343     // Handle base type only.
344     auto base_type = HandleBaseTypeForAnnotation(dtype_str, container_type_str, format_type_func, node, comment);
345     if (base_type != nullptr) {
346       return base_type;
347     }
348     // Handle container type: tensor, list_ and tuple_.
349     return HandleContainerTypeForAnnotation(dtype_str, container_type_str, format_type_func, node, comment);
350   }
351   return nullptr;
352 }
353 
GetJitAnnotationSideEffectFromComment(const AnfNodePtr & node)354 bool GetJitAnnotationSideEffectFromComment(const AnfNodePtr &node) {
355   MS_EXCEPTION_IF_NULL(node);
356   const auto &debug_info = trace::GetSourceCodeDebugInfo(node->debug_info());
357   const auto &location = debug_info->location();
358   if (location == nullptr) {
359     MS_LOG(DEBUG) << "Location info is null, node: " << node->DebugString();
360     return false;
361   }
362   const auto &comments = location->comments();
363   if (comments.empty()) {
364     return false;
365   }
366   // Only use the last comment.
367   const auto &comment = comments.back();
368   std::regex regex("^#\\s*@jit.typing:\\s*side_effect");
369   if (std::regex_match(comment, regex)) {
370     return true;
371   }
372   return false;
373 }
374 
ConvertRealStrToUnicodeStr(const std::string & target,size_t index)375 std::string ConvertRealStrToUnicodeStr(const std::string &target, size_t index) {
376   std::stringstream script_buffer;
377   script_buffer << kPyExecPrefix << std::to_string(index);
378   std::vector<size_t> convert_pos;
379   for (size_t i = 0; i < target.size(); ++i) {
380     auto c = target[i];
381     if (!std::isalnum(c)) {
382       convert_pos.push_back(i);
383     }
384   }
385   size_t start = 0;
386   for (auto end : convert_pos) {
387     std::string sub_non_convert = target.substr(start, end - start);
388     if (sub_non_convert.size() != 0) {
389       script_buffer << kUnderLine << sub_non_convert;
390     }
391     char sub_convert = target[end];
392     std::stringstream hex_s;
393     hex_s << kUnderLine << kHexPrefix << std::hex << static_cast<int>(sub_convert);
394     script_buffer << hex_s.str();
395     start = end + 1;
396   }
397   if (target.substr(start).size() != 0) {
398     script_buffer << kUnderLine << target.substr(start);
399   }
400   script_buffer << kPyExecSuffix;
401   auto unicode_str = script_buffer.str();
402   MS_LOG(DEBUG) << "Get Unicode str: " << unicode_str;
403   return script_buffer.str();
404 }
405 
GeneratePyExecuteNodeForCallObj(const FuncGraphPtr & func_graph,const py::object & meta_obj,const AnfNodePtr & node,const std::string & name)406 AnfNodePtr GeneratePyExecuteNodeForCallObj(const FuncGraphPtr &func_graph, const py::object &meta_obj,
407                                            const AnfNodePtr &node, const std::string &name) {
408   if (py::isinstance<py::none>(meta_obj)) {
409     return nullptr;
410   }
411   auto res = fallback::ConvertPyObjectToPyInterpret(func_graph, name, meta_obj, node, false);
412   // '__keep_metafg_obj_flag__' is to keep metafg obj rather than convert to prim.
413   res->set_user_data("__keep_metafg_obj_flag__", std::make_shared<bool>(true));
414   return res;
415 }
416 
ContainsSequenceAnyType(const AbstractBasePtr & abs)417 bool ContainsSequenceAnyType(const AbstractBasePtr &abs) {
418   if (abs == nullptr) {
419     return false;
420   }
421   if (abs->isa<abstract::AbstractSequence>()) {
422     auto seq_abs = abs->cast_ptr<abstract::AbstractSequence>();
423     MS_EXCEPTION_IF_NULL(seq_abs);
424     if (seq_abs->dynamic_len()) {
425       auto element_abs = seq_abs->dynamic_len_element_abs();
426       if (ContainsSequenceAnyType(element_abs)) {
427         return true;
428       }
429     } else {
430       const auto &elements = seq_abs->elements();
431       for (size_t item_index = 0; item_index < elements.size(); ++item_index) {
432         const auto &item_abs = elements[item_index];
433         if (ContainsSequenceAnyType(item_abs)) {
434           return true;
435         }
436       }
437     }
438   }
439   return abs->isa<abstract::AbstractAny>();
440 }
441 
SequenceAllElementsIsScalar(const AbstractBasePtr & abs)442 bool SequenceAllElementsIsScalar(const AbstractBasePtr &abs) {
443   if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
444     return false;
445   }
446   auto seq_abs = abs->cast_ptr<abstract::AbstractSequence>();
447   MS_EXCEPTION_IF_NULL(seq_abs);
448   if (seq_abs->dynamic_len()) {
449     auto element_abs = seq_abs->dynamic_len_element_abs();
450     if (element_abs == nullptr || !element_abs->isa<abstract::AbstractScalar>()) {
451       return false;
452     }
453     auto arg_type = element_abs->BuildType();
454     MS_EXCEPTION_IF_NULL(arg_type);
455     return arg_type->isa<Number>();
456   }
457   const auto &elements = seq_abs->elements();
458   for (size_t item_index = 0; item_index < elements.size(); ++item_index) {
459     const auto &item_abs = elements[item_index];
460     if (item_abs == nullptr || !item_abs->isa<abstract::AbstractScalar>()) {
461       return false;
462     }
463     auto item_arg_type = item_abs->BuildType();
464     MS_EXCEPTION_IF_NULL(item_arg_type);
465     if (!item_arg_type->isa<Number>()) {
466       return false;
467     }
468   }
469   return true;
470 }
471 
GeneratePyObj(const abstract::AbstractBasePtr & abs)472 py::object GeneratePyObj(const abstract::AbstractBasePtr &abs) {
473   MS_EXCEPTION_IF_NULL(abs);
474   if (abs->isa<abstract::AbstractList>()) {
475     auto abs_list = abs->cast<abstract::AbstractListPtr>();
476     if (HasObjInExtraInfoHolder(abs_list)) {
477       return GetObjFromExtraInfoHolder(abs_list);
478     }
479     py::list ret = py::list(abs_list->size());
480     const auto &elements = abs_list->elements();
481     for (size_t i = 0; i < elements.size(); ++i) {
482       ret[i] = GeneratePyObj(elements[i]);
483     }
484     return ret;
485   } else if (abs->isa<abstract::AbstractTuple>()) {
486     auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
487     py::tuple ret = py::tuple(abs_tuple->size());
488     const auto &elements = abs_tuple->elements();
489     for (size_t i = 0; i < elements.size(); ++i) {
490       ret[i] = GeneratePyObj(elements[i]);
491     }
492     return ret;
493   } else if (abs->isa<abstract::AbstractDictionary>()) {
494     auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
495     py::dict ret = py::dict();
496     const auto &key_value_pairs = abs_dict->elements();
497     for (size_t i = 0; i < key_value_pairs.size(); ++i) {
498       py::object key = GeneratePyObj(key_value_pairs[i].first);
499       // The key should be unique.
500       key = py::isinstance<py::none>(key) ? py::str(std::to_string(i)) : key;
501       ret[key] = GeneratePyObj(key_value_pairs[i].second);
502     }
503     return ret;
504   }
505   return ValueToPyData(abs->BuildValue());
506 }
507 
EnableFallbackListDictInplace()508 bool EnableFallbackListDictInplace() {
509   const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
510   static const auto allow_inplace_ops = common::GetCompileConfig("FALLBACK_SUPPORT_LIST_DICT_INPLACE") != "0";
511   return allow_fallback_runtime && allow_inplace_ops;
512 }
513 
AttachPyObjToExtraInfoHolder(const abstract::AbstractBasePtr & abs,const py::object & obj,bool create_in_graph)514 void AttachPyObjToExtraInfoHolder(const abstract::AbstractBasePtr &abs, const py::object &obj, bool create_in_graph) {
515   MS_EXCEPTION_IF_NULL(abs);
516   constexpr auto py_object_key = "py_obj_key";
517   constexpr auto create_in_graph_key = "create_in_graph_key";
518   if (abs->isa<abstract::AbstractList>()) {
519     auto abs_list = abs->cast<abstract::AbstractListPtr>();
520     abs_list->SetData<py::object>(py_object_key, std::make_shared<py::object>(obj));
521     abs_list->SetData<bool>(create_in_graph_key, std::make_shared<bool>(create_in_graph));
522     return;
523   }
524   if (abs->isa<abstract::AbstractDictionary>()) {
525     auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
526     abs_dict->SetData<py::object>(py_object_key, std::make_shared<py::object>(obj));
527     abs_dict->SetData<bool>(create_in_graph_key, std::make_shared<bool>(create_in_graph));
528     return;
529   }
530   MS_INTERNAL_EXCEPTION(TypeError) << "The abstract should be a ExtraInfoHolder but got : " << abs->ToString();
531 }
532 
GetObjFromExtraInfoHolder(const abstract::AbstractBasePtr & abs)533 py::object GetObjFromExtraInfoHolder(const abstract::AbstractBasePtr &abs) {
534   MS_EXCEPTION_IF_NULL(abs);
535   constexpr auto py_object_key = "py_obj_key";
536   if (abs->isa<abstract::AbstractList>()) {
537     auto abs_list = abs->cast<abstract::AbstractListPtr>();
538     return *abs_list->GetData<py::object>(py_object_key);
539   }
540   if (abs->isa<abstract::AbstractDictionary>()) {
541     auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
542     return *abs_dict->GetData<py::object>(py_object_key);
543   }
544   MS_INTERNAL_EXCEPTION(TypeError) << "The abstract should be a ExtraInfoHolder but got : " << abs->ToString();
545 }
546 
HasCreateInGraphInExtraInfoHolder(const abstract::AbstractBasePtr & abs)547 bool HasCreateInGraphInExtraInfoHolder(const abstract::AbstractBasePtr &abs) {
548   MS_EXCEPTION_IF_NULL(abs);
549   constexpr auto create_in_graph_key = "create_in_graph_key";
550   if (abs->isa<abstract::AbstractList>()) {
551     auto abs_list = abs->cast<abstract::AbstractListPtr>();
552     return abs_list->HasData(create_in_graph_key);
553   }
554   if (abs->isa<abstract::AbstractDictionary>()) {
555     auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
556     return abs_dict->HasData(create_in_graph_key);
557   }
558   return false;
559 }
560 
GetCreateInGraphFromExtraInfoHolder(const abstract::AbstractBasePtr & abs)561 bool GetCreateInGraphFromExtraInfoHolder(const abstract::AbstractBasePtr &abs) {
562   MS_EXCEPTION_IF_NULL(abs);
563   constexpr auto create_in_graph_key = "create_in_graph_key";
564   if (abs->isa<abstract::AbstractList>()) {
565     auto abs_list = abs->cast<abstract::AbstractListPtr>();
566     return *abs_list->GetData<bool>(create_in_graph_key);
567   }
568   if (abs->isa<abstract::AbstractDictionary>()) {
569     auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
570     return *abs_dict->GetData<bool>(create_in_graph_key);
571   }
572   MS_INTERNAL_EXCEPTION(TypeError) << "The abstract should be a ExtraInfoHolder but got : " << abs->ToString();
573 }
574 
HasObjInExtraInfoHolder(const abstract::AbstractBasePtr & abs)575 bool HasObjInExtraInfoHolder(const abstract::AbstractBasePtr &abs) {
576   MS_EXCEPTION_IF_NULL(abs);
577   constexpr auto py_object_key = "py_obj_key";
578   if (abs->isa<abstract::AbstractList>()) {
579     auto abs_list = abs->cast<abstract::AbstractListPtr>();
580     return abs_list->HasData(py_object_key);
581   }
582   if (abs->isa<abstract::AbstractDictionary>()) {
583     auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
584     return abs_dict->HasData(py_object_key);
585   }
586   return false;
587 }
588 
589 // Nested attach list and dict object to corresponding abstract.
AttachPyObjToAbs(const AbstractBasePtr & abs,const py::object & obj,bool create_in_graph)590 void AttachPyObjToAbs(const AbstractBasePtr &abs, const py::object &obj, bool create_in_graph) {
591   if (!EnableFallbackListDictInplace()) {
592     return;
593   }
594   if (abs->isa<abstract::AbstractNamedTuple>()) {
595     return;
596   }
597   if (!abs->isa<abstract::AbstractSequence>() && !abs->isa<abstract::AbstractDictionary>()) {
598     return;
599   }
600   if (py::hasattr(obj, PYTHON_CELL_AS_LIST) || py::hasattr(obj, PYTHON_CELL_AS_DICT)) {
601     // CellList and CellDict do not support inplace operations, do not need to attach python object.
602     return;
603   }
604   if (abs->isa<abstract::AbstractCSRTensor>() || abs->isa<abstract::AbstractCOOTensor>()) {
605     return;
606   }
607   if (abs->isa<abstract::AbstractList>()) {
608     MS_LOG(DEBUG) << "Attach list python" << obj << " to abstract: " << abs->ToString();
609     if (!py::isinstance<py::list>(obj)) {
610       MS_INTERNAL_EXCEPTION(TypeError) << "Object should be list but got: " << py::str(obj);
611     }
612     auto abs_list = abs->cast<abstract::AbstractListPtr>();
613     AttachPyObjToExtraInfoHolder(abs_list, obj, create_in_graph);
614     auto list_obj = py::list(obj);
615     for (size_t i = 0; i < abs_list->size(); ++i) {
616       auto element_abs = abs_list->elements()[i];
617       auto element_obj = list_obj[i];
618       AttachPyObjToAbs(element_abs, element_obj, create_in_graph);
619     }
620     return;
621   }
622   if (abs->isa<abstract::AbstractDictionary>()) {
623     if (!py::isinstance<py::dict>(obj)) {
624       MS_INTERNAL_EXCEPTION(TypeError) << "Object should be dict but got: " << py::str(obj);
625     }
626     auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
627     MS_LOG(DEBUG) << "Attach dict python" << obj << " to abstract: " << abs->ToString();
628     AttachPyObjToExtraInfoHolder(abs_dict, obj, create_in_graph);
629     auto dict_obj = py::dict(obj);
630     auto key_list_obj = py::list(obj);
631     const auto &key_value_pairs = abs_dict->elements();
632     for (size_t i = 0; i < key_value_pairs.size(); ++i) {
633       auto value_abs = key_value_pairs[i].second;
634       auto value_obj = dict_obj[key_list_obj[i]];
635       AttachPyObjToAbs(value_abs, value_obj, create_in_graph);
636     }
637     return;
638   }
639   auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
640   if (!py::isinstance<py::tuple>(obj)) {
641     MS_INTERNAL_EXCEPTION(TypeError) << "Object should be tuple but got: " << py::str(obj);
642   }
643   auto tuple_obj = py::tuple(obj);
644   for (size_t i = 0; i < abs_tuple->size(); ++i) {
645     auto element_abs = abs_tuple->elements()[i];
646     auto element_obj = tuple_obj[i];
647     AttachPyObjToAbs(element_abs, element_obj, create_in_graph);
648   }
649 }
650 
GetPyObjectPtrStr(const py::object & obj)651 std::string GetPyObjectPtrStr(const py::object &obj) {
652   std::stringstream ss;
653   ss << obj.ptr();
654   return ss.str();
655 }
656 
CheckInterpretInput(const AnfNodePtr & node)657 bool CheckInterpretInput(const AnfNodePtr &node) {
658   MS_EXCEPTION_IF_NULL(node);
659   if (IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
660     return true;
661   }
662   if (node->isa<CNode>()) {
663     auto cnode = node->cast<CNodePtr>();
664     const auto &inputs = cnode->inputs();
665     return std::any_of(inputs.begin(), inputs.end(),
666                        [](const AnfNodePtr &input) { return CheckInterpretInput(input); });
667   }
668   return false;
669 }
670 
SetPyObjectToNode(const AnfNodePtr & node,const py::object & obj)671 void SetPyObjectToNode(const AnfNodePtr &node, const py::object &obj) {
672   MS_EXCEPTION_IF_NULL(node);
673   if (!EnableFallbackListDictInplace()) {
674     return;
675   }
676   constexpr auto py_obj_str = "__py_object__";
677   if (py::isinstance<py::list>(obj)) {
678     node->set_user_data<py::list>(py_obj_str, std::make_shared<py::list>(py::list(obj)));
679   } else if (py::isinstance<py::tuple>(obj)) {
680     node->set_user_data<py::tuple>(py_obj_str, std::make_shared<py::tuple>(py::tuple(obj)));
681   } else if (py::isinstance<py::dict>(obj)) {
682     node->set_user_data<py::dict>(py_obj_str, std::make_shared<py::dict>(py::dict(obj)));
683   }
684 }
685 
HasPyObjectInNode(const AnfNodePtr & node)686 bool HasPyObjectInNode(const AnfNodePtr &node) {
687   MS_EXCEPTION_IF_NULL(node);
688   constexpr auto py_obj_str = "__py_object__";
689   return node->has_user_data(py_obj_str);
690 }
691 
GetPyObjectFromNode(const AnfNodePtr & node)692 py::object GetPyObjectFromNode(const AnfNodePtr &node) {
693   MS_EXCEPTION_IF_NULL(node);
694   constexpr auto py_obj_str = "__py_object__";
695   return *node->user_data<py::object>(py_obj_str);
696 }
697 
698 // Convert node to pyinterpret with specific function name.
699 //    ConvertCNodeToPyInterpretForPrim(prim(x1, x2), func_name)
700 //    --->
701 //    PyInterpret("func_name(__input1__, __input2__)", global_dict, {"__input1__": x1, "__input2__": x2})
ConvertCNodeToPyInterpretForPrim(const CNodePtr & cnode,const string & name)702 AnfNodePtr ConvertCNodeToPyInterpretForPrim(const CNodePtr &cnode, const string &name) {
703   MS_EXCEPTION_IF_NULL(cnode);
704   const auto &fg = cnode->func_graph();
705   MS_EXCEPTION_IF_NULL(fg);
706   std::stringstream script_buffer;
707   script_buffer << name << "(";
708   const auto &inputs = cnode->inputs();
709   std::vector<AnfNodePtr> keys_tuple_node_inputs{NewValueNode(prim::kPrimMakeTuple)};
710   std::vector<AnfNodePtr> values_tuple_node_inputs{NewValueNode(prim::kPrimMakeTuple)};
711   for (size_t index = 1; index < inputs.size(); ++index) {
712     const auto &internal_arg = fallback::ConvertRealStrToUnicodeStr(name, index);
713     script_buffer << internal_arg << ", ";
714     auto key_node = NewValueNode(std::make_shared<StringImm>(internal_arg));
715     auto value_node = inputs[index];
716     (void)keys_tuple_node_inputs.emplace_back(key_node);
717     (void)values_tuple_node_inputs.emplace_back(value_node);
718   }
719   script_buffer << ")";
720   const std::string &script = script_buffer.str();
721   auto keys_tuple_node = fg->NewCNodeInOrder(keys_tuple_node_inputs);
722   auto values_tuple_node = fg->NewCNodeInOrder(values_tuple_node_inputs);
723   auto local_dict_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), keys_tuple_node, values_tuple_node});
724   auto pyinterpret_node = CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node, cnode->debug_info());
725   MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << pyinterpret_node->DebugString();
726   return pyinterpret_node;
727 }
728 
729 // Convert some CNode to PyExectue, eg:
730 // isinstance(xxx.asnumpy(), np.ndarray)  -- > PyExectue("isinstance(arg1, arg2)", local_keys, local_values)
ConvertCNodeToPyExecuteForPrim(const CNodePtr & cnode,const string & name)731 AnfNodePtr ConvertCNodeToPyExecuteForPrim(const CNodePtr &cnode, const string &name) {
732   MS_EXCEPTION_IF_NULL(cnode);
733   const auto &fg = cnode->func_graph();
734   MS_EXCEPTION_IF_NULL(fg);
735   std::string script = name + "(";
736   std::string internal_arg;
737   size_t arg_nums = cnode->size() - 1;
738   std::vector<AnfNodePtr> keys_tuple_node_inputs{NewValueNode(prim::kPrimMakeTuple)};
739   std::vector<AnfNodePtr> values_tuple_node_inputs{NewValueNode(prim::kPrimMakeTuple)};
740   for (size_t index = 1; index < arg_nums; ++index) {
741     internal_arg = fallback::ConvertRealStrToUnicodeStr(name, index);
742     script = script + internal_arg + ", ";
743     auto key_node = NewValueNode(std::make_shared<StringImm>(internal_arg));
744     auto value_node = cnode->input(index);
745     (void)keys_tuple_node_inputs.emplace_back(key_node);
746     (void)values_tuple_node_inputs.emplace_back(value_node);
747   }
748   string last_input = fallback::ConvertRealStrToUnicodeStr(name, arg_nums);
749   script = script + last_input + ")";
750   (void)keys_tuple_node_inputs.emplace_back(NewValueNode(std::make_shared<StringImm>(last_input)));
751   (void)values_tuple_node_inputs.emplace_back(cnode->input(arg_nums));
752   auto script_node = NewValueNode(std::make_shared<StringImm>(script));
753   auto keys_tuple_node = fg->NewCNodeInOrder(keys_tuple_node_inputs);
754   auto values_tuple_node = fg->NewCNodeInOrder(values_tuple_node_inputs);
755   auto pyexecute_node =
756     CreatePyExecuteCNodeInOrder(fg, script_node, keys_tuple_node, values_tuple_node, cnode->debug_info());
757   MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << pyexecute_node->DebugString();
758   return pyexecute_node;
759 }
760 
GeneratePyInterpretWithAbstract(const FuncGraphPtr & fg,const std::vector<std::string> & funcs_str,const size_t input_size)761 AnfNodePtr GeneratePyInterpretWithAbstract(const FuncGraphPtr &fg, const std::vector<std::string> &funcs_str,
762                                            const size_t input_size) {
763   AnfNodePtrList node_inputs{NewValueNode(prim::kPrimMakeTuple)};
764   AnfNodePtrList keys_inputs{NewValueNode(prim::kPrimMakeTuple)};
765   std::stringstream script_buffer;
766   for (size_t i = 0; i < funcs_str.size(); ++i) {
767     script_buffer << funcs_str[i] << "(";
768   }
769   for (size_t i = 0; i < input_size; ++i) {
770     const std::string cur_name = "__input_" + std::to_string(i) + "__";
771     script_buffer << cur_name << ",";
772     (void)keys_inputs.emplace_back(NewValueNode(cur_name));
773     (void)node_inputs.emplace_back(fg->add_parameter());
774   }
775   for (size_t i = 0; i < funcs_str.size(); ++i) {
776     script_buffer << ")";
777   }
778   auto script_text = script_buffer.str();
779   auto script = std::make_shared<parse::Script>(script_text);
780   auto script_node = NewValueNode(script);
781   auto global_dict_node = NewValueNode(std::make_shared<parse::InterpretedObject>(py::dict()));
782   auto keys_tuple = fg->NewCNode(keys_inputs);
783   auto values_tuple = fg->NewCNode(node_inputs);
784   auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), keys_tuple, values_tuple});
785   auto ret_node = fg->NewCNode({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
786   return ret_node;
787 }
788 
ConvertGetAttrNodeToPyInterpret(const FuncGraphPtr & fg,const CNodePtr & cnode,const std::string & name)789 AnfNodePtr ConvertGetAttrNodeToPyInterpret(const FuncGraphPtr &fg, const CNodePtr &cnode, const std::string &name) {
790   MS_EXCEPTION_IF_NULL(cnode);
791   MS_EXCEPTION_IF_NULL(fg);
792   const std::unordered_map<std::string, std::string> internal_attr_map = {
793     {"__ms_next__", "__import__('mindspore').common._utils._jit_fallback_next_func"}};
794   auto iter = internal_attr_map.find(name);
795   if (iter == internal_attr_map.end()) {
796     return ConvertCNodeToPyInterpretForPrim(cnode, "getattr");
797   }
798   AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
799   AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
800   std::stringstream script_buffer;
801   script_buffer << iter->second << "(";
802 
803   const std::string data_str = "__data__";
804   script_buffer << data_str << ")";
805   (void)local_key_inputs.emplace_back(NewValueNode(data_str));
806   constexpr size_t data_index = 1;
807   (void)local_value_inputs.emplace_back(cnode->input(data_index));
808 
809   const auto &script = script_buffer.str();
810   auto local_key_node = fg->NewCNode(local_key_inputs);
811   auto local_value_node = fg->NewCNode(local_value_inputs);
812   auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
813 
814   auto ret = CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node, cnode->debug_info());
815   MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << ret->DebugString();
816   return ret;
817 }
818 
GetPyObjForFuncGraphAbstractClosure(const AbstractBasePtr & abs)819 py::object GetPyObjForFuncGraphAbstractClosure(const AbstractBasePtr &abs) {
820   if (!abs->isa<abstract::FuncGraphAbstractClosure>()) {
821     return py::none();
822   }
823   auto abs_func = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
824   auto fg = abs_func->func_graph();
825   MS_EXCEPTION_IF_NULL(fg);
826   auto wrapper_obj = fg->python_obj();
827   if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
828     auto obj = wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj();
829     return obj;
830   }
831   // Handle lambda expression scene. Graph generated from lambda function does not have attached python object.
832   auto fg_debug_info = fg->debug_info();
833   MS_EXCEPTION_IF_NULL(fg_debug_info);
834   const auto &fg_name = fg_debug_info->name();
835   const std::string lambda_suffix = "_lambda_";
836   bool end_with_lambda_suffix =
837     (fg_name.size() >= lambda_suffix.size() && fg_name.substr(fg_name.size() - lambda_suffix.size()) == lambda_suffix);
838   if (end_with_lambda_suffix) {
839     auto location = trace::GetSourceCodeDebugInfo(fg_debug_info)->location();
840     MS_EXCEPTION_IF_NULL(location);
841     const auto &lambda_script = location->expr_src();
842     py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
843     return python_adapter::CallPyModFn(mod, "generate_lambda_object", lambda_script);
844   }
845   return py::none();
846 }
847 
GeneratePyInterpretNodeFromMetaFuncGraph(const FuncGraphPtr & func_graph,const AnfNodePtrList & node_inputs,const py::object & meta_obj,const TypePtrList & types,const std::string & name)848 AnfNodePtr GeneratePyInterpretNodeFromMetaFuncGraph(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_inputs,
849                                                     const py::object &meta_obj, const TypePtrList &types,
850                                                     const std::string &name) {
851   std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
852   std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
853   AnfNodePtr call_node = GeneratePyExecuteNodeForCallObj(func_graph, meta_obj, node_inputs[0], name);
854   auto node_inputs_size = node_inputs.size();
855   std::stringstream script_buffer;
856   if (call_node != nullptr) {
857     (void)key_value_list.emplace_back(call_node);
858     std::string uniname = fallback::ConvertRealStrToUnicodeStr(name, 0);
859     (void)key_value_names_list.push_back(NewValueNode(uniname));
860     script_buffer << uniname << "(";
861   } else {
862     script_buffer << "__import__('mindspore').ops.composite.multitype_ops." << name << "(";
863   }
864   for (size_t i = 0; i < node_inputs_size; i++) {
865     if (types[i]->isa<Slice>()) {
866       (void)key_value_names_list.emplace_back(NewValueNode("__start__"));
867       (void)key_value_names_list.emplace_back(NewValueNode("__stop__"));
868       (void)key_value_names_list.emplace_back(NewValueNode("__step__"));
869       auto start_node =
870         func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), node_inputs[i], NewValueNode("start")});
871       auto end_node =
872         func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), node_inputs[i], NewValueNode("stop")});
873       auto step_node =
874         func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), node_inputs[i], NewValueNode("step")});
875       (void)key_value_list.emplace_back(start_node);
876       (void)key_value_list.emplace_back(end_node);
877       (void)key_value_list.emplace_back(step_node);
878       script_buffer << "slice(__start__,__stop__,__step__)";
879     } else {
880       std::stringstream input_key;
881       input_key << "__input_key_" << i << "__";
882       (void)key_value_names_list.push_back(NewValueNode(input_key.str()));
883       (void)key_value_list.emplace_back(node_inputs[i]);
884       script_buffer << input_key.str();
885     }
886     if (i != node_inputs_size) {
887       script_buffer << ",";
888     }
889   }
890   script_buffer << ")";
891   const auto script_str = script_buffer.str();
892   const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
893   const auto key_value_tuple = func_graph->NewCNode(key_value_list);
894 
895   // Generate PyInterpret node with
896   auto local_dict = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), key_value_name_tuple, key_value_tuple});
897   auto res = CreatePyInterpretCNode(func_graph, script_str, py::dict(), local_dict, key_value_name_tuple->debug_info());
898   res->set_user_data(kCheckListDictInplace, std::make_shared<bool>(true));
899   MS_LOG(DEBUG) << "Generate PyInterpret node: " << res->DebugString();
900   return res;
901 }
902 }  // namespace fallback
903 
904 namespace raiseutils {
905 namespace {
CheckIsStr(const AbstractBasePtr & abs)906 bool CheckIsStr(const AbstractBasePtr &abs) {
907   auto scalar = abs->cast_ptr<abstract::AbstractScalar>();
908   MS_EXCEPTION_IF_NULL(scalar);
909   auto scalar_type = scalar->BuildType();
910   MS_EXCEPTION_IF_NULL(scalar_type);
911   if (scalar_type->IsSameTypeId(String::kTypeId)) {
912     return true;
913   }
914   return false;
915 }
916 
GetScalarStringValue(const AbstractBasePtr & abs)917 std::string GetScalarStringValue(const AbstractBasePtr &abs) {
918   MS_EXCEPTION_IF_NULL(abs);
919   auto scalar = abs->cast<abstract::AbstractScalarPtr>();
920   MS_EXCEPTION_IF_NULL(scalar);
921   auto scalar_value = scalar->BuildValue();
922   return scalar_value->ToString();
923 }
924 
GetVariable(const AnfNodePtr & input,const std::shared_ptr<KeyValueInfo> & key_value,const std::string & exception_str,bool need_symbol)925 std::string GetVariable(const AnfNodePtr &input, const std::shared_ptr<KeyValueInfo> &key_value,
926                         const std::string &exception_str, bool need_symbol) {
927   std::string key = MakeRaiseKey(key_value->num_str);
928   std::stringstream script_buffer;
929   key_value->num_str += 1;
930   if (need_symbol) {
931     script_buffer << exception_str << "'+f'{" << key << "}'+'";
932   } else {
933     script_buffer << exception_str << key;
934   }
935   (void)key_value->keys.emplace_back(NewValueNode(std::make_shared<StringImm>(key)));
936   (void)key_value->values.emplace_back(input);
937   return script_buffer.str();
938 }
939 
GetTupleOrListString(const AbstractBasePtr & arg,const AnfNodePtr & input,const std::shared_ptr<KeyValueInfo> & key_value,bool need_symbol,bool need_comma)940 std::string GetTupleOrListString(const AbstractBasePtr &arg, const AnfNodePtr &input,
941                                  const std::shared_ptr<KeyValueInfo> &key_value, bool need_symbol, bool need_comma) {
942   MS_EXCEPTION_IF_NULL(arg);
943   bool has_variable = CheckHasVariable(arg);
944   std::stringstream exception_str;
945   bool is_tuple = arg->isa<abstract::AbstractTuple>();
946   // Process raise ValueError("str")
947   auto arg_tuple = arg->cast_ptr<abstract::AbstractSequence>();
948   MS_EXCEPTION_IF_NULL(arg_tuple);
949   const auto &arg_tuple_elements = arg_tuple->elements();
950   if (!input->isa<CNode>() && has_variable) {
951     return GetVariable(input, key_value, exception_str.str(), need_symbol);
952   }
953   if (arg_tuple_elements.size() > 1 && !IsPrimitiveCNode(input, prim::kPrimJoinedStr)) {
954     if (is_tuple) {
955       exception_str << "(";
956     } else {
957       exception_str << "[";
958     }
959   }
960   if (has_variable) {
961     auto cnode = input->cast_ptr<CNode>();
962     MS_EXCEPTION_IF_NULL(cnode);
963     bool not_variable =
964       (!arg->BuildValue()->ContainsValueAny()) || IsValueNode<prim::DoSignaturePrimitive>(cnode->input(0));
965     for (size_t index = 0; index < arg_tuple_elements.size(); ++index) {
966       auto &element = arg_tuple_elements[index];
967       const auto &inputs = cnode->inputs();
968       if (arg_tuple_elements.size() >= cnode->size()) {
969         MS_LOG(EXCEPTION) << "Size of cnode should be greater than arg_tuple_elements, "
970                           << "but got cnode size: " << cnode->size()
971                           << " arg_tuple_elements size: " << arg_tuple_elements.size();
972       }
973       auto inputs_in_tuple = inputs[index + 1];
974       exception_str << GetExceptionString(element, inputs_in_tuple, key_value, need_symbol, need_comma);
975       if (index != arg_tuple_elements.size() - 1 && need_comma && not_variable) {
976         exception_str << ", ";
977       }
978     }
979   } else {
980     for (size_t index = 0; index < arg_tuple_elements.size(); ++index) {
981       auto &element = arg_tuple_elements[index];
982       exception_str << GetExceptionString(element, input, key_value, need_symbol, need_comma);
983       if (index != arg_tuple_elements.size() - 1 && need_comma) {
984         exception_str << ", ";
985       }
986     }
987   }
988   if (arg_tuple_elements.size() > 1 && !IsPrimitiveCNode(input, prim::kPrimJoinedStr)) {
989     if (is_tuple) {
990       exception_str << ")";
991     } else {
992       exception_str << "]";
993     }
994   }
995   return exception_str.str();
996 }
997 }  // namespace
998 
MakeRaiseKey(const int index)999 std::string MakeRaiseKey(const int index) { return "__internal_error_value" + std::to_string(index) + "__"; }
1000 
CheckNeedSymbol(const AbstractBasePtr & abs)1001 bool CheckNeedSymbol(const AbstractBasePtr &abs) {
1002   MS_EXCEPTION_IF_NULL(abs);
1003   bool need_symbol = false;
1004   if (abs->isa<abstract::AbstractScalar>()) {
1005     need_symbol = CheckIsStr(abs);
1006   } else if (abs->isa<abstract::AbstractSequence>()) {
1007     auto abs_list = abs->cast_ptr<abstract::AbstractSequence>();
1008     MS_EXCEPTION_IF_NULL(abs_list);
1009     const auto &elements = abs_list->elements();
1010     for (auto &element : elements) {
1011       MS_EXCEPTION_IF_NULL(element);
1012       if (element->isa<abstract::AbstractScalar>()) {
1013         need_symbol = CheckIsStr(element);
1014         if (need_symbol) {
1015           return need_symbol;
1016         }
1017       }
1018     }
1019   }
1020   return need_symbol;
1021 }
1022 
GetExceptionString(const AbstractBasePtr & arg,const AnfNodePtr & input,const std::shared_ptr<KeyValueInfo> & key_value,bool need_symbol,bool need_comma)1023 std::string GetExceptionString(const AbstractBasePtr &arg, const AnfNodePtr &input,
1024                                const std::shared_ptr<KeyValueInfo> &key_value, bool need_symbol, bool need_comma) {
1025   std::string exception_str;
1026   MS_EXCEPTION_IF_NULL(arg);
1027   if (arg->isa<abstract::AbstractSequence>() && !IsPrimitiveCNode(input, prim::kPrimGetAttr)) {
1028     return GetTupleOrListString(arg, input, key_value, need_symbol, need_comma);
1029   } else if (arg->BuildValue()->ContainsValueAny() || arg->isa<abstract::AbstractTensor>() ||
1030              IsPrimitiveCNode(input, prim::kPrimGetAttr)) {
1031     exception_str = GetVariable(input, key_value, exception_str, need_symbol);
1032   } else if (arg->isa<abstract::AbstractDictionary>()) {
1033     MS_LOG(EXCEPTION) << "Dictionary type is currently not supporting";
1034   } else if (arg->isa<abstract::AbstractScalar>()) {
1035     // Process raise ValueError
1036     exception_str += GetScalarStringValue(arg);
1037   } else {
1038     MS_LOG(INTERNAL_EXCEPTION) << "Unexpected abstract: " << arg->ToString();
1039   }
1040   return exception_str;
1041 }
1042 
CheckHasVariable(const AbstractBasePtr & arg)1043 bool CheckHasVariable(const AbstractBasePtr &arg) {
1044   if (arg->isa<abstract::AbstractSequence>()) {
1045     auto arg_tuple = arg->cast_ptr<abstract::AbstractSequence>();
1046     MS_EXCEPTION_IF_NULL(arg_tuple);
1047     const auto &arg_tuple_elements = arg_tuple->elements();
1048     if (arg_tuple_elements.size() == 0) {
1049       MS_LOG(INTERNAL_EXCEPTION) << "The arg_tuple_elements can't be empty.";
1050     }
1051     for (size_t index = 0; index < arg_tuple_elements.size(); ++index) {
1052       auto &element = arg_tuple_elements[index];
1053       if (CheckHasVariable(element)) {
1054         return true;
1055       }
1056     }
1057   } else if (arg->BuildValue()->ContainsValueAny() || arg->isa<abstract::AbstractTensor>()) {
1058     return true;
1059   }
1060   return false;
1061 }
1062 
GetExceptionType(const AbstractBasePtr & abs,const AnfNodePtr & node,const std::shared_ptr<KeyValueInfo> & key_value,bool has_variable)1063 std::string GetExceptionType(const AbstractBasePtr &abs, const AnfNodePtr &node,
1064                              const std::shared_ptr<KeyValueInfo> &key_value, bool has_variable) {
1065   MS_EXCEPTION_IF_NULL(node);
1066   auto clt = GetValueNode<ClassTypePtr>(node);
1067   if (clt != nullptr) {
1068     const auto &class_name = clt->name();
1069     auto begin = class_name.find("'") + 1;
1070     auto end = class_name.substr(begin).find("'");
1071     auto class_type = class_name.substr(begin, end);
1072     return class_type;
1073   }
1074   std::string str;
1075   if (abs->isa<abstract::AbstractScalar>()) {
1076     auto scalar = abs->cast_ptr<abstract::AbstractScalar>();
1077     MS_EXCEPTION_IF_NULL(scalar);
1078     auto scalar_value = scalar->BuildValue();
1079     MS_EXCEPTION_IF_NULL(scalar_value);
1080     if (scalar_value->isa<StringImm>()) {
1081       str = GetValue<std::string>(scalar_value);
1082       if (GetValueNode<StringImmPtr>(node) == nullptr && has_variable) {
1083         (void)key_value->keys.emplace_back(NewValueNode(std::make_shared<StringImm>(str)));
1084         (void)key_value->values.emplace_back(node);
1085       }
1086       return str;
1087     }
1088   }
1089   MS_LOG(EXCEPTION) << "The abstract of exception type is not scalar: " << abs->ToString();
1090 }
1091 
1092 namespace {
HasVariableCondition(const FuncGraphPtr & cur_graph,std::vector<FuncGraphPtr> * prev_graph)1093 bool HasVariableCondition(const FuncGraphPtr &cur_graph, std::vector<FuncGraphPtr> *prev_graph) {
1094   if (cur_graph == nullptr) {
1095     return false;
1096   }
1097   if (cur_graph->is_tensor_condition_branch()) {
1098     return true;
1099   }
1100   auto cur_fg_map = cur_graph->func_graph_cnodes_index();
1101   for (auto &cur_fg_use : cur_fg_map) {
1102     auto temp_node = cur_fg_use.first->first->cast<CNodePtr>();
1103     MS_EXCEPTION_IF_NULL(temp_node);
1104     if (std::find(prev_graph->begin(), prev_graph->end(), cur_graph) != prev_graph->end()) {
1105       continue;
1106     }
1107     prev_graph->push_back(cur_graph);
1108     if (HasVariableCondition(temp_node->func_graph(), prev_graph)) {
1109       return true;
1110     }
1111   }
1112   if (HasVariableCondition(cur_graph->parent(), prev_graph)) {
1113     return true;
1114   }
1115   return false;
1116 }
1117 }  // namespace
1118 
HasVariableCondition(const FuncGraphPtr & cur_graph)1119 bool HasVariableCondition(const FuncGraphPtr &cur_graph) {
1120   std::vector<FuncGraphPtr> prev_graph;
1121   return HasVariableCondition(cur_graph, &prev_graph);
1122 }
1123 }  // namespace raiseutils
1124 }  // namespace mindspore
1125