1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_guard/shape_ctx.h"
17 #include <algorithm>
18 #include <map>
19 #include "ir/tensor.h"
20
21 namespace py = pybind11;
22
23 namespace mindspore {
24 namespace pijit {
25
ShapeContext(PyFrameObject * f,PyObject * signature)26 ShapeContext::ShapeContext(PyFrameObject *f, PyObject *signature)
27 : frame_(f), signature_(signature), is_method_(false), applied_(false) {
28 Py_XINCREF(f);
29 Py_XINCREF(signature);
30 if (signature != nullptr) {
31 if (!PyTuple_Check(signature) && !PyList_Check(signature)) {
32 auto tuple = PyTuple_New(1);
33 PyTuple_SET_ITEM(tuple, 0, signature);
34 Py_DECREF(signature);
35 signature_ = tuple;
36 } else if (!PyTuple_Check(signature)) {
37 int size = PyList_Size(signature);
38 auto tuple = PyTuple_New(size);
39 for (int i = 0; i < size; ++i) {
40 auto item = PyList_GetItem(signature, i);
41 Py_XINCREF(item);
42 PyTuple_SET_ITEM(tuple, i, item);
43 }
44 Py_DECREF(signature);
45 signature_ = tuple;
46 }
47 int argc = frame_->f_code->co_argcount + frame_->f_code->co_kwonlyargcount;
48 is_method_ = (argc == (PyTuple_GET_SIZE(signature_) + 1)) ? true : false;
49 std::vector<PyObject *> locals(&(frame_->f_localsplus[is_method_ ? 1 : 0]), &(frame_->f_localsplus[argc]));
50 origin_ = locals;
51 }
52 }
53
~ShapeContext()54 ShapeContext::~ShapeContext() {
55 RevertSignature();
56 Py_XDECREF(frame_);
57 Py_XDECREF(signature_);
58 }
59
60 static constexpr int64_t kDynamicDim = -2;
61 static constexpr int64_t kDynamicShape = -1;
62
IsShapeUnknown(mindspore::tensor::TensorPtr tensor)63 static bool IsShapeUnknown(mindspore::tensor::TensorPtr tensor) {
64 auto &shape = tensor->shape();
65 if (std::any_of(shape.begin(), shape.end(), [](const auto &element) { return element == kDynamicShape; })) {
66 return true;
67 }
68 if (shape.size() == 1 && shape[0] == kDynamicDim) {
69 return true;
70 }
71 return false;
72 }
73
CheckDynamicShape(mindspore::tensor::TensorPtr sig,mindspore::tensor::TensorPtr org)74 static bool CheckDynamicShape(mindspore::tensor::TensorPtr sig, mindspore::tensor::TensorPtr org) {
75 if (sig->data_type() != org->data_type()) {
76 return false;
77 }
78 auto &sig_shape = sig->shape();
79 if (sig_shape.size() == 1 && sig_shape[0] == kDynamicDim) {
80 return true;
81 }
82 auto &org_shape = org->shape();
83 if (sig_shape.size() != org_shape.size()) {
84 return false;
85 }
86 for (size_t i = 0; i < sig_shape.size(); ++i) {
87 if (sig_shape[i] != org_shape[i] && sig_shape[i] != kDynamicShape) {
88 return false;
89 }
90 }
91 return true;
92 }
93
CheckSymbolicShape(PyObject * attr,mindspore::tensor::TensorPtr org)94 static bool CheckSymbolicShape(PyObject *attr, mindspore::tensor::TensorPtr org) {
95 if (attr == nullptr || !PyList_Check(attr) || org == nullptr) {
96 return false;
97 }
98 auto shape = org->shape();
99 std::map<int64_t, int64_t> symbolic_shape_data;
100 for (int i = 0; i < PyList_GET_SIZE(attr); ++i) {
101 auto item = PyList_GetItem(attr, i);
102 if (!PyDict_Check(item)) {
103 continue;
104 }
105 auto id = PyDict_GetItemString(item, "id");
106 if (id != nullptr) {
107 auto idv = PyLong_AsLong(id);
108 if (symbolic_shape_data.find(idv) == symbolic_shape_data.end()) {
109 symbolic_shape_data[idv] = shape[i];
110 } else if (symbolic_shape_data[idv] != shape[i]) {
111 return false;
112 }
113 }
114 auto min = PyDict_GetItemString(item, "min");
115 if (min != nullptr && PyLong_Check(min) && PyLong_AsLong(min) > shape[i]) {
116 return false;
117 }
118 auto max = PyDict_GetItemString(item, "max");
119 if (max != nullptr && PyLong_Check(max) && PyLong_AsLong(max) < shape[i]) {
120 return false;
121 }
122 auto d = PyDict_GetItemString(item, "divisor");
123 int64_t dv = d != nullptr ? PyLong_AsLong(d) : 1;
124 auto r = PyDict_GetItemString(item, "remainder");
125 int64_t rv = r != nullptr ? PyLong_AsLong(r) : 0;
126 if (dv > shape[i] || shape[i] % dv != rv) {
127 return false;
128 }
129 }
130 return true;
131 }
132
CheckTensorValid(PyObject * sig,PyObject * org)133 static bool CheckTensorValid(PyObject *sig, PyObject *org) {
134 mindspore::tensor::TensorPtr psig = py::cast<mindspore::tensor::TensorPtr>(sig);
135 mindspore::tensor::TensorPtr porg = py::cast<mindspore::tensor::TensorPtr>(org);
136 if (IsShapeUnknown(psig) && !CheckDynamicShape(psig, porg)) {
137 return false;
138 }
139 if (PyObject_HasAttrString(sig, "symbolic_shape")) {
140 PyObject *attr = PyObject_GetAttrString(sig, "symbolic_shape");
141 if (!CheckSymbolicShape(attr, porg)) {
142 Py_DECREF(attr);
143 return false;
144 }
145 Py_DECREF(attr);
146 }
147 return true;
148 }
149
150 static bool CheckItemValid(PyObject *sig, PyObject *org);
CheckListValid(PyObject * sig,PyObject * org)151 static bool CheckListValid(PyObject *sig, PyObject *org) {
152 if (PyList_Size(sig) != PyList_Size(org)) {
153 return false;
154 }
155 for (Py_ssize_t i = 0; i < PyList_Size(sig); ++i) {
156 PyObject *sig_item = PyList_GetItem(sig, i);
157 PyObject *org_item = PyList_GetItem(org, i);
158 if (!CheckItemValid(sig_item, org_item)) {
159 return false;
160 }
161 }
162 return true;
163 }
164
CheckTupleValid(PyObject * sig,PyObject * org)165 static bool CheckTupleValid(PyObject *sig, PyObject *org) {
166 if (PyTuple_GET_SIZE(sig) != PyTuple_GET_SIZE(org)) {
167 return false;
168 }
169 for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(sig); ++i) {
170 PyObject *sig_item = PyTuple_GET_ITEM(sig, i);
171 PyObject *org_item = PyTuple_GET_ITEM(org, i);
172 if (!CheckItemValid(sig_item, org_item)) {
173 return false;
174 }
175 }
176 return true;
177 }
178
CheckItemValid(PyObject * sig,PyObject * org)179 static bool CheckItemValid(PyObject *sig, PyObject *org) {
180 if (sig == nullptr || org == nullptr || sig == Py_None || org == Py_None) {
181 return true;
182 }
183 if (py::isinstance<mindspore::tensor::Tensor>(sig) && py::isinstance<mindspore::tensor::Tensor>(org) &&
184 !CheckTensorValid(sig, org)) {
185 return false;
186 }
187 if (PyList_Check(sig) && PyList_Check(org) && !CheckListValid(sig, org)) {
188 return false;
189 }
190 if (PyTuple_Check(sig) && PyTuple_Check(org) && !CheckTupleValid(sig, org)) {
191 return false;
192 }
193 return true;
194 }
195
CheckValid()196 bool ShapeContext::CheckValid() {
197 if (signature_ == nullptr) {
198 return false;
199 }
200 int argc = frame_->f_code->co_argcount + frame_->f_code->co_kwonlyargcount;
201 if ((PyTuple_GET_SIZE(signature_) + (is_method_ ? 1 : 0)) != argc) {
202 return false;
203 }
204 for (int i = 0; i < PyTuple_GET_SIZE(signature_); ++i) {
205 auto sig = PyTuple_GetItem(signature_, i);
206 auto org = origin_[i];
207 if (!CheckItemValid(sig, org)) {
208 return false;
209 }
210 }
211 return true;
212 }
213
ApplySignature()214 void ShapeContext::ApplySignature() {
215 if (applied_) {
216 return;
217 }
218 if (!CheckValid()) {
219 return;
220 }
221 int argc = frame_->f_code->co_argcount + frame_->f_code->co_kwonlyargcount;
222 for (int i = (is_method_ ? 1 : 0), j = 0; i < argc; ++i, ++j) {
223 PyObject *sig_item = PyTuple_GetItem(signature_, j);
224 PyObject *org_item = frame_->f_localsplus[i];
225 if (sig_item != nullptr && sig_item != Py_None && org_item != nullptr && org_item != Py_None) {
226 frame_->f_localsplus[i] = sig_item;
227 }
228 }
229 applied_ = true;
230 }
231
RevertSignature()232 void ShapeContext::RevertSignature() {
233 if (!applied_) {
234 return;
235 }
236 int argc = frame_->f_code->co_argcount + frame_->f_code->co_kwonlyargcount;
237 for (int i = (is_method_ ? 1 : 0), j = 0; i < argc; ++i, ++j) {
238 frame_->f_localsplus[i] = origin_[j];
239 }
240 applied_ = false;
241 }
242
243 } // namespace pijit
244 } // namespace mindspore
245