• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_guard/guard.h"
17 #include <chrono>
18 #include <regex>
19 #include "pybind11/pybind11.h"
20 #include "pybind_api/ir/cell_py.h"
21 #include "pybind_api/ir/primitive_py.h"
22 #include "include/common/utils/convert_utils_py.h"
23 #include "pipeline/jit/pi/utils/utils.h"
24 #include "pipeline/jit/pi/graph_guard/strategy.h"
25 
26 namespace mindspore {
27 namespace pijit {
28 const char kSpecializeScalar[] = "specialize_scalar";
29 const char kSpecializeContainer[] = "specialize_container";
30 const char kSpecializeTensor[] = "specialize_tensor";
31 const char kGuardRelaxCnt[] = "relax_guard_count";
32 
33 static std::map<std::string, bool> g_mapBoolDefaultConfig = {
34   {kSpecializeScalar, false},
35   {kSpecializeContainer, false},
36   {kSpecializeTensor, false},
37 };
38 
39 static std::map<std::string, int> g_mapIntDefaultConfig = {
40   {kGuardRelaxCnt, 0},
41 };
42 
43 static GuardItemPtr GuardOnGDeduce(TracePtr var, PyObject *obj, const std::map<std::string, bool> &config);
44 static GuardItemPtr GuardOnScalar(TracePtr var, const std::map<std::string, bool> &config);
45 static GuardItemPtr GuardOnContainer(TracePtr var, const std::map<std::string, bool> &config);
46 static GuardItemPtr GuardOnLiteral(TracePtr var, const std::map<std::string, bool> &config);
47 static GuardItemPtr GuardOnTensor(TracePtr var, const std::map<std::string, bool> &config);
48 static GuardItemPtr GuardOnMutableOrConstObj(TracePtr var);
49 static GuardItemPtr GuardOnDynamicLenContainer(TracePtr var);
50 
CheckLiteral(PyObject * obj)51 static bool CheckLiteral(PyObject *obj) {
52   if (obj == nullptr) {
53     return false;
54   }
55 
56   ReprRecursionScope scope(obj);
57   if (scope.ReEnterOrError()) {
58     return scope.ReEnter();
59   }
60   if (CheckScalar(obj)) {
61     return true;
62   } else if (PyList_Check(obj)) {
63     for (Py_ssize_t i = 0; i < PyList_Size(obj); ++i) {
64       PyObject *item = PyList_GetItem(obj, i);
65       if (!CheckLiteral(item)) {
66         return false;
67       }
68     }
69     return true;
70   } else if (PyTuple_Check(obj)) {
71     for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(obj); ++i) {
72       PyObject *item = PyTuple_GET_ITEM(obj, i);
73       if (!CheckLiteral(item)) {
74         return false;
75       }
76     }
77     return true;
78   } else if (PySet_Check(obj) || PyFrozenSet_Check(obj)) {
79     Py_ssize_t pos = 0;
80     PyObject *item;
81     Py_hash_t hash;
82     while (_PySet_NextEntry(obj, &pos, &item, &hash)) {
83       if (!CheckLiteral(item)) {
84         return false;
85       }
86     }
87     return true;
88   } else if (PyDict_Check(obj)) {
89     Py_ssize_t pos = 0;
90     PyObject *key;
91     PyObject *val;
92     while (PyDict_Next(obj, &pos, &key, &val)) {
93       if (!CheckLiteral(key) || !CheckLiteral(val)) {
94         return false;
95       }
96     }
97     return true;
98   }
99   return false;
100 }
101 
CheckOwnerIsCell(TracePtr var)102 bool CheckOwnerIsCell(TracePtr var) {
103   if (py::isinstance<mindspore::Cell>(var->GetObject())) {
104     return true;
105   } else if (var->GetOrigin() != NULL) {
106     return CheckOwnerIsCell(var);
107   } else {
108     return false;
109   }
110 }
111 
112 class OptGuardPerfImpl : public OptGuardPerf {
113  public:
114   virtual void GetGuardPerfInfo(std::map<std::string, std::pair<size_t, size_t>> *guard_info,
115                                 std::map<std::string, std::pair<size_t, std::vector<size_t>>> *item_info,
116                                 std::map<std::string, std::pair<size_t, size_t>> *trace_info,
117                                 std::map<std::string, std::pair<size_t, size_t>> *guard_freq_info) const;
118   OptGuardPerfImpl() = default;
119   virtual ~OptGuardPerfImpl() = default;
120   virtual void LogGuardPerfStart(OptGuard *tag2, GuardItem *item);
121   virtual void LogGuardPerfEnd(GuardItem *item, bool res);
122   virtual void LogItemPerfStart(int total_stage);
123   virtual void LogItemPerfEnd(GuardItem *item, int stage);
124   virtual void LogTracePerfStart();
125   virtual void LogTracePerfEnd(Trace *trace, bool cache);
126 
127  protected:
128   OptGuard *cur_tag2_ = nullptr;
129   GuardItem *cur_guard_ = nullptr;
130   std::chrono::steady_clock::time_point guard_start_;
131   std::chrono::steady_clock::time_point trace_start_;
132   std::vector<std::chrono::steady_clock::time_point> item_stage_;
133   std::map<std::string, std::pair<size_t, size_t>> guard_info_;
134   std::map<std::string, std::pair<size_t, std::vector<size_t>>> item_info_;
135   std::map<std::string, std::pair<size_t, size_t>> trace_info_;
136   std::map<std::string, std::pair<size_t, size_t>> guard_freq_info_;
137 };
138 
139 static OptGuardPerfImpl g_guard_perf;
GetGuardPerf()140 OptGuardPerf *OptGuardPerf::GetGuardPerf() { return &g_guard_perf; }
141 
GetGuardPerfInfo(std::map<std::string,std::pair<size_t,size_t>> * guard_info,std::map<std::string,std::pair<size_t,std::vector<size_t>>> * item_info,std::map<std::string,std::pair<size_t,size_t>> * trace_info,std::map<std::string,std::pair<size_t,size_t>> * guard_freq_info) const142 void OptGuardPerfImpl::GetGuardPerfInfo(std::map<std::string, std::pair<size_t, size_t>> *guard_info,
143                                         std::map<std::string, std::pair<size_t, std::vector<size_t>>> *item_info,
144                                         std::map<std::string, std::pair<size_t, size_t>> *trace_info,
145                                         std::map<std::string, std::pair<size_t, size_t>> *guard_freq_info) const {
146   if (guard_info != nullptr) {
147     guard_info->clear();
148     guard_info->insert(guard_info_.begin(), guard_info_.end());
149   }
150   if (trace_info != nullptr) {
151     trace_info->clear();
152     trace_info->insert(trace_info_.begin(), trace_info_.end());
153   }
154   if (guard_freq_info != nullptr) {
155     guard_freq_info->clear();
156     guard_freq_info->insert(guard_freq_info_.begin(), guard_freq_info_.end());
157   }
158   if (item_info != nullptr) {
159     item_info->clear();
160     item_info->insert(item_info_.begin(), item_info_.end());
161   }
162 }
163 
LogGuardPerfStart(OptGuard * tag2,GuardItem * item)164 void OptGuardPerfImpl::LogGuardPerfStart(OptGuard *tag2, GuardItem *item) {
165   cur_guard_ = item;
166   cur_tag2_ = tag2;
167   guard_start_ = std::chrono::steady_clock::now();
168 }
169 
LogGuardPerfEnd(GuardItem * item,bool res)170 void OptGuardPerfImpl::LogGuardPerfEnd(GuardItem *item, bool res) {
171   auto duration =
172     std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - guard_start_);
173   size_t dur = (size_t)(duration.count());
174   size_t inc = 1;
175   auto info = item->ToString();
176   std::stringstream s;
177   s << reinterpret_cast<void *>(cur_tag2_) << "=>" << reinterpret_cast<void *>(cur_guard_) << "=>";
178   info = s.str() + info;
179   auto iter = guard_info_.find(info);
180   if (iter != guard_info_.end()) {
181     iter->second.first += inc;
182     iter->second.second += dur;
183   } else {
184     guard_info_[info] = std::make_pair(inc, dur);
185   }
186   iter = guard_freq_info_.find(info);
187   if (iter != guard_freq_info_.end()) {
188     if (res) {
189       iter->second.first += 1;
190     } else {
191       iter->second.second += 1;
192     }
193   } else {
194     if (res) {
195       guard_freq_info_[info] = std::make_pair(1, 0);
196     } else {
197       guard_freq_info_[info] = std::make_pair(0, 1);
198     }
199   }
200 }
201 
LogItemPerfStart(int total_stage)202 void OptGuardPerfImpl::LogItemPerfStart(int total_stage) {
203   item_stage_.clear();
204   item_stage_.resize(total_stage + 1);
205   item_stage_[0] = std::chrono::steady_clock::now();
206 }
207 
LogItemPerfEnd(GuardItem * item,int stage)208 void OptGuardPerfImpl::LogItemPerfEnd(GuardItem *item, int stage) {
209   size_t cur_stage = static_cast<size_t>(stage + 1);
210   if (item_stage_.size() > cur_stage) {
211     item_stage_[cur_stage] = std::chrono::steady_clock::now();
212   }
213   if (item_stage_.size() == (cur_stage + 1)) {
214     auto info = item->ToString();
215     std::stringstream s;
216     s << reinterpret_cast<void *>(cur_tag2_) << "=>" << reinterpret_cast<void *>(cur_guard_) << "=>";
217     info = s.str() + info;
218     std::vector<size_t> vecDur;
219     for (int idx = 0; idx <= stage; ++idx) {
220       auto duration = std::chrono::duration_cast<std::chrono::microseconds>(item_stage_[idx + 1] - item_stage_[idx]);
221       vecDur.push_back((size_t)(duration.count()));
222     }
223     auto iter = item_info_.find(info);
224     if (iter != item_info_.end()) {
225       iter->second.first += 1;
226       for (size_t i = 0; i < vecDur.size(); ++i) {
227         iter->second.second[i] += vecDur[i];
228       }
229     } else {
230       item_info_[info] = std::make_pair(1, vecDur);
231     }
232   }
233 }
234 
LogTracePerfStart()235 void OptGuardPerfImpl::LogTracePerfStart() { trace_start_ = std::chrono::steady_clock::now(); }
236 
LogTracePerfEnd(Trace * trace,bool cache)237 void OptGuardPerfImpl::LogTracePerfEnd(Trace *trace, bool cache) {
238   auto duration =
239     std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - trace_start_);
240   size_t dur = (size_t)(duration.count());
241   size_t inc = 1;
242   auto info = trace->ToString(true);
243   std::stringstream s;
244   s << reinterpret_cast<void *>(cur_guard_) << "=>";
245   if (cache) {
246     s << "cache:";
247   }
248   info = s.str() + info;
249   auto iter = trace_info_.find(info);
250   if (iter != trace_info_.end()) {
251     iter->second.first += inc;
252     iter->second.second += dur;
253   } else {
254     trace_info_[info] = std::make_pair(inc, dur);
255   }
256 }
257 
OptGuard()258 OptGuard::OptGuard() {
259   bool_config_ = g_mapBoolDefaultConfig;
260   int_config_ = g_mapIntDefaultConfig;
261 }
262 
UpdateGuardList(GuardItemPtr item)263 void OptGuard::UpdateGuardList(GuardItemPtr item) {
264   // reorder list to speed up check on next run
265   for (size_t i = 0; i < guardList_.size(); ++i) {
266     if (guardList_[i] == item) {
267       guardList_.erase(guardList_.begin() + i);
268       guardList_.insert(guardList_.begin(), item);
269     }
270   }
271 }
272 
Check(const PyFrameObject * frame,bool print,std::map<size_t,PyObject * > * cache,std::map<size_t,bool> * success,std::map<size_t,bool> * fail,bool perf)273 bool OptGuard::Check(const PyFrameObject *frame, bool print, std::map<size_t, PyObject *> *cache,
274                      std::map<size_t, bool> *success, std::map<size_t, bool> *fail, bool perf) {
275   // filter failure case
276   if (fail != nullptr) {
277     for (auto item : guardMap_) {
278       if (fail->find(item.first) != fail->end()) {
279         return false;
280       }
281     }
282   }
283   std::vector<GuardItemPtr> list;
284   list.reserve(guardList_.size());
285   // filter success case
286   if (success != nullptr) {
287     for (auto item : guardList_) {
288       if (success->find(item->Info().Id()) == success->end()) {
289         list.push_back(item);
290       }
291     }
292   } else {
293     list = guardList_;
294   }
295   list = OptStrategy::MakeGuardItemListStrategyByFrame(frame, list);
296   for (size_t i = 0; i < list.size(); ++i) {
297     GuardItemPtr item = list[i];
298     if (perf) {
299       g_guard_perf.LogGuardPerfStart(this, item.get());
300     }
301     bool result = item->Check(frame, cache, perf);
302     if (perf) {
303       g_guard_perf.LogGuardPerfEnd(item.get(), result);
304     }
305     if (!result) {
306       UpdateGuardList(item);
307       if (fail != nullptr) {
308         fail->operator[](item->Info().Id()) = false;
309       }
310       if (print) {
311         auto trace = item->GetTrace();
312         auto obj = GetObjectFromTrace(frame, trace);
313         GRAPH_JIT_LOG_F("Guard check fail: %s v.s. %s\n", item->ToString().c_str(),
314                         std::string(py::str(py::cast<py::object>(obj))).c_str());
315         Py_XDECREF(obj);
316       } else if (IS_OUTPUT_ON(mindspore::kDebug)) {
317         MS_LOG(DEBUG) << "Guard check fail:" << item->ToString();
318       }
319       return false;
320     } else if (success != nullptr) {
321       success->operator[](item->Info().Id()) = true;
322     }
323   }
324   return true;
325 }
326 
GuardOn(TracePtr var,GuardLevel tp,bool needSpecialize,int recurseDepth)327 bool OptGuard::GuardOn(TracePtr var, GuardLevel tp, bool needSpecialize, int recurseDepth) {
328   // Now we have TypeGuard IdGuard NameGuard AttrGuard EqGuard, let's add guard to guardlist based on type
329   PyObject *obj = var->GetObject();
330   if (int_config_.find(kGuardRelaxCnt) != int_config_.end()) {
331     var->SetRelaxCount(int_config_[kGuardRelaxCnt]);
332   }
333   GuardItemPtr item = nullptr;
334   if (obj != nullptr) {
335     py::object py_obj = py::reinterpret_borrow<py::object>(obj);
336     if (IsStubTensor(py_obj)) {
337       py_obj = python_adapter::CallPyObjMethod(py_obj, "stub_sync");
338       obj = py_obj.ptr();
339     }
340     if (tp == GuardLevel::GDeduce) {
341       item = GuardOnGDeduce(var, obj, bool_config_);
342     } else if (tp == GuardLevel::GId) {
343       item = GuardId(var);
344     } else if (tp == GuardLevel::GType) {
345       item = GuardType(var);
346     } else if (tp == GuardLevel::GAttr) {
347       item = GuardAttr(var);
348     } else if (tp == GuardLevel::GEqual) {
349       item = GuardEqual(var, needSpecialize, recurseDepth);
350     }
351   } else {
352     // Check obj == None
353     item = GuardEqual(var, 0);
354   }
355   if (item != nullptr) {
356     size_t szItem = item->Info().Id();
357     if (guardMap_.find(szItem) == guardMap_.end()) {
358       guardList_.push_back(item);
359       guardMap_[szItem] = item;
360     }
361     return true;
362   } else {
363     return false;
364   }
365 }
366 
Info()367 const InfoPack &OptGuard::Info() {
368   if (info_ == nullptr) {
369     InfoPack info;
370     info.Begin();
371     for (auto &item : guardList_) {
372       info << item->Info();
373     }
374     info.End();
375     info_ = std::make_shared<InfoPack>(info);
376     info_->Update();
377   }
378   return *info_;
379 }
380 
GuardOnGDeduce(TracePtr var,PyObject * obj,const std::map<std::string,bool> & config)381 static GuardItemPtr GuardOnGDeduce(TracePtr var, PyObject *obj, const std::map<std::string, bool> &config) {
382   GuardItemPtr item = nullptr;
383   if (CheckLiteral(obj)) {
384     item = GuardOnLiteral(var, config);
385   } else if (PyFrozenSet_Check(obj)) {
386     item = GuardId(var);
387   } else if (PyFunction_Check(obj) || PyMethod_Check(obj) || PyInstanceMethod_Check(obj)) {
388     item = GuardEqual(var, false, 0);
389   } else if (PyType_Check(obj)) {
390     item = GuardEqual(var, false, 0);
391   } else if (CheckContainer(obj)) {
392     // due to the failure of CheckLiteral, it need check size and element type
393     item = GuardOnContainer(var, config);
394   } else if (PySlice_Check(obj)) {
395     item = GuardType(var);
396   } else if (py::isinstance<py::array>(obj)) {
397     item = GuardId(var);
398   } else if (py::isinstance<mindspore::Type>(obj)) {
399     item = GuardEqual(var, true, INT_MAX);
400   } else if (IsTensorPyObject(obj)) {
401     item = GuardOnTensor(var, config);
402   } else if (py::isinstance<mindspore::PrimitivePyAdapter>(obj)) {
403     if (CheckOwnerIsCell(var)) {
404       item = GuardEqual(var, true, INT_MAX);
405     } else {
406       item = GuardRepr(var);
407     }
408   } else if (py::isinstance<mindspore::Cell>(obj)) {
409     item = GuardRepr(var);
410   } else if (py::isinstance<mindspore::ParamInfo>(obj)) {
411     item = GuardEqual(var, true, INT_MAX);
412   } else {
413     item = GuardType(var);
414   }
415   return item;
416 }
417 
GuardOnScalar(TracePtr var,const std::map<std::string,bool> & config)418 static GuardItemPtr GuardOnScalar(TracePtr var, const std::map<std::string, bool> &config) {
419   GuardItemPtr item = GuardOnMutableOrConstObj(var);
420   if (item != nullptr) {
421     return item;
422   }
423   bool need_specialize = false;
424   auto cfg = config.find(kSpecializeScalar);
425   if (cfg != config.end()) {
426     need_specialize = cfg->second;
427   }
428   // need take dynamic symbolic into account
429   if (need_specialize) {
430     if ((var->GetOriginType() == TraceType::Global || var->GetOriginType() == TraceType::BuiltIn) ||
431         var->GetOriginType() == TraceType::Param || var->GetTraceType() == TraceType::Item ||
432         var->GetTraceType() == TraceType::Attr) {
433       item = GuardEqual(var, true, INT_MAX);
434     } else {
435       item = GuardType(var);
436     }
437   } else {
438     item = GuardEqual(var, false, 0);
439   }
440   return item;
441 }
442 
GuardOnContainer(TracePtr var,const std::map<std::string,bool> & config)443 static GuardItemPtr GuardOnContainer(TracePtr var, const std::map<std::string, bool> &config) {
444   GuardItemPtr item = GuardOnDynamicLenContainer(var);
445   if (item != nullptr) {
446     return item;
447   } else {
448     item = GuardOnMutableOrConstObj(var);
449   }
450   if (item != nullptr) {
451     return item;
452   }
453   bool need_specialize = false;
454   auto cfg = config.find(kSpecializeContainer);
455   if (cfg != config.end()) {
456     need_specialize = cfg->second;
457   }
458   if (need_specialize) {
459     item = GuardEqual(var, true, INT_MAX);
460   } else {
461     item = GuardEqual(var, false, 0);
462   }
463   return item;
464 }
465 
GuardOnLiteral(TracePtr var,const std::map<std::string,bool> & config)466 static GuardItemPtr GuardOnLiteral(TracePtr var, const std::map<std::string, bool> &config) {
467   GuardItemPtr item = nullptr;
468   PyObject *obj = var->GetObject();
469   if (CheckScalar(obj)) {
470     return GuardOnScalar(var, config);
471   } else if (CheckContainer(obj)) {
472     return GuardOnContainer(var, config);
473   } else {
474     item = GuardOnMutableOrConstObj(var);
475     if (item == nullptr) {
476       item = GuardEqual(var, false, 0);
477     }
478   }
479   return item;
480 }
481 
GuardOnTensor(TracePtr var,const std::map<std::string,bool> & config)482 static GuardItemPtr GuardOnTensor(TracePtr var, const std::map<std::string, bool> &config) {
483   GuardItemPtr item = nullptr;
484   bool need_specialize = false;
485   auto cfg = config.find(kSpecializeTensor);
486   if (cfg != config.end()) {
487     need_specialize = cfg->second;
488   }
489   item = GuardOnMutableOrConstObj(var);
490   if (item != nullptr) {
491     return item;
492   }
493   if (CheckOwnerIsCell(var)) {
494     if (var->GetOriginType() == TraceType::Const) {
495       item = GuardId(var);
496     } else {
497       item = GuardEqual(var, false, INT_MAX);
498     }
499   } else if (var->GetOriginType() == TraceType::Const) {
500     item = GuardId(var);
501   } else if (need_specialize) {
502     item = GuardEqual(var, true, INT_MAX);
503   } else {
504     item = GuardEqual(var, false, INT_MAX);
505   }
506   return item;
507 }
508 
GuardOnMutableOrConstObj(TracePtr var)509 static GuardItemPtr GuardOnMutableOrConstObj(TracePtr var) {
510   PyObject *obj = var->GetObject();
511   GuardItemPtr item = nullptr;
512   if (HasMutableOrConstAttr(obj)) {
513     if (CheckMutableOrNonConstAttr(obj)) {
514       item = GuardEqual(var, false, INT_MAX);
515     } else {
516       item = GuardEqual(var, true, INT_MAX);
517     }
518   }
519   return item;
520 }
521 
GuardOnDynamicLenContainer(TracePtr var)522 static GuardItemPtr GuardOnDynamicLenContainer(TracePtr var) {
523   PyObject *obj = var->GetObject();
524   GuardItemPtr item = nullptr;
525   if (HasDynamicLength(obj)) {
526     if (CheckDynamicLength(obj)) {
527       item = GuardType(var);
528     } else {
529       item = GuardEqual(var, false, 0);
530     }
531   }
532   return item;
533 }
534 
AddTraceFromGuard(const std::vector<TracePtr> & traces,OptGuardPtr other)535 void OptGuard::AddTraceFromGuard(const std::vector<TracePtr> &traces, OptGuardPtr other) {
536   for (size_t i = 0; i < traces.size(); ++i) {
537     auto dst = traces[i];
538     auto src = std::make_shared<RootTrace>(dst->GetObject(), TraceType::Param, i);
539     for (auto item : other->guardList_) {
540       item->Replace(dst, src);
541     }
542   }
543   for (auto item : other->guardList_) {
544     guardList_.push_back(item);
545   }
546 }
547 
GetDescript()548 std::string OptGuard::GetDescript() {
549   std::string ret;
550   for (auto item : guardList_) {
551     ret += ";" + item->ToString();
552   }
553   if (ret.size() > 0) {
554     ret = ret.substr(1);
555   }
556   return ret;
557 }
558 
UpdateConfig(const std::map<std::string,bool> & bool_config,const std::map<std::string,int> & int_config)559 void OptGuard::UpdateConfig(const std::map<std::string, bool> &bool_config,
560                             const std::map<std::string, int> &int_config) {
561   for (auto item : bool_config) {
562     if (g_mapBoolDefaultConfig.find(item.first) != g_mapBoolDefaultConfig.end()) {
563       bool_config_[item.first] = item.second;
564     }
565   }
566   for (auto item : int_config) {
567     if (g_mapIntDefaultConfig.find(item.first) != g_mapIntDefaultConfig.end()) {
568       int_config_[item.first] = item.second;
569     }
570   }
571 }
572 
Backup()573 void OptGuard::Backup() { guardStack_.push(std::make_pair(guardList_, guardMap_)); }
574 
Rollback()575 void OptGuard::Rollback() {
576   GuardCheckPoint point = guardStack_.top();
577   guardList_.swap(point.first);
578   guardMap_.swap(point.second);
579   guardStack_.pop();
580 }
581 
Pop()582 void OptGuard::Pop() { guardStack_.pop(); }
583 
MatchDynamicShape(GuardItemPtr item,const std::vector<GuardItemPtr> & list)584 static bool MatchDynamicShape(GuardItemPtr item, const std::vector<GuardItemPtr> &list) {
585   auto trace_type = item->GetTrace()->GetTraceType();
586   auto guard_type = item->GetType();
587   if ((trace_type != TraceType::Deref && trace_type != TraceType::Param) || guard_type != GIType::GTEqual) {
588     return false;
589   }
590   for (auto other : list) {
591     if (item->MatchDynamicShape(other)) {
592       return true;
593     }
594   }
595   return false;
596 }
597 
MatchShape(OptGuardPtr other)598 bool OptGuard::MatchShape(OptGuardPtr other) {
599   if (std::any_of(guardList_.begin(), guardList_.end(), [other](auto &item) {
600         return (!std::any_of(other->guardList_.begin(), other->guardList_.end(), [item](GuardItemPtr oi) {
601           return *item == *oi;
602         }) && !MatchDynamicShape(item, other->guardList_));
603       })) {
604     return false;
605   }
606   if (std::any_of(other->guardList_.begin(), other->guardList_.end(), [this](auto &item) {
607         return (!std::any_of(guardList_.begin(), guardList_.end(), [item](GuardItemPtr oi) { return *item == *oi; }));
608       })) {
609     return false;
610   }
611   return true;
612 }
613 
FindItem(const std::vector<GuardItemPtr> & guardList,int idx,TraceType type,PyObject * obj)614 static PyObject *FindItem(const std::vector<GuardItemPtr> &guardList, int idx, TraceType type, PyObject *obj) {
615   auto iter = std::find_if(guardList.begin(), guardList.end(), [idx, type](GuardItemPtr item) {
616     if (item->GetTrace()->GetTraceType() == type) {
617       int index;
618       std::string name, module_name;
619       (reinterpret_cast<RootTrace *>(item->GetTrace().get()))->GetParam(&index, &name, &module_name);
620       return (idx == index);
621     } else {
622       return false;
623     }
624   });
625   if (iter != guardList.end()) {
626     GuardItemPtr item = *iter;
627     return item->ApplyDynamicShape(obj);
628   } else {
629     return nullptr;
630   }
631 }
632 
ApplyDynamicShape(PyFrameObject * f)633 std::vector<PyObject *> OptGuard::ApplyDynamicShape(PyFrameObject *f) {
634   std::vector<PyObject *> ret;
635   int argc = f->f_code->co_argcount + f->f_code->co_kwonlyargcount;
636   PyObject *vargs = NULL;
637   PyObject *kwargs = NULL;
638   if (f->f_code->co_flags & CO_VARARGS) {
639     vargs = f->f_localsplus[argc];
640   }
641   if (f->f_code->co_flags & CO_VARKEYWORDS) {
642     kwargs = f->f_localsplus[argc + (vargs ? 1 : 0)];
643   }
644   for (int i = 0; i < argc; ++i) {
645     auto new_obj = FindItem(guardList_, i, TraceType::Param, f->f_localsplus[i]);
646     if (new_obj == nullptr) {
647       ret.push_back(nullptr);
648     } else {
649       ret.push_back(f->f_localsplus[i]);
650       f->f_localsplus[i] = new_obj;
651     }
652   }
653   if (vargs != NULL) {
654     ret.push_back(nullptr);
655   }
656   if (kwargs != NULL) {
657     ret.push_back(nullptr);
658   }
659   ret.resize(f->f_code->co_nlocals, nullptr);
660   for (int i = 0; f->f_code->co_cell2arg && i < PyTuple_GET_SIZE(f->f_code->co_cellvars); ++i) {
661     Py_ssize_t arg = f->f_code->co_cell2arg[i];
662     if (arg != CO_CELL_NOT_AN_ARG) {
663       auto cell = f->f_localsplus[f->f_code->co_nlocals + i];
664       auto new_obj = FindItem(guardList_, i, TraceType::Deref, PyCell_GET(cell));
665       if (new_obj == nullptr) {
666         ret.push_back(nullptr);
667       } else {
668         ret.push_back(PyCell_GET(cell));
669         PyCell_SET(cell, new_obj);
670       }
671     }
672   }
673   ret.resize(f->f_code->co_nlocals + PyTuple_GET_SIZE(f->f_code->co_cellvars), nullptr);
674   for (int i = 0; i < PyTuple_GET_SIZE(f->f_code->co_freevars); ++i) {
675     Py_ssize_t arg = PyTuple_GET_SIZE(f->f_code->co_cellvars) + i;
676     auto cell = f->f_localsplus[f->f_code->co_nlocals + arg];
677     auto new_obj = FindItem(guardList_, arg, TraceType::Deref, PyCell_GET(cell));
678     if (new_obj == nullptr) {
679       ret.push_back(nullptr);
680     } else {
681       ret.push_back(PyCell_GET(cell));
682       PyCell_SET(cell, new_obj);
683     }
684   }
685   return ret;
686 }
687 
RevertDynamicShape(PyFrameObject * f,const std::vector<PyObject * > & backup)688 void OptGuard::RevertDynamicShape(PyFrameObject *f, const std::vector<PyObject *> &backup) {
689   int argc = f->f_code->co_argcount + f->f_code->co_kwonlyargcount;
690   for (int i = 0; i < argc; ++i) {
691     if (backup[i] != nullptr) {
692       Py_XDECREF(f->f_localsplus[i]);
693       f->f_localsplus[i] = backup[i];
694     }
695   }
696   for (int i = 0; f->f_code->co_cell2arg && i < PyTuple_GET_SIZE(f->f_code->co_cellvars); ++i) {
697     Py_ssize_t arg = f->f_code->co_cell2arg[i];
698     if (arg != CO_CELL_NOT_AN_ARG) {
699       auto cell = f->f_localsplus[f->f_code->co_nlocals + i];
700       if (backup[f->f_code->co_nlocals + i] != nullptr) {
701         Py_XDECREF(PyCell_GET(cell));
702         PyCell_SET(cell, backup[f->f_code->co_nlocals + i]);
703       }
704     }
705   }
706   for (int i = 0; i < PyTuple_GET_SIZE(f->f_code->co_freevars); ++i) {
707     Py_ssize_t arg = PyTuple_GET_SIZE(f->f_code->co_cellvars) + i;
708     auto cell = f->f_localsplus[f->f_code->co_nlocals + arg];
709     if (backup[f->f_code->co_nlocals + arg] != nullptr) {
710       Py_XDECREF(PyCell_GET(cell));
711       PyCell_SET(cell, backup[f->f_code->co_nlocals + arg]);
712     }
713   }
714 }
715 
ToString() const716 std::string OptGuard::ToString() const {
717   std::stringstream s;
718   for (const auto &i : guardMap_) {
719     s << "  guard [" << i.first << "] [" << i.second->ToString() << " ] at [" << i.second.get() << "]\n";
720   }
721   return s.str();
722 }
723 
Optimize()724 OptGuardPtr OptGuard::Optimize() {
725   bool need_update = false;
726   for (size_t i = 0; i < guardList_.size(); ++i) {
727     auto old_item = guardList_[i];
728     auto new_item = old_item->Optimize();
729     if (new_item != nullptr) {
730       guardList_[i] = new_item;
731       guardMap_.erase(old_item->Info().Id());
732       guardMap_[new_item->Info().Id()] = new_item;
733       need_update = true;
734     }
735   }
736   if (need_update) {
737     info_ = nullptr;
738     Info();
739     return shared_from_this();
740   } else {
741     return nullptr;
742   }
743 }
744 
745 }  // namespace pijit
746 }  // namespace mindspore
747