• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_capture/constant_info.h"
17 #include <set>
18 #include <vector>
19 #include <functional>
20 #include "pipeline/jit/pi/pydef.h"
21 #include "pipeline/jit/pi/graph_capture/node.h"
22 #include "pipeline/jit/pi/graph_capture/graph.h"
23 
24 namespace mindspore {
25 namespace pijit {
26 
27 constexpr const char kModuleName[] = "mindspore";
28 constexpr const char kTensorShapeName[] = "shape";
29 constexpr const char kTensorDtypeName[] = "dtype";
30 constexpr const size_t CmpSize = 2;
31 
32 static void MakePrimitiveConstantInfoCommon(ValueNode *node);
33 
set_value(const py::object & op)34 void ConstantInfo::set_value(const py::object &op) {
35   value_ = op;
36   if (op.ptr() == nullptr) {
37     return;
38   }
39   set_type(Py_TYPE(op.ptr()));
40   if (type() == &PyTuple_Type) {
41     set_len(PyTuple_GET_SIZE(op.ptr()));
42   }
43   if (type() == &PyList_Type) {
44     set_len(PyList_GET_SIZE(op.ptr()));
45   }
46 }
47 
ToString() const48 std::string ConstantInfo::ToString() const {
49   std::stringstream s;
50   if (type() != nullptr) {
51     s << "type=" << (type()->tp_name ? type()->tp_name : "<unnamed>") << ", ";
52   }
53   if (value().ptr() != nullptr) {
54     s << "value=" << std::string(py::str(value().ptr())) << ", ";
55   }
56   if (len() != -1) {
57     s << "len=" << len() << ", ";
58   }
59   for (const auto &i : attrs_) {
60     s << i.first << "=" << std::string(py::str(i.second.ptr())) << ", ";
61   }
62   return s.str();
63 }
64 
IsConstantValue(int op,const std::vector<ValueNode * > & inputs)65 bool IsConstantValue(int op, const std::vector<ValueNode *> &inputs) {
66   static const std::set<int> support_constant_op = {
67     BINARY_SUBSCR, COMPARE_OP, IS_OP,     CONTAINS_OP, LOAD_ATTR,           LIST_TO_TUPLE,
68     BUILD_TUPLE,   BUILD_LIST, BUILD_MAP, BUILD_SLICE, BUILD_CONST_KEY_MAP,
69   };
70   Opcode code_info(op);
71   if (code_info.HasConst()) {
72     return true;
73   }
74   auto iter = std::find_if_not(inputs.begin(), inputs.end(), [](ValueNode *i) { return i->IsConstantValue(); });
75   if (iter != inputs.end()) {
76     return false;
77   }
78   if (support_constant_op.find(op) != support_constant_op.end()) {
79     return true;
80   }
81   if (code_info.IsBinaryMath() && code_info.MayDelete()) {
82     return true;
83   }
84   return false;
85 }
86 
MakeConstantFold(ValueNode * node)87 static void MakeConstantFold(ValueNode *node) {
88   node->SetConstantValue(IsConstantValue(node->GetOpcode(), node->getInputs()));
89 }
90 
MakeCodeConstantInfo(ValueNode * node)91 static void MakeCodeConstantInfo(ValueNode *node) {
92   static const std::map<int, PyTypeObject *> constant_type = {
93     {BUILD_TUPLE, &PyTuple_Type},    {BUILD_LIST, &PyList_Type},        {BUILD_SET, &PySet_Type},
94     {BUILD_MAP, &PyDict_Type},       {BUILD_SLICE, &PySlice_Type},      {BUILD_CONST_KEY_MAP, &PyDict_Type},
95     {BUILD_STRING, &PyUnicode_Type}, {LIST_TO_TUPLE, &PyTuple_Type},    {IS_OP, Py_TYPE(Py_True)},
96     {CONTAINS_OP, Py_TYPE(Py_True)}, {MAKE_FUNCTION, &PyFunction_Type},
97   };
98   static const std::set<int> constant_len = {BUILD_TUPLE, BUILD_LIST, BUILD_SET, BUILD_MAP, BUILD_CONST_KEY_MAP};
99 
100   int opcode = node->GetOpcode();
101   int oparg = node->GetOparg();
102   PyTypeObject *tp = nullptr;
103   Py_ssize_t len = -1;
104   auto iter1 = constant_type.find(opcode);
105   if (iter1 != constant_type.end()) {
106     tp = iter1->second;
107   }
108   if (constant_len.find(opcode) != constant_len.end()) {
109     len = oparg;
110   }
111   if (tp != nullptr || len != -1) {
112     node->MakeConstantInfo()->set_type(tp);
113     node->MakeConstantInfo()->set_len(len);
114   }
115 }
116 
MakeShapeInfoOfTensor(ValueNode * node)117 static void MakeShapeInfoOfTensor(ValueNode *node) {
118   // NOTE: MetaTensor shape is list, mindspore._c_expression.Tensor and mindspore.Tensor is tuple
119   node->MakeConstantInfo()->set_type(&PyTuple_Type);
120 }
121 
MakeDimInfoOfTensor(ValueNode * node)122 static void MakeDimInfoOfTensor(ValueNode *node) {
123   const auto &cnst = node->GetConstantInfo();
124   if (cnst == nullptr) {
125     return;
126   }
127   node->SetConstantValue(cnst->HasAttr(kTensorShapeName));
128 }
129 
MakeConstantInfoOfTensorAttr(ValueNode * node)130 static void MakeConstantInfoOfTensorAttr(ValueNode *node) {
131   const std::string &name = node->GetName();
132   if (name == kTensorShapeName) {
133     MakeShapeInfoOfTensor(node);
134   }
135   if (name == "ndim") {
136     MakeDimInfoOfTensor(node);
137   }
138 }
139 
CheckConstantAttr(ValueNode * node)140 bool CheckConstantAttr(ValueNode *node) {
141   const auto &src_cnst_info = node->input(0)->GetConstantInfo();
142   const std::string &name = node->GetName();
143   if (src_cnst_info != nullptr && src_cnst_info->HasAttr(name)) {
144     node->MakeConstantInfo()->set_value(src_cnst_info->GetAttr(name));
145   }
146 
147   if (node->GetVobj() == nullptr || node->input(0)->GetVobj() == nullptr) {
148     return false;
149   }
150   AObject *src_info = node->input(0)->GetVobj();
151   if (src_info->GetType() == AObject::kTypeTensor) {
152     MakeConstantInfoOfTensorAttr(node);
153     return false;
154   }
155   if (src_info->GetType() == AObject::kTypeModule && src_info->GetPyObject().ptr() != nullptr) {
156     // mindspore module attribute
157     const char *module_name = PyModule_GetName(src_info->GetPyObject().ptr());
158     if (module_name == nullptr) {
159       PyErr_Clear();
160       return false;
161     }
162     return strncmp(module_name, kModuleName, sizeof(kModuleName) - 1) == 0;
163   }
164   return false;
165 }
166 
CheckConstantGlobal(ValueNode * node)167 bool CheckConstantGlobal(ValueNode *node) {
168   const char *module_name = node->GetGraph()->GetModuleName();
169   return strncmp(module_name, kModuleName, sizeof(kModuleName) - 1) == 0;
170 }
171 
CheckConstantIs(ValueNode * node)172 bool CheckConstantIs(ValueNode *node) {
173   const auto &l_cnst_info = node->input(0)->GetConstantInfo();
174   const auto &r_cnst_info = node->input(1)->GetConstantInfo();
175   if (l_cnst_info == nullptr || r_cnst_info == nullptr) {
176     return false;
177   }
178   if (l_cnst_info->type() != nullptr && r_cnst_info->type() != nullptr) {
179     // if type not equal, IS_OP always False
180     return l_cnst_info->type() != r_cnst_info->type();
181   }
182   return false;
183 }
184 
MakeConstantBinary(ValueNode * node)185 bool MakeConstantBinary(ValueNode *node) {
186   AObject *res_info = node->GetVobj();
187   if (res_info == nullptr) {
188     return false;
189   }
190   AObject::Type type = res_info->GetType();
191   if (type != AObject::kTypeTensor) {
192     return false;
193   }
194   const auto &l_cnst = node->input(0)->GetConstantInfo();
195   if (l_cnst == nullptr) {
196     return false;
197   }
198   if (l_cnst->type() != nullptr) {
199     MakePrimitiveConstantInfoCommon(node);
200   }
201   return false;
202 }
203 
MakeConstantBinarySubscr(ValueNode * node)204 bool MakeConstantBinarySubscr(ValueNode *node) {
205   const auto &r_cnst = node->input(1)->GetConstantInfo();
206   if (r_cnst == nullptr || r_cnst->type() == nullptr) {
207     return false;
208   }
209   ValueNode *map_node = node->input(0);
210   if (map_node->GetOpcode() == LOAD_ATTR) {
211     ValueNode *src_node = map_node->input(0);
212     bool is_shape = src_node->GetVobj()->GetType() == AObject::kTypeTensor && map_node->GetName() == kTensorShapeName;
213     if (is_shape && r_cnst->type() == &PyLong_Type) {
214       node->MakeConstantInfo()->set_type(&PyLong_Type);
215       return false;
216     }
217   }
218   const auto &l_cnst = node->input(0)->GetConstantInfo();
219   if (l_cnst == nullptr || l_cnst->type() == nullptr) {
220     return false;
221   }
222   if (r_cnst->type() == &PySlice_Type) {
223     if (l_cnst->type() == &PyTuple_Type || l_cnst->type() == &PyList_Type) {
224       node->MakeConstantInfo()->set_type(l_cnst->type());
225       return false;
226     }
227   }
228   return MakeConstantBinary(node);
229 }
230 
MakeSpecializeConstantValue(ValueNode * node)231 static void MakeSpecializeConstantValue(ValueNode *node) {
232   if (node->IsConstantValue()) {
233     return;
234   }
235   if (Opcode(node->GetOpcode()).IsBinaryMath()) {
236     MakeConstantBinary(node);
237   }
238   static const std::map<int, bool (*)(ValueNode *)> specialize = {
239     {LOAD_ATTR, CheckConstantAttr},   {LOAD_GLOBAL, CheckConstantGlobal},
240     {IS_OP, CheckConstantIs},         {BINARY_SUBSCR, MakeConstantBinarySubscr},
241     {COMPARE_OP, MakeConstantBinary},
242   };
243   auto iter = specialize.find(node->GetOpcode());
244   if (iter == specialize.end()) {
245     return;
246   }
247   if (!iter->second(node)) {
248     return;
249   }
250   node->SetConstantValue(true);
251 }
252 
MakeSpecificConstantInfo(ValueNode * node)253 static void MakeSpecificConstantInfo(ValueNode *node) {
254   if (!node) {
255     return;
256   }
257   // os.environ
258   if (node->GetOpcode() == LOAD_ATTR && node->input(0)->GetVobj() &&
259       node->input(0)->GetVobj()->GetType() == AObject::kTypeModule && node->input(0)->GetVobj()->GetPyObject().ptr()) {
260     auto module_obj = node->input(0)->GetVobj()->GetPyObject().ptr();
261     const std::string &name = node->GetName();
262     const char *module_name = PyModule_GetName(module_obj);
263     if (module_name == nullptr) {
264       PyErr_Clear();
265       return;
266     }
267     if (strncmp(module_name, "os", CmpSize) == 0 && name == "environ") {
268       auto env_obj = PyObject_GetAttrString(module_obj, "environ");
269       node->SetConstantValue(true);
270       node->MakeConstantInfo()->set_value(env_obj);
271       node->SetOpcode(LOAD_CONST);
272       node->SetOparg(-1);
273       node->ClearInputs();
274       return;
275     }
276   }
277 }
278 
CollectConstantInfo(ValueNode * node)279 void ConstantInfo::CollectConstantInfo(ValueNode *node) {
280   MakeConstantFold(node);
281   MakeCodeConstantInfo(node);
282   MakeSpecializeConstantValue(node);
283   MakeSpecificConstantInfo(node);
284 }
285 
MakeConstantInfoOfPrimScalarToTensor(ValueNode * node)286 void MakeConstantInfoOfPrimScalarToTensor(ValueNode *node) {
287   node->MakeConstantInfo()->SetAttr(kTensorShapeName, py::tuple());
288 }
289 
MakeConstantInfoOfPrimCast(ValueNode * node)290 void MakeConstantInfoOfPrimCast(ValueNode *node) {
291   ValueNode *dtype = node->input(2);
292   if (dtype->IsConstantValue()) {
293     node->MakeConstantInfo()->SetAttr(kTensorDtypeName, dtype->GetConstantInfo()->value());
294   }
295 }
296 
MakeConstantInfoOfPrimIsShapeUnKnown(ValueNode * node)297 void MakeConstantInfoOfPrimIsShapeUnKnown(ValueNode *node) {
298   // primitive IsShapeUnKnown only accept tuple and list, pynative mode it's always False
299   node->SetVobj(AObject::Convert(Py_False));
300   node->SetConstantValue(true);
301 }
302 
MakeConvertToMsTensorInfo(ValueNode * node)303 static void MakeConvertToMsTensorInfo(ValueNode *node) {
304   const auto &cnst = node->input(1)->GetConstantInfo();
305   if (cnst == nullptr) {
306     return;
307   }
308   *node->MakeConstantInfo() = *cnst;
309 }
310 
MakeReshapeInfo(ValueNode * node)311 static void MakeReshapeInfo(ValueNode *node) {
312   const auto &shape_cnst = node->input(2)->GetConstantInfo();
313   if (shape_cnst == nullptr || shape_cnst->value().ptr() == nullptr) {
314     return;
315   }
316   const auto &cnst_info = node->input(1)->GetConstantInfo();
317 
318   PyObject *out_shape = shape_cnst->value().ptr();
319   PyObject **begin = PyList_Check(out_shape) ? &PyList_GET_ITEM(out_shape, 0) : &PyTuple_GET_ITEM(out_shape, 0);
320   PyObject **end = begin + (PyList_Check(out_shape) ? PyList_GET_SIZE(out_shape) : PyTuple_GET_SIZE(out_shape));
321   bool is_dynamic_shape = std::any_of(begin, end, [](PyObject *op) { return PyLong_AsLong(op) == -1; });
322   bool is_constant_shape = !is_dynamic_shape || (cnst_info != nullptr && cnst_info->HasAttr(kTensorShapeName));
323   if (is_constant_shape) {
324     py::object cnst_shape = node->GetVobj()->GetPyObject().attr(kTensorShapeName);
325     node->MakeConstantInfo()->SetAttr(kTensorShapeName, cnst_shape);
326   }
327 }
328 
GetConstantPrimitiveMap()329 static const std::map<std::string, void (*)(ValueNode *)> &GetConstantPrimitiveMap() {
330   static const std::map<std::string, void (*)(ValueNode *)> cnst_prim = {
331     {"ScalarToTensor", MakeConstantInfoOfPrimScalarToTensor},
332     {"Cast", MakeConstantInfoOfPrimCast},
333     {"IsShapeUnKnown", MakeConstantInfoOfPrimIsShapeUnKnown},
334     {"Shape", MakeShapeInfoOfTensor},
335     {"Reshape", MakeReshapeInfo},
336     {"ConvertToMsTensor", MakeConvertToMsTensorInfo},
337     {"ConvertToAdapterTensor", MakeConvertToMsTensorInfo},
338   };
339   return cnst_prim;
340 }
341 
MakePrimitiveConstantInfoCommon(ValueNode * node)342 static void MakePrimitiveConstantInfoCommon(ValueNode *node) {
343   AObject *info = node->GetVobj();
344   if (info == nullptr) {
345     return;
346   }
347   // assume primitive return type is always constant !!!
348   const auto &cnst = node->MakeConstantInfo();
349   cnst->set_type(info->GetTypeObject());
350 
351   if (info->GetType() != AObject::kTypeTensor) {
352     return;
353   }
354   // check all inputs tensor shape is constant, other inputs is constant
355   const auto &inputs = node->getInputs();
356   bool constant_shape = std::none_of(inputs.begin(), inputs.end(), [](ValueNode *i) {
357     const auto &cnst = i->GetConstantInfo();
358     if (cnst == nullptr) {
359       return true;
360     }
361     if (i->GetVobj()->GetType() == AObject::kTypeTensor) {
362       return !cnst->HasAttr(kTensorShapeName);
363     }
364     return cnst->value().ptr() != nullptr;
365   });
366   if (constant_shape) {
367     cnst->SetAttr(kTensorShapeName, info->GetPyObject().attr(kTensorShapeName));
368   }
369 }
370 
CollectPrimitiveConstantInfo(CallNode * node)371 void ConstantInfo::CollectPrimitiveConstantInfo(CallNode *node) {
372   MakePrimitiveConstantInfoCommon(node);
373 
374   std::string prim_key = node->input(0)->GetVobj()->GetPyObject().attr("name").cast<std::string>();
375   auto iter = GetConstantPrimitiveMap().find(prim_key);
376   if (iter == GetConstantPrimitiveMap().end()) {
377     return;
378   }
379   iter->second(node);
380 }
381 
CheckConstantLen(ValueNode * node)382 static bool CheckConstantLen(ValueNode *node) {
383   const auto &cnst = node->input(1)->GetConstantInfo();
384   if (cnst == nullptr || cnst->len() == -1) {
385     return false;
386   }
387   PyObject *len = node->GetVobj() ? node->GetVobj()->GetPyObject().ptr() : nullptr;
388   if (len != nullptr) {
389     MS_EXCEPTION_IF_CHECK_FAIL(cnst->len() == PyLong_AsSsize_t(len), "error constant len");
390   } else {
391     node->SetVobj(AObject::Convert(py::int_(cnst->len()).ptr()));
392   }
393   return true;
394 }
395 
CheckConstantInstanceCheck(ValueNode * node)396 static bool CheckConstantInstanceCheck(ValueNode *node) {
397   const auto &c1 = node->input(1)->GetConstantInfo();
398   bool cnst = c1 != nullptr && c1->type() != nullptr;
399   constexpr int second_arg = 2;
400   return cnst && node->input(second_arg)->IsConstantValue();
401 }
402 
GetConstantBuiltinFuncMap()403 static const std::map<PyCFunction, bool (*)(ValueNode *)> &GetConstantBuiltinFuncMap() {
404   using Handler = bool (*)(ValueNode *);
405   static std::map<PyCFunction, Handler> cnst_func = {};
406   static auto func_map_init = [](const char *func_name, Handler handler) {
407     auto func = PyDict_GetItemString(PyEval_GetBuiltins(), func_name);
408     auto cfunc = PyCFunction_GET_FUNCTION(func);
409     cnst_func.insert({cfunc, handler});
410   };
411   if (!cnst_func.empty()) {
412     return cnst_func;
413   }
414   func_map_init("len", CheckConstantLen);
415   func_map_init("isinstance", CheckConstantInstanceCheck);
416   return cnst_func;
417 }
418 
CollectBuiltinFuncConstantInfo(CallNode * node)419 void ConstantInfo::CollectBuiltinFuncConstantInfo(CallNode *node) {
420   MS_EXCEPTION_IF_NULL(node->input(0)->GetVobj()->GetPyObject().ptr());
421   PyObject *func = node->input(0)->GetVobj()->GetPyObject().ptr();
422   if (PyMethod_Check(func)) {
423     func = PyMethod_GET_FUNCTION(func);
424   }
425   if (PyInstanceMethod_Check(func)) {
426     func = PyInstanceMethod_GET_FUNCTION(func);
427   }
428   MS_EXCEPTION_IF_CHECK_FAIL(PyCFunction_Check(func), "must be builtin function or method");
429   PyCFunction cfunc = PyCFunction_GET_FUNCTION(func);
430 
431   auto iter = GetConstantBuiltinFuncMap().find(cfunc);
432   if (iter == GetConstantBuiltinFuncMap().end()) {
433     return;
434   }
435   if (iter->second(node)) {
436     node->SetConstantValue(true);
437   }
438 }
439 
440 }  // namespace pijit
441 }  // namespace mindspore
442