1 /**
2 * Copyright 2023-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/ps/fallback.h"
18
19 #include <algorithm>
20 #include <iostream>
21 #include <memory>
22 #include <regex>
23 #include <string>
24 #include <vector>
25 #include <utility>
26
27 #include "mindspore/core/ops/structure_ops.h"
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "include/common/fallback.h"
31 #include "include/common/utils/python_adapter.h"
32 #include "include/common/utils/convert_utils_py.h"
33 #include "utils/log_adapter.h"
34 #include "utils/ms_context.h"
35 #include "utils/compile_config.h"
36 #include "utils/interpret_node_recorder.h"
37 #include "pipeline/jit/ps/debug/trace.h"
38 #include "pipeline/jit/ps/parse/resolve.h"
39 #include "abstract/abstract_value.h"
40 #include "ir/func_graph.h"
41
42 namespace mindspore {
43 namespace fallback {
44 namespace {
45 // Get the type from python type string, defined in Python module 'mindspore.common.dtype'.
GetTypeFromString(const std::string & dtype)46 TypePtr GetTypeFromString(const std::string &dtype) {
47 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
48 constexpr auto get_dtype_python_function = "get_dtype";
49 auto type = python_adapter::CallPyModFn(mod, get_dtype_python_function, py::str(dtype));
50 MS_LOG(DEBUG) << "type: " << type;
51 if (py::isinstance<py::none>(type)) {
52 return nullptr;
53 }
54 auto type_ptr = py::cast<TypePtr>(type);
55 if (type_ptr == nullptr) {
56 return nullptr;
57 }
58 return type_ptr->Clone();
59 }
60
GetErrorFormatMessage(const AnfNodePtr & node,const std::string & comment)61 std::string GetErrorFormatMessage(const AnfNodePtr &node, const std::string &comment) {
62 std::stringstream err_buf;
63 err_buf << "Wrong comment format for JIT type annotation: '" << comment
64 << "'.\ne.g. '# @jit.typing: () -> tensor_type[int32]' or:"
65 << "\n---\n\tdtype_var = ms.int32\n\t# @jit.typing: () -> tensor_type[{dtype_var}]\n\t...\n---\n\n"
66 << trace::GetDebugInfoStr(node->debug_info());
67 return err_buf.str();
68 }
69
HandleBaseTypeForAnnotation(const std::string & dtype_str,const std::string & container_type_str,const FormatedVariableTypeFunc & format_type_func,const AnfNodePtr & node,const std::string & comment)70 TypePtr HandleBaseTypeForAnnotation(const std::string &dtype_str, const std::string &container_type_str,
71 const FormatedVariableTypeFunc &format_type_func, const AnfNodePtr &node,
72 const std::string &comment) {
73 if (!dtype_str.empty()) {
74 return nullptr;
75 }
76 TypePtr base_type = nullptr;
77 // Handle dtype.
78 if (container_type_str.front() == '{' && container_type_str.back() == '}') { // Handle format variable type.
79 if (!format_type_func) {
80 MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
81 }
82 constexpr auto excluded_size = 2;
83 const auto &variable_base_type = container_type_str.substr(1, container_type_str.size() - excluded_size);
84 // Find variable type.
85 if (!variable_base_type.empty()) {
86 base_type = format_type_func(variable_base_type);
87 if (base_type == nullptr) { // Not throw exception if not match any variable.
88 return nullptr;
89 }
90 }
91 } else { // Handle string type.
92 const auto &base_type_str = container_type_str;
93 base_type = GetTypeFromString(base_type_str);
94 }
95 if (base_type == nullptr) {
96 MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
97 }
98 return base_type;
99 }
100
GetDTypeFromDTypeStr(const std::string & dtype_str,const FormatedVariableTypeFunc & format_type_func,const AnfNodePtr & node,const std::string & comment)101 std::pair<bool, TypePtr> GetDTypeFromDTypeStr(const std::string &dtype_str,
102 const FormatedVariableTypeFunc &format_type_func, const AnfNodePtr &node,
103 const std::string &comment) {
104 TypePtr dtype = nullptr;
105 if (dtype_str.front() == '{' && dtype_str.back() == '}') { // Handle format variable dtype.
106 if (!format_type_func) {
107 MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
108 }
109 constexpr auto excluded_size = 2;
110 const auto &variable_dtype = dtype_str.substr(1, dtype_str.size() - excluded_size);
111 // Find variable dtype.
112 if (!variable_dtype.empty()) {
113 dtype = format_type_func(variable_dtype);
114 if (dtype == nullptr) { // Not throw exception if not match any variable.
115 return std::make_pair(false, nullptr);
116 }
117 }
118 } else { // Handle string dtype.
119 dtype = GetTypeFromString(dtype_str);
120 }
121 return std::make_pair(true, dtype);
122 }
123
HandleContainerTypeForAnnotation(const std::string & dtype_str,const std::string & container_type_str,const FormatedVariableTypeFunc & format_type_func,const AnfNodePtr & node,const std::string & comment)124 TypePtr HandleContainerTypeForAnnotation(const std::string &dtype_str, const std::string &container_type_str,
125 const FormatedVariableTypeFunc &format_type_func, const AnfNodePtr &node,
126 const std::string &comment) {
127 const auto &container_type = GetTypeFromString(container_type_str);
128 if (container_type == nullptr) {
129 MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
130 }
131 if (!container_type->isa<Tuple>() && !container_type->isa<List>() && !container_type->isa<TensorType>()) {
132 MS_LOG(EXCEPTION) << "JIT type annotation only support tensor/list_/tuple_, but got '" << container_type_str;
133 }
134
135 auto [is_match, dtype] = GetDTypeFromDTypeStr(dtype_str, format_type_func, node, comment);
136 if (!is_match) {
137 return nullptr;
138 }
139 if (dtype == nullptr) {
140 MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
141 }
142 if (container_type->isa<TensorType>()) { // Handle tensor type.
143 if (!dtype->isa<Number>()) {
144 MS_LOG(EXCEPTION) << "Cannot get dtype for by input string: '" << dtype_str << "', for '" << container_type_str
145 << "'\n"
146 << trace::GetDebugInfoStr(node->debug_info());
147 }
148 container_type->cast<TensorTypePtr>()->set_element(dtype);
149 } else if (container_type->isa<Tuple>() || container_type->isa<List>()) { // Handle list_/tuple_ type.
150 // To handle nested sequence later.
151 if (!dtype->isa<Number>() && !dtype->isa<TensorType>()) {
152 MS_LOG(EXCEPTION) << "Cannot get element type for by input string: '" << dtype_str << "', for '"
153 << container_type_str << "'\n"
154 << trace::GetDebugInfoStr(node->debug_info());
155 }
156 if (container_type->isa<Tuple>()) {
157 container_type->cast<TuplePtr>()->set_elements(TypePtrList({dtype}));
158 } else if (container_type->isa<List>()) {
159 container_type->cast<ListPtr>()->set_elements(TypePtrList({dtype}));
160 }
161 return nullptr; // Supports tuple_[...] / list_[...] later.
162 }
163 return container_type;
164 }
165 } // namespace
166
CreatePyExecuteCNode(const FuncGraphPtr & fg,const AnfNodePtr & script,const AnfNodePtr & keys,const AnfNodePtr & values,const NodeDebugInfoPtr & debug_info)167 CNodePtr CreatePyExecuteCNode(const FuncGraphPtr &fg, const AnfNodePtr &script, const AnfNodePtr &keys,
168 const AnfNodePtr &values, const NodeDebugInfoPtr &debug_info) {
169 const auto interpreted_cnode = fg->NewCNode({NewValueNode(prim::kPrimPyExecute), script, keys, values});
170 if (debug_info != nullptr) {
171 interpreted_cnode->set_debug_info(debug_info);
172 }
173 // Record the PyExecute node.
174 InterpretNodeRecorder::GetInstance().PushPyExecuteNode(interpreted_cnode);
175 return interpreted_cnode;
176 }
177
CreatePyExecuteCNode(const AnfNodePtr & orig_node,const AnfNodePtr & script,const AnfNodePtr & keys,const AnfNodePtr & values)178 CNodePtr CreatePyExecuteCNode(const AnfNodePtr &orig_node, const AnfNodePtr &script, const AnfNodePtr &keys,
179 const AnfNodePtr &values) {
180 const FuncGraphPtr &fg = orig_node->func_graph();
181 if (fg == nullptr) {
182 MS_LOG(INTERNAL_EXCEPTION) << "The func graph is null. orig_node: " << orig_node->DebugString();
183 }
184 const auto interpreted_cnode = CreatePyExecuteCNode(fg, script, keys, values, orig_node->debug_info());
185 return interpreted_cnode;
186 }
187
CreatePyExecuteCNodeInOrder(const FuncGraphPtr & fg,const AnfNodePtr & script,const AnfNodePtr & keys,const AnfNodePtr & values,const NodeDebugInfoPtr & debug_info)188 CNodePtr CreatePyExecuteCNodeInOrder(const FuncGraphPtr &fg, const AnfNodePtr &script, const AnfNodePtr &keys,
189 const AnfNodePtr &values, const NodeDebugInfoPtr &debug_info) {
190 const auto interpreted_cnode = fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), script, keys, values});
191 interpreted_cnode->set_debug_info(debug_info);
192 // Record the PyExecute node.
193 InterpretNodeRecorder::GetInstance().PushPyExecuteNode(interpreted_cnode);
194 return interpreted_cnode;
195 }
196
CreatePyExecuteCNodeInOrder(const AnfNodePtr & orig_node,const AnfNodePtr & script,const AnfNodePtr & keys,const AnfNodePtr & values)197 CNodePtr CreatePyExecuteCNodeInOrder(const AnfNodePtr &orig_node, const AnfNodePtr &script, const AnfNodePtr &keys,
198 const AnfNodePtr &values) {
199 const FuncGraphPtr &fg = orig_node->func_graph();
200 if (fg == nullptr) {
201 MS_LOG(INTERNAL_EXCEPTION) << "The func graph is null. orig_node: " << orig_node->DebugString();
202 }
203 const auto interpreted_cnode = CreatePyExecuteCNodeInOrder(fg, script, keys, values, orig_node->debug_info());
204 return interpreted_cnode;
205 }
206
CreatePyInterpretCNode(const FuncGraphPtr & fg,const std::string & script_text,const py::object & global_dict_obj,const AnfNodePtr & local_dict_node,const NodeDebugInfoPtr & debug_info)207 CNodePtr CreatePyInterpretCNode(const FuncGraphPtr &fg, const std::string &script_text,
208 const py::object &global_dict_obj, const AnfNodePtr &local_dict_node,
209 const NodeDebugInfoPtr &debug_info) {
210 auto script = std::make_shared<parse::Script>(script_text);
211 auto script_node = NewValueNode(script);
212 parse::PyObjectWrapperPtr global_dict_wrapper = std::make_shared<parse::InterpretedObject>(global_dict_obj);
213 auto global_dict_node = NewValueNode(global_dict_wrapper);
214 auto node = fg->NewCNode({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
215 if (debug_info != nullptr) {
216 node->set_debug_info(debug_info);
217 }
218 InterpretNodeRecorder::GetInstance().PushPyInterpretNode(node);
219 return node;
220 }
221
CreatePyInterpretCNodeInOrder(const FuncGraphPtr & fg,const std::string & script_text,const py::object & global_dict_obj,const AnfNodePtr & local_dict_node,const NodeDebugInfoPtr & debug_info)222 CNodePtr CreatePyInterpretCNodeInOrder(const FuncGraphPtr &fg, const std::string &script_text,
223 const py::object &global_dict_obj, const AnfNodePtr &local_dict_node,
224 const NodeDebugInfoPtr &debug_info) {
225 auto script = std::make_shared<parse::Script>(script_text);
226 auto script_node = NewValueNode(script);
227 parse::PyObjectWrapperPtr global_dict_wrapper = std::make_shared<parse::InterpretedObject>(global_dict_obj);
228 auto global_dict_node = NewValueNode(global_dict_wrapper);
229 auto node =
230 fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
231 if (debug_info != nullptr) {
232 node->set_debug_info(debug_info);
233 }
234 InterpretNodeRecorder::GetInstance().PushPyInterpretNode(node);
235 return node;
236 }
237
SetPyObjectToLocalVariable(const std::string & key,const py::object & value)238 void SetPyObjectToLocalVariable(const std::string &key, const py::object &value) {
239 py::module mod = python_adapter::GetPyModule("mindspore.common._jit_fallback_utils");
240 constexpr auto set_local_variable = "set_local_variable";
241 MS_LOG(DEBUG) << set_local_variable << "([" << key << "]/" << key << ", " << value << ")";
242 (void)python_adapter::CallPyModFn(mod, set_local_variable, key, value);
243 }
244
ConvertPyObjectToPyExecute(const FuncGraphPtr & fg,const std::string & key,const py::object value,const AnfNodePtr & node,bool replace)245 AnfNodePtr ConvertPyObjectToPyExecute(const FuncGraphPtr &fg, const std::string &key, const py::object value,
246 const AnfNodePtr &node, bool replace) {
247 auto value_node_key = ConvertRealStrToUnicodeStr(key, 0);
248 // Set the value node into dict firstly.
249 SetPyObjectToLocalVariable(value_node_key, value);
250
251 // Get the value node from the dict in IR.
252 std::stringstream script_buffer;
253 script_buffer << "__import__('mindspore').common._jit_fallback_utils.get_local_variable(" << value_node_key << ")";
254 const std::string &script = script_buffer.str();
255 const auto script_str = std::make_shared<StringImm>(script);
256
257 // Build new CNode for value node.
258 ValuePtrList keys({std::make_shared<StringImm>(value_node_key)});
259 ValuePtrList values({std::make_shared<StringImm>(value_node_key)});
260 const auto interpreted_cnode =
261 CreatePyExecuteCNode(fg, NewValueNode(script_str), NewValueNode(std::make_shared<ValueTuple>(keys)),
262 NewValueNode(std::make_shared<ValueTuple>(values)), node->debug_info());
263 constexpr auto debug_recursive_level = 2;
264 MS_LOG(DEBUG) << "original node: " << node->DebugString(debug_recursive_level)
265 << ", interpreted_cnode: " << interpreted_cnode->DebugString(debug_recursive_level);
266 if (replace) {
267 fg->ReplaceInOrder(node, interpreted_cnode);
268 }
269 return interpreted_cnode;
270 }
271
ConvertPyObjectToPyInterpret(const FuncGraphPtr & fg,const std::string & key,const py::object value,const AnfNodePtr & node,bool replace)272 AnfNodePtr ConvertPyObjectToPyInterpret(const FuncGraphPtr &fg, const std::string &key, const py::object value,
273 const AnfNodePtr &node, bool replace) {
274 auto value_node_key = ConvertRealStrToUnicodeStr(key, 0);
275 // Set the value node into dict firstly.
276 SetPyObjectToLocalVariable(value_node_key, value);
277
278 // Build the script
279 std::stringstream script_buffer;
280 script_buffer << "__import__('mindspore').common._jit_fallback_utils.get_local_variable(" << value_node_key << ")";
281 const std::string &script = script_buffer.str();
282 auto script_str = std::make_shared<parse::Script>(script);
283 auto script_node = NewValueNode(script_str);
284
285 // Build the global dict.
286 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
287 constexpr auto python_get_dict = "get_global_params";
288 const auto &global_dict = python_adapter::CallPyModFn(mod, python_get_dict);
289 parse::PyObjectWrapperPtr interpreted_global_dict = std::make_shared<parse::InterpretedObject>(global_dict);
290 auto global_dict_node = NewValueNode(interpreted_global_dict);
291
292 // Build the local dict.
293 ValuePtrList local_keys({std::make_shared<StringImm>(value_node_key)});
294 ValuePtrList local_values({std::make_shared<StringImm>(value_node_key)});
295 auto local_key_tuple = NewValueNode(std::make_shared<ValueTuple>(local_keys));
296 auto local_value_tuple = NewValueNode(std::make_shared<ValueTuple>(local_values));
297 auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_tuple, local_value_tuple});
298 auto prim = NewValueNode(prim::kPrimPyInterpret);
299 auto interpret_node = fg->NewCNode({prim, script_node, global_dict_node, local_dict_node});
300 InterpretNodeRecorder::GetInstance().PushPyInterpretNode(interpret_node);
301 if (replace) {
302 fg->ReplaceInOrder(node, interpret_node);
303 }
304 return interpret_node;
305 }
306
ConvertMsClassObjectToPyExecute(const FuncGraphPtr & fg,const ValuePtr & value,const AnfNodePtr & node)307 AnfNodePtr ConvertMsClassObjectToPyExecute(const FuncGraphPtr &fg, const ValuePtr &value, const AnfNodePtr &node) {
308 const auto &ms_class_value = dyn_cast<parse::MsClassObject>(value);
309 if (ms_class_value == nullptr) {
310 return nullptr;
311 }
312 return ConvertPyObjectToPyExecute(fg, ms_class_value->name(), ms_class_value->obj(), node, true);
313 }
314
GetJitAnnotationTypeFromComment(const AnfNodePtr & node,const FormatedVariableTypeFunc & format_type_func)315 TypePtr GetJitAnnotationTypeFromComment(const AnfNodePtr &node, const FormatedVariableTypeFunc &format_type_func) {
316 const auto &debug_info = trace::GetSourceCodeDebugInfo(node->debug_info());
317 const auto &location = debug_info->location();
318 if (location == nullptr) {
319 MS_LOG(INFO) << "Location info is null, node: " << node->DebugString();
320 return nullptr;
321 }
322 const auto &comments = location->comments();
323 if (comments.empty()) {
324 return nullptr;
325 }
326 // Only use the last comment.
327 const auto &comment = comments.back();
328 std::regex regex("^#\\s*@jit.typing\\s*:\\s*\\(\\)\\s*->\\s*([a-zA-Z0-9{}_]+)?\\[?([a-zA-Z0-9{}_]+)?\\]?$");
329 std::smatch matched_results;
330 if (std::regex_match(comment, matched_results, regex)) {
331 constexpr auto container_match_count = 3;
332 // Not match.
333 if (matched_results.size() != container_match_count) {
334 return nullptr;
335 }
336 const auto &container_type_str = matched_results[1].str();
337 const auto &dtype_str = matched_results[container_match_count - 1].str();
338 MS_LOG(DEBUG) << "matched_results: " << matched_results[0] << ", " << container_type_str << ", " << dtype_str;
339 // Match nothing.
340 if (container_type_str.empty()) {
341 MS_LOG(EXCEPTION) << GetErrorFormatMessage(node, comment);
342 }
343 // Handle base type only.
344 auto base_type = HandleBaseTypeForAnnotation(dtype_str, container_type_str, format_type_func, node, comment);
345 if (base_type != nullptr) {
346 return base_type;
347 }
348 // Handle container type: tensor, list_ and tuple_.
349 return HandleContainerTypeForAnnotation(dtype_str, container_type_str, format_type_func, node, comment);
350 }
351 return nullptr;
352 }
353
GetJitAnnotationSideEffectFromComment(const AnfNodePtr & node)354 bool GetJitAnnotationSideEffectFromComment(const AnfNodePtr &node) {
355 MS_EXCEPTION_IF_NULL(node);
356 const auto &debug_info = trace::GetSourceCodeDebugInfo(node->debug_info());
357 const auto &location = debug_info->location();
358 if (location == nullptr) {
359 MS_LOG(DEBUG) << "Location info is null, node: " << node->DebugString();
360 return false;
361 }
362 const auto &comments = location->comments();
363 if (comments.empty()) {
364 return false;
365 }
366 // Only use the last comment.
367 const auto &comment = comments.back();
368 std::regex regex("^#\\s*@jit.typing:\\s*side_effect");
369 if (std::regex_match(comment, regex)) {
370 return true;
371 }
372 return false;
373 }
374
ConvertRealStrToUnicodeStr(const std::string & target,size_t index)375 std::string ConvertRealStrToUnicodeStr(const std::string &target, size_t index) {
376 std::stringstream script_buffer;
377 script_buffer << kPyExecPrefix << std::to_string(index);
378 std::vector<size_t> convert_pos;
379 for (size_t i = 0; i < target.size(); ++i) {
380 auto c = target[i];
381 if (!std::isalnum(c)) {
382 convert_pos.push_back(i);
383 }
384 }
385 size_t start = 0;
386 for (auto end : convert_pos) {
387 std::string sub_non_convert = target.substr(start, end - start);
388 if (sub_non_convert.size() != 0) {
389 script_buffer << kUnderLine << sub_non_convert;
390 }
391 char sub_convert = target[end];
392 std::stringstream hex_s;
393 hex_s << kUnderLine << kHexPrefix << std::hex << static_cast<int>(sub_convert);
394 script_buffer << hex_s.str();
395 start = end + 1;
396 }
397 if (target.substr(start).size() != 0) {
398 script_buffer << kUnderLine << target.substr(start);
399 }
400 script_buffer << kPyExecSuffix;
401 auto unicode_str = script_buffer.str();
402 MS_LOG(DEBUG) << "Get Unicode str: " << unicode_str;
403 return script_buffer.str();
404 }
405
GeneratePyExecuteNodeForCallObj(const FuncGraphPtr & func_graph,const py::object & meta_obj,const AnfNodePtr & node,const std::string & name)406 AnfNodePtr GeneratePyExecuteNodeForCallObj(const FuncGraphPtr &func_graph, const py::object &meta_obj,
407 const AnfNodePtr &node, const std::string &name) {
408 if (py::isinstance<py::none>(meta_obj)) {
409 return nullptr;
410 }
411 auto res = fallback::ConvertPyObjectToPyInterpret(func_graph, name, meta_obj, node, false);
412 // '__keep_metafg_obj_flag__' is to keep metafg obj rather than convert to prim.
413 res->set_user_data("__keep_metafg_obj_flag__", std::make_shared<bool>(true));
414 return res;
415 }
416
ContainsSequenceAnyType(const AbstractBasePtr & abs)417 bool ContainsSequenceAnyType(const AbstractBasePtr &abs) {
418 if (abs == nullptr) {
419 return false;
420 }
421 if (abs->isa<abstract::AbstractSequence>()) {
422 auto seq_abs = abs->cast_ptr<abstract::AbstractSequence>();
423 MS_EXCEPTION_IF_NULL(seq_abs);
424 if (seq_abs->dynamic_len()) {
425 auto element_abs = seq_abs->dynamic_len_element_abs();
426 if (ContainsSequenceAnyType(element_abs)) {
427 return true;
428 }
429 } else {
430 const auto &elements = seq_abs->elements();
431 for (size_t item_index = 0; item_index < elements.size(); ++item_index) {
432 const auto &item_abs = elements[item_index];
433 if (ContainsSequenceAnyType(item_abs)) {
434 return true;
435 }
436 }
437 }
438 }
439 return abs->isa<abstract::AbstractAny>();
440 }
441
SequenceAllElementsIsScalar(const AbstractBasePtr & abs)442 bool SequenceAllElementsIsScalar(const AbstractBasePtr &abs) {
443 if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
444 return false;
445 }
446 auto seq_abs = abs->cast_ptr<abstract::AbstractSequence>();
447 MS_EXCEPTION_IF_NULL(seq_abs);
448 if (seq_abs->dynamic_len()) {
449 auto element_abs = seq_abs->dynamic_len_element_abs();
450 if (element_abs == nullptr || !element_abs->isa<abstract::AbstractScalar>()) {
451 return false;
452 }
453 auto arg_type = element_abs->BuildType();
454 MS_EXCEPTION_IF_NULL(arg_type);
455 return arg_type->isa<Number>();
456 }
457 const auto &elements = seq_abs->elements();
458 for (size_t item_index = 0; item_index < elements.size(); ++item_index) {
459 const auto &item_abs = elements[item_index];
460 if (item_abs == nullptr || !item_abs->isa<abstract::AbstractScalar>()) {
461 return false;
462 }
463 auto item_arg_type = item_abs->BuildType();
464 MS_EXCEPTION_IF_NULL(item_arg_type);
465 if (!item_arg_type->isa<Number>()) {
466 return false;
467 }
468 }
469 return true;
470 }
471
GeneratePyObj(const abstract::AbstractBasePtr & abs)472 py::object GeneratePyObj(const abstract::AbstractBasePtr &abs) {
473 MS_EXCEPTION_IF_NULL(abs);
474 if (abs->isa<abstract::AbstractList>()) {
475 auto abs_list = abs->cast<abstract::AbstractListPtr>();
476 if (HasObjInExtraInfoHolder(abs_list)) {
477 return GetObjFromExtraInfoHolder(abs_list);
478 }
479 py::list ret = py::list(abs_list->size());
480 const auto &elements = abs_list->elements();
481 for (size_t i = 0; i < elements.size(); ++i) {
482 ret[i] = GeneratePyObj(elements[i]);
483 }
484 return ret;
485 } else if (abs->isa<abstract::AbstractTuple>()) {
486 auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
487 py::tuple ret = py::tuple(abs_tuple->size());
488 const auto &elements = abs_tuple->elements();
489 for (size_t i = 0; i < elements.size(); ++i) {
490 ret[i] = GeneratePyObj(elements[i]);
491 }
492 return ret;
493 } else if (abs->isa<abstract::AbstractDictionary>()) {
494 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
495 py::dict ret = py::dict();
496 const auto &key_value_pairs = abs_dict->elements();
497 for (size_t i = 0; i < key_value_pairs.size(); ++i) {
498 py::object key = GeneratePyObj(key_value_pairs[i].first);
499 // The key should be unique.
500 key = py::isinstance<py::none>(key) ? py::str(std::to_string(i)) : key;
501 ret[key] = GeneratePyObj(key_value_pairs[i].second);
502 }
503 return ret;
504 }
505 return ValueToPyData(abs->BuildValue());
506 }
507
EnableFallbackListDictInplace()508 bool EnableFallbackListDictInplace() {
509 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
510 static const auto allow_inplace_ops = common::GetCompileConfig("FALLBACK_SUPPORT_LIST_DICT_INPLACE") != "0";
511 return allow_fallback_runtime && allow_inplace_ops;
512 }
513
AttachPyObjToExtraInfoHolder(const abstract::AbstractBasePtr & abs,const py::object & obj,bool create_in_graph)514 void AttachPyObjToExtraInfoHolder(const abstract::AbstractBasePtr &abs, const py::object &obj, bool create_in_graph) {
515 MS_EXCEPTION_IF_NULL(abs);
516 constexpr auto py_object_key = "py_obj_key";
517 constexpr auto create_in_graph_key = "create_in_graph_key";
518 if (abs->isa<abstract::AbstractList>()) {
519 auto abs_list = abs->cast<abstract::AbstractListPtr>();
520 abs_list->SetData<py::object>(py_object_key, std::make_shared<py::object>(obj));
521 abs_list->SetData<bool>(create_in_graph_key, std::make_shared<bool>(create_in_graph));
522 return;
523 }
524 if (abs->isa<abstract::AbstractDictionary>()) {
525 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
526 abs_dict->SetData<py::object>(py_object_key, std::make_shared<py::object>(obj));
527 abs_dict->SetData<bool>(create_in_graph_key, std::make_shared<bool>(create_in_graph));
528 return;
529 }
530 MS_INTERNAL_EXCEPTION(TypeError) << "The abstract should be a ExtraInfoHolder but got : " << abs->ToString();
531 }
532
GetObjFromExtraInfoHolder(const abstract::AbstractBasePtr & abs)533 py::object GetObjFromExtraInfoHolder(const abstract::AbstractBasePtr &abs) {
534 MS_EXCEPTION_IF_NULL(abs);
535 constexpr auto py_object_key = "py_obj_key";
536 if (abs->isa<abstract::AbstractList>()) {
537 auto abs_list = abs->cast<abstract::AbstractListPtr>();
538 return *abs_list->GetData<py::object>(py_object_key);
539 }
540 if (abs->isa<abstract::AbstractDictionary>()) {
541 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
542 return *abs_dict->GetData<py::object>(py_object_key);
543 }
544 MS_INTERNAL_EXCEPTION(TypeError) << "The abstract should be a ExtraInfoHolder but got : " << abs->ToString();
545 }
546
HasCreateInGraphInExtraInfoHolder(const abstract::AbstractBasePtr & abs)547 bool HasCreateInGraphInExtraInfoHolder(const abstract::AbstractBasePtr &abs) {
548 MS_EXCEPTION_IF_NULL(abs);
549 constexpr auto create_in_graph_key = "create_in_graph_key";
550 if (abs->isa<abstract::AbstractList>()) {
551 auto abs_list = abs->cast<abstract::AbstractListPtr>();
552 return abs_list->HasData(create_in_graph_key);
553 }
554 if (abs->isa<abstract::AbstractDictionary>()) {
555 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
556 return abs_dict->HasData(create_in_graph_key);
557 }
558 return false;
559 }
560
GetCreateInGraphFromExtraInfoHolder(const abstract::AbstractBasePtr & abs)561 bool GetCreateInGraphFromExtraInfoHolder(const abstract::AbstractBasePtr &abs) {
562 MS_EXCEPTION_IF_NULL(abs);
563 constexpr auto create_in_graph_key = "create_in_graph_key";
564 if (abs->isa<abstract::AbstractList>()) {
565 auto abs_list = abs->cast<abstract::AbstractListPtr>();
566 return *abs_list->GetData<bool>(create_in_graph_key);
567 }
568 if (abs->isa<abstract::AbstractDictionary>()) {
569 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
570 return *abs_dict->GetData<bool>(create_in_graph_key);
571 }
572 MS_INTERNAL_EXCEPTION(TypeError) << "The abstract should be a ExtraInfoHolder but got : " << abs->ToString();
573 }
574
HasObjInExtraInfoHolder(const abstract::AbstractBasePtr & abs)575 bool HasObjInExtraInfoHolder(const abstract::AbstractBasePtr &abs) {
576 MS_EXCEPTION_IF_NULL(abs);
577 constexpr auto py_object_key = "py_obj_key";
578 if (abs->isa<abstract::AbstractList>()) {
579 auto abs_list = abs->cast<abstract::AbstractListPtr>();
580 return abs_list->HasData(py_object_key);
581 }
582 if (abs->isa<abstract::AbstractDictionary>()) {
583 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
584 return abs_dict->HasData(py_object_key);
585 }
586 return false;
587 }
588
589 // Nested attach list and dict object to corresponding abstract.
AttachPyObjToAbs(const AbstractBasePtr & abs,const py::object & obj,bool create_in_graph)590 void AttachPyObjToAbs(const AbstractBasePtr &abs, const py::object &obj, bool create_in_graph) {
591 if (!EnableFallbackListDictInplace()) {
592 return;
593 }
594 if (abs->isa<abstract::AbstractNamedTuple>()) {
595 return;
596 }
597 if (!abs->isa<abstract::AbstractSequence>() && !abs->isa<abstract::AbstractDictionary>()) {
598 return;
599 }
600 if (py::hasattr(obj, PYTHON_CELL_AS_LIST) || py::hasattr(obj, PYTHON_CELL_AS_DICT)) {
601 // CellList and CellDict do not support inplace operations, do not need to attach python object.
602 return;
603 }
604 if (abs->isa<abstract::AbstractCSRTensor>() || abs->isa<abstract::AbstractCOOTensor>()) {
605 return;
606 }
607 if (abs->isa<abstract::AbstractList>()) {
608 MS_LOG(DEBUG) << "Attach list python" << obj << " to abstract: " << abs->ToString();
609 if (!py::isinstance<py::list>(obj)) {
610 MS_INTERNAL_EXCEPTION(TypeError) << "Object should be list but got: " << py::str(obj);
611 }
612 auto abs_list = abs->cast<abstract::AbstractListPtr>();
613 AttachPyObjToExtraInfoHolder(abs_list, obj, create_in_graph);
614 auto list_obj = py::list(obj);
615 for (size_t i = 0; i < abs_list->size(); ++i) {
616 auto element_abs = abs_list->elements()[i];
617 auto element_obj = list_obj[i];
618 AttachPyObjToAbs(element_abs, element_obj, create_in_graph);
619 }
620 return;
621 }
622 if (abs->isa<abstract::AbstractDictionary>()) {
623 if (!py::isinstance<py::dict>(obj)) {
624 MS_INTERNAL_EXCEPTION(TypeError) << "Object should be dict but got: " << py::str(obj);
625 }
626 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
627 MS_LOG(DEBUG) << "Attach dict python" << obj << " to abstract: " << abs->ToString();
628 AttachPyObjToExtraInfoHolder(abs_dict, obj, create_in_graph);
629 auto dict_obj = py::dict(obj);
630 auto key_list_obj = py::list(obj);
631 const auto &key_value_pairs = abs_dict->elements();
632 for (size_t i = 0; i < key_value_pairs.size(); ++i) {
633 auto value_abs = key_value_pairs[i].second;
634 auto value_obj = dict_obj[key_list_obj[i]];
635 AttachPyObjToAbs(value_abs, value_obj, create_in_graph);
636 }
637 return;
638 }
639 auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
640 if (!py::isinstance<py::tuple>(obj)) {
641 MS_INTERNAL_EXCEPTION(TypeError) << "Object should be tuple but got: " << py::str(obj);
642 }
643 auto tuple_obj = py::tuple(obj);
644 for (size_t i = 0; i < abs_tuple->size(); ++i) {
645 auto element_abs = abs_tuple->elements()[i];
646 auto element_obj = tuple_obj[i];
647 AttachPyObjToAbs(element_abs, element_obj, create_in_graph);
648 }
649 }
650
GetPyObjectPtrStr(const py::object & obj)651 std::string GetPyObjectPtrStr(const py::object &obj) {
652 std::stringstream ss;
653 ss << obj.ptr();
654 return ss.str();
655 }
656
CheckInterpretInput(const AnfNodePtr & node)657 bool CheckInterpretInput(const AnfNodePtr &node) {
658 MS_EXCEPTION_IF_NULL(node);
659 if (IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
660 return true;
661 }
662 if (node->isa<CNode>()) {
663 auto cnode = node->cast<CNodePtr>();
664 const auto &inputs = cnode->inputs();
665 return std::any_of(inputs.begin(), inputs.end(),
666 [](const AnfNodePtr &input) { return CheckInterpretInput(input); });
667 }
668 return false;
669 }
670
SetPyObjectToNode(const AnfNodePtr & node,const py::object & obj)671 void SetPyObjectToNode(const AnfNodePtr &node, const py::object &obj) {
672 MS_EXCEPTION_IF_NULL(node);
673 if (!EnableFallbackListDictInplace()) {
674 return;
675 }
676 constexpr auto py_obj_str = "__py_object__";
677 if (py::isinstance<py::list>(obj)) {
678 node->set_user_data<py::list>(py_obj_str, std::make_shared<py::list>(py::list(obj)));
679 } else if (py::isinstance<py::tuple>(obj)) {
680 node->set_user_data<py::tuple>(py_obj_str, std::make_shared<py::tuple>(py::tuple(obj)));
681 } else if (py::isinstance<py::dict>(obj)) {
682 node->set_user_data<py::dict>(py_obj_str, std::make_shared<py::dict>(py::dict(obj)));
683 }
684 }
685
HasPyObjectInNode(const AnfNodePtr & node)686 bool HasPyObjectInNode(const AnfNodePtr &node) {
687 MS_EXCEPTION_IF_NULL(node);
688 constexpr auto py_obj_str = "__py_object__";
689 return node->has_user_data(py_obj_str);
690 }
691
GetPyObjectFromNode(const AnfNodePtr & node)692 py::object GetPyObjectFromNode(const AnfNodePtr &node) {
693 MS_EXCEPTION_IF_NULL(node);
694 constexpr auto py_obj_str = "__py_object__";
695 return *node->user_data<py::object>(py_obj_str);
696 }
697
698 // Convert node to pyinterpret with specific function name.
699 // ConvertCNodeToPyInterpretForPrim(prim(x1, x2), func_name)
700 // --->
701 // PyInterpret("func_name(__input1__, __input2__)", global_dict, {"__input1__": x1, "__input2__": x2})
ConvertCNodeToPyInterpretForPrim(const CNodePtr & cnode,const string & name)702 AnfNodePtr ConvertCNodeToPyInterpretForPrim(const CNodePtr &cnode, const string &name) {
703 MS_EXCEPTION_IF_NULL(cnode);
704 const auto &fg = cnode->func_graph();
705 MS_EXCEPTION_IF_NULL(fg);
706 std::stringstream script_buffer;
707 script_buffer << name << "(";
708 const auto &inputs = cnode->inputs();
709 std::vector<AnfNodePtr> keys_tuple_node_inputs{NewValueNode(prim::kPrimMakeTuple)};
710 std::vector<AnfNodePtr> values_tuple_node_inputs{NewValueNode(prim::kPrimMakeTuple)};
711 for (size_t index = 1; index < inputs.size(); ++index) {
712 const auto &internal_arg = fallback::ConvertRealStrToUnicodeStr(name, index);
713 script_buffer << internal_arg << ", ";
714 auto key_node = NewValueNode(std::make_shared<StringImm>(internal_arg));
715 auto value_node = inputs[index];
716 (void)keys_tuple_node_inputs.emplace_back(key_node);
717 (void)values_tuple_node_inputs.emplace_back(value_node);
718 }
719 script_buffer << ")";
720 const std::string &script = script_buffer.str();
721 auto keys_tuple_node = fg->NewCNodeInOrder(keys_tuple_node_inputs);
722 auto values_tuple_node = fg->NewCNodeInOrder(values_tuple_node_inputs);
723 auto local_dict_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), keys_tuple_node, values_tuple_node});
724 auto pyinterpret_node = CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node, cnode->debug_info());
725 MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << pyinterpret_node->DebugString();
726 return pyinterpret_node;
727 }
728
729 // Convert some CNode to PyExectue, eg:
730 // isinstance(xxx.asnumpy(), np.ndarray) -- > PyExectue("isinstance(arg1, arg2)", local_keys, local_values)
ConvertCNodeToPyExecuteForPrim(const CNodePtr & cnode,const string & name)731 AnfNodePtr ConvertCNodeToPyExecuteForPrim(const CNodePtr &cnode, const string &name) {
732 MS_EXCEPTION_IF_NULL(cnode);
733 const auto &fg = cnode->func_graph();
734 MS_EXCEPTION_IF_NULL(fg);
735 std::string script = name + "(";
736 std::string internal_arg;
737 size_t arg_nums = cnode->size() - 1;
738 std::vector<AnfNodePtr> keys_tuple_node_inputs{NewValueNode(prim::kPrimMakeTuple)};
739 std::vector<AnfNodePtr> values_tuple_node_inputs{NewValueNode(prim::kPrimMakeTuple)};
740 for (size_t index = 1; index < arg_nums; ++index) {
741 internal_arg = fallback::ConvertRealStrToUnicodeStr(name, index);
742 script = script + internal_arg + ", ";
743 auto key_node = NewValueNode(std::make_shared<StringImm>(internal_arg));
744 auto value_node = cnode->input(index);
745 (void)keys_tuple_node_inputs.emplace_back(key_node);
746 (void)values_tuple_node_inputs.emplace_back(value_node);
747 }
748 string last_input = fallback::ConvertRealStrToUnicodeStr(name, arg_nums);
749 script = script + last_input + ")";
750 (void)keys_tuple_node_inputs.emplace_back(NewValueNode(std::make_shared<StringImm>(last_input)));
751 (void)values_tuple_node_inputs.emplace_back(cnode->input(arg_nums));
752 auto script_node = NewValueNode(std::make_shared<StringImm>(script));
753 auto keys_tuple_node = fg->NewCNodeInOrder(keys_tuple_node_inputs);
754 auto values_tuple_node = fg->NewCNodeInOrder(values_tuple_node_inputs);
755 auto pyexecute_node =
756 CreatePyExecuteCNodeInOrder(fg, script_node, keys_tuple_node, values_tuple_node, cnode->debug_info());
757 MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << pyexecute_node->DebugString();
758 return pyexecute_node;
759 }
760
GeneratePyInterpretWithAbstract(const FuncGraphPtr & fg,const std::vector<std::string> & funcs_str,const size_t input_size)761 AnfNodePtr GeneratePyInterpretWithAbstract(const FuncGraphPtr &fg, const std::vector<std::string> &funcs_str,
762 const size_t input_size) {
763 AnfNodePtrList node_inputs{NewValueNode(prim::kPrimMakeTuple)};
764 AnfNodePtrList keys_inputs{NewValueNode(prim::kPrimMakeTuple)};
765 std::stringstream script_buffer;
766 for (size_t i = 0; i < funcs_str.size(); ++i) {
767 script_buffer << funcs_str[i] << "(";
768 }
769 for (size_t i = 0; i < input_size; ++i) {
770 const std::string cur_name = "__input_" + std::to_string(i) + "__";
771 script_buffer << cur_name << ",";
772 (void)keys_inputs.emplace_back(NewValueNode(cur_name));
773 (void)node_inputs.emplace_back(fg->add_parameter());
774 }
775 for (size_t i = 0; i < funcs_str.size(); ++i) {
776 script_buffer << ")";
777 }
778 auto script_text = script_buffer.str();
779 auto script = std::make_shared<parse::Script>(script_text);
780 auto script_node = NewValueNode(script);
781 auto global_dict_node = NewValueNode(std::make_shared<parse::InterpretedObject>(py::dict()));
782 auto keys_tuple = fg->NewCNode(keys_inputs);
783 auto values_tuple = fg->NewCNode(node_inputs);
784 auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), keys_tuple, values_tuple});
785 auto ret_node = fg->NewCNode({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
786 return ret_node;
787 }
788
ConvertGetAttrNodeToPyInterpret(const FuncGraphPtr & fg,const CNodePtr & cnode,const std::string & name)789 AnfNodePtr ConvertGetAttrNodeToPyInterpret(const FuncGraphPtr &fg, const CNodePtr &cnode, const std::string &name) {
790 MS_EXCEPTION_IF_NULL(cnode);
791 MS_EXCEPTION_IF_NULL(fg);
792 const std::unordered_map<std::string, std::string> internal_attr_map = {
793 {"__ms_next__", "__import__('mindspore').common._utils._jit_fallback_next_func"}};
794 auto iter = internal_attr_map.find(name);
795 if (iter == internal_attr_map.end()) {
796 return ConvertCNodeToPyInterpretForPrim(cnode, "getattr");
797 }
798 AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
799 AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
800 std::stringstream script_buffer;
801 script_buffer << iter->second << "(";
802
803 const std::string data_str = "__data__";
804 script_buffer << data_str << ")";
805 (void)local_key_inputs.emplace_back(NewValueNode(data_str));
806 constexpr size_t data_index = 1;
807 (void)local_value_inputs.emplace_back(cnode->input(data_index));
808
809 const auto &script = script_buffer.str();
810 auto local_key_node = fg->NewCNode(local_key_inputs);
811 auto local_value_node = fg->NewCNode(local_value_inputs);
812 auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
813
814 auto ret = CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node, cnode->debug_info());
815 MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << ret->DebugString();
816 return ret;
817 }
818
GetPyObjForFuncGraphAbstractClosure(const AbstractBasePtr & abs)819 py::object GetPyObjForFuncGraphAbstractClosure(const AbstractBasePtr &abs) {
820 if (!abs->isa<abstract::FuncGraphAbstractClosure>()) {
821 return py::none();
822 }
823 auto abs_func = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
824 auto fg = abs_func->func_graph();
825 MS_EXCEPTION_IF_NULL(fg);
826 auto wrapper_obj = fg->python_obj();
827 if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
828 auto obj = wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj();
829 return obj;
830 }
831 // Handle lambda expression scene. Graph generated from lambda function does not have attached python object.
832 auto fg_debug_info = fg->debug_info();
833 MS_EXCEPTION_IF_NULL(fg_debug_info);
834 const auto &fg_name = fg_debug_info->name();
835 const std::string lambda_suffix = "_lambda_";
836 bool end_with_lambda_suffix =
837 (fg_name.size() >= lambda_suffix.size() && fg_name.substr(fg_name.size() - lambda_suffix.size()) == lambda_suffix);
838 if (end_with_lambda_suffix) {
839 auto location = trace::GetSourceCodeDebugInfo(fg_debug_info)->location();
840 MS_EXCEPTION_IF_NULL(location);
841 const auto &lambda_script = location->expr_src();
842 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
843 return python_adapter::CallPyModFn(mod, "generate_lambda_object", lambda_script);
844 }
845 return py::none();
846 }
847
GeneratePyInterpretNodeFromMetaFuncGraph(const FuncGraphPtr & func_graph,const AnfNodePtrList & node_inputs,const py::object & meta_obj,const TypePtrList & types,const std::string & name)848 AnfNodePtr GeneratePyInterpretNodeFromMetaFuncGraph(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_inputs,
849 const py::object &meta_obj, const TypePtrList &types,
850 const std::string &name) {
851 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
852 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
853 AnfNodePtr call_node = GeneratePyExecuteNodeForCallObj(func_graph, meta_obj, node_inputs[0], name);
854 auto node_inputs_size = node_inputs.size();
855 std::stringstream script_buffer;
856 if (call_node != nullptr) {
857 (void)key_value_list.emplace_back(call_node);
858 std::string uniname = fallback::ConvertRealStrToUnicodeStr(name, 0);
859 (void)key_value_names_list.push_back(NewValueNode(uniname));
860 script_buffer << uniname << "(";
861 } else {
862 script_buffer << "__import__('mindspore').ops.composite.multitype_ops." << name << "(";
863 }
864 for (size_t i = 0; i < node_inputs_size; i++) {
865 if (types[i]->isa<Slice>()) {
866 (void)key_value_names_list.emplace_back(NewValueNode("__start__"));
867 (void)key_value_names_list.emplace_back(NewValueNode("__stop__"));
868 (void)key_value_names_list.emplace_back(NewValueNode("__step__"));
869 auto start_node =
870 func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), node_inputs[i], NewValueNode("start")});
871 auto end_node =
872 func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), node_inputs[i], NewValueNode("stop")});
873 auto step_node =
874 func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), node_inputs[i], NewValueNode("step")});
875 (void)key_value_list.emplace_back(start_node);
876 (void)key_value_list.emplace_back(end_node);
877 (void)key_value_list.emplace_back(step_node);
878 script_buffer << "slice(__start__,__stop__,__step__)";
879 } else {
880 std::stringstream input_key;
881 input_key << "__input_key_" << i << "__";
882 (void)key_value_names_list.push_back(NewValueNode(input_key.str()));
883 (void)key_value_list.emplace_back(node_inputs[i]);
884 script_buffer << input_key.str();
885 }
886 if (i != node_inputs_size) {
887 script_buffer << ",";
888 }
889 }
890 script_buffer << ")";
891 const auto script_str = script_buffer.str();
892 const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
893 const auto key_value_tuple = func_graph->NewCNode(key_value_list);
894
895 // Generate PyInterpret node with
896 auto local_dict = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), key_value_name_tuple, key_value_tuple});
897 auto res = CreatePyInterpretCNode(func_graph, script_str, py::dict(), local_dict, key_value_name_tuple->debug_info());
898 res->set_user_data(kCheckListDictInplace, std::make_shared<bool>(true));
899 MS_LOG(DEBUG) << "Generate PyInterpret node: " << res->DebugString();
900 return res;
901 }
902 } // namespace fallback
903
904 namespace raiseutils {
905 namespace {
CheckIsStr(const AbstractBasePtr & abs)906 bool CheckIsStr(const AbstractBasePtr &abs) {
907 auto scalar = abs->cast_ptr<abstract::AbstractScalar>();
908 MS_EXCEPTION_IF_NULL(scalar);
909 auto scalar_type = scalar->BuildType();
910 MS_EXCEPTION_IF_NULL(scalar_type);
911 if (scalar_type->IsSameTypeId(String::kTypeId)) {
912 return true;
913 }
914 return false;
915 }
916
GetScalarStringValue(const AbstractBasePtr & abs)917 std::string GetScalarStringValue(const AbstractBasePtr &abs) {
918 MS_EXCEPTION_IF_NULL(abs);
919 auto scalar = abs->cast<abstract::AbstractScalarPtr>();
920 MS_EXCEPTION_IF_NULL(scalar);
921 auto scalar_value = scalar->BuildValue();
922 return scalar_value->ToString();
923 }
924
GetVariable(const AnfNodePtr & input,const std::shared_ptr<KeyValueInfo> & key_value,const std::string & exception_str,bool need_symbol)925 std::string GetVariable(const AnfNodePtr &input, const std::shared_ptr<KeyValueInfo> &key_value,
926 const std::string &exception_str, bool need_symbol) {
927 std::string key = MakeRaiseKey(key_value->num_str);
928 std::stringstream script_buffer;
929 key_value->num_str += 1;
930 if (need_symbol) {
931 script_buffer << exception_str << "'+f'{" << key << "}'+'";
932 } else {
933 script_buffer << exception_str << key;
934 }
935 (void)key_value->keys.emplace_back(NewValueNode(std::make_shared<StringImm>(key)));
936 (void)key_value->values.emplace_back(input);
937 return script_buffer.str();
938 }
939
GetTupleOrListString(const AbstractBasePtr & arg,const AnfNodePtr & input,const std::shared_ptr<KeyValueInfo> & key_value,bool need_symbol,bool need_comma)940 std::string GetTupleOrListString(const AbstractBasePtr &arg, const AnfNodePtr &input,
941 const std::shared_ptr<KeyValueInfo> &key_value, bool need_symbol, bool need_comma) {
942 MS_EXCEPTION_IF_NULL(arg);
943 bool has_variable = CheckHasVariable(arg);
944 std::stringstream exception_str;
945 bool is_tuple = arg->isa<abstract::AbstractTuple>();
946 // Process raise ValueError("str")
947 auto arg_tuple = arg->cast_ptr<abstract::AbstractSequence>();
948 MS_EXCEPTION_IF_NULL(arg_tuple);
949 const auto &arg_tuple_elements = arg_tuple->elements();
950 if (!input->isa<CNode>() && has_variable) {
951 return GetVariable(input, key_value, exception_str.str(), need_symbol);
952 }
953 if (arg_tuple_elements.size() > 1 && !IsPrimitiveCNode(input, prim::kPrimJoinedStr)) {
954 if (is_tuple) {
955 exception_str << "(";
956 } else {
957 exception_str << "[";
958 }
959 }
960 if (has_variable) {
961 auto cnode = input->cast_ptr<CNode>();
962 MS_EXCEPTION_IF_NULL(cnode);
963 bool not_variable =
964 (!arg->BuildValue()->ContainsValueAny()) || IsValueNode<prim::DoSignaturePrimitive>(cnode->input(0));
965 for (size_t index = 0; index < arg_tuple_elements.size(); ++index) {
966 auto &element = arg_tuple_elements[index];
967 const auto &inputs = cnode->inputs();
968 if (arg_tuple_elements.size() >= cnode->size()) {
969 MS_LOG(EXCEPTION) << "Size of cnode should be greater than arg_tuple_elements, "
970 << "but got cnode size: " << cnode->size()
971 << " arg_tuple_elements size: " << arg_tuple_elements.size();
972 }
973 auto inputs_in_tuple = inputs[index + 1];
974 exception_str << GetExceptionString(element, inputs_in_tuple, key_value, need_symbol, need_comma);
975 if (index != arg_tuple_elements.size() - 1 && need_comma && not_variable) {
976 exception_str << ", ";
977 }
978 }
979 } else {
980 for (size_t index = 0; index < arg_tuple_elements.size(); ++index) {
981 auto &element = arg_tuple_elements[index];
982 exception_str << GetExceptionString(element, input, key_value, need_symbol, need_comma);
983 if (index != arg_tuple_elements.size() - 1 && need_comma) {
984 exception_str << ", ";
985 }
986 }
987 }
988 if (arg_tuple_elements.size() > 1 && !IsPrimitiveCNode(input, prim::kPrimJoinedStr)) {
989 if (is_tuple) {
990 exception_str << ")";
991 } else {
992 exception_str << "]";
993 }
994 }
995 return exception_str.str();
996 }
997 } // namespace
998
MakeRaiseKey(const int index)999 std::string MakeRaiseKey(const int index) { return "__internal_error_value" + std::to_string(index) + "__"; }
1000
CheckNeedSymbol(const AbstractBasePtr & abs)1001 bool CheckNeedSymbol(const AbstractBasePtr &abs) {
1002 MS_EXCEPTION_IF_NULL(abs);
1003 bool need_symbol = false;
1004 if (abs->isa<abstract::AbstractScalar>()) {
1005 need_symbol = CheckIsStr(abs);
1006 } else if (abs->isa<abstract::AbstractSequence>()) {
1007 auto abs_list = abs->cast_ptr<abstract::AbstractSequence>();
1008 MS_EXCEPTION_IF_NULL(abs_list);
1009 const auto &elements = abs_list->elements();
1010 for (auto &element : elements) {
1011 MS_EXCEPTION_IF_NULL(element);
1012 if (element->isa<abstract::AbstractScalar>()) {
1013 need_symbol = CheckIsStr(element);
1014 if (need_symbol) {
1015 return need_symbol;
1016 }
1017 }
1018 }
1019 }
1020 return need_symbol;
1021 }
1022
GetExceptionString(const AbstractBasePtr & arg,const AnfNodePtr & input,const std::shared_ptr<KeyValueInfo> & key_value,bool need_symbol,bool need_comma)1023 std::string GetExceptionString(const AbstractBasePtr &arg, const AnfNodePtr &input,
1024 const std::shared_ptr<KeyValueInfo> &key_value, bool need_symbol, bool need_comma) {
1025 std::string exception_str;
1026 MS_EXCEPTION_IF_NULL(arg);
1027 if (arg->isa<abstract::AbstractSequence>() && !IsPrimitiveCNode(input, prim::kPrimGetAttr)) {
1028 return GetTupleOrListString(arg, input, key_value, need_symbol, need_comma);
1029 } else if (arg->BuildValue()->ContainsValueAny() || arg->isa<abstract::AbstractTensor>() ||
1030 IsPrimitiveCNode(input, prim::kPrimGetAttr)) {
1031 exception_str = GetVariable(input, key_value, exception_str, need_symbol);
1032 } else if (arg->isa<abstract::AbstractDictionary>()) {
1033 MS_LOG(EXCEPTION) << "Dictionary type is currently not supporting";
1034 } else if (arg->isa<abstract::AbstractScalar>()) {
1035 // Process raise ValueError
1036 exception_str += GetScalarStringValue(arg);
1037 } else {
1038 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected abstract: " << arg->ToString();
1039 }
1040 return exception_str;
1041 }
1042
CheckHasVariable(const AbstractBasePtr & arg)1043 bool CheckHasVariable(const AbstractBasePtr &arg) {
1044 if (arg->isa<abstract::AbstractSequence>()) {
1045 auto arg_tuple = arg->cast_ptr<abstract::AbstractSequence>();
1046 MS_EXCEPTION_IF_NULL(arg_tuple);
1047 const auto &arg_tuple_elements = arg_tuple->elements();
1048 if (arg_tuple_elements.size() == 0) {
1049 MS_LOG(INTERNAL_EXCEPTION) << "The arg_tuple_elements can't be empty.";
1050 }
1051 for (size_t index = 0; index < arg_tuple_elements.size(); ++index) {
1052 auto &element = arg_tuple_elements[index];
1053 if (CheckHasVariable(element)) {
1054 return true;
1055 }
1056 }
1057 } else if (arg->BuildValue()->ContainsValueAny() || arg->isa<abstract::AbstractTensor>()) {
1058 return true;
1059 }
1060 return false;
1061 }
1062
GetExceptionType(const AbstractBasePtr & abs,const AnfNodePtr & node,const std::shared_ptr<KeyValueInfo> & key_value,bool has_variable)1063 std::string GetExceptionType(const AbstractBasePtr &abs, const AnfNodePtr &node,
1064 const std::shared_ptr<KeyValueInfo> &key_value, bool has_variable) {
1065 MS_EXCEPTION_IF_NULL(node);
1066 auto clt = GetValueNode<ClassTypePtr>(node);
1067 if (clt != nullptr) {
1068 const auto &class_name = clt->name();
1069 auto begin = class_name.find("'") + 1;
1070 auto end = class_name.substr(begin).find("'");
1071 auto class_type = class_name.substr(begin, end);
1072 return class_type;
1073 }
1074 std::string str;
1075 if (abs->isa<abstract::AbstractScalar>()) {
1076 auto scalar = abs->cast_ptr<abstract::AbstractScalar>();
1077 MS_EXCEPTION_IF_NULL(scalar);
1078 auto scalar_value = scalar->BuildValue();
1079 MS_EXCEPTION_IF_NULL(scalar_value);
1080 if (scalar_value->isa<StringImm>()) {
1081 str = GetValue<std::string>(scalar_value);
1082 if (GetValueNode<StringImmPtr>(node) == nullptr && has_variable) {
1083 (void)key_value->keys.emplace_back(NewValueNode(std::make_shared<StringImm>(str)));
1084 (void)key_value->values.emplace_back(node);
1085 }
1086 return str;
1087 }
1088 }
1089 MS_LOG(EXCEPTION) << "The abstract of exception type is not scalar: " << abs->ToString();
1090 }
1091
1092 namespace {
HasVariableCondition(const FuncGraphPtr & cur_graph,std::vector<FuncGraphPtr> * prev_graph)1093 bool HasVariableCondition(const FuncGraphPtr &cur_graph, std::vector<FuncGraphPtr> *prev_graph) {
1094 if (cur_graph == nullptr) {
1095 return false;
1096 }
1097 if (cur_graph->is_tensor_condition_branch()) {
1098 return true;
1099 }
1100 auto cur_fg_map = cur_graph->func_graph_cnodes_index();
1101 for (auto &cur_fg_use : cur_fg_map) {
1102 auto temp_node = cur_fg_use.first->first->cast<CNodePtr>();
1103 MS_EXCEPTION_IF_NULL(temp_node);
1104 if (std::find(prev_graph->begin(), prev_graph->end(), cur_graph) != prev_graph->end()) {
1105 continue;
1106 }
1107 prev_graph->push_back(cur_graph);
1108 if (HasVariableCondition(temp_node->func_graph(), prev_graph)) {
1109 return true;
1110 }
1111 }
1112 if (HasVariableCondition(cur_graph->parent(), prev_graph)) {
1113 return true;
1114 }
1115 return false;
1116 }
1117 } // namespace
1118
HasVariableCondition(const FuncGraphPtr & cur_graph)1119 bool HasVariableCondition(const FuncGraphPtr &cur_graph) {
1120 std::vector<FuncGraphPtr> prev_graph;
1121 return HasVariableCondition(cur_graph, &prev_graph);
1122 }
1123 } // namespace raiseutils
1124 } // namespace mindspore
1125