1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_capture/constant_info.h"
17 #include <set>
18 #include <vector>
19 #include <functional>
20 #include "pipeline/jit/pi/pydef.h"
21 #include "pipeline/jit/pi/graph_capture/node.h"
22 #include "pipeline/jit/pi/graph_capture/graph.h"
23
24 namespace mindspore {
25 namespace pijit {
26
27 constexpr const char kModuleName[] = "mindspore";
28 constexpr const char kTensorShapeName[] = "shape";
29 constexpr const char kTensorDtypeName[] = "dtype";
30 constexpr const size_t CmpSize = 2;
31
32 static void MakePrimitiveConstantInfoCommon(ValueNode *node);
33
set_value(const py::object & op)34 void ConstantInfo::set_value(const py::object &op) {
35 value_ = op;
36 if (op.ptr() == nullptr) {
37 return;
38 }
39 set_type(Py_TYPE(op.ptr()));
40 if (type() == &PyTuple_Type) {
41 set_len(PyTuple_GET_SIZE(op.ptr()));
42 }
43 if (type() == &PyList_Type) {
44 set_len(PyList_GET_SIZE(op.ptr()));
45 }
46 }
47
ToString() const48 std::string ConstantInfo::ToString() const {
49 std::stringstream s;
50 if (type() != nullptr) {
51 s << "type=" << (type()->tp_name ? type()->tp_name : "<unnamed>") << ", ";
52 }
53 if (value().ptr() != nullptr) {
54 s << "value=" << std::string(py::str(value().ptr())) << ", ";
55 }
56 if (len() != -1) {
57 s << "len=" << len() << ", ";
58 }
59 for (const auto &i : attrs_) {
60 s << i.first << "=" << std::string(py::str(i.second.ptr())) << ", ";
61 }
62 return s.str();
63 }
64
IsConstantValue(int op,const std::vector<ValueNode * > & inputs)65 bool IsConstantValue(int op, const std::vector<ValueNode *> &inputs) {
66 static const std::set<int> support_constant_op = {
67 BINARY_SUBSCR, COMPARE_OP, IS_OP, CONTAINS_OP, LOAD_ATTR, LIST_TO_TUPLE,
68 BUILD_TUPLE, BUILD_LIST, BUILD_MAP, BUILD_SLICE, BUILD_CONST_KEY_MAP,
69 };
70 Opcode code_info(op);
71 if (code_info.HasConst()) {
72 return true;
73 }
74 auto iter = std::find_if_not(inputs.begin(), inputs.end(), [](ValueNode *i) { return i->IsConstantValue(); });
75 if (iter != inputs.end()) {
76 return false;
77 }
78 if (support_constant_op.find(op) != support_constant_op.end()) {
79 return true;
80 }
81 if (code_info.IsBinaryMath() && code_info.MayDelete()) {
82 return true;
83 }
84 return false;
85 }
86
MakeConstantFold(ValueNode * node)87 static void MakeConstantFold(ValueNode *node) {
88 node->SetConstantValue(IsConstantValue(node->GetOpcode(), node->getInputs()));
89 }
90
MakeCodeConstantInfo(ValueNode * node)91 static void MakeCodeConstantInfo(ValueNode *node) {
92 static const std::map<int, PyTypeObject *> constant_type = {
93 {BUILD_TUPLE, &PyTuple_Type}, {BUILD_LIST, &PyList_Type}, {BUILD_SET, &PySet_Type},
94 {BUILD_MAP, &PyDict_Type}, {BUILD_SLICE, &PySlice_Type}, {BUILD_CONST_KEY_MAP, &PyDict_Type},
95 {BUILD_STRING, &PyUnicode_Type}, {LIST_TO_TUPLE, &PyTuple_Type}, {IS_OP, Py_TYPE(Py_True)},
96 {CONTAINS_OP, Py_TYPE(Py_True)}, {MAKE_FUNCTION, &PyFunction_Type},
97 };
98 static const std::set<int> constant_len = {BUILD_TUPLE, BUILD_LIST, BUILD_SET, BUILD_MAP, BUILD_CONST_KEY_MAP};
99
100 int opcode = node->GetOpcode();
101 int oparg = node->GetOparg();
102 PyTypeObject *tp = nullptr;
103 Py_ssize_t len = -1;
104 auto iter1 = constant_type.find(opcode);
105 if (iter1 != constant_type.end()) {
106 tp = iter1->second;
107 }
108 if (constant_len.find(opcode) != constant_len.end()) {
109 len = oparg;
110 }
111 if (tp != nullptr || len != -1) {
112 node->MakeConstantInfo()->set_type(tp);
113 node->MakeConstantInfo()->set_len(len);
114 }
115 }
116
MakeShapeInfoOfTensor(ValueNode * node)117 static void MakeShapeInfoOfTensor(ValueNode *node) {
118 // NOTE: MetaTensor shape is list, mindspore._c_expression.Tensor and mindspore.Tensor is tuple
119 node->MakeConstantInfo()->set_type(&PyTuple_Type);
120 }
121
MakeDimInfoOfTensor(ValueNode * node)122 static void MakeDimInfoOfTensor(ValueNode *node) {
123 const auto &cnst = node->GetConstantInfo();
124 if (cnst == nullptr) {
125 return;
126 }
127 node->SetConstantValue(cnst->HasAttr(kTensorShapeName));
128 }
129
MakeConstantInfoOfTensorAttr(ValueNode * node)130 static void MakeConstantInfoOfTensorAttr(ValueNode *node) {
131 const std::string &name = node->GetName();
132 if (name == kTensorShapeName) {
133 MakeShapeInfoOfTensor(node);
134 }
135 if (name == "ndim") {
136 MakeDimInfoOfTensor(node);
137 }
138 }
139
CheckConstantAttr(ValueNode * node)140 bool CheckConstantAttr(ValueNode *node) {
141 const auto &src_cnst_info = node->input(0)->GetConstantInfo();
142 const std::string &name = node->GetName();
143 if (src_cnst_info != nullptr && src_cnst_info->HasAttr(name)) {
144 node->MakeConstantInfo()->set_value(src_cnst_info->GetAttr(name));
145 }
146
147 if (node->GetVobj() == nullptr || node->input(0)->GetVobj() == nullptr) {
148 return false;
149 }
150 AObject *src_info = node->input(0)->GetVobj();
151 if (src_info->GetType() == AObject::kTypeTensor) {
152 MakeConstantInfoOfTensorAttr(node);
153 return false;
154 }
155 if (src_info->GetType() == AObject::kTypeModule && src_info->GetPyObject().ptr() != nullptr) {
156 // mindspore module attribute
157 const char *module_name = PyModule_GetName(src_info->GetPyObject().ptr());
158 if (module_name == nullptr) {
159 PyErr_Clear();
160 return false;
161 }
162 return strncmp(module_name, kModuleName, sizeof(kModuleName) - 1) == 0;
163 }
164 return false;
165 }
166
CheckConstantGlobal(ValueNode * node)167 bool CheckConstantGlobal(ValueNode *node) {
168 const char *module_name = node->GetGraph()->GetModuleName();
169 return strncmp(module_name, kModuleName, sizeof(kModuleName) - 1) == 0;
170 }
171
CheckConstantIs(ValueNode * node)172 bool CheckConstantIs(ValueNode *node) {
173 const auto &l_cnst_info = node->input(0)->GetConstantInfo();
174 const auto &r_cnst_info = node->input(1)->GetConstantInfo();
175 if (l_cnst_info == nullptr || r_cnst_info == nullptr) {
176 return false;
177 }
178 if (l_cnst_info->type() != nullptr && r_cnst_info->type() != nullptr) {
179 // if type not equal, IS_OP always False
180 return l_cnst_info->type() != r_cnst_info->type();
181 }
182 return false;
183 }
184
MakeConstantBinary(ValueNode * node)185 bool MakeConstantBinary(ValueNode *node) {
186 AObject *res_info = node->GetVobj();
187 if (res_info == nullptr) {
188 return false;
189 }
190 AObject::Type type = res_info->GetType();
191 if (type != AObject::kTypeTensor) {
192 return false;
193 }
194 const auto &l_cnst = node->input(0)->GetConstantInfo();
195 if (l_cnst == nullptr) {
196 return false;
197 }
198 if (l_cnst->type() != nullptr) {
199 MakePrimitiveConstantInfoCommon(node);
200 }
201 return false;
202 }
203
MakeConstantBinarySubscr(ValueNode * node)204 bool MakeConstantBinarySubscr(ValueNode *node) {
205 const auto &r_cnst = node->input(1)->GetConstantInfo();
206 if (r_cnst == nullptr || r_cnst->type() == nullptr) {
207 return false;
208 }
209 ValueNode *map_node = node->input(0);
210 if (map_node->GetOpcode() == LOAD_ATTR) {
211 ValueNode *src_node = map_node->input(0);
212 bool is_shape = src_node->GetVobj()->GetType() == AObject::kTypeTensor && map_node->GetName() == kTensorShapeName;
213 if (is_shape && r_cnst->type() == &PyLong_Type) {
214 node->MakeConstantInfo()->set_type(&PyLong_Type);
215 return false;
216 }
217 }
218 const auto &l_cnst = node->input(0)->GetConstantInfo();
219 if (l_cnst == nullptr || l_cnst->type() == nullptr) {
220 return false;
221 }
222 if (r_cnst->type() == &PySlice_Type) {
223 if (l_cnst->type() == &PyTuple_Type || l_cnst->type() == &PyList_Type) {
224 node->MakeConstantInfo()->set_type(l_cnst->type());
225 return false;
226 }
227 }
228 return MakeConstantBinary(node);
229 }
230
MakeSpecializeConstantValue(ValueNode * node)231 static void MakeSpecializeConstantValue(ValueNode *node) {
232 if (node->IsConstantValue()) {
233 return;
234 }
235 if (Opcode(node->GetOpcode()).IsBinaryMath()) {
236 MakeConstantBinary(node);
237 }
238 static const std::map<int, bool (*)(ValueNode *)> specialize = {
239 {LOAD_ATTR, CheckConstantAttr}, {LOAD_GLOBAL, CheckConstantGlobal},
240 {IS_OP, CheckConstantIs}, {BINARY_SUBSCR, MakeConstantBinarySubscr},
241 {COMPARE_OP, MakeConstantBinary},
242 };
243 auto iter = specialize.find(node->GetOpcode());
244 if (iter == specialize.end()) {
245 return;
246 }
247 if (!iter->second(node)) {
248 return;
249 }
250 node->SetConstantValue(true);
251 }
252
MakeSpecificConstantInfo(ValueNode * node)253 static void MakeSpecificConstantInfo(ValueNode *node) {
254 if (!node) {
255 return;
256 }
257 // os.environ
258 if (node->GetOpcode() == LOAD_ATTR && node->input(0)->GetVobj() &&
259 node->input(0)->GetVobj()->GetType() == AObject::kTypeModule && node->input(0)->GetVobj()->GetPyObject().ptr()) {
260 auto module_obj = node->input(0)->GetVobj()->GetPyObject().ptr();
261 const std::string &name = node->GetName();
262 const char *module_name = PyModule_GetName(module_obj);
263 if (module_name == nullptr) {
264 PyErr_Clear();
265 return;
266 }
267 if (strncmp(module_name, "os", CmpSize) == 0 && name == "environ") {
268 auto env_obj = PyObject_GetAttrString(module_obj, "environ");
269 node->SetConstantValue(true);
270 node->MakeConstantInfo()->set_value(env_obj);
271 node->SetOpcode(LOAD_CONST);
272 node->SetOparg(-1);
273 node->ClearInputs();
274 return;
275 }
276 }
277 }
278
CollectConstantInfo(ValueNode * node)279 void ConstantInfo::CollectConstantInfo(ValueNode *node) {
280 MakeConstantFold(node);
281 MakeCodeConstantInfo(node);
282 MakeSpecializeConstantValue(node);
283 MakeSpecificConstantInfo(node);
284 }
285
MakeConstantInfoOfPrimScalarToTensor(ValueNode * node)286 void MakeConstantInfoOfPrimScalarToTensor(ValueNode *node) {
287 node->MakeConstantInfo()->SetAttr(kTensorShapeName, py::tuple());
288 }
289
MakeConstantInfoOfPrimCast(ValueNode * node)290 void MakeConstantInfoOfPrimCast(ValueNode *node) {
291 ValueNode *dtype = node->input(2);
292 if (dtype->IsConstantValue()) {
293 node->MakeConstantInfo()->SetAttr(kTensorDtypeName, dtype->GetConstantInfo()->value());
294 }
295 }
296
MakeConstantInfoOfPrimIsShapeUnKnown(ValueNode * node)297 void MakeConstantInfoOfPrimIsShapeUnKnown(ValueNode *node) {
298 // primitive IsShapeUnKnown only accept tuple and list, pynative mode it's always False
299 node->SetVobj(AObject::Convert(Py_False));
300 node->SetConstantValue(true);
301 }
302
MakeConvertToMsTensorInfo(ValueNode * node)303 static void MakeConvertToMsTensorInfo(ValueNode *node) {
304 const auto &cnst = node->input(1)->GetConstantInfo();
305 if (cnst == nullptr) {
306 return;
307 }
308 *node->MakeConstantInfo() = *cnst;
309 }
310
MakeReshapeInfo(ValueNode * node)311 static void MakeReshapeInfo(ValueNode *node) {
312 const auto &shape_cnst = node->input(2)->GetConstantInfo();
313 if (shape_cnst == nullptr || shape_cnst->value().ptr() == nullptr) {
314 return;
315 }
316 const auto &cnst_info = node->input(1)->GetConstantInfo();
317
318 PyObject *out_shape = shape_cnst->value().ptr();
319 PyObject **begin = PyList_Check(out_shape) ? &PyList_GET_ITEM(out_shape, 0) : &PyTuple_GET_ITEM(out_shape, 0);
320 PyObject **end = begin + (PyList_Check(out_shape) ? PyList_GET_SIZE(out_shape) : PyTuple_GET_SIZE(out_shape));
321 bool is_dynamic_shape = std::any_of(begin, end, [](PyObject *op) { return PyLong_AsLong(op) == -1; });
322 bool is_constant_shape = !is_dynamic_shape || (cnst_info != nullptr && cnst_info->HasAttr(kTensorShapeName));
323 if (is_constant_shape) {
324 py::object cnst_shape = node->GetVobj()->GetPyObject().attr(kTensorShapeName);
325 node->MakeConstantInfo()->SetAttr(kTensorShapeName, cnst_shape);
326 }
327 }
328
GetConstantPrimitiveMap()329 static const std::map<std::string, void (*)(ValueNode *)> &GetConstantPrimitiveMap() {
330 static const std::map<std::string, void (*)(ValueNode *)> cnst_prim = {
331 {"ScalarToTensor", MakeConstantInfoOfPrimScalarToTensor},
332 {"Cast", MakeConstantInfoOfPrimCast},
333 {"IsShapeUnKnown", MakeConstantInfoOfPrimIsShapeUnKnown},
334 {"Shape", MakeShapeInfoOfTensor},
335 {"Reshape", MakeReshapeInfo},
336 {"ConvertToMsTensor", MakeConvertToMsTensorInfo},
337 {"ConvertToAdapterTensor", MakeConvertToMsTensorInfo},
338 };
339 return cnst_prim;
340 }
341
MakePrimitiveConstantInfoCommon(ValueNode * node)342 static void MakePrimitiveConstantInfoCommon(ValueNode *node) {
343 AObject *info = node->GetVobj();
344 if (info == nullptr) {
345 return;
346 }
347 // assume primitive return type is always constant !!!
348 const auto &cnst = node->MakeConstantInfo();
349 cnst->set_type(info->GetTypeObject());
350
351 if (info->GetType() != AObject::kTypeTensor) {
352 return;
353 }
354 // check all inputs tensor shape is constant, other inputs is constant
355 const auto &inputs = node->getInputs();
356 bool constant_shape = std::none_of(inputs.begin(), inputs.end(), [](ValueNode *i) {
357 const auto &cnst = i->GetConstantInfo();
358 if (cnst == nullptr) {
359 return true;
360 }
361 if (i->GetVobj()->GetType() == AObject::kTypeTensor) {
362 return !cnst->HasAttr(kTensorShapeName);
363 }
364 return cnst->value().ptr() != nullptr;
365 });
366 if (constant_shape) {
367 cnst->SetAttr(kTensorShapeName, info->GetPyObject().attr(kTensorShapeName));
368 }
369 }
370
CollectPrimitiveConstantInfo(CallNode * node)371 void ConstantInfo::CollectPrimitiveConstantInfo(CallNode *node) {
372 MakePrimitiveConstantInfoCommon(node);
373
374 std::string prim_key = node->input(0)->GetVobj()->GetPyObject().attr("name").cast<std::string>();
375 auto iter = GetConstantPrimitiveMap().find(prim_key);
376 if (iter == GetConstantPrimitiveMap().end()) {
377 return;
378 }
379 iter->second(node);
380 }
381
CheckConstantLen(ValueNode * node)382 static bool CheckConstantLen(ValueNode *node) {
383 const auto &cnst = node->input(1)->GetConstantInfo();
384 if (cnst == nullptr || cnst->len() == -1) {
385 return false;
386 }
387 PyObject *len = node->GetVobj() ? node->GetVobj()->GetPyObject().ptr() : nullptr;
388 if (len != nullptr) {
389 MS_EXCEPTION_IF_CHECK_FAIL(cnst->len() == PyLong_AsSsize_t(len), "error constant len");
390 } else {
391 node->SetVobj(AObject::Convert(py::int_(cnst->len()).ptr()));
392 }
393 return true;
394 }
395
CheckConstantInstanceCheck(ValueNode * node)396 static bool CheckConstantInstanceCheck(ValueNode *node) {
397 const auto &c1 = node->input(1)->GetConstantInfo();
398 bool cnst = c1 != nullptr && c1->type() != nullptr;
399 constexpr int second_arg = 2;
400 return cnst && node->input(second_arg)->IsConstantValue();
401 }
402
GetConstantBuiltinFuncMap()403 static const std::map<PyCFunction, bool (*)(ValueNode *)> &GetConstantBuiltinFuncMap() {
404 using Handler = bool (*)(ValueNode *);
405 static std::map<PyCFunction, Handler> cnst_func = {};
406 static auto func_map_init = [](const char *func_name, Handler handler) {
407 auto func = PyDict_GetItemString(PyEval_GetBuiltins(), func_name);
408 auto cfunc = PyCFunction_GET_FUNCTION(func);
409 cnst_func.insert({cfunc, handler});
410 };
411 if (!cnst_func.empty()) {
412 return cnst_func;
413 }
414 func_map_init("len", CheckConstantLen);
415 func_map_init("isinstance", CheckConstantInstanceCheck);
416 return cnst_func;
417 }
418
CollectBuiltinFuncConstantInfo(CallNode * node)419 void ConstantInfo::CollectBuiltinFuncConstantInfo(CallNode *node) {
420 MS_EXCEPTION_IF_NULL(node->input(0)->GetVobj()->GetPyObject().ptr());
421 PyObject *func = node->input(0)->GetVobj()->GetPyObject().ptr();
422 if (PyMethod_Check(func)) {
423 func = PyMethod_GET_FUNCTION(func);
424 }
425 if (PyInstanceMethod_Check(func)) {
426 func = PyInstanceMethod_GET_FUNCTION(func);
427 }
428 MS_EXCEPTION_IF_CHECK_FAIL(PyCFunction_Check(func), "must be builtin function or method");
429 PyCFunction cfunc = PyCFunction_GET_FUNCTION(func);
430
431 auto iter = GetConstantBuiltinFuncMap().find(cfunc);
432 if (iter == GetConstantBuiltinFuncMap().end()) {
433 return;
434 }
435 if (iter->second(node)) {
436 node->SetConstantValue(true);
437 }
438 }
439
440 } // namespace pijit
441 } // namespace mindspore
442