• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/pi/graph_compiler/utils.h"
20 #include <algorithm>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include "include/common/utils/python_adapter.h"
26 #include "pipeline/jit/pi/pydef.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "frontend/operator/ops.h"
29 #include "mindspore/core/ops/sparse_tensor_ops.h"
30 #include "mindspore/core/ops/sequence_ops.h"
31 #include "mindspore/core/ops/comparison_ops.h"
32 #include "mindspore/core/ops/array_ops.h"
33 #include "mindspore/core/ops/math_ops.h"
34 #include "mindspore/core/ops/structure_ops.h"
35 #include "mindspore/core/ops/arithmetic_ops.h"
36 #include "mindspore/core/ops/framework_ops.h"
37 #include "pipeline/jit/ps/parse/data_converter.h"
38 #include "pipeline/jit/ps/resource.h"
39 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
40 
41 namespace mindspore {
42 namespace pijit {
43 namespace {
44 // Arg is mutable when it is mutable or it is meta tensor and it is not const
IsMutableArg(const py::object & obj,const ValuePtr & value)45 bool IsMutableArg(const py::object &obj, const ValuePtr &value) {
46   return value->isa<tensor::MetaSparseTensor>() || (value->isa<tensor::MetaTensor>() && !GraphUtils::IsConst(obj)) ||
47          GraphUtils::IsMutable(obj);
48 }
49 
IsMetaTensorTuple(const ValuePtr & value)50 bool IsMetaTensorTuple(const ValuePtr &value) {
51   if (!value->isa<ValueTuple>()) {
52     return false;
53   }
54   auto tuple = value->cast<ValueTuplePtr>();
55   for (auto element : tuple->value()) {
56     if (!element->isa<tensor::MetaTensor>() && !IsMetaTensorTuple(element)) {
57       return false;
58     }
59   }
60   return true;
61 }
62 
EnableArgBroaden(const py::object & obj,const ValuePtr & value,bool enable_tuple_broaden)63 bool EnableArgBroaden(const py::object &obj, const ValuePtr &value, bool enable_tuple_broaden) {
64   return IsMutableArg(obj, value) || value->isa<tensor::MetaSparseTensor>() ||
65          (value->isa<Scalar>() && MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) ||
66          (enable_tuple_broaden && IsMetaTensorTuple(value));
67 }
68 
CheckAndConvertToVariableLenSequence(const py::object & obj,AbstractBasePtr abs)69 void CheckAndConvertToVariableLenSequence(const py::object &obj, AbstractBasePtr abs) {
70   if (!GraphUtils::IsDynamicLength(obj)) {
71     return;
72   }
73   if (!abs->isa<abstract::AbstractSequence>()) {
74     MS_EXCEPTION(TypeError) << "For mutable, when the variable_len the True, the first input should be"
75                             << " list or tuple, but got: " << abs->ToString();
76   }
77   auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
78   abs_seq->CheckAndConvertToDynamicLenSequence();
79 }
80 }  // namespace
81 
IsTupleCanBroaden(const py::object & obj)82 bool GraphUtils::IsTupleCanBroaden(const py::object &obj) {
83   if (!py::isinstance<py::tuple>(obj)) {
84     return false;
85   }
86   py::tuple tuple = py::cast<py::tuple>(obj);
87   for (auto item : tuple) {
88     auto elem = py::cast<py::object>(item);
89     if (!py::isinstance<mindspore::tensor::Tensor>(elem) && !IsTupleCanBroaden(elem)) {
90       return false;
91     }
92   }
93   return true;
94 }
95 
IsGradForScalar(const py::object & obj)96 bool GraphUtils::IsGradForScalar(const py::object &obj) {
97   return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) &&
98          (py::isinstance<py::int_>(obj) || py::isinstance<py::float_>(obj));
99 }
100 
IsTensor(const py::object & obj)101 bool GraphUtils::IsTensor(const py::object &obj) {
102   return py::isinstance<mindspore::tensor::Tensor>(obj) || py::isinstance<mindspore::tensor::CSRTensor>(obj) ||
103          py::isinstance<mindspore::tensor::COOTensor>(obj) || py::isinstance<mindspore::tensor::RowTensor>(obj);
104 }
105 
IsEmptyContainer(const py::object & obj)106 bool GraphUtils::IsEmptyContainer(const py::object &obj) {
107   if (!py::isinstance<py::tuple>(obj) && !py::isinstance<py::list>(obj) && !py::isinstance<py::dict>(obj)) {
108     return false;
109   }
110   return py::len(obj) == 0;
111 }
112 
ArgsToAbstract(const py::object & arg,const ValuePtr & value,bool enable_tuple_broaden)113 AbstractBasePtr GraphUtils::ArgsToAbstract(const py::object &arg, const ValuePtr &value, bool enable_tuple_broaden) {
114   auto ret = abstract::ToAbstract(value, nullptr, nullptr);
115   if (EnableArgBroaden(arg, value, enable_tuple_broaden)) {
116     ret = AbstractBroaden(ret);
117   }
118   CheckAndConvertToVariableLenSequence(arg, ret);
119   return ret;
120 }
121 
GetPrimOrMetaFuncGraph(int op_code)122 AnfNodePtr GraphUtils::GetPrimOrMetaFuncGraph(int op_code) {
123   auto ret = GetPrimitive(op_code);
124   if (ret != nullptr) {
125     return NewValueNode(ret);
126   }
127   return GetMetaFuncGraph(op_code);
128 }
129 
GetPrimitive(int op_code)130 PrimitivePtr GraphUtils::GetPrimitive(int op_code) {
131   static std::map<int, PrimitivePtr> op_code_2_prim = {
132     {UNARY_INVERT, prim::kPrimInvert},       {RETURN_VALUE, prim::kPrimReturn},
133     {LIST_TO_TUPLE, prim::kPrimMakeTuple},   {LIST_APPEND, prim::kPrimListAppend},
134     {BUILD_TUPLE, prim::kPrimMakeTuple},     {BUILD_LIST, prim::kPrimMakeList},
135     {BUILD_SET, prim::kPrimMakeList},        {BUILD_MAP, prim::kPrimMakeDict},
136     {BUILD_SLICE, prim::kPrimMakeSlice},     {BUILD_CONST_KEY_MAP, prim::kPrimMakeDict},
137     {BUILD_STRING, prim::kPrimStringConcat}, {LOAD_ATTR, prim::kPrimGetAttr},
138     {LOAD_METHOD, prim::kPrimGetAttr}};
139 
140   if (op_code_2_prim.find(op_code) == op_code_2_prim.end()) {
141     return nullptr;
142   }
143 
144   return op_code_2_prim.at(op_code);
145 }
146 
OpCodeToGraphName(int op_code)147 std::string GraphUtils::OpCodeToGraphName(int op_code) {
148   static std::map<int, std::string> op_code_2_graph_name = {{UNARY_NEGATIVE, "negative"},
149                                                             {UNARY_NOT, "logical_not"},
150                                                             {BINARY_POWER, "pow_"},
151                                                             {BINARY_MULTIPLY, "mul"},
152                                                             {BINARY_MODULO, "mod"},
153                                                             {BINARY_ADD, "add"},
154                                                             {BINARY_SUBTRACT, "sub"},
155                                                             {BINARY_SUBSCR, "getitem"},
156                                                             {BINARY_FLOOR_DIVIDE, "floordiv"},
157                                                             {BINARY_TRUE_DIVIDE, "div"},
158                                                             {INPLACE_FLOOR_DIVIDE, "floordiv"},
159                                                             {INPLACE_TRUE_DIVIDE, "div"},
160                                                             {INPLACE_ADD, "add"},
161                                                             {INPLACE_SUBTRACT, "sub"},
162                                                             {INPLACE_MULTIPLY, "mul"},
163                                                             {INPLACE_MODULO, "mod"},
164                                                             {BINARY_LSHIFT, "left_shift"},
165                                                             {BINARY_RSHIFT, "right_shift"},
166                                                             {BINARY_AND, "bitwise_and"},
167                                                             {BINARY_XOR, "bitwise_xor"},
168                                                             {BINARY_OR, "bitwise_or"},
169                                                             {INPLACE_POWER, "pow"},
170                                                             {INPLACE_LSHIFT, "left_shift"},
171                                                             {INPLACE_RSHIFT, "right_shift"},
172                                                             {INPLACE_AND, "bitwise_and"},
173                                                             {INPLACE_XOR, "bitwise_xor"},
174                                                             {INPLACE_OR, "bitwise_or"},
175                                                             {DICT_MERGE, "add"},
176                                                             {LIST_EXTEND, "add"}};
177   auto iter = op_code_2_graph_name.find(op_code);
178   if (iter == op_code_2_graph_name.end()) {
179     return "";
180   }
181   return iter->second;
182 }
183 
OpCompareArgToGraphName(int oparg)184 std::string GraphUtils::OpCompareArgToGraphName(int oparg) {
185   static std::map<int, std::string> compare_arg_2_graph_name = {{Py_LT, "less"},    {Py_LE, "less_equal"},
186                                                                 {Py_EQ, "equal"},   {Py_NE, "not_equal"},
187                                                                 {Py_GT, "greater"}, {Py_GE, "greater_equal"}};
188   auto iter = compare_arg_2_graph_name.find(oparg);
189   if (iter == compare_arg_2_graph_name.end()) {
190     return "";
191   }
192   return iter->second;
193 }
194 
GetMetaFuncGraph(int op_code)195 AnfNodePtr GraphUtils::GetMetaFuncGraph(int op_code) {
196   // MS_EXCEPTION_IF_CHECK_FAIL(op_code_2_graph_name.find(op_code) != op_code_2_graph_name.end(),
197   //                            "Not find the mutitype ops of OpCode " + std::to_string(op_code) + ".");
198   const auto &graph_name = OpCodeToGraphName(op_code);
199   if (graph_name != "") {
200     return GetMetaFuncGraph(graph_name);
201   }
202   return nullptr;
203 }
204 
GetMetaFuncGraph(const std::string & name)205 AnfNodePtr GraphUtils::GetMetaFuncGraph(const std::string &name) {
206   py::object obj = python_adapter::GetPyFn("mindspore.ops.composite.multitype_ops", name);
207   return ConvertPythonObjectToAnfNode(obj);
208 }
209 
ConvertPythonObjectToAnfNode(const py::object & object)210 AnfNodePtr GraphUtils::ConvertPythonObjectToAnfNode(const py::object &object) {
211   ValuePtr value = nullptr;
212   bool succ = mindspore::parse::ConvertData(object, &value, python_adapter::UseSignatureInResolve());
213   if (!succ) {
214     MS_LOG(EXCEPTION) << "Convert " << (std::string)py::str(object) << " To AnfNode Fail.";
215   }
216   return NewValueNode(value);
217 }
218 
219 }  // namespace pijit
220 }  // namespace mindspore
221