1 /**
2 * Copyright 2023-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_FALLBACK_H_
18 #define MINDSPORE_CCSRC_PIPELINE_JIT_FALLBACK_H_
19
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <unordered_map>
24
25 #include "ir/anf.h"
26 #include "ir/dtype/type.h"
27 #include "abstract/abstract_value.h"
28 #include "include/common/utils/python_adapter.h"
29 #include "pipeline/jit/ps/parse/resolve.h"
30
31 namespace mindspore {
32 namespace fallback {
33 constexpr auto kPyExecPrefix = "__py_exec_index";
34 constexpr auto kPyExecSuffix = "__";
35 constexpr auto kUnderLine = "_";
36 constexpr auto kHexPrefix = "0x";
37 constexpr auto kObjectAttrChange = "object_attr_change";
38 constexpr auto kCheckListDictInplace = "check_list_dict_inplace";
39 constexpr auto kLocalDictCheck = "local_dict_check";
40 constexpr auto kIsAdapter = "is_adapter";
41 constexpr auto kAdapterTensor = "adapter_tensor";
42
43 // Create a PyExecute CNode by old node or debug_info.
44 CNodePtr CreatePyExecuteCNode(const FuncGraphPtr &fg, const AnfNodePtr &script, const AnfNodePtr &keys,
45 const AnfNodePtr &values, const NodeDebugInfoPtr &debug_info);
46 CNodePtr CreatePyExecuteCNode(const AnfNodePtr &orig_node, const AnfNodePtr &script, const AnfNodePtr &keys,
47 const AnfNodePtr &values);
48 CNodePtr CreatePyExecuteCNodeInOrder(const FuncGraphPtr &fg, const AnfNodePtr &script, const AnfNodePtr &keys,
49 const AnfNodePtr &values, const NodeDebugInfoPtr &debug_info);
50 CNodePtr CreatePyExecuteCNodeInOrder(const AnfNodePtr &orig_node, const AnfNodePtr &script, const AnfNodePtr &keys,
51 const AnfNodePtr &values);
52 // Create a PyInterpret CNode by old node or debug_info.
53 CNodePtr CreatePyInterpretCNode(const FuncGraphPtr &fg, const std::string &script_text,
54 const py::object &global_dict_obj, const AnfNodePtr &local_dict_node,
55 const NodeDebugInfoPtr &debug_info = nullptr);
56 CNodePtr CreatePyInterpretCNodeInOrder(const FuncGraphPtr &fg, const std::string &script_text,
57 const py::object &global_dict_obj, const AnfNodePtr &local_dict_node,
58 const NodeDebugInfoPtr &debug_info = nullptr);
59
60 // Create primitive cnode to PyInterpret/PyExecute node with specific function name.
61 AnfNodePtr ConvertCNodeToPyInterpretForPrim(const CNodePtr &cnode, const string &name);
62 AnfNodePtr ConvertCNodeToPyExecuteForPrim(const CNodePtr &cnode, const string &name);
63
64 // Create PyInterpret node according to input abstract size and corresponding function name.
65 AnfNodePtr GeneratePyInterpretWithAbstract(const FuncGraphPtr &fg, const std::vector<std::string> &funcs_str,
66 const size_t input_size);
67
68 // Generate PyInterpret node for meta function graph.
69 AnfNodePtr GeneratePyInterpretNodeFromMetaFuncGraph(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_inputs,
70 const py::object &meta_obj, const TypePtrList &types,
71 const std::string &name);
72
73 // Convert Python object to PyInterpret/PyExecute node.
74 AnfNodePtr ConvertPyObjectToPyExecute(const FuncGraphPtr &fg, const std::string &key, const py::object value,
75 const AnfNodePtr &node, bool replace);
76 AnfNodePtr ConvertPyObjectToPyInterpret(const FuncGraphPtr &fg, const std::string &key, const py::object value,
77 const AnfNodePtr &node, bool replace);
78 AnfNodePtr ConvertMsClassObjectToPyExecute(const FuncGraphPtr &fg, const ValuePtr &value, const AnfNodePtr &node);
79
80 // Convert GetAttr node to PyInterpret/PyExecute.
81 AnfNodePtr ConvertGetAttrNodeToPyInterpret(const FuncGraphPtr &fg, const CNodePtr &cnode, const std::string &name);
82
83 // Get Python object from abstract function.
84 py::object GetPyObjForFuncGraphAbstractClosure(const AbstractBasePtr &abs);
85
86 // Function about jit annotation.
87 using FormatedVariableTypeFunc = std::function<TypePtr(const std::string &)>;
88 TypePtr GetJitAnnotationTypeFromComment(const AnfNodePtr &node,
89 const FormatedVariableTypeFunc &format_type_func = FormatedVariableTypeFunc());
90 bool GetJitAnnotationSideEffectFromComment(const AnfNodePtr &node);
91 bool ContainsSequenceAnyType(const AbstractBasePtr &abs);
92 bool SequenceAllElementsIsScalar(const AbstractBasePtr &abs);
93 std::string ConvertRealStrToUnicodeStr(const std::string &target, size_t index);
94 std::string GetPyObjectPtrStr(const py::object &obj);
95
96 // Check whether the node contains PyInterpret input.
97 bool CheckInterpretInput(const AnfNodePtr &node);
98
99 // Function about list/dict inplace operation.
100 bool EnableFallbackListDictInplace();
101 // Generate python object according to abstract.
102 py::object GeneratePyObj(const abstract::AbstractBasePtr &abs);
103 // Handle python object for abstract using ExtraInfoHolder.
104 void AttachPyObjToExtraInfoHolder(const abstract::AbstractBasePtr &abs, const py::object &obj, bool create_in_graph);
105 bool HasObjInExtraInfoHolder(const abstract::AbstractBasePtr &abs);
106 py::object GetObjFromExtraInfoHolder(const abstract::AbstractBasePtr &abs);
107 bool HasCreateInGraphInExtraInfoHolder(const abstract::AbstractBasePtr &abs);
108 bool GetCreateInGraphFromExtraInfoHolder(const abstract::AbstractBasePtr &abs);
109 // Attach python object to abstract recursively using ExtraInfoHolder.
110 void AttachPyObjToAbs(const AbstractBasePtr &abs, const py::object &obj, bool create_in_graph);
111 // Handle python object for AnfNode.
112 void SetPyObjectToNode(const AnfNodePtr &node, const py::object &obj);
113 bool HasPyObjectInNode(const AnfNodePtr &node);
114 void SetPyObjectToLocalVariable(const std::string &key, const py::object &value);
115 py::object GetPyObjectFromNode(const AnfNodePtr &node);
116
117 template <typename T>
HasRealType(const std::shared_ptr<T> & owner)118 bool HasRealType(const std::shared_ptr<T> &owner) {
119 return owner->has_user_data("__py_execute_real_type__");
120 }
121
122 template <typename T, typename U>
SetRealType(const std::shared_ptr<T> & owner,const std::shared_ptr<U> & data)123 void SetRealType(const std::shared_ptr<T> &owner, const std::shared_ptr<U> &data) {
124 owner->template set_user_data<U>("__py_execute_real_type__", data);
125 }
126
127 template <typename T, typename U>
GetRealType(const std::shared_ptr<T> & owner)128 std::shared_ptr<U> GetRealType(const std::shared_ptr<T> &owner) {
129 return owner->template user_data<U>("__py_execute_real_type__");
130 }
131
132 template <typename T>
HasRealShape(const std::shared_ptr<T> & owner)133 bool HasRealShape(const std::shared_ptr<T> &owner) {
134 return owner->has_user_data("__py_execute_real_shape__");
135 }
136
137 template <typename T, typename U>
SetRealShape(const std::shared_ptr<T> & owner,const std::shared_ptr<U> & data)138 void SetRealShape(const std::shared_ptr<T> &owner, const std::shared_ptr<U> &data) {
139 owner->template set_user_data<U>("__py_execute_real_shape__", data);
140 }
141
142 template <typename T, typename U>
GetRealShape(const std::shared_ptr<T> & owner)143 std::shared_ptr<U> GetRealShape(const std::shared_ptr<T> &owner) {
144 return owner->template user_data<U>("__py_execute_real_shape__");
145 }
146 } // namespace fallback
147
148 namespace raiseutils {
149 using ClassTypePtr = std::shared_ptr<parse::ClassType>;
150
151 struct KeyValueInfo {
152 int num_str = 0;
153 std::vector<AnfNodePtr> keys;
154 std::vector<AnfNodePtr> values;
155 };
156
157 std::string GetExceptionType(const AbstractBasePtr &abs, const AnfNodePtr &cnode,
158 const std::shared_ptr<KeyValueInfo> &key_value, bool has_variable = true);
159
160 bool CheckHasVariable(const AbstractBasePtr &arg);
161
162 std::string GetExceptionString(const AbstractBasePtr &arg, const AnfNodePtr &input,
163 const std::shared_ptr<KeyValueInfo> &key_value, bool need_symbol = false,
164 bool need_comma = false);
165
166 bool CheckNeedSymbol(const AbstractBasePtr &abs);
167
168 std::string MakeRaiseKey(int index);
169
170 bool HasVariableCondition(const FuncGraphPtr &cur_graph);
171 } // namespace raiseutils
172 } // namespace mindspore
173
174 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_FALLBACK_H_
175