1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_guard/guard.h"
17 #include <chrono>
18 #include <regex>
19 #include "pybind11/pybind11.h"
20 #include "pybind_api/ir/cell_py.h"
21 #include "pybind_api/ir/primitive_py.h"
22 #include "include/common/utils/convert_utils_py.h"
23 #include "pipeline/jit/pi/utils/utils.h"
24 #include "pipeline/jit/pi/graph_guard/strategy.h"
25
26 namespace mindspore {
27 namespace pijit {
28 const char kSpecializeScalar[] = "specialize_scalar";
29 const char kSpecializeContainer[] = "specialize_container";
30 const char kSpecializeTensor[] = "specialize_tensor";
31 const char kGuardRelaxCnt[] = "relax_guard_count";
32
33 static std::map<std::string, bool> g_mapBoolDefaultConfig = {
34 {kSpecializeScalar, false},
35 {kSpecializeContainer, false},
36 {kSpecializeTensor, false},
37 };
38
39 static std::map<std::string, int> g_mapIntDefaultConfig = {
40 {kGuardRelaxCnt, 0},
41 };
42
43 static GuardItemPtr GuardOnGDeduce(TracePtr var, PyObject *obj, const std::map<std::string, bool> &config);
44 static GuardItemPtr GuardOnScalar(TracePtr var, const std::map<std::string, bool> &config);
45 static GuardItemPtr GuardOnContainer(TracePtr var, const std::map<std::string, bool> &config);
46 static GuardItemPtr GuardOnLiteral(TracePtr var, const std::map<std::string, bool> &config);
47 static GuardItemPtr GuardOnTensor(TracePtr var, const std::map<std::string, bool> &config);
48 static GuardItemPtr GuardOnMutableOrConstObj(TracePtr var);
49 static GuardItemPtr GuardOnDynamicLenContainer(TracePtr var);
50
CheckLiteral(PyObject * obj)51 static bool CheckLiteral(PyObject *obj) {
52 if (obj == nullptr) {
53 return false;
54 }
55
56 ReprRecursionScope scope(obj);
57 if (scope.ReEnterOrError()) {
58 return scope.ReEnter();
59 }
60 if (CheckScalar(obj)) {
61 return true;
62 } else if (PyList_Check(obj)) {
63 for (Py_ssize_t i = 0; i < PyList_Size(obj); ++i) {
64 PyObject *item = PyList_GetItem(obj, i);
65 if (!CheckLiteral(item)) {
66 return false;
67 }
68 }
69 return true;
70 } else if (PyTuple_Check(obj)) {
71 for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(obj); ++i) {
72 PyObject *item = PyTuple_GET_ITEM(obj, i);
73 if (!CheckLiteral(item)) {
74 return false;
75 }
76 }
77 return true;
78 } else if (PySet_Check(obj) || PyFrozenSet_Check(obj)) {
79 Py_ssize_t pos = 0;
80 PyObject *item;
81 Py_hash_t hash;
82 while (_PySet_NextEntry(obj, &pos, &item, &hash)) {
83 if (!CheckLiteral(item)) {
84 return false;
85 }
86 }
87 return true;
88 } else if (PyDict_Check(obj)) {
89 Py_ssize_t pos = 0;
90 PyObject *key;
91 PyObject *val;
92 while (PyDict_Next(obj, &pos, &key, &val)) {
93 if (!CheckLiteral(key) || !CheckLiteral(val)) {
94 return false;
95 }
96 }
97 return true;
98 }
99 return false;
100 }
101
CheckOwnerIsCell(TracePtr var)102 bool CheckOwnerIsCell(TracePtr var) {
103 if (py::isinstance<mindspore::Cell>(var->GetObject())) {
104 return true;
105 } else if (var->GetOrigin() != NULL) {
106 return CheckOwnerIsCell(var);
107 } else {
108 return false;
109 }
110 }
111
112 class OptGuardPerfImpl : public OptGuardPerf {
113 public:
114 virtual void GetGuardPerfInfo(std::map<std::string, std::pair<size_t, size_t>> *guard_info,
115 std::map<std::string, std::pair<size_t, std::vector<size_t>>> *item_info,
116 std::map<std::string, std::pair<size_t, size_t>> *trace_info,
117 std::map<std::string, std::pair<size_t, size_t>> *guard_freq_info) const;
118 OptGuardPerfImpl() = default;
119 virtual ~OptGuardPerfImpl() = default;
120 virtual void LogGuardPerfStart(OptGuard *tag2, GuardItem *item);
121 virtual void LogGuardPerfEnd(GuardItem *item, bool res);
122 virtual void LogItemPerfStart(int total_stage);
123 virtual void LogItemPerfEnd(GuardItem *item, int stage);
124 virtual void LogTracePerfStart();
125 virtual void LogTracePerfEnd(Trace *trace, bool cache);
126
127 protected:
128 OptGuard *cur_tag2_ = nullptr;
129 GuardItem *cur_guard_ = nullptr;
130 std::chrono::steady_clock::time_point guard_start_;
131 std::chrono::steady_clock::time_point trace_start_;
132 std::vector<std::chrono::steady_clock::time_point> item_stage_;
133 std::map<std::string, std::pair<size_t, size_t>> guard_info_;
134 std::map<std::string, std::pair<size_t, std::vector<size_t>>> item_info_;
135 std::map<std::string, std::pair<size_t, size_t>> trace_info_;
136 std::map<std::string, std::pair<size_t, size_t>> guard_freq_info_;
137 };
138
139 static OptGuardPerfImpl g_guard_perf;
GetGuardPerf()140 OptGuardPerf *OptGuardPerf::GetGuardPerf() { return &g_guard_perf; }
141
GetGuardPerfInfo(std::map<std::string,std::pair<size_t,size_t>> * guard_info,std::map<std::string,std::pair<size_t,std::vector<size_t>>> * item_info,std::map<std::string,std::pair<size_t,size_t>> * trace_info,std::map<std::string,std::pair<size_t,size_t>> * guard_freq_info) const142 void OptGuardPerfImpl::GetGuardPerfInfo(std::map<std::string, std::pair<size_t, size_t>> *guard_info,
143 std::map<std::string, std::pair<size_t, std::vector<size_t>>> *item_info,
144 std::map<std::string, std::pair<size_t, size_t>> *trace_info,
145 std::map<std::string, std::pair<size_t, size_t>> *guard_freq_info) const {
146 if (guard_info != nullptr) {
147 guard_info->clear();
148 guard_info->insert(guard_info_.begin(), guard_info_.end());
149 }
150 if (trace_info != nullptr) {
151 trace_info->clear();
152 trace_info->insert(trace_info_.begin(), trace_info_.end());
153 }
154 if (guard_freq_info != nullptr) {
155 guard_freq_info->clear();
156 guard_freq_info->insert(guard_freq_info_.begin(), guard_freq_info_.end());
157 }
158 if (item_info != nullptr) {
159 item_info->clear();
160 item_info->insert(item_info_.begin(), item_info_.end());
161 }
162 }
163
LogGuardPerfStart(OptGuard * tag2,GuardItem * item)164 void OptGuardPerfImpl::LogGuardPerfStart(OptGuard *tag2, GuardItem *item) {
165 cur_guard_ = item;
166 cur_tag2_ = tag2;
167 guard_start_ = std::chrono::steady_clock::now();
168 }
169
LogGuardPerfEnd(GuardItem * item,bool res)170 void OptGuardPerfImpl::LogGuardPerfEnd(GuardItem *item, bool res) {
171 auto duration =
172 std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - guard_start_);
173 size_t dur = (size_t)(duration.count());
174 size_t inc = 1;
175 auto info = item->ToString();
176 std::stringstream s;
177 s << reinterpret_cast<void *>(cur_tag2_) << "=>" << reinterpret_cast<void *>(cur_guard_) << "=>";
178 info = s.str() + info;
179 auto iter = guard_info_.find(info);
180 if (iter != guard_info_.end()) {
181 iter->second.first += inc;
182 iter->second.second += dur;
183 } else {
184 guard_info_[info] = std::make_pair(inc, dur);
185 }
186 iter = guard_freq_info_.find(info);
187 if (iter != guard_freq_info_.end()) {
188 if (res) {
189 iter->second.first += 1;
190 } else {
191 iter->second.second += 1;
192 }
193 } else {
194 if (res) {
195 guard_freq_info_[info] = std::make_pair(1, 0);
196 } else {
197 guard_freq_info_[info] = std::make_pair(0, 1);
198 }
199 }
200 }
201
LogItemPerfStart(int total_stage)202 void OptGuardPerfImpl::LogItemPerfStart(int total_stage) {
203 item_stage_.clear();
204 item_stage_.resize(total_stage + 1);
205 item_stage_[0] = std::chrono::steady_clock::now();
206 }
207
LogItemPerfEnd(GuardItem * item,int stage)208 void OptGuardPerfImpl::LogItemPerfEnd(GuardItem *item, int stage) {
209 size_t cur_stage = static_cast<size_t>(stage + 1);
210 if (item_stage_.size() > cur_stage) {
211 item_stage_[cur_stage] = std::chrono::steady_clock::now();
212 }
213 if (item_stage_.size() == (cur_stage + 1)) {
214 auto info = item->ToString();
215 std::stringstream s;
216 s << reinterpret_cast<void *>(cur_tag2_) << "=>" << reinterpret_cast<void *>(cur_guard_) << "=>";
217 info = s.str() + info;
218 std::vector<size_t> vecDur;
219 for (int idx = 0; idx <= stage; ++idx) {
220 auto duration = std::chrono::duration_cast<std::chrono::microseconds>(item_stage_[idx + 1] - item_stage_[idx]);
221 vecDur.push_back((size_t)(duration.count()));
222 }
223 auto iter = item_info_.find(info);
224 if (iter != item_info_.end()) {
225 iter->second.first += 1;
226 for (size_t i = 0; i < vecDur.size(); ++i) {
227 iter->second.second[i] += vecDur[i];
228 }
229 } else {
230 item_info_[info] = std::make_pair(1, vecDur);
231 }
232 }
233 }
234
LogTracePerfStart()235 void OptGuardPerfImpl::LogTracePerfStart() { trace_start_ = std::chrono::steady_clock::now(); }
236
LogTracePerfEnd(Trace * trace,bool cache)237 void OptGuardPerfImpl::LogTracePerfEnd(Trace *trace, bool cache) {
238 auto duration =
239 std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - trace_start_);
240 size_t dur = (size_t)(duration.count());
241 size_t inc = 1;
242 auto info = trace->ToString(true);
243 std::stringstream s;
244 s << reinterpret_cast<void *>(cur_guard_) << "=>";
245 if (cache) {
246 s << "cache:";
247 }
248 info = s.str() + info;
249 auto iter = trace_info_.find(info);
250 if (iter != trace_info_.end()) {
251 iter->second.first += inc;
252 iter->second.second += dur;
253 } else {
254 trace_info_[info] = std::make_pair(inc, dur);
255 }
256 }
257
OptGuard()258 OptGuard::OptGuard() {
259 bool_config_ = g_mapBoolDefaultConfig;
260 int_config_ = g_mapIntDefaultConfig;
261 }
262
UpdateGuardList(GuardItemPtr item)263 void OptGuard::UpdateGuardList(GuardItemPtr item) {
264 // reorder list to speed up check on next run
265 for (size_t i = 0; i < guardList_.size(); ++i) {
266 if (guardList_[i] == item) {
267 guardList_.erase(guardList_.begin() + i);
268 guardList_.insert(guardList_.begin(), item);
269 }
270 }
271 }
272
Check(const PyFrameObject * frame,bool print,std::map<size_t,PyObject * > * cache,std::map<size_t,bool> * success,std::map<size_t,bool> * fail,bool perf)273 bool OptGuard::Check(const PyFrameObject *frame, bool print, std::map<size_t, PyObject *> *cache,
274 std::map<size_t, bool> *success, std::map<size_t, bool> *fail, bool perf) {
275 // filter failure case
276 if (fail != nullptr) {
277 for (auto item : guardMap_) {
278 if (fail->find(item.first) != fail->end()) {
279 return false;
280 }
281 }
282 }
283 std::vector<GuardItemPtr> list;
284 list.reserve(guardList_.size());
285 // filter success case
286 if (success != nullptr) {
287 for (auto item : guardList_) {
288 if (success->find(item->Info().Id()) == success->end()) {
289 list.push_back(item);
290 }
291 }
292 } else {
293 list = guardList_;
294 }
295 list = OptStrategy::MakeGuardItemListStrategyByFrame(frame, list);
296 for (size_t i = 0; i < list.size(); ++i) {
297 GuardItemPtr item = list[i];
298 if (perf) {
299 g_guard_perf.LogGuardPerfStart(this, item.get());
300 }
301 bool result = item->Check(frame, cache, perf);
302 if (perf) {
303 g_guard_perf.LogGuardPerfEnd(item.get(), result);
304 }
305 if (!result) {
306 UpdateGuardList(item);
307 if (fail != nullptr) {
308 fail->operator[](item->Info().Id()) = false;
309 }
310 if (print) {
311 auto trace = item->GetTrace();
312 auto obj = GetObjectFromTrace(frame, trace);
313 GRAPH_JIT_LOG_F("Guard check fail: %s v.s. %s\n", item->ToString().c_str(),
314 std::string(py::str(py::cast<py::object>(obj))).c_str());
315 Py_XDECREF(obj);
316 } else if (IS_OUTPUT_ON(mindspore::kDebug)) {
317 MS_LOG(DEBUG) << "Guard check fail:" << item->ToString();
318 }
319 return false;
320 } else if (success != nullptr) {
321 success->operator[](item->Info().Id()) = true;
322 }
323 }
324 return true;
325 }
326
GuardOn(TracePtr var,GuardLevel tp,bool needSpecialize,int recurseDepth)327 bool OptGuard::GuardOn(TracePtr var, GuardLevel tp, bool needSpecialize, int recurseDepth) {
328 // Now we have TypeGuard IdGuard NameGuard AttrGuard EqGuard, let's add guard to guardlist based on type
329 PyObject *obj = var->GetObject();
330 if (int_config_.find(kGuardRelaxCnt) != int_config_.end()) {
331 var->SetRelaxCount(int_config_[kGuardRelaxCnt]);
332 }
333 GuardItemPtr item = nullptr;
334 if (obj != nullptr) {
335 py::object py_obj = py::reinterpret_borrow<py::object>(obj);
336 if (IsStubTensor(py_obj)) {
337 py_obj = python_adapter::CallPyObjMethod(py_obj, "stub_sync");
338 obj = py_obj.ptr();
339 }
340 if (tp == GuardLevel::GDeduce) {
341 item = GuardOnGDeduce(var, obj, bool_config_);
342 } else if (tp == GuardLevel::GId) {
343 item = GuardId(var);
344 } else if (tp == GuardLevel::GType) {
345 item = GuardType(var);
346 } else if (tp == GuardLevel::GAttr) {
347 item = GuardAttr(var);
348 } else if (tp == GuardLevel::GEqual) {
349 item = GuardEqual(var, needSpecialize, recurseDepth);
350 }
351 } else {
352 // Check obj == None
353 item = GuardEqual(var, 0);
354 }
355 if (item != nullptr) {
356 size_t szItem = item->Info().Id();
357 if (guardMap_.find(szItem) == guardMap_.end()) {
358 guardList_.push_back(item);
359 guardMap_[szItem] = item;
360 }
361 return true;
362 } else {
363 return false;
364 }
365 }
366
Info()367 const InfoPack &OptGuard::Info() {
368 if (info_ == nullptr) {
369 InfoPack info;
370 info.Begin();
371 for (auto &item : guardList_) {
372 info << item->Info();
373 }
374 info.End();
375 info_ = std::make_shared<InfoPack>(info);
376 info_->Update();
377 }
378 return *info_;
379 }
380
GuardOnGDeduce(TracePtr var,PyObject * obj,const std::map<std::string,bool> & config)381 static GuardItemPtr GuardOnGDeduce(TracePtr var, PyObject *obj, const std::map<std::string, bool> &config) {
382 GuardItemPtr item = nullptr;
383 if (CheckLiteral(obj)) {
384 item = GuardOnLiteral(var, config);
385 } else if (PyFrozenSet_Check(obj)) {
386 item = GuardId(var);
387 } else if (PyFunction_Check(obj) || PyMethod_Check(obj) || PyInstanceMethod_Check(obj)) {
388 item = GuardEqual(var, false, 0);
389 } else if (PyType_Check(obj)) {
390 item = GuardEqual(var, false, 0);
391 } else if (CheckContainer(obj)) {
392 // due to the failure of CheckLiteral, it need check size and element type
393 item = GuardOnContainer(var, config);
394 } else if (PySlice_Check(obj)) {
395 item = GuardType(var);
396 } else if (py::isinstance<py::array>(obj)) {
397 item = GuardId(var);
398 } else if (py::isinstance<mindspore::Type>(obj)) {
399 item = GuardEqual(var, true, INT_MAX);
400 } else if (IsTensorPyObject(obj)) {
401 item = GuardOnTensor(var, config);
402 } else if (py::isinstance<mindspore::PrimitivePyAdapter>(obj)) {
403 if (CheckOwnerIsCell(var)) {
404 item = GuardEqual(var, true, INT_MAX);
405 } else {
406 item = GuardRepr(var);
407 }
408 } else if (py::isinstance<mindspore::Cell>(obj)) {
409 item = GuardRepr(var);
410 } else if (py::isinstance<mindspore::ParamInfo>(obj)) {
411 item = GuardEqual(var, true, INT_MAX);
412 } else {
413 item = GuardType(var);
414 }
415 return item;
416 }
417
GuardOnScalar(TracePtr var,const std::map<std::string,bool> & config)418 static GuardItemPtr GuardOnScalar(TracePtr var, const std::map<std::string, bool> &config) {
419 GuardItemPtr item = GuardOnMutableOrConstObj(var);
420 if (item != nullptr) {
421 return item;
422 }
423 bool need_specialize = false;
424 auto cfg = config.find(kSpecializeScalar);
425 if (cfg != config.end()) {
426 need_specialize = cfg->second;
427 }
428 // need take dynamic symbolic into account
429 if (need_specialize) {
430 if ((var->GetOriginType() == TraceType::Global || var->GetOriginType() == TraceType::BuiltIn) ||
431 var->GetOriginType() == TraceType::Param || var->GetTraceType() == TraceType::Item ||
432 var->GetTraceType() == TraceType::Attr) {
433 item = GuardEqual(var, true, INT_MAX);
434 } else {
435 item = GuardType(var);
436 }
437 } else {
438 item = GuardEqual(var, false, 0);
439 }
440 return item;
441 }
442
GuardOnContainer(TracePtr var,const std::map<std::string,bool> & config)443 static GuardItemPtr GuardOnContainer(TracePtr var, const std::map<std::string, bool> &config) {
444 GuardItemPtr item = GuardOnDynamicLenContainer(var);
445 if (item != nullptr) {
446 return item;
447 } else {
448 item = GuardOnMutableOrConstObj(var);
449 }
450 if (item != nullptr) {
451 return item;
452 }
453 bool need_specialize = false;
454 auto cfg = config.find(kSpecializeContainer);
455 if (cfg != config.end()) {
456 need_specialize = cfg->second;
457 }
458 if (need_specialize) {
459 item = GuardEqual(var, true, INT_MAX);
460 } else {
461 item = GuardEqual(var, false, 0);
462 }
463 return item;
464 }
465
GuardOnLiteral(TracePtr var,const std::map<std::string,bool> & config)466 static GuardItemPtr GuardOnLiteral(TracePtr var, const std::map<std::string, bool> &config) {
467 GuardItemPtr item = nullptr;
468 PyObject *obj = var->GetObject();
469 if (CheckScalar(obj)) {
470 return GuardOnScalar(var, config);
471 } else if (CheckContainer(obj)) {
472 return GuardOnContainer(var, config);
473 } else {
474 item = GuardOnMutableOrConstObj(var);
475 if (item == nullptr) {
476 item = GuardEqual(var, false, 0);
477 }
478 }
479 return item;
480 }
481
GuardOnTensor(TracePtr var,const std::map<std::string,bool> & config)482 static GuardItemPtr GuardOnTensor(TracePtr var, const std::map<std::string, bool> &config) {
483 GuardItemPtr item = nullptr;
484 bool need_specialize = false;
485 auto cfg = config.find(kSpecializeTensor);
486 if (cfg != config.end()) {
487 need_specialize = cfg->second;
488 }
489 item = GuardOnMutableOrConstObj(var);
490 if (item != nullptr) {
491 return item;
492 }
493 if (CheckOwnerIsCell(var)) {
494 if (var->GetOriginType() == TraceType::Const) {
495 item = GuardId(var);
496 } else {
497 item = GuardEqual(var, false, INT_MAX);
498 }
499 } else if (var->GetOriginType() == TraceType::Const) {
500 item = GuardId(var);
501 } else if (need_specialize) {
502 item = GuardEqual(var, true, INT_MAX);
503 } else {
504 item = GuardEqual(var, false, INT_MAX);
505 }
506 return item;
507 }
508
GuardOnMutableOrConstObj(TracePtr var)509 static GuardItemPtr GuardOnMutableOrConstObj(TracePtr var) {
510 PyObject *obj = var->GetObject();
511 GuardItemPtr item = nullptr;
512 if (HasMutableOrConstAttr(obj)) {
513 if (CheckMutableOrNonConstAttr(obj)) {
514 item = GuardEqual(var, false, INT_MAX);
515 } else {
516 item = GuardEqual(var, true, INT_MAX);
517 }
518 }
519 return item;
520 }
521
GuardOnDynamicLenContainer(TracePtr var)522 static GuardItemPtr GuardOnDynamicLenContainer(TracePtr var) {
523 PyObject *obj = var->GetObject();
524 GuardItemPtr item = nullptr;
525 if (HasDynamicLength(obj)) {
526 if (CheckDynamicLength(obj)) {
527 item = GuardType(var);
528 } else {
529 item = GuardEqual(var, false, 0);
530 }
531 }
532 return item;
533 }
534
AddTraceFromGuard(const std::vector<TracePtr> & traces,OptGuardPtr other)535 void OptGuard::AddTraceFromGuard(const std::vector<TracePtr> &traces, OptGuardPtr other) {
536 for (size_t i = 0; i < traces.size(); ++i) {
537 auto dst = traces[i];
538 auto src = std::make_shared<RootTrace>(dst->GetObject(), TraceType::Param, i);
539 for (auto item : other->guardList_) {
540 item->Replace(dst, src);
541 }
542 }
543 for (auto item : other->guardList_) {
544 guardList_.push_back(item);
545 }
546 }
547
GetDescript()548 std::string OptGuard::GetDescript() {
549 std::string ret;
550 for (auto item : guardList_) {
551 ret += ";" + item->ToString();
552 }
553 if (ret.size() > 0) {
554 ret = ret.substr(1);
555 }
556 return ret;
557 }
558
UpdateConfig(const std::map<std::string,bool> & bool_config,const std::map<std::string,int> & int_config)559 void OptGuard::UpdateConfig(const std::map<std::string, bool> &bool_config,
560 const std::map<std::string, int> &int_config) {
561 for (auto item : bool_config) {
562 if (g_mapBoolDefaultConfig.find(item.first) != g_mapBoolDefaultConfig.end()) {
563 bool_config_[item.first] = item.second;
564 }
565 }
566 for (auto item : int_config) {
567 if (g_mapIntDefaultConfig.find(item.first) != g_mapIntDefaultConfig.end()) {
568 int_config_[item.first] = item.second;
569 }
570 }
571 }
572
Backup()573 void OptGuard::Backup() { guardStack_.push(std::make_pair(guardList_, guardMap_)); }
574
Rollback()575 void OptGuard::Rollback() {
576 GuardCheckPoint point = guardStack_.top();
577 guardList_.swap(point.first);
578 guardMap_.swap(point.second);
579 guardStack_.pop();
580 }
581
Pop()582 void OptGuard::Pop() { guardStack_.pop(); }
583
MatchDynamicShape(GuardItemPtr item,const std::vector<GuardItemPtr> & list)584 static bool MatchDynamicShape(GuardItemPtr item, const std::vector<GuardItemPtr> &list) {
585 auto trace_type = item->GetTrace()->GetTraceType();
586 auto guard_type = item->GetType();
587 if ((trace_type != TraceType::Deref && trace_type != TraceType::Param) || guard_type != GIType::GTEqual) {
588 return false;
589 }
590 for (auto other : list) {
591 if (item->MatchDynamicShape(other)) {
592 return true;
593 }
594 }
595 return false;
596 }
597
MatchShape(OptGuardPtr other)598 bool OptGuard::MatchShape(OptGuardPtr other) {
599 if (std::any_of(guardList_.begin(), guardList_.end(), [other](auto &item) {
600 return (!std::any_of(other->guardList_.begin(), other->guardList_.end(), [item](GuardItemPtr oi) {
601 return *item == *oi;
602 }) && !MatchDynamicShape(item, other->guardList_));
603 })) {
604 return false;
605 }
606 if (std::any_of(other->guardList_.begin(), other->guardList_.end(), [this](auto &item) {
607 return (!std::any_of(guardList_.begin(), guardList_.end(), [item](GuardItemPtr oi) { return *item == *oi; }));
608 })) {
609 return false;
610 }
611 return true;
612 }
613
FindItem(const std::vector<GuardItemPtr> & guardList,int idx,TraceType type,PyObject * obj)614 static PyObject *FindItem(const std::vector<GuardItemPtr> &guardList, int idx, TraceType type, PyObject *obj) {
615 auto iter = std::find_if(guardList.begin(), guardList.end(), [idx, type](GuardItemPtr item) {
616 if (item->GetTrace()->GetTraceType() == type) {
617 int index;
618 std::string name, module_name;
619 (reinterpret_cast<RootTrace *>(item->GetTrace().get()))->GetParam(&index, &name, &module_name);
620 return (idx == index);
621 } else {
622 return false;
623 }
624 });
625 if (iter != guardList.end()) {
626 GuardItemPtr item = *iter;
627 return item->ApplyDynamicShape(obj);
628 } else {
629 return nullptr;
630 }
631 }
632
ApplyDynamicShape(PyFrameObject * f)633 std::vector<PyObject *> OptGuard::ApplyDynamicShape(PyFrameObject *f) {
634 std::vector<PyObject *> ret;
635 int argc = f->f_code->co_argcount + f->f_code->co_kwonlyargcount;
636 PyObject *vargs = NULL;
637 PyObject *kwargs = NULL;
638 if (f->f_code->co_flags & CO_VARARGS) {
639 vargs = f->f_localsplus[argc];
640 }
641 if (f->f_code->co_flags & CO_VARKEYWORDS) {
642 kwargs = f->f_localsplus[argc + (vargs ? 1 : 0)];
643 }
644 for (int i = 0; i < argc; ++i) {
645 auto new_obj = FindItem(guardList_, i, TraceType::Param, f->f_localsplus[i]);
646 if (new_obj == nullptr) {
647 ret.push_back(nullptr);
648 } else {
649 ret.push_back(f->f_localsplus[i]);
650 f->f_localsplus[i] = new_obj;
651 }
652 }
653 if (vargs != NULL) {
654 ret.push_back(nullptr);
655 }
656 if (kwargs != NULL) {
657 ret.push_back(nullptr);
658 }
659 ret.resize(f->f_code->co_nlocals, nullptr);
660 for (int i = 0; f->f_code->co_cell2arg && i < PyTuple_GET_SIZE(f->f_code->co_cellvars); ++i) {
661 Py_ssize_t arg = f->f_code->co_cell2arg[i];
662 if (arg != CO_CELL_NOT_AN_ARG) {
663 auto cell = f->f_localsplus[f->f_code->co_nlocals + i];
664 auto new_obj = FindItem(guardList_, i, TraceType::Deref, PyCell_GET(cell));
665 if (new_obj == nullptr) {
666 ret.push_back(nullptr);
667 } else {
668 ret.push_back(PyCell_GET(cell));
669 PyCell_SET(cell, new_obj);
670 }
671 }
672 }
673 ret.resize(f->f_code->co_nlocals + PyTuple_GET_SIZE(f->f_code->co_cellvars), nullptr);
674 for (int i = 0; i < PyTuple_GET_SIZE(f->f_code->co_freevars); ++i) {
675 Py_ssize_t arg = PyTuple_GET_SIZE(f->f_code->co_cellvars) + i;
676 auto cell = f->f_localsplus[f->f_code->co_nlocals + arg];
677 auto new_obj = FindItem(guardList_, arg, TraceType::Deref, PyCell_GET(cell));
678 if (new_obj == nullptr) {
679 ret.push_back(nullptr);
680 } else {
681 ret.push_back(PyCell_GET(cell));
682 PyCell_SET(cell, new_obj);
683 }
684 }
685 return ret;
686 }
687
RevertDynamicShape(PyFrameObject * f,const std::vector<PyObject * > & backup)688 void OptGuard::RevertDynamicShape(PyFrameObject *f, const std::vector<PyObject *> &backup) {
689 int argc = f->f_code->co_argcount + f->f_code->co_kwonlyargcount;
690 for (int i = 0; i < argc; ++i) {
691 if (backup[i] != nullptr) {
692 Py_XDECREF(f->f_localsplus[i]);
693 f->f_localsplus[i] = backup[i];
694 }
695 }
696 for (int i = 0; f->f_code->co_cell2arg && i < PyTuple_GET_SIZE(f->f_code->co_cellvars); ++i) {
697 Py_ssize_t arg = f->f_code->co_cell2arg[i];
698 if (arg != CO_CELL_NOT_AN_ARG) {
699 auto cell = f->f_localsplus[f->f_code->co_nlocals + i];
700 if (backup[f->f_code->co_nlocals + i] != nullptr) {
701 Py_XDECREF(PyCell_GET(cell));
702 PyCell_SET(cell, backup[f->f_code->co_nlocals + i]);
703 }
704 }
705 }
706 for (int i = 0; i < PyTuple_GET_SIZE(f->f_code->co_freevars); ++i) {
707 Py_ssize_t arg = PyTuple_GET_SIZE(f->f_code->co_cellvars) + i;
708 auto cell = f->f_localsplus[f->f_code->co_nlocals + arg];
709 if (backup[f->f_code->co_nlocals + arg] != nullptr) {
710 Py_XDECREF(PyCell_GET(cell));
711 PyCell_SET(cell, backup[f->f_code->co_nlocals + arg]);
712 }
713 }
714 }
715
ToString() const716 std::string OptGuard::ToString() const {
717 std::stringstream s;
718 for (const auto &i : guardMap_) {
719 s << " guard [" << i.first << "] [" << i.second->ToString() << " ] at [" << i.second.get() << "]\n";
720 }
721 return s.str();
722 }
723
Optimize()724 OptGuardPtr OptGuard::Optimize() {
725 bool need_update = false;
726 for (size_t i = 0; i < guardList_.size(); ++i) {
727 auto old_item = guardList_[i];
728 auto new_item = old_item->Optimize();
729 if (new_item != nullptr) {
730 guardList_[i] = new_item;
731 guardMap_.erase(old_item->Info().Id());
732 guardMap_[new_item->Info().Id()] = new_item;
733 need_update = true;
734 }
735 }
736 if (need_update) {
737 info_ = nullptr;
738 Info();
739 return shared_from_this();
740 } else {
741 return nullptr;
742 }
743 }
744
745 } // namespace pijit
746 } // namespace mindspore
747