• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_guard/shape_ctx.h"
17 #include <algorithm>
18 #include <map>
19 #include "ir/tensor.h"
20 
21 namespace py = pybind11;
22 
23 namespace mindspore {
24 namespace pijit {
25 
ShapeContext(PyFrameObject * f,PyObject * signature)26 ShapeContext::ShapeContext(PyFrameObject *f, PyObject *signature)
27     : frame_(f), signature_(signature), is_method_(false), applied_(false) {
28   Py_XINCREF(f);
29   Py_XINCREF(signature);
30   if (signature != nullptr) {
31     if (!PyTuple_Check(signature) && !PyList_Check(signature)) {
32       auto tuple = PyTuple_New(1);
33       PyTuple_SET_ITEM(tuple, 0, signature);
34       Py_DECREF(signature);
35       signature_ = tuple;
36     } else if (!PyTuple_Check(signature)) {
37       int size = PyList_Size(signature);
38       auto tuple = PyTuple_New(size);
39       for (int i = 0; i < size; ++i) {
40         auto item = PyList_GetItem(signature, i);
41         Py_XINCREF(item);
42         PyTuple_SET_ITEM(tuple, i, item);
43       }
44       Py_DECREF(signature);
45       signature_ = tuple;
46     }
47     int argc = frame_->f_code->co_argcount + frame_->f_code->co_kwonlyargcount;
48     is_method_ = (argc == (PyTuple_GET_SIZE(signature_) + 1)) ? true : false;
49     std::vector<PyObject *> locals(&(frame_->f_localsplus[is_method_ ? 1 : 0]), &(frame_->f_localsplus[argc]));
50     origin_ = locals;
51   }
52 }
53 
~ShapeContext()54 ShapeContext::~ShapeContext() {
55   RevertSignature();
56   Py_XDECREF(frame_);
57   Py_XDECREF(signature_);
58 }
59 
60 static constexpr int64_t kDynamicDim = -2;
61 static constexpr int64_t kDynamicShape = -1;
62 
IsShapeUnknown(mindspore::tensor::TensorPtr tensor)63 static bool IsShapeUnknown(mindspore::tensor::TensorPtr tensor) {
64   auto &shape = tensor->shape();
65   if (std::any_of(shape.begin(), shape.end(), [](const auto &element) { return element == kDynamicShape; })) {
66     return true;
67   }
68   if (shape.size() == 1 && shape[0] == kDynamicDim) {
69     return true;
70   }
71   return false;
72 }
73 
CheckDynamicShape(mindspore::tensor::TensorPtr sig,mindspore::tensor::TensorPtr org)74 static bool CheckDynamicShape(mindspore::tensor::TensorPtr sig, mindspore::tensor::TensorPtr org) {
75   if (sig->data_type() != org->data_type()) {
76     return false;
77   }
78   auto &sig_shape = sig->shape();
79   if (sig_shape.size() == 1 && sig_shape[0] == kDynamicDim) {
80     return true;
81   }
82   auto &org_shape = org->shape();
83   if (sig_shape.size() != org_shape.size()) {
84     return false;
85   }
86   for (size_t i = 0; i < sig_shape.size(); ++i) {
87     if (sig_shape[i] != org_shape[i] && sig_shape[i] != kDynamicShape) {
88       return false;
89     }
90   }
91   return true;
92 }
93 
CheckSymbolicShape(PyObject * attr,mindspore::tensor::TensorPtr org)94 static bool CheckSymbolicShape(PyObject *attr, mindspore::tensor::TensorPtr org) {
95   if (attr == nullptr || !PyList_Check(attr) || org == nullptr) {
96     return false;
97   }
98   auto shape = org->shape();
99   std::map<int64_t, int64_t> symbolic_shape_data;
100   for (int i = 0; i < PyList_GET_SIZE(attr); ++i) {
101     auto item = PyList_GetItem(attr, i);
102     if (!PyDict_Check(item)) {
103       continue;
104     }
105     auto id = PyDict_GetItemString(item, "id");
106     if (id != nullptr) {
107       auto idv = PyLong_AsLong(id);
108       if (symbolic_shape_data.find(idv) == symbolic_shape_data.end()) {
109         symbolic_shape_data[idv] = shape[i];
110       } else if (symbolic_shape_data[idv] != shape[i]) {
111         return false;
112       }
113     }
114     auto min = PyDict_GetItemString(item, "min");
115     if (min != nullptr && PyLong_Check(min) && PyLong_AsLong(min) > shape[i]) {
116       return false;
117     }
118     auto max = PyDict_GetItemString(item, "max");
119     if (max != nullptr && PyLong_Check(max) && PyLong_AsLong(max) < shape[i]) {
120       return false;
121     }
122     auto d = PyDict_GetItemString(item, "divisor");
123     int64_t dv = d != nullptr ? PyLong_AsLong(d) : 1;
124     auto r = PyDict_GetItemString(item, "remainder");
125     int64_t rv = r != nullptr ? PyLong_AsLong(r) : 0;
126     if (dv > shape[i] || shape[i] % dv != rv) {
127       return false;
128     }
129   }
130   return true;
131 }
132 
CheckTensorValid(PyObject * sig,PyObject * org)133 static bool CheckTensorValid(PyObject *sig, PyObject *org) {
134   mindspore::tensor::TensorPtr psig = py::cast<mindspore::tensor::TensorPtr>(sig);
135   mindspore::tensor::TensorPtr porg = py::cast<mindspore::tensor::TensorPtr>(org);
136   if (IsShapeUnknown(psig) && !CheckDynamicShape(psig, porg)) {
137     return false;
138   }
139   if (PyObject_HasAttrString(sig, "symbolic_shape")) {
140     PyObject *attr = PyObject_GetAttrString(sig, "symbolic_shape");
141     if (!CheckSymbolicShape(attr, porg)) {
142       Py_DECREF(attr);
143       return false;
144     }
145     Py_DECREF(attr);
146   }
147   return true;
148 }
149 
150 static bool CheckItemValid(PyObject *sig, PyObject *org);
CheckListValid(PyObject * sig,PyObject * org)151 static bool CheckListValid(PyObject *sig, PyObject *org) {
152   if (PyList_Size(sig) != PyList_Size(org)) {
153     return false;
154   }
155   for (Py_ssize_t i = 0; i < PyList_Size(sig); ++i) {
156     PyObject *sig_item = PyList_GetItem(sig, i);
157     PyObject *org_item = PyList_GetItem(org, i);
158     if (!CheckItemValid(sig_item, org_item)) {
159       return false;
160     }
161   }
162   return true;
163 }
164 
CheckTupleValid(PyObject * sig,PyObject * org)165 static bool CheckTupleValid(PyObject *sig, PyObject *org) {
166   if (PyTuple_GET_SIZE(sig) != PyTuple_GET_SIZE(org)) {
167     return false;
168   }
169   for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(sig); ++i) {
170     PyObject *sig_item = PyTuple_GET_ITEM(sig, i);
171     PyObject *org_item = PyTuple_GET_ITEM(org, i);
172     if (!CheckItemValid(sig_item, org_item)) {
173       return false;
174     }
175   }
176   return true;
177 }
178 
CheckItemValid(PyObject * sig,PyObject * org)179 static bool CheckItemValid(PyObject *sig, PyObject *org) {
180   if (sig == nullptr || org == nullptr || sig == Py_None || org == Py_None) {
181     return true;
182   }
183   if (py::isinstance<mindspore::tensor::Tensor>(sig) && py::isinstance<mindspore::tensor::Tensor>(org) &&
184       !CheckTensorValid(sig, org)) {
185     return false;
186   }
187   if (PyList_Check(sig) && PyList_Check(org) && !CheckListValid(sig, org)) {
188     return false;
189   }
190   if (PyTuple_Check(sig) && PyTuple_Check(org) && !CheckTupleValid(sig, org)) {
191     return false;
192   }
193   return true;
194 }
195 
CheckValid()196 bool ShapeContext::CheckValid() {
197   if (signature_ == nullptr) {
198     return false;
199   }
200   int argc = frame_->f_code->co_argcount + frame_->f_code->co_kwonlyargcount;
201   if ((PyTuple_GET_SIZE(signature_) + (is_method_ ? 1 : 0)) != argc) {
202     return false;
203   }
204   for (int i = 0; i < PyTuple_GET_SIZE(signature_); ++i) {
205     auto sig = PyTuple_GetItem(signature_, i);
206     auto org = origin_[i];
207     if (!CheckItemValid(sig, org)) {
208       return false;
209     }
210   }
211   return true;
212 }
213 
ApplySignature()214 void ShapeContext::ApplySignature() {
215   if (applied_) {
216     return;
217   }
218   if (!CheckValid()) {
219     return;
220   }
221   int argc = frame_->f_code->co_argcount + frame_->f_code->co_kwonlyargcount;
222   for (int i = (is_method_ ? 1 : 0), j = 0; i < argc; ++i, ++j) {
223     PyObject *sig_item = PyTuple_GetItem(signature_, j);
224     PyObject *org_item = frame_->f_localsplus[i];
225     if (sig_item != nullptr && sig_item != Py_None && org_item != nullptr && org_item != Py_None) {
226       frame_->f_localsplus[i] = sig_item;
227     }
228   }
229   applied_ = true;
230 }
231 
RevertSignature()232 void ShapeContext::RevertSignature() {
233   if (!applied_) {
234     return;
235   }
236   int argc = frame_->f_code->co_argcount + frame_->f_code->co_kwonlyargcount;
237   for (int i = (is_method_ ? 1 : 0), j = 0; i < argc; ++i, ++j) {
238     frame_->f_localsplus[i] = origin_[j];
239   }
240   applied_ = false;
241 }
242 
243 }  // namespace pijit
244 }  // namespace mindspore
245