1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/pi/graph_compiler/utils.h"
20 #include <algorithm>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include "include/common/utils/python_adapter.h"
26 #include "pipeline/jit/pi/pydef.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "frontend/operator/ops.h"
29 #include "mindspore/core/ops/sparse_tensor_ops.h"
30 #include "mindspore/core/ops/sequence_ops.h"
31 #include "mindspore/core/ops/comparison_ops.h"
32 #include "mindspore/core/ops/array_ops.h"
33 #include "mindspore/core/ops/math_ops.h"
34 #include "mindspore/core/ops/structure_ops.h"
35 #include "mindspore/core/ops/arithmetic_ops.h"
36 #include "mindspore/core/ops/framework_ops.h"
37 #include "pipeline/jit/ps/parse/data_converter.h"
38 #include "pipeline/jit/ps/resource.h"
39 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
40
41 namespace mindspore {
42 namespace pijit {
43 namespace {
44 // Arg is mutable when it is mutable or it is meta tensor and it is not const
IsMutableArg(const py::object & obj,const ValuePtr & value)45 bool IsMutableArg(const py::object &obj, const ValuePtr &value) {
46 return value->isa<tensor::MetaSparseTensor>() || (value->isa<tensor::MetaTensor>() && !GraphUtils::IsConst(obj)) ||
47 GraphUtils::IsMutable(obj);
48 }
49
IsMetaTensorTuple(const ValuePtr & value)50 bool IsMetaTensorTuple(const ValuePtr &value) {
51 if (!value->isa<ValueTuple>()) {
52 return false;
53 }
54 auto tuple = value->cast<ValueTuplePtr>();
55 for (auto element : tuple->value()) {
56 if (!element->isa<tensor::MetaTensor>() && !IsMetaTensorTuple(element)) {
57 return false;
58 }
59 }
60 return true;
61 }
62
EnableArgBroaden(const py::object & obj,const ValuePtr & value,bool enable_tuple_broaden)63 bool EnableArgBroaden(const py::object &obj, const ValuePtr &value, bool enable_tuple_broaden) {
64 return IsMutableArg(obj, value) || value->isa<tensor::MetaSparseTensor>() ||
65 (value->isa<Scalar>() && MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) ||
66 (enable_tuple_broaden && IsMetaTensorTuple(value));
67 }
68
CheckAndConvertToVariableLenSequence(const py::object & obj,AbstractBasePtr abs)69 void CheckAndConvertToVariableLenSequence(const py::object &obj, AbstractBasePtr abs) {
70 if (!GraphUtils::IsDynamicLength(obj)) {
71 return;
72 }
73 if (!abs->isa<abstract::AbstractSequence>()) {
74 MS_EXCEPTION(TypeError) << "For mutable, when the variable_len the True, the first input should be"
75 << " list or tuple, but got: " << abs->ToString();
76 }
77 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
78 abs_seq->CheckAndConvertToDynamicLenSequence();
79 }
80 } // namespace
81
IsTupleCanBroaden(const py::object & obj)82 bool GraphUtils::IsTupleCanBroaden(const py::object &obj) {
83 if (!py::isinstance<py::tuple>(obj)) {
84 return false;
85 }
86 py::tuple tuple = py::cast<py::tuple>(obj);
87 for (auto item : tuple) {
88 auto elem = py::cast<py::object>(item);
89 if (!py::isinstance<mindspore::tensor::Tensor>(elem) && !IsTupleCanBroaden(elem)) {
90 return false;
91 }
92 }
93 return true;
94 }
95
IsGradForScalar(const py::object & obj)96 bool GraphUtils::IsGradForScalar(const py::object &obj) {
97 return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) &&
98 (py::isinstance<py::int_>(obj) || py::isinstance<py::float_>(obj));
99 }
100
IsTensor(const py::object & obj)101 bool GraphUtils::IsTensor(const py::object &obj) {
102 return py::isinstance<mindspore::tensor::Tensor>(obj) || py::isinstance<mindspore::tensor::CSRTensor>(obj) ||
103 py::isinstance<mindspore::tensor::COOTensor>(obj) || py::isinstance<mindspore::tensor::RowTensor>(obj);
104 }
105
IsEmptyContainer(const py::object & obj)106 bool GraphUtils::IsEmptyContainer(const py::object &obj) {
107 if (!py::isinstance<py::tuple>(obj) && !py::isinstance<py::list>(obj) && !py::isinstance<py::dict>(obj)) {
108 return false;
109 }
110 return py::len(obj) == 0;
111 }
112
ArgsToAbstract(const py::object & arg,const ValuePtr & value,bool enable_tuple_broaden)113 AbstractBasePtr GraphUtils::ArgsToAbstract(const py::object &arg, const ValuePtr &value, bool enable_tuple_broaden) {
114 auto ret = abstract::ToAbstract(value, nullptr, nullptr);
115 if (EnableArgBroaden(arg, value, enable_tuple_broaden)) {
116 ret = AbstractBroaden(ret);
117 }
118 CheckAndConvertToVariableLenSequence(arg, ret);
119 return ret;
120 }
121
GetPrimOrMetaFuncGraph(int op_code)122 AnfNodePtr GraphUtils::GetPrimOrMetaFuncGraph(int op_code) {
123 auto ret = GetPrimitive(op_code);
124 if (ret != nullptr) {
125 return NewValueNode(ret);
126 }
127 return GetMetaFuncGraph(op_code);
128 }
129
GetPrimitive(int op_code)130 PrimitivePtr GraphUtils::GetPrimitive(int op_code) {
131 static std::map<int, PrimitivePtr> op_code_2_prim = {
132 {UNARY_INVERT, prim::kPrimInvert}, {RETURN_VALUE, prim::kPrimReturn},
133 {LIST_TO_TUPLE, prim::kPrimMakeTuple}, {LIST_APPEND, prim::kPrimListAppend},
134 {BUILD_TUPLE, prim::kPrimMakeTuple}, {BUILD_LIST, prim::kPrimMakeList},
135 {BUILD_SET, prim::kPrimMakeList}, {BUILD_MAP, prim::kPrimMakeDict},
136 {BUILD_SLICE, prim::kPrimMakeSlice}, {BUILD_CONST_KEY_MAP, prim::kPrimMakeDict},
137 {BUILD_STRING, prim::kPrimStringConcat}, {LOAD_ATTR, prim::kPrimGetAttr},
138 {LOAD_METHOD, prim::kPrimGetAttr}};
139
140 if (op_code_2_prim.find(op_code) == op_code_2_prim.end()) {
141 return nullptr;
142 }
143
144 return op_code_2_prim.at(op_code);
145 }
146
OpCodeToGraphName(int op_code)147 std::string GraphUtils::OpCodeToGraphName(int op_code) {
148 static std::map<int, std::string> op_code_2_graph_name = {{UNARY_NEGATIVE, "negative"},
149 {UNARY_NOT, "logical_not"},
150 {BINARY_POWER, "pow_"},
151 {BINARY_MULTIPLY, "mul"},
152 {BINARY_MODULO, "mod"},
153 {BINARY_ADD, "add"},
154 {BINARY_SUBTRACT, "sub"},
155 {BINARY_SUBSCR, "getitem"},
156 {BINARY_FLOOR_DIVIDE, "floordiv"},
157 {BINARY_TRUE_DIVIDE, "div"},
158 {INPLACE_FLOOR_DIVIDE, "floordiv"},
159 {INPLACE_TRUE_DIVIDE, "div"},
160 {INPLACE_ADD, "add"},
161 {INPLACE_SUBTRACT, "sub"},
162 {INPLACE_MULTIPLY, "mul"},
163 {INPLACE_MODULO, "mod"},
164 {BINARY_LSHIFT, "left_shift"},
165 {BINARY_RSHIFT, "right_shift"},
166 {BINARY_AND, "bitwise_and"},
167 {BINARY_XOR, "bitwise_xor"},
168 {BINARY_OR, "bitwise_or"},
169 {INPLACE_POWER, "pow"},
170 {INPLACE_LSHIFT, "left_shift"},
171 {INPLACE_RSHIFT, "right_shift"},
172 {INPLACE_AND, "bitwise_and"},
173 {INPLACE_XOR, "bitwise_xor"},
174 {INPLACE_OR, "bitwise_or"},
175 {DICT_MERGE, "add"},
176 {LIST_EXTEND, "add"}};
177 auto iter = op_code_2_graph_name.find(op_code);
178 if (iter == op_code_2_graph_name.end()) {
179 return "";
180 }
181 return iter->second;
182 }
183
OpCompareArgToGraphName(int oparg)184 std::string GraphUtils::OpCompareArgToGraphName(int oparg) {
185 static std::map<int, std::string> compare_arg_2_graph_name = {{Py_LT, "less"}, {Py_LE, "less_equal"},
186 {Py_EQ, "equal"}, {Py_NE, "not_equal"},
187 {Py_GT, "greater"}, {Py_GE, "greater_equal"}};
188 auto iter = compare_arg_2_graph_name.find(oparg);
189 if (iter == compare_arg_2_graph_name.end()) {
190 return "";
191 }
192 return iter->second;
193 }
194
GetMetaFuncGraph(int op_code)195 AnfNodePtr GraphUtils::GetMetaFuncGraph(int op_code) {
196 // MS_EXCEPTION_IF_CHECK_FAIL(op_code_2_graph_name.find(op_code) != op_code_2_graph_name.end(),
197 // "Not find the mutitype ops of OpCode " + std::to_string(op_code) + ".");
198 const auto &graph_name = OpCodeToGraphName(op_code);
199 if (graph_name != "") {
200 return GetMetaFuncGraph(graph_name);
201 }
202 return nullptr;
203 }
204
GetMetaFuncGraph(const std::string & name)205 AnfNodePtr GraphUtils::GetMetaFuncGraph(const std::string &name) {
206 py::object obj = python_adapter::GetPyFn("mindspore.ops.composite.multitype_ops", name);
207 return ConvertPythonObjectToAnfNode(obj);
208 }
209
ConvertPythonObjectToAnfNode(const py::object & object)210 AnfNodePtr GraphUtils::ConvertPythonObjectToAnfNode(const py::object &object) {
211 ValuePtr value = nullptr;
212 bool succ = mindspore::parse::ConvertData(object, &value, python_adapter::UseSignatureInResolve());
213 if (!succ) {
214 MS_LOG(EXCEPTION) << "Convert " << (std::string)py::str(object) << " To AnfNode Fail.";
215 }
216 return NewValueNode(value);
217 }
218
219 } // namespace pijit
220 } // namespace mindspore
221