• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_guard/trace.h"
17 #include <map>
18 #include <vector>
19 #include <set>
20 #include <unordered_map>
21 #include <functional>
22 #include <utility>
23 #include <regex>
24 #include "pipeline/jit/pi/graph_guard/guard.h"
25 #include "pipeline/jit/pi/graph_guard/guard_utils.h"
26 #include "pybind11/pybind11.h"
27 #include "pybind_api/ir/primitive_py.h"
28 #include "include/common/utils/convert_utils_py.h"
29 #include "pipeline/jit/pi/graph_guard/infer.h"
30 #include "pipeline/jit/pi/graph_guard/strategy.h"
31 #include "pipeline/jit/pi/utils/utils.h"
32 #include "include/common/utils/python_adapter.h"
33 #include "pipeline/jit/pi/graph_capture/abstract_object.h"
34 #include "pipeline/jit/pi/pi_jit_config.h"
35 #include "pipeline/jit/pi/external.h"
36 
37 namespace mindspore {
38 namespace pijit {
39 
40 static constexpr size_t kParamCountOne = 1;
41 static constexpr size_t kParamCountTwo = 2;
42 static constexpr size_t kParamCountThree = 3;
43 static constexpr size_t kParamIndexOne = 0;
44 static constexpr size_t kParamIndexTwo = 1;
45 static constexpr size_t kParamIndexThree = 2;
46 static const char kCastPrimName[] = "Cast";
47 static const char kLayerNormPrimName[] = "LayerNorm";
48 static const char kReshapePrimName[] = "Reshape";
49 static const char kShapePrimName[] = "Shape";
50 static const char kShape_Name[] = "shape_";
51 static const char kShapeName[] = "shape";
52 static const char kRankPrimName[] = "Rank";
53 static const char kCastToMSTensor[] = "cast_to_ms_tensor";
54 static const char kCastToAdapterTensor[] = "cast_to_adapter_tensor";
55 static const std::vector<std::string> kCastFunc = {kCastToMSTensor, kCastToAdapterTensor};
56 static const char kIsInstance[] = "isinstance";
57 static const char kTensorName[] = "Tensor";
58 static const char kDTypeAttrName[] = "dtype";
59 static const char kDType_AttrName[] = "dtype_";
60 static const char kDTypePrimName[] = "DType";
61 static const char kCodeName[] = "__code__";
62 static const char kFuncName[] = "__func__";
63 static const char kIsSeqValUnknown[] = "is_sequence_value_unknown";
64 static const char kIsSeqShapeUnknown[] = "is_sequence_shape_unknown";
65 static const char kMindTorchFlag[] = "mindtorch";
66 static const char kTrainingFlag[] = "training";
67 static const char kMindSporePackPrefix[] = "mindspore.";
68 static const char kMindtorchPackPrefix[] = "mindtorch.";
69 
70 constexpr const char *kFuncWhiteListModuleName = "mindspore._extends.pijit.pijit_func_white_list";
71 constexpr const char *kGuardFuncMapName = "_guard_func_map";
72 
73 static PyObject *RichCompare(PyObject *left, PyObject *right, int oparg);
74 
IsCastFunc(std::string name)75 static bool IsCastFunc(std::string name) {
76   return std::find(kCastFunc.begin(), kCastFunc.end(), name) != kCastFunc.end();
77 }
78 
OptimizeTrace(TracePtr trace,bool * update)79 static TracePtr OptimizeTrace(TracePtr trace, bool *update) {
80   if (trace != nullptr) {
81     auto new_trace = trace->Optimize();
82     if (new_trace != nullptr) {
83       if (update != nullptr) {
84         *update = true;
85       }
86       return new_trace;
87     }
88   }
89   return trace;
90 }
91 
92 template <typename T>
CastTrace(TracePtr trace)93 std::shared_ptr<T> CastTrace(TracePtr trace) {
94   if (trace != nullptr && T::Support(trace->GetTraceType())) {
95     return std::static_pointer_cast<T>(trace);
96   }
97   return nullptr;
98 }
99 
CastConstTrace(TracePtr trace)100 static ConstTracePtr CastConstTrace(TracePtr trace) {
101   ConstTracePtr ret = CastTrace<ConstTrace>(trace);
102   if (ret != nullptr && ret->GetIndex() == -1) {
103     return ret;
104   }
105   return nullptr;
106 }
107 
CastOpTrace(TracePtr trace,int opcode)108 static OpTracePtr CastOpTrace(TracePtr trace, int opcode) {
109   OpTracePtr ret = CastTrace<OpTrace>(trace);
110   if (ret != nullptr && ret->GetOpCode() == opcode) {
111     return ret;
112   }
113   return nullptr;
114 }
115 
CastOpTrace(TracePtr trace,const std::string & name)116 static OpTracePtr CastOpTrace(TracePtr trace, const std::string &name) {
117   OpTracePtr ret = CastTrace<OpTrace>(trace);
118   if (ret != nullptr && ret->GetName() == name) {
119     return ret;
120   }
121   return nullptr;
122 }
123 
124 class TracePerf {
125  public:
TracePerf(Trace * trace,bool enable,bool cache)126   TracePerf(Trace *trace, bool enable, bool cache)
127       : trace_(trace), enable_(enable), cache_(cache), perf_(OptGuardPerf::GetGuardPerf()) {
128     if (enable_) {
129       perf_->LogTracePerfStart();
130     }
131   }
~TracePerf()132   ~TracePerf() {
133     if (enable_) {
134       perf_->LogTracePerfEnd(trace_, cache_);
135     }
136   }
137 
138  protected:
139   Trace *trace_;
140   bool enable_;
141   bool cache_;
142   OptGuardPerf *perf_;
143 };
144 
Trace(PyObject * pObj,std::shared_ptr<Trace> pOrigin)145 Trace::Trace(PyObject *pObj, std::shared_ptr<Trace> pOrigin)
146     : obj_(pObj),
147       origin_(pOrigin),
148       info_(nullptr),
149       is_const_(false),
150       relax_count_(-1),
151       relax_limit_(0),
152       is_specialized_(false),
153       depth_(0) {
154   if (pOrigin != nullptr) {
155     originType_ = pOrigin->GetOriginType();
156     curType_ = pOrigin->GetTraceType();
157   } else {
158     originType_ = Unknown;
159     curType_ = Unknown;
160   }
161   if (obj_ != Py_None && obj_ != NULL) {
162     Py_INCREF(obj_);
163   }
164 }
165 
~Trace()166 Trace::~Trace() {
167   if (obj_ != Py_None && obj_ != NULL) {
168     Py_DECREF(obj_);
169   }
170 }
171 
GetOrigin()172 TracePtr Trace::GetOrigin() {
173   if (origin_ != nullptr) {
174     return origin_;
175   } else {
176     return nullptr;
177   }
178 }
179 
GetObject()180 PyObject *Trace::GetObject() { return obj_; }
181 
GetTraceType()182 TraceType Trace::GetTraceType() { return curType_; }
183 
GetOriginType()184 TraceType Trace::GetOriginType() { return originType_; }
185 
Replace(std::shared_ptr<Trace> dst,std::shared_ptr<Trace> src)186 void Trace::Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src) {
187   if (origin_ != nullptr) {
188     if (*origin_ == *src) {
189       origin_ = dst;
190     } else {
191       origin_->Replace(dst, src);
192     }
193   }
194 }
195 
operator ==(const Trace & trace)196 bool Trace::operator==(const Trace &trace) {
197   if (curType_ == trace.curType_ && obj_ == trace.obj_) {
198     return true;
199   } else {
200     return false;
201   }
202 }
203 
Detach()204 void Trace::Detach() {
205   if (obj_ != Py_None && obj_ != nullptr && !is_const_ && !PyLong_Check(obj_) && !PyType_Check(obj_)) {
206     Py_DECREF(obj_);
207     obj_ = nullptr;
208   }
209   if (origin_ != nullptr) {
210     origin_->Detach();
211   }
212 }
213 
Retrieve(PTraceContext context,bool perf)214 PyObject *Trace::Retrieve(PTraceContext context, bool perf) {
215   TracePerf tp(this, perf, true);
216   if (is_const_) {
217     Py_XINCREF(obj_);
218     return obj_;
219   }
220   if (context->cache != nullptr) {
221     size_t szTrace = this->Info().Id();
222     auto cache = context->cache;
223     auto iter = cache->find(szTrace);
224     if (iter != cache->end()) {
225       auto item = iter->second;
226       Py_XINCREF(item);
227       return item;
228     }
229   }
230   return nullptr;
231 }
232 
Cache(PTraceContext context,PyObject * obj)233 void Trace::Cache(PTraceContext context, PyObject *obj) {
234   if (context->cache != nullptr && obj != nullptr) {
235     size_t szTrace = this->Info().Id();
236     Py_XINCREF(obj);
237     auto iter = context->cache->find(szTrace);
238     if (iter != context->cache->end()) {
239       Py_XDECREF(iter->second);
240       iter->second = obj;
241     } else {
242       (*(context->cache))[szTrace] = obj;
243     }
244   }
245   if (RelaxEnabled() && obj != nullptr) {
246     if (obj_ != nullptr) {
247       auto cmp = RichCompare(obj_, obj, Py_EQ);
248       if (cmp != Py_True) {
249         relax_count_ = -1;
250       } else if (relax_count_ < relax_limit_) {
251         relax_count_++;
252       } else {
253         is_const_ = true;
254       }
255       Py_XDECREF(cmp);
256     }
257     if (obj_ == nullptr) {
258       Py_XINCREF(obj);
259       obj_ = obj;
260     }
261   }
262 }
263 
IsConst() const264 bool Trace::IsConst() const { return is_const_; }
265 
This()266 TracePtr Trace::This() { return shared_from_this(); }
267 
SetRelaxCount(int cnt)268 void Trace::SetRelaxCount(int cnt) {
269   relax_count_ = -1;
270   relax_limit_ = cnt;
271 }
272 
GetRelaxCount() const273 int Trace::GetRelaxCount() const { return relax_limit_; }
274 
EnableRelax()275 void Trace::EnableRelax() { relax_count_ = 0; }
276 
RelaxEnabled() const277 bool Trace::RelaxEnabled() const { return relax_count_ >= 0; }
278 
IsSpecialized() const279 bool Trace::IsSpecialized() const { return is_specialized_; }
280 
GetDepth() const281 int Trace::GetDepth() const { return depth_; }
282 
Optimize()283 TracePtr Trace::Optimize() { return nullptr; }
284 
FormatString(std::map<Trace *,size_t> * cache)285 std::string Trace::FormatString(std::map<Trace *, size_t> *cache) {
286   cache->insert(std::make_pair(this, cache->size()));
287   return "%" + std::to_string(cache->find(this)->second) + " = " + this->ToString();
288 }
289 
RootTrace(PyObject * pObj,TraceType tt,int index,std::string name,std::string module_name)290 RootTrace::RootTrace(PyObject *pObj, TraceType tt, int index, std::string name, std::string module_name)
291     : Trace(pObj, nullptr), idx_(index), name_(name), module_name_(module_name) {
292   depth_ = 1;
293   originType_ = tt;
294   curType_ = tt;
295   for (auto n : kPIJitConfigDefault.allowed_inline_modules()) {
296     if (module_name.find(n) == 0) {
297       is_const_ = true;
298       break;
299     }
300   }
301   if (!is_const_ && (module_name.find(kMindSporePackPrefix) == 0 || module_name.find(kMindtorchPackPrefix) == 0 ||
302                      name.find(kCastToAdapterTensor) == 0 || name.find(kCastToMSTensor) == 0)) {
303     is_const_ = true;
304   }
305   if (curType_ == TraceType::Deref) {
306     is_const_ = false;
307   }
308   if (pObj == nullptr) {
309     return;
310   }
311   if (py::isinstance<mindspore::tensor::MetaTensor>(pObj) || py::isinstance<mindspore::tensor::Tensor>(pObj) ||
312       IsStubTensor(py::cast<py::object>(pObj))) {
313     is_specialized_ = false;
314   }
315 }
316 
GetParam(int * index,std::string * name,std::string * module_name)317 void RootTrace::GetParam(int *index, std::string *name, std::string *module_name) {
318   *index = idx_;
319   *name = name_;
320   *module_name = module_name_;
321 }
322 
Retrieve(PTraceContext context,bool perf)323 PyObject *RootTrace::Retrieve(PTraceContext context, bool perf) {
324   PyObject *ret = Trace::Retrieve(context, perf);
325   if (ret != nullptr) {
326     return ret;
327   }
328   TracePerf tp(this, perf, false);
329   switch (curType_) {
330     case TraceType::Global: {
331       ret = RetrieveGlobal(context);
332       break;
333     }
334     case TraceType::Deref: {
335       ret = RetrieveDeref(context);
336       break;
337     }
338     case TraceType::Closure: {
339       ret = RetrieveClosure(context);
340       break;
341     }
342     case TraceType::BuiltIn: {
343       ret = RetrieveBuiltin(context);
344       break;
345     }
346     case TraceType::Local:
347       ret = RetrieveLocal(context);
348       Py_XINCREF(ret);
349       break;
350     case TraceType::Param:
351       ret = RetrieveParam(context);
352       Py_XINCREF(ret);
353       break;
354     case TraceType::Name: {
355       ret = RetrieveName(context);
356       break;
357     }
358     case TraceType::ClassDeref: {
359       ret = RetrieveClassDeref(context);
360       break;
361     }
362     default:
363       break;
364   }
365   if (ret != Py_None && ret != NULL) {
366     Cache(context, ret);
367   }
368   return ret;
369 }
370 
RetrieveGlobal(PTraceContext context)371 PyObject *RootTrace::RetrieveGlobal(PTraceContext context) {
372   MS_EXCEPTION_IF_CHECK_FAIL(name_.size() > 0, "check trace");
373   PyObject *globals = context->f_globals;
374   if (!module_name_.empty()) {
375     PyObject *mn = PyUnicode_FromString(module_name_.c_str());
376     PyObject *mm = PyImport_GetModule(mn);  // ensure module is initialized
377     if (mn != nullptr && mm != nullptr) {
378       globals = PyModule_GetDict(mm);
379     }
380     PyErr_Clear();
381     Py_XDECREF(mn);
382     Py_XDECREF(mm);
383   }
384   PyObject *key = PyUnicode_FromString(name_.c_str());
385   PyObject *ret = PyObject_GetItem(globals, key);
386   if (ret == nullptr) {
387     PyErr_Clear();
388     ret = PyObject_GetItem(context->f_builtins, key);
389     if (ret == nullptr) {
390       PyErr_Clear();
391     }
392   }
393   Py_DECREF(key);
394   return ret;
395 }
396 
RetrieveDeref(PTraceContext context)397 PyObject *RootTrace::RetrieveDeref(PTraceContext context) {
398   PyObject *ret = nullptr;
399   PyObject *cell = context->f_localsplus[context->f_code->co_nlocals + idx_];
400   if (cell != nullptr && cell != Py_None) {
401     ret = reinterpret_cast<PyObject *>(PyCell_GET(cell));
402     Py_XINCREF(ret);
403   }
404   return ret;
405 }
406 
RetrieveClosure(PTraceContext context)407 PyObject *RootTrace::RetrieveClosure(PTraceContext context) {
408   PyObject *ret = context->f_localsplus[context->f_code->co_nlocals + idx_];
409   Py_XINCREF(ret);
410   return ret;
411 }
412 
RetrieveBuiltin(PTraceContext context)413 PyObject *RootTrace::RetrieveBuiltin(PTraceContext context) {
414   MS_EXCEPTION_IF_CHECK_FAIL(name_.size() > 0, "check trace");
415   PyObject *key = PyUnicode_FromString(name_.c_str());
416   PyObject *ret = PyObject_GetItem(context->f_builtins, key);
417   if (ret == nullptr) {
418     PyErr_Clear();
419     ret = PyObject_GetItem(context->f_globals, key);
420     if (ret == nullptr) {
421       PyErr_Clear();
422     }
423   }
424   Py_DECREF(key);
425   return ret;
426 }
427 
RetrieveLocal(PTraceContext context)428 PyObject *RootTrace::RetrieveLocal(PTraceContext context) { return context->f_locals; }
429 
RetrieveParam(PTraceContext context)430 PyObject *RootTrace::RetrieveParam(PTraceContext context) { return context->f_localsplus[idx_]; }
431 
RetrieveName(PTraceContext context)432 PyObject *RootTrace::RetrieveName(PTraceContext context) {
433   PyObject *ret = nullptr;
434   PyObject *name = PyTuple_GetItem(context->f_code->co_names, idx_);
435   PyObject *locals = context->f_locals;
436   if (PyDict_CheckExact(locals)) {
437     ret = PyDict_GetItem(locals, name);
438     Py_XINCREF(ret);
439   } else {
440     ret = PyObject_GetItem(locals, name);
441   }
442   if (ret == nullptr) {
443     ret = PyDict_GetItem(context->f_globals, name);
444     Py_XINCREF(ret);
445   }
446   if (ret == nullptr) {
447     if (PyDict_CheckExact(context->f_builtins)) {
448       ret = PyDict_GetItem(context->f_builtins, name);
449       Py_XINCREF(ret);
450     } else {
451       ret = PyObject_GetItem(context->f_builtins, name);
452     }
453   }
454   return ret;
455 }
456 
RetrieveClassDeref(PTraceContext context)457 PyObject *RootTrace::RetrieveClassDeref(PTraceContext context) {
458   PyObject *ret = nullptr;
459   Py_ssize_t idx = idx_ - PyTuple_GET_SIZE(context->f_code->co_cellvars);
460   if (idx >= 0 && idx < PyTuple_GET_SIZE(context->f_code->co_freevars)) {
461     PyObject *name = PyTuple_GET_ITEM(context->f_code->co_freevars, idx);
462     if (PyDict_CheckExact(context->f_locals)) {
463       ret = PyDict_GetItem(context->f_locals, name);
464       Py_XINCREF(ret);
465     } else {
466       ret = PyObject_GetItem(context->f_locals, name);
467     }
468     if (!ret) {
469       PyObject *cell = context->f_localsplus[context->f_code->co_nlocals + idx_];
470       ret = PyCell_GET(cell);
471       Py_XINCREF(ret);
472     }
473   }
474   return ret;
475 }
476 
ToString(bool include_param)477 std::string RootTrace::ToString(bool include_param) {
478   if (strTrace_.size() > 0) {
479     return strTrace_;
480   }
481   std::string ret;
482   switch (curType_) {
483     case TraceType::Global:
484       if (!module_name_.empty()) {
485         ret = "(global " + module_name_ + ".__dict__[" + name_ + "])";
486       } else {
487         ret = "f_globals[" + name_ + "]";
488       }
489       break;
490     case TraceType::Deref:
491       ret = "f_freevars[" + std::to_string(idx_) + "]";
492       break;
493     case TraceType::Closure:
494       ret = "f_closure[" + std::to_string(idx_) + "]";
495       break;
496     case TraceType::BuiltIn:
497       ret = "f_builtins[" + name_ + "]";
498       break;
499     case TraceType::Local:
500       ret = "f_locals";
501       break;
502     case TraceType::Param:
503       ret = "f_localsplus[";
504       ret += std::to_string(idx_);
505       ret += "]";
506       break;
507     case TraceType::Name:
508       ret = "f->f_code->co_names[";
509       ret += std::to_string(idx_);
510       ret += "]";
511       break;
512     case TraceType::ClassDeref:
513       ret = "f->f_classdef[";
514       ret += std::to_string(idx_);
515       ret += "]";
516       break;
517     default:
518       ret = "unknown_root";
519       break;
520   }
521   ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
522   ret = std::regex_replace(ret, std::regex("(\n)"), "");
523   strTrace_ = ret;
524   return ret;
525 }
526 
Info()527 const InfoPack &RootTrace::Info() {
528   if (info_ == nullptr) {
529     InfoPack info;
530     info << uint8_t(curType_);
531     info.Begin();
532     switch (curType_) {
533       case TraceType::Global:
534         info << (!module_name_.empty());
535         if (!module_name_.empty()) {
536           info << module_name_ << name_;
537         } else {
538           info << name_;
539         }
540         break;
541       case TraceType::Deref:
542       case TraceType::Closure:
543       case TraceType::Param:
544       case TraceType::Name:
545       case TraceType::ClassDeref:
546         info << idx_;
547         break;
548       case TraceType::BuiltIn:
549         info << name_;
550         break;
551       case TraceType::Local:
552       default:
553         break;
554     }
555     info.End();
556     info_ = std::make_shared<InfoPack>(info);
557     info_->Update();
558   }
559   return *info_;
560 }
561 
operator ==(const Trace & trace)562 bool RootTrace::operator==(const Trace &trace) {
563   bool ret = false;
564   if (Trace::operator==(trace)) {
565     const RootTrace &t = (const RootTrace &)trace;
566     ret = idx_ == t.idx_;
567     if (ret && idx_ == -1) {
568       ret = name_ == t.name_ && module_name_ == t.module_name_;
569     }
570   }
571   return ret;
572 }
573 
Support(TraceType tt)574 bool RootTrace::Support(TraceType tt) {
575   switch (tt) {
576     case TraceType::Global:
577     case TraceType::Deref:
578     case TraceType::Closure:
579     case TraceType::BuiltIn:
580     case TraceType::Local:
581     case TraceType::Param:
582     case TraceType::Name:
583     case TraceType::ClassDeref:
584       return true;
585     default:
586       return false;
587   }
588 }
589 
ItemTrace(PyObject * pObj,TracePtr pOrigin,TracePtr pItem)590 ItemTrace::ItemTrace(PyObject *pObj, TracePtr pOrigin, TracePtr pItem) : Trace(pObj, pOrigin), item_(pItem) {
591   curType_ = TraceType::Item;
592   if (origin_ != nullptr && item_ != nullptr) {
593     if (origin_->IsConst() && item_->IsConst()) {
594       is_const_ = true;
595     }
596     if (!origin_->IsSpecialized() && !item_->IsSpecialized()) {
597       is_specialized_ = false;
598     }
599   }
600   if (origin_ != nullptr) {
601     depth_ = origin_->GetDepth() + 1;
602   } else {
603     depth_ = 1;
604   }
605   if (item_ != nullptr) {
606     auto d = item_->GetDepth() + 1;
607     if (d > depth_) {
608       depth_ = d;
609     }
610   }
611 }
612 
GetItem()613 TracePtr ItemTrace::GetItem() { return item_; }
614 
Replace(std::shared_ptr<Trace> dst,std::shared_ptr<Trace> src)615 void ItemTrace::Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src) {
616   Trace::Replace(dst, src);
617   if (item_ != nullptr) {
618     if (*item_ == *src) {
619       item_ = dst;
620     } else {
621       item_->Replace(dst, src);
622     }
623   }
624 }
625 
Retrieve(PTraceContext context,bool perf)626 PyObject *ItemTrace::Retrieve(PTraceContext context, bool perf) {
627   PyObject *ret = Trace::Retrieve(context, perf);
628   if (ret != nullptr) {
629     return ret;
630   }
631   if (origin_ != nullptr && item_ != nullptr) {
632     PyObject *pSet = origin_->Retrieve(context, perf);
633     PyObject *pItem = item_->Retrieve(context, perf);
634     if (pSet != NULL && pItem != NULL) {
635       TracePerf tp(this, perf, false);
636       if (PyDict_CheckExact(pSet)) {
637         ret = PyDict_GetItem(pSet, pItem);
638         Py_INCREF(ret);
639       } else {
640         ret = PyObject_GetItem(pSet, pItem);
641       }
642       Cache(context, ret);
643     }
644     Py_XDECREF(pSet);
645     Py_XDECREF(pItem);
646   }
647   return ret;
648 }
649 
ToString(bool include_param)650 std::string ItemTrace::ToString(bool include_param) {
651   if (strTrace_.size() > 0) {
652     return strTrace_;
653   }
654   std::string ret;
655   if (origin_ != nullptr && item_ != nullptr) {
656     std::string ori = origin_->ToString(include_param);
657     std::string itm = item_->ToString(include_param);
658     ret = ori + "[" + itm + "]";
659   }
660   ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
661   ret = std::regex_replace(ret, std::regex("(\n)"), "");
662   strTrace_ = ret;
663   return ret;
664 }
665 
Info()666 const InfoPack &ItemTrace::Info() {
667   if (info_ == nullptr) {
668     InfoPack info;
669     info << uint8_t(curType_);
670     info.Begin();
671     info << (origin_ != nullptr && item_ != nullptr);
672     if (origin_ != nullptr && item_ != nullptr) {
673       auto ori = origin_->Info();
674       auto itm = item_->Info();
675       info << ori << itm;
676     }
677     info.End();
678     info_ = std::make_shared<InfoPack>(info);
679     info_->Update();
680   }
681   return *info_;
682 }
683 
Optimize()684 TracePtr ItemTrace::Optimize() {
685   bool need_update = false;
686   origin_ = OptimizeTrace(origin_, &need_update);
687   item_ = OptimizeTrace(item_, &need_update);
688   if (need_update) {
689     if (origin_ != nullptr && item_ != nullptr && origin_->IsConst() && item_->IsConst()) {
690       is_const_ = true;
691     }
692     info_ = nullptr;
693     strTrace_ = "";
694     Info();
695     return shared_from_this();
696   } else {
697     return nullptr;
698   }
699 }
700 
SetRelaxCount(int cnt)701 void ItemTrace::SetRelaxCount(int cnt) {
702   Trace::SetRelaxCount(cnt);
703   if (origin_ != nullptr) {
704     origin_->SetRelaxCount(cnt);
705   }
706   if (item_ != nullptr) {
707     item_->SetRelaxCount(cnt);
708   }
709 }
710 
operator ==(const Trace & trace)711 bool ItemTrace::operator==(const Trace &trace) {
712   if (Trace::operator==(trace)) {
713     const ItemTrace &t = (const ItemTrace &)trace;
714     if (!item_ && !(t.item_)) {
715       return true;
716     } else if (item_ != nullptr && t.item_ != nullptr) {
717       return *item_ == *(t.item_);
718     }
719   }
720   return false;
721 }
722 
Detach()723 void ItemTrace::Detach() {
724   Trace::Detach();
725   if (item_ != nullptr) {
726     item_->Detach();
727   }
728 }
729 
Support(TraceType tt)730 bool ItemTrace::Support(TraceType tt) { return tt == TraceType::Item; }
731 
AttrTrace(PyObject * pObj,TracePtr pOrigin,std::string strAttr)732 AttrTrace::AttrTrace(PyObject *pObj, TracePtr pOrigin, std::string strAttr) : Trace(pObj, pOrigin), attr_(strAttr) {
733   curType_ = TraceType::Attr;
734   if (origin_ != nullptr && origin_->IsConst()) {
735     is_const_ = true;
736   }
737   if (origin_ != nullptr) {
738     depth_ = origin_->GetDepth() + 1;
739   } else {
740     depth_ = 1;
741   }
742 }
743 
GetAttribute()744 std::string AttrTrace::GetAttribute() { return attr_; }
745 
Retrieve(PTraceContext context,bool perf)746 PyObject *AttrTrace::Retrieve(PTraceContext context, bool perf) {
747   PyObject *ret = Trace::Retrieve(context, perf);
748   if (ret != nullptr) {
749     return ret;
750   }
751   if (origin_ != nullptr) {
752     PyObject *pOrigin = origin_->Retrieve(context, perf);
753     if (pOrigin != NULL) {
754       TracePerf tp(this, perf, false);
755       PyObject *itemName = PyUnicode_FromString(attr_.c_str());
756       if (PyDict_CheckExact(pOrigin)) {
757         ret = PyDict_GetItem(pOrigin, itemName);
758         Py_INCREF(ret);
759       } else {
760         ret = PyObject_GetItem(pOrigin, itemName);
761       }
762       Py_DECREF(itemName);
763       Py_DECREF(pOrigin);
764       Cache(context, ret);
765     }
766   }
767   return ret;
768 }
769 
ToString(bool include_param)770 std::string AttrTrace::ToString(bool include_param) {
771   if (strTrace_.size() > 0) {
772     return strTrace_;
773   }
774   std::string ret;
775   if (origin_ != nullptr) {
776     std::string ori = origin_->ToString(include_param);
777     ret = ori + "." + attr_;
778   }
779   ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
780   ret = std::regex_replace(ret, std::regex("(\n)"), "");
781   strTrace_ = ret;
782   return ret;
783 }
784 
Info()785 const InfoPack &AttrTrace::Info() {
786   if (info_ == nullptr) {
787     InfoPack info;
788     info << uint8_t(curType_);
789     info.Begin();
790     info << (origin_ != nullptr);
791     if (origin_ != nullptr) {
792       auto ori = origin_->Info();
793       info << ori;
794     }
795     info << attr_;
796     info.End();
797     info_ = std::make_shared<InfoPack>(info);
798     info_->Update();
799   }
800   return *info_;
801 }
802 
Optimize()803 TracePtr AttrTrace::Optimize() {
804   bool need_update = false;
805   origin_ = OptimizeTrace(origin_, &need_update);
806   if (need_update) {
807     if (origin_ != nullptr && origin_->IsConst()) {
808       is_const_ = true;
809     }
810     info_ = nullptr;
811     strTrace_ = "";
812     Info();
813     return shared_from_this();
814   } else {
815     return nullptr;
816   }
817 }
818 
SetRelaxCount(int cnt)819 void AttrTrace::SetRelaxCount(int cnt) {
820   Trace::SetRelaxCount(cnt);
821   if (origin_ != nullptr) {
822     origin_->SetRelaxCount(cnt);
823   }
824 }
825 
operator ==(const Trace & trace)826 bool AttrTrace::operator==(const Trace &trace) {
827   if (Trace::operator==(trace)) {
828     const AttrTrace &t = (const AttrTrace &)trace;
829     return attr_ == t.attr_;
830   }
831   return false;
832 }
833 
Support(TraceType tt)834 bool AttrTrace::Support(TraceType tt) { return tt == TraceType::Attr; }
835 
ConstTrace(PyObject * pObj,int iIndex)836 ConstTrace::ConstTrace(PyObject *pObj, int iIndex) : Trace(pObj, nullptr), index_(iIndex) {
837   curType_ = TraceType::Const;
838   originType_ = TraceType::Const;
839   if (index_ == -1) {
840     is_const_ = true;
841   }
842   depth_ = 1;
843 }
844 
GetIndex()845 int ConstTrace::GetIndex() { return index_; }
846 
Retrieve(PTraceContext context,bool perf)847 PyObject *ConstTrace::Retrieve(PTraceContext context, bool perf) {
848   PyObject *ret = Trace::Retrieve(context, perf);
849   if (ret != nullptr) {
850     return ret;
851   }
852   if (obj_ != NULL) {
853     Py_INCREF(obj_);
854     return obj_;
855   }
856   if (index_ >= 0 && index_ < PyTuple_GET_SIZE(context->f_code->co_consts)) {
857     TracePerf tp(this, perf, false);
858     ret = PyTuple_GET_ITEM(context->f_code->co_consts, index_);
859     Py_INCREF(ret);
860     Cache(context, ret);
861   } else {
862     ret = obj_;
863     Py_INCREF(ret);
864   }
865   return ret;
866 }
867 
ToString(bool include_param)868 std::string ConstTrace::ToString(bool include_param) {
869   if (strTrace_.size() > 0) {
870     return strTrace_;
871   }
872   std::string ret = "co_consts";
873   if (index_ != -1) {
874     ret = ret + "[" + std::to_string(index_) + "]";
875   } else {
876     ret = ret + "[-1](" + std::string(py::str(obj_)) + ")";
877   }
878   ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
879   ret = std::regex_replace(ret, std::regex("(\n)"), "");
880   strTrace_ = ret;
881   return ret;
882 }
883 
Info()884 const InfoPack &ConstTrace::Info() {
885   if (info_ == nullptr) {
886     InfoPack info;
887     info << uint8_t(curType_);
888     info.Begin();
889     if (index_ != -1) {
890       info << index_;
891     } else {
892       info << index_ << obj_;
893     }
894     info.End();
895     info_ = std::make_shared<InfoPack>(info);
896     info_->Update();
897   }
898   return *info_;
899 }
900 
operator ==(const Trace & trace)901 bool ConstTrace::operator==(const Trace &trace) {
902   if (Trace::operator==(trace)) {
903     const ConstTrace &t = (const ConstTrace &)trace;
904     return index_ == t.index_;
905   }
906   return false;
907 }
908 
Detach()909 void ConstTrace::Detach() {}
910 
Support(TraceType tt)911 bool ConstTrace::Support(TraceType tt) { return tt == TraceType::Const; }
912 
TypeTrace(PyObject * pObj,TracePtr pOrigin)913 TypeTrace::TypeTrace(PyObject *pObj, TracePtr pOrigin) : Trace(pObj, pOrigin) {
914   pType_ = Py_TYPE(pObj);
915   curType_ = TraceType::Type;
916   if (origin_ != nullptr && origin_->IsConst()) {
917     is_const_ = true;
918   }
919   if (origin_ != nullptr) {
920     depth_ = origin_->GetDepth() + 1;
921   } else {
922     depth_ = 1;
923   }
924 }
925 
GetType()926 PyTypeObject *TypeTrace::GetType() { return pType_; }
927 
Retrieve(PTraceContext context,bool perf)928 PyObject *TypeTrace::Retrieve(PTraceContext context, bool perf) {
929   if (is_const_) {
930     auto rt = reinterpret_cast<PyObject *>(pType_);
931     Py_INCREF(rt);
932     return rt;
933   }
934   PyObject *ret = Trace::Retrieve(context, perf);
935   if (ret != nullptr) {
936     return ret;
937   }
938   if (origin_ != NULL) {
939     PyObject *pOrigin = origin_->Retrieve(context, perf);
940     if (pOrigin != NULL) {
941       TracePerf tp(this, perf, false);
942       ret = reinterpret_cast<PyObject *>(Py_TYPE(pOrigin));
943       Py_INCREF(ret);
944       Py_DECREF(pOrigin);
945       Cache(context, ret);
946       return ret;
947     }
948   }
949   return ret;
950 }
951 
ToString(bool include_param)952 std::string TypeTrace::ToString(bool include_param) {
953   if (strTrace_.size() > 0) {
954     return strTrace_;
955   }
956   std::string ret = "type(type:";
957   ret += std::string(py::str(reinterpret_cast<PyObject *>(pType_)));
958   if (origin_ != NULL) {
959     ret += ", origin:" + origin_->ToString(include_param);
960   }
961   ret += ")";
962   ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
963   ret += std::regex_replace(ret, std::regex("(\n)"), "");
964   strTrace_ = ret;
965   return ret;
966 }
967 
Info()968 const InfoPack &TypeTrace::Info() {
969   if (info_ == nullptr) {
970     InfoPack info;
971     info << uint8_t(curType_);
972     info.Begin();
973     info << reinterpret_cast<PyObject *>(pType_);
974     info << (origin_ != nullptr);
975     if (origin_ != nullptr) {
976       info << origin_->Info();
977     }
978     info.End();
979     info_ = std::make_shared<InfoPack>(info);
980     info_->Update();
981   }
982   return *info_;
983 }
984 
Optimize()985 TracePtr TypeTrace::Optimize() {
986   bool need_update = false;
987   origin_ = OptimizeTrace(origin_, &need_update);
988   if (need_update) {
989     if (origin_ != nullptr && origin_->IsConst()) {
990       is_const_ = true;
991     }
992     info_ = nullptr;
993     strTrace_ = "";
994     Info();
995     return shared_from_this();
996   } else {
997     return nullptr;
998   }
999 }
1000 
SetRelaxCount(int cnt)1001 void TypeTrace::SetRelaxCount(int cnt) {
1002   Trace::SetRelaxCount(cnt);
1003   if (origin_ != nullptr) {
1004     origin_->SetRelaxCount(cnt);
1005   }
1006 }
1007 
operator ==(const Trace & trace)1008 bool TypeTrace::operator==(const Trace &trace) {
1009   if (Trace::operator==(trace)) {
1010     const TypeTrace &t = (const TypeTrace &)trace;
1011     return pType_ == t.pType_;
1012   }
1013   return false;
1014 }
1015 
Detach()1016 void TypeTrace::Detach() {
1017   if (is_const_) {
1018     is_const_ = false;
1019     Trace::Detach();
1020     is_const_ = true;
1021   } else {
1022     Trace::Detach();
1023   }
1024 }
1025 
Support(TraceType tt)1026 bool TypeTrace::Support(TraceType tt) { return tt == TraceType::Type; }
1027 
RichCompare(PyObject * left,PyObject * right,int oparg)1028 static PyObject *RichCompare(PyObject *left, PyObject *right, int oparg) {
1029   bool invert;
1030   if (oparg >= Py_LT && oparg <= Py_GE) {
1031     return PyObject_RichCompare(left, right, oparg);
1032   } else if (Opcode(COMPARE_OP).CheckIsOp(oparg, &invert)) {
1033     auto ret = ((left == right) ^ invert) ? Py_True : Py_False;
1034     Py_INCREF(ret);
1035     return ret;
1036   } else if (Opcode(COMPARE_OP).CheckContainsOp(oparg, &invert)) {
1037     auto stat = PySequence_Contains(right, left);
1038     if (stat < 0) {
1039       return nullptr;
1040     }
1041     auto ret = (stat ^ invert) ? Py_True : Py_False;
1042     Py_INCREF(ret);
1043     return ret;
1044   }
1045   return nullptr;
1046 }
1047 
support_infer_primitive(PyObject * obj)1048 static bool support_infer_primitive(PyObject *obj) {
1049   if (py::isinstance<mindspore::PrimitivePyAdapter>(obj) || py::isinstance<mindspore::PrimitiveFunctionAdapter>(obj)) {
1050     auto inst = mindspore::pijit::InferEngine::GetInstance();
1051     return inst->SupportInfer(obj);
1052   } else {
1053     return false;
1054   }
1055 }
1056 
support_create_primitive(PyObject * obj)1057 static bool support_create_primitive(PyObject *obj) {
1058   if (!obj || !PyType_Check(obj)) {
1059     return false;
1060   }
1061   py::object m = py::reinterpret_steal<py::object>(PyImport_GetModule(py::str("mindspore.ops").ptr()));
1062   if (!m.ptr()) {
1063     PyErr_Clear();
1064     return false;
1065   }
1066   py::object t = py::reinterpret_steal<py::object>(PyObject_GetAttrString(m.ptr(), "Primitive"));
1067   if (PyType_IsSubtype(reinterpret_cast<PyTypeObject *>(obj), reinterpret_cast<PyTypeObject *>((t.ptr())))) {
1068     return true;
1069   } else {
1070     return false;
1071   }
1072 }
1073 
1074 extern bool CheckJitConstexpr(const py::object &func);
1075 extern bool CheckMSConstexpr(const py::object &func);
1076 extern bool CheckBuiltinFuncOrMethod(const py::object &func);
SupportCall(PyObject * func,const std::string & name)1077 static bool SupportCall(PyObject *func, const std::string &name) {
1078   /**
1079    * NOTE: exclude method type, it shouldn't be guard
1080    */
1081   static const std::set<PyTypeObject *> support_create_instance_type = {
1082     &PyComplex_Type, &PyMap_Type,       &PyBaseObject_Type, &PyRange_Type,   &PyZip_Type,  &PySlice_Type,
1083     &PyBool_Type,    &PyFloat_Type,     &PyLong_Type,       &PyType_Type,    &PyList_Type, &PyTuple_Type,
1084     &PySet_Type,     &PyFrozenSet_Type, &PyDict_Type,       &PyUnicode_Type, &PyEnum_Type, &PyMethod_Type,
1085   };
1086   if (PyType_CheckExact(func)) {
1087     if (IsMsClass(func)) {
1088       return true;
1089     }
1090     return support_create_instance_type.find(reinterpret_cast<PyTypeObject *>(func)) !=
1091            support_create_instance_type.end();
1092   }
1093 
1094   py::object handle = py::cast<py::object>(func);
1095   if (CheckJitConstexpr(handle)) {
1096     return true;
1097   }
1098   if (CheckMSConstexpr(handle)) {
1099     return true;
1100   }
1101   if (CheckBuiltinFuncOrMethod(handle)) {
1102     return true;
1103   }
1104   return support_infer_primitive(func) || support_create_primitive(func) || IsMsClass(func) ||
1105          (name.size() != 0 && PyDict_GetItemString(PyEval_GetBuiltins(), name.c_str()) == func);
1106 }
1107 
DoCall(const std::vector<PyObject * > & params,int op,const std::string & name)1108 static PyObject *DoCall(const std::vector<PyObject *> &params, int op, const std::string &name) {
1109   if (!Opcode(op).IsCall() || params.size() < 1) {
1110     return nullptr;
1111   }
1112   if (support_infer_primitive(params[0])) {
1113     std::vector<PyObject *> list;
1114     auto inst = mindspore::pijit::InferEngine::GetInstance();
1115     list.insert(list.begin(), params.begin() + 1, params.end());
1116     bool is_abstract = false;
1117     try {
1118       return inst->InferPrimitive(params[0], list, &is_abstract);
1119     } catch (py::error_already_set &e) {
1120       MS_LOG(ERROR) << "InferPrimitive failed " << std::endl << e.what();
1121     } catch (py::builtin_exception &e) {
1122       MS_LOG(ERROR) << "InferPrimitive failed " << std::endl << e.what();
1123     }
1124     return nullptr;
1125   }
1126 
1127   size_t nargs = (params.size() - 1);
1128   size_t kw_cnt;
1129   if (op == CALL_FUNCTION) {
1130     return PyObject_Vectorcall(params[0], params.data() + 1, nargs, NULL);
1131   } else if (op == CALL_FUNCTION_KW) {
1132     kw_cnt = PyTuple_GET_SIZE(params.back());
1133     return PyObject_Vectorcall(params[0], params.data() + 1, nargs - 1 - kw_cnt, params.back());
1134   } else if (op == CALL_FUNCTION_EX) {
1135     return PyObject_Call(params[0], params[1], params.size() > 2 ? params[2] : nullptr);
1136   }
1137   return nullptr;
1138 }
1139 
1140 using PyObjectArray = std::vector<PyObject *>;
1141 
CheckAndDoBinary(int op,const PyObjectArray & objs,binaryfunc pyfunc)1142 static PyObject *CheckAndDoBinary(int op, const PyObjectArray &objs, binaryfunc pyfunc) {
1143   if (py::isinstance<mindspore::tensor::Tensor>(objs[0])) {
1144     auto arg0 = py::reinterpret_borrow<py::object>(objs[0]);
1145     auto arg1 = py::reinterpret_borrow<py::object>(objs[1]);
1146     auto res = pijit::AbstractTensor::Binary(op, arg0, arg1);
1147     return res.inc_ref().ptr();
1148   } else {
1149     return pyfunc(objs[0], objs[1]);
1150   }
1151 }
1152 
1153 using PythonBytecodeSupportCheckFunc = std::function<bool(int opargs, const PyObjectArray &objs)>;
1154 using PythonBytecodeExecuteFunc = std::function<PyObject *(int opargs, const PyObjectArray &objs, PTraceContext ctx)>;
1155 using PythonBytecodeFuncSet = std::pair<PythonBytecodeSupportCheckFunc, PythonBytecodeExecuteFunc>;
ByteCodeUnsupported(int opargs,const PyObjectArray & objs)1156 static bool ByteCodeUnsupported(int opargs, const PyObjectArray &objs) { return false; }
ByteCodeSupported(int opargs,const PyObjectArray & objs)1157 static bool ByteCodeSupported(int opargs, const PyObjectArray &objs) { return true; }
1158 #define ByteCodeTest(bytecode)                                                                                       \
1159   [](int opargs, const PyObjectArray &objs) {                                                                        \
1160     return OptStrategy::MakeCalcStrategyByInputs(bytecode, opargs, objs) != OptStrategy::CalcKind::kCalcUnsupported; \
1161   }
1162 #define ByteCodeCheck(bytecode, opargs, objs) \
1163   (OptStrategy::MakeCalcStrategyByInputs(bytecode, opargs, objs) == OptStrategy::CalcKind::kCalcValue)
1164 static std::unordered_map<int, PythonBytecodeFuncSet> kBytecodeExecuter = {
1165   {POP_TOP, {ByteCodeUnsupported, nullptr}},
1166   {ROT_TWO, {ByteCodeUnsupported, nullptr}},
1167   {ROT_THREE, {ByteCodeUnsupported, nullptr}},
1168   {DUP_TOP, {ByteCodeUnsupported, nullptr}},
1169   {DUP_TOP_TWO, {ByteCodeUnsupported, nullptr}},
1170   {NOP, {ByteCodeUnsupported, nullptr}},
1171   {UNARY_POSITIVE,
1172    {ByteCodeTest(UNARY_POSITIVE),
__anonaa04c86f0102() 1173     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1174       if (ByteCodeCheck(UNARY_POSITIVE, opargs, objs)) {
1175         return PyNumber_Positive(objs[0]);
1176       } else {
1177         Py_XINCREF(objs[0]);
1178         return objs[0];
1179       }
1180     }}},
1181   {UNARY_NEGATIVE,
1182    {ByteCodeTest(UNARY_NEGATIVE),
__anonaa04c86f0202() 1183     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1184       if (ByteCodeCheck(UNARY_NEGATIVE, opargs, objs)) {
1185         return PyNumber_Negative(objs[0]);
1186       } else {
1187         Py_XINCREF(objs[0]);
1188         return objs[0];
1189       }
1190     }}},
1191   {UNARY_NOT,
1192    {ByteCodeTest(UNARY_NOT),
__anonaa04c86f0302() 1193     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1194       if (ByteCodeCheck(UNARY_NOT, opargs, objs)) {
1195         auto ret = PyObject_IsTrue(objs[0]) ? Py_False : Py_True;
1196         Py_INCREF(ret);
1197         return ret;
1198       } else {
1199         Py_INCREF(Py_True);
1200         return Py_True;
1201       }
1202     }}},
1203   {UNARY_INVERT,
1204    {ByteCodeTest(UNARY_INVERT),
__anonaa04c86f0402() 1205     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1206       if (ByteCodeCheck(UNARY_INVERT, opargs, objs)) {
1207         return PyNumber_Invert(objs[0]);
1208       } else {
1209         Py_INCREF(objs[0]);
1210         return objs[0];
1211       }
1212     }}},
1213   {BINARY_MATRIX_MULTIPLY,
1214    {ByteCodeTest(BINARY_MATRIX_MULTIPLY),
__anonaa04c86f0502() 1215     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1216       if (ByteCodeCheck(BINARY_MATRIX_MULTIPLY, opargs, objs)) {
1217         return PyNumber_MatrixMultiply(objs[0], objs[1]);
1218       } else {
1219         Py_INCREF(objs[0]);
1220         return objs[0];
1221       }
1222     }}},
1223   {INPLACE_MATRIX_MULTIPLY,
1224    {ByteCodeTest(INPLACE_MATRIX_MULTIPLY),
__anonaa04c86f0602() 1225     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1226       if (ByteCodeCheck(INPLACE_MATRIX_MULTIPLY, opargs, objs)) {
1227         return PyNumber_InPlaceMatrixMultiply(objs[0], objs[1]);
1228       } else {
1229         Py_INCREF(objs[0]);
1230         return objs[0];
1231       }
1232     }}},
1233   {BINARY_POWER,
1234    {ByteCodeTest(BINARY_POWER),
__anonaa04c86f0702() 1235     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1236       if (ByteCodeCheck(BINARY_POWER, opargs, objs)) {
1237         return PyNumber_Power(objs[0], objs[1], Py_None);
1238       } else {
1239         Py_INCREF(objs[0]);
1240         return objs[0];
1241       }
1242     }}},
1243   {BINARY_MULTIPLY,
1244    {ByteCodeTest(BINARY_MULTIPLY),
__anonaa04c86f0802() 1245     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1246       if (ByteCodeCheck(BINARY_MULTIPLY, opargs, objs)) {
1247         return CheckAndDoBinary(BINARY_MULTIPLY, objs, PyNumber_Multiply);
1248       } else {
1249         Py_INCREF(objs[0]);
1250         return objs[0];
1251       }
1252     }}},
1253   {BINARY_MODULO,
1254    {ByteCodeTest(BINARY_MODULO),
__anonaa04c86f0902() 1255     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1256       if (ByteCodeCheck(BINARY_MODULO, opargs, objs)) {
1257         if (PyUnicode_CheckExact(objs[0]) && (!PyUnicode_Check(objs[1]) || PyUnicode_CheckExact(objs[1]))) {
1258           return PyUnicode_Format(objs[0], objs[1]);
1259         } else {
1260           return PyNumber_Remainder(objs[0], objs[1]);
1261         }
1262       } else {
1263         Py_INCREF(objs[0]);
1264         return objs[0];
1265       }
1266     }}},
1267   {BINARY_ADD,
__anonaa04c86f0a02() 1268    {[](int opargs, const PyObjectArray &objs) -> bool {
1269       return (!PyUnicode_CheckExact(objs[0]) || !PyUnicode_CheckExact(objs[1])) &&
1270              OptStrategy::MakeCalcStrategyByInputs(BINARY_ADD, opargs, objs) != OptStrategy::CalcKind::kCalcUnsupported;
1271     },
1272     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1273       if (ByteCodeCheck(BINARY_ADD, opargs, objs)) {
1274         return CheckAndDoBinary(BINARY_ADD, objs, PyNumber_Add);
1275       } else {
1276         Py_INCREF(objs[0]);
1277         return objs[0];
1278       }
1279     }}},
1280   {BINARY_SUBTRACT,
1281    {ByteCodeTest(BINARY_SUBTRACT),
__anonaa04c86f0c02() 1282     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1283       if (ByteCodeCheck(BINARY_SUBTRACT, opargs, objs)) {
1284         return CheckAndDoBinary(BINARY_SUBTRACT, objs, PyNumber_Subtract);
1285       } else {
1286         Py_INCREF(objs[0]);
1287         return objs[0];
1288       }
1289     }}},
1290   {BINARY_SUBSCR,
1291    {ByteCodeTest(BINARY_SUBSCR),
__anonaa04c86f0d02() 1292     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1293       return PyObject_GetItem(objs[0], objs[1]);
1294     }}},
1295   {BINARY_FLOOR_DIVIDE,
1296    {ByteCodeTest(BINARY_FLOOR_DIVIDE),
__anonaa04c86f0e02() 1297     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1298       if (ByteCodeCheck(BINARY_FLOOR_DIVIDE, opargs, objs)) {
1299         return CheckAndDoBinary(BINARY_FLOOR_DIVIDE, objs, PyNumber_FloorDivide);
1300       } else {
1301         Py_INCREF(objs[0]);
1302         return objs[0];
1303       }
1304     }}},
1305   {BINARY_TRUE_DIVIDE,
1306    {ByteCodeTest(BINARY_TRUE_DIVIDE),
__anonaa04c86f0f02() 1307     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1308       if (ByteCodeCheck(BINARY_TRUE_DIVIDE, opargs, objs)) {
1309         return CheckAndDoBinary(BINARY_TRUE_DIVIDE, objs, PyNumber_TrueDivide);
1310       } else {
1311         Py_INCREF(objs[0]);
1312         return objs[0];
1313       }
1314     }}},
1315   {INPLACE_FLOOR_DIVIDE,
1316    {ByteCodeTest(INPLACE_FLOOR_DIVIDE),
__anonaa04c86f1002() 1317     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1318       if (ByteCodeCheck(INPLACE_FLOOR_DIVIDE, opargs, objs)) {
1319         return PyNumber_InPlaceFloorDivide(objs[0], objs[1]);
1320       } else {
1321         Py_INCREF(objs[0]);
1322         return objs[0];
1323       }
1324     }}},
1325   {INPLACE_TRUE_DIVIDE,
1326    {ByteCodeTest(INPLACE_TRUE_DIVIDE),
__anonaa04c86f1102() 1327     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1328       if (ByteCodeCheck(INPLACE_TRUE_DIVIDE, opargs, objs)) {
1329         return PyNumber_InPlaceTrueDivide(objs[0], objs[1]);
1330       } else {
1331         Py_INCREF(objs[0]);
1332         return objs[0];
1333       }
1334     }}},
1335   {GET_AITER, {ByteCodeUnsupported, nullptr}},
1336   {GET_ANEXT, {ByteCodeUnsupported, nullptr}},
1337   {BEFORE_ASYNC_WITH, {ByteCodeUnsupported, nullptr}},
1338   {INPLACE_ADD,
__anonaa04c86f1202() 1339    {[](int opargs, const PyObjectArray &objs) -> bool {
1340       return (!PyUnicode_CheckExact(objs[0]) || !PyUnicode_CheckExact(objs[1])) &&
1341              OptStrategy::MakeCalcStrategyByInputs(INPLACE_ADD, opargs, objs) !=
1342                OptStrategy::CalcKind::kCalcUnsupported;
1343     },
1344     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1345       if (ByteCodeCheck(INPLACE_ADD, opargs, objs)) {
1346         return PyNumber_InPlaceAdd(objs[0], objs[1]);
1347       } else {
1348         Py_INCREF(objs[0]);
1349         return objs[0];
1350       }
1351     }}},
1352   {INPLACE_SUBTRACT,
1353    {ByteCodeTest(INPLACE_SUBTRACT),
__anonaa04c86f1402() 1354     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1355       if (ByteCodeCheck(INPLACE_SUBTRACT, opargs, objs)) {
1356         return PyNumber_InPlaceSubtract(objs[0], objs[1]);
1357       } else {
1358         Py_INCREF(objs[0]);
1359         return objs[0];
1360       }
1361     }}},
1362   {INPLACE_MULTIPLY,
1363    {ByteCodeTest(INPLACE_MULTIPLY),
__anonaa04c86f1502() 1364     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1365       if (ByteCodeCheck(INPLACE_MULTIPLY, opargs, objs)) {
1366         return PyNumber_InPlaceMultiply(objs[0], objs[1]);
1367       } else {
1368         Py_INCREF(objs[0]);
1369         return objs[0];
1370       }
1371     }}},
1372   {INPLACE_MODULO,
1373    {ByteCodeTest(INPLACE_MODULO),
__anonaa04c86f1602() 1374     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1375       if (ByteCodeCheck(INPLACE_MODULO, opargs, objs)) {
1376         return PyNumber_InPlaceRemainder(objs[0], objs[1]);
1377       } else {
1378         Py_INCREF(objs[0]);
1379         return objs[0];
1380       }
1381     }}},
1382   {STORE_SUBSCR, {ByteCodeUnsupported, nullptr}},
1383   {DELETE_SUBSCR, {ByteCodeUnsupported, nullptr}},
1384   {BINARY_LSHIFT,
1385    {ByteCodeTest(BINARY_LSHIFT),
__anonaa04c86f1702() 1386     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1387       if (ByteCodeCheck(BINARY_LSHIFT, opargs, objs)) {
1388         return PyNumber_Lshift(objs[0], objs[1]);
1389       } else {
1390         Py_INCREF(objs[0]);
1391         return objs[0];
1392       }
1393     }}},
1394   {BINARY_RSHIFT,
1395    {ByteCodeTest(BINARY_RSHIFT),
__anonaa04c86f1802() 1396     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1397       if (ByteCodeCheck(BINARY_RSHIFT, opargs, objs)) {
1398         return PyNumber_Rshift(objs[0], objs[1]);
1399       } else {
1400         Py_INCREF(objs[0]);
1401         return objs[0];
1402       }
1403     }}},
1404   {BINARY_AND,
1405    {ByteCodeTest(BINARY_AND),
__anonaa04c86f1902() 1406     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1407       if (ByteCodeCheck(BINARY_AND, opargs, objs)) {
1408         return PyNumber_And(objs[0], objs[1]);
1409       } else {
1410         Py_INCREF(objs[0]);
1411         return objs[0];
1412       }
1413     }}},
1414   {BINARY_XOR,
1415    {ByteCodeTest(BINARY_XOR),
__anonaa04c86f1a02() 1416     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1417       if (ByteCodeCheck(BINARY_XOR, opargs, objs)) {
1418         return PyNumber_Xor(objs[0], objs[1]);
1419       } else {
1420         Py_INCREF(objs[0]);
1421         return objs[0];
1422       }
1423     }}},
1424   {BINARY_OR,
1425    {ByteCodeTest(BINARY_OR),
__anonaa04c86f1b02() 1426     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1427       if (ByteCodeCheck(BINARY_OR, opargs, objs)) {
1428         return PyNumber_Or(objs[0], objs[1]);
1429       } else {
1430         Py_INCREF(objs[0]);
1431         return objs[0];
1432       }
1433     }}},
1434   {INPLACE_POWER,
1435    {ByteCodeTest(INPLACE_POWER),
__anonaa04c86f1c02() 1436     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1437       if (ByteCodeCheck(INPLACE_POWER, opargs, objs)) {
1438         return PyNumber_InPlacePower(objs[0], objs[1], Py_None);
1439       } else {
1440         Py_INCREF(objs[0]);
1441         return objs[0];
1442       }
1443     }}},
1444   {GET_ITER,
1445    {ByteCodeSupported,
__anonaa04c86f1d02() 1446     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * { return PyObject_GetIter(objs[0]); }}},
1447   {GET_YIELD_FROM_ITER,
1448    {ByteCodeSupported,
__anonaa04c86f1e02() 1449     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1450       PyObject *iterable = objs[0];
1451       if (PyCoro_CheckExact(iterable)) {
1452         if (!(ctx->f_code->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE))) {
1453           return nullptr;
1454         }
1455       } else if (!PyGen_CheckExact(iterable)) {
1456         return PyObject_GetIter(iterable);
1457       }
1458       Py_INCREF(iterable);
1459       return iterable;
1460     }}},
1461   {PRINT_EXPR, {ByteCodeUnsupported, nullptr}},
1462   {LOAD_BUILD_CLASS, {ByteCodeUnsupported, nullptr}},
1463   {YIELD_FROM, {ByteCodeUnsupported, nullptr}},
1464   {GET_AWAITABLE, {ByteCodeUnsupported, nullptr}},
1465   {INPLACE_LSHIFT,
1466    {ByteCodeTest(INPLACE_LSHIFT),
__anonaa04c86f1f02() 1467     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1468       if (ByteCodeCheck(INPLACE_LSHIFT, opargs, objs)) {
1469         return PyNumber_InPlaceLshift(objs[0], objs[1]);
1470       } else {
1471         Py_INCREF(objs[0]);
1472         return objs[0];
1473       }
1474     }}},
1475   {INPLACE_RSHIFT,
1476    {ByteCodeTest(INPLACE_RSHIFT),
__anonaa04c86f2002() 1477     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1478       if (ByteCodeCheck(INPLACE_RSHIFT, opargs, objs)) {
1479         return PyNumber_InPlaceRshift(objs[0], objs[1]);
1480       } else {
1481         Py_INCREF(objs[0]);
1482         return objs[0];
1483       }
1484     }}},
1485   {INPLACE_AND,
1486    {ByteCodeTest(INPLACE_AND),
1487     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2102() 1488                                                                    * {
1489                                                                      if (ByteCodeCheck(INPLACE_AND, opargs, objs)) {
1490                                                                        return PyNumber_InPlaceAnd(objs[0], objs[1]);
1491                                                                      } else {
1492                                                                        Py_INCREF(objs[0]);
1493                                                                        return objs[0];
1494                                                                      }
1495                                                                    }}},
1496   {INPLACE_XOR,
1497    {ByteCodeTest(INPLACE_XOR),
1498     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2202() 1499                                                                    * {
1500                                                                      if (ByteCodeCheck(INPLACE_XOR, opargs, objs)) {
1501                                                                        return PyNumber_InPlaceXor(objs[0], objs[1]);
1502                                                                      } else {
1503                                                                        Py_INCREF(objs[0]);
1504                                                                        return objs[0];
1505                                                                      }
1506                                                                    }}},
1507   {INPLACE_OR,
1508    {ByteCodeTest(INPLACE_OR),
1509     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2302() 1510                                                                    * {
1511                                                                      if (ByteCodeCheck(INPLACE_OR, opargs, objs)) {
1512                                                                        return PyNumber_InPlaceOr(objs[0], objs[1]);
1513                                                                      } else {
1514                                                                        Py_INCREF(objs[0]);
1515                                                                        return objs[0];
1516                                                                      }
1517                                                                    }}},
1518   {RETURN_VALUE, {ByteCodeUnsupported, nullptr}},
1519   {IMPORT_STAR, {ByteCodeUnsupported, nullptr}},
1520   {SETUP_ANNOTATIONS, {ByteCodeUnsupported, nullptr}},
1521   {YIELD_VALUE, {ByteCodeUnsupported, nullptr}},
1522   {POP_BLOCK, {ByteCodeUnsupported, nullptr}},
1523   {POP_EXCEPT, {ByteCodeUnsupported, nullptr}},
1524   {STORE_NAME, {ByteCodeUnsupported, nullptr}},
1525   {DELETE_NAME, {ByteCodeUnsupported, nullptr}},
1526   {UNPACK_SEQUENCE, {ByteCodeUnsupported, nullptr}},
1527   {FOR_ITER, {ByteCodeUnsupported, nullptr}},
1528   {UNPACK_EX, {ByteCodeUnsupported, nullptr}},
1529   {STORE_ATTR, {ByteCodeUnsupported, nullptr}},
1530   {DELETE_ATTR, {ByteCodeUnsupported, nullptr}},
1531   {STORE_GLOBAL, {ByteCodeUnsupported, nullptr}},
1532   {DELETE_GLOBAL, {ByteCodeUnsupported, nullptr}},
1533   {LOAD_CONST, {ByteCodeSupported, nullptr}},
1534   {LOAD_NAME, {ByteCodeSupported, nullptr}},
1535   {BUILD_TUPLE,
1536    {ByteCodeSupported,
1537     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2402() 1538                                                                    * {
1539                                                                      PyObject *tup = PyTuple_New(opargs);
1540                                                                      while (--opargs >= 0) {
1541                                                                        Py_INCREF(objs[opargs]);
1542                                                                        PyTuple_SET_ITEM(tup, opargs, objs[opargs]);
1543                                                                      }
1544                                                                      return tup;
1545                                                                    }}},
1546   {BUILD_LIST,
1547    {ByteCodeSupported,
1548     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2502() 1549                                                                    * {
1550                                                                      PyObject *list = PyList_New(opargs);
1551                                                                      while (--opargs >= 0) {
1552                                                                        Py_INCREF(objs[opargs]);
1553                                                                        PyList_SET_ITEM(list, opargs, objs[opargs]);
1554                                                                      }
1555                                                                      return list;
1556                                                                    }}},
1557   {BUILD_SET,
1558    {ByteCodeSupported,
1559     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2602() 1560                                                                    * {
1561                                                                      PyObject *set = PySet_New(NULL);
1562                                                                      for (int i = opargs; i > 0; i--) {
1563                                                                        PySet_Add(set, objs[opargs - i]);
1564                                                                      }
1565                                                                      return set;
1566                                                                    }}},
1567   {BUILD_MAP,
1568    {ByteCodeSupported,
1569     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2702() 1570                                                                    * {
1571                                                                      PyObject *map =
1572                                                                        _PyDict_NewPresized((Py_ssize_t)opargs);
1573                                                                      for (Py_ssize_t i = opargs; i > 0; i--) {
1574                                                                        PyObject *key = objs[2 * (opargs - i)];
1575                                                                        PyObject *value = objs[2 * (opargs - i) + 1];
1576                                                                        PyDict_SetItem(map, key, value);
1577                                                                      }
1578                                                                      return map;
1579                                                                    }}},
1580   {LOAD_ATTR, {ByteCodeSupported, nullptr}},
1581   {COMPARE_OP,
__anonaa04c86f2802() 1582    {[](int opargs, const PyObjectArray &objs) {
1583       return OptStrategy::MakeCalcStrategyByInputs(COMPARE_OP, opargs, objs) != OptStrategy::CalcKind::kCalcUnsupported;
1584     },
1585     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1586       if (ByteCodeCheck(COMPARE_OP, opargs, objs)) {
1587         return RichCompare(objs[0], objs[1], opargs);
1588       } else {
1589         Py_INCREF(Py_True);
1590         return Py_True;
1591       }
1592     }}},
1593   {IMPORT_NAME, {ByteCodeUnsupported, nullptr}},
1594   {IMPORT_FROM, {ByteCodeUnsupported, nullptr}},
1595   {JUMP_FORWARD, {ByteCodeUnsupported, nullptr}},
1596   {JUMP_IF_FALSE_OR_POP, {ByteCodeUnsupported, nullptr}},
1597   {JUMP_IF_TRUE_OR_POP, {ByteCodeUnsupported, nullptr}},
1598   {JUMP_ABSOLUTE, {ByteCodeUnsupported, nullptr}},
1599   {POP_JUMP_IF_FALSE, {ByteCodeUnsupported, nullptr}},
1600   {POP_JUMP_IF_TRUE, {ByteCodeUnsupported, nullptr}},
1601   {LOAD_GLOBAL, {ByteCodeSupported, nullptr}},
1602   {SETUP_FINALLY, {ByteCodeUnsupported, nullptr}},
1603   {LOAD_FAST, {ByteCodeUnsupported, nullptr}},
1604   {STORE_FAST, {ByteCodeUnsupported, nullptr}},
1605   {DELETE_FAST, {ByteCodeUnsupported, nullptr}},
1606   {RAISE_VARARGS, {ByteCodeUnsupported, nullptr}},
1607   {CALL_FUNCTION, {ByteCodeSupported, nullptr}},
1608   {MAKE_FUNCTION, {ByteCodeUnsupported, nullptr}},
1609   {BUILD_SLICE,
1610    {ByteCodeSupported,
1611     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2a02() 1612                                                                    * {
1613                                                                      PyObject *start;
1614                                                                      PyObject *stop;
1615                                                                      PyObject *step;
1616                                                                      if (opargs == 3)
1617                                                                        step = objs[2];
1618                                                                      else
1619                                                                        step = nullptr;
1620                                                                      stop = objs[1];
1621                                                                      start = objs[0];
1622                                                                      return PySlice_New(start, stop, step);
1623                                                                    }}},
1624   {LOAD_CLOSURE, {ByteCodeSupported, nullptr}},
1625   {LOAD_DEREF, {ByteCodeSupported, nullptr}},
1626   {STORE_DEREF, {ByteCodeUnsupported, nullptr}},
1627   {DELETE_DEREF, {ByteCodeUnsupported, nullptr}},
1628   {CALL_FUNCTION_KW, {ByteCodeSupported, nullptr}},
1629   {CALL_FUNCTION_EX, {ByteCodeSupported, nullptr}},
1630   {SETUP_WITH, {ByteCodeUnsupported, nullptr}},
1631   {EXTENDED_ARG, {ByteCodeUnsupported, nullptr}},
1632   {LIST_APPEND,
1633    {ByteCodeSupported,
1634     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2b02() 1635                                                                    * {
1636                                                                      PyList_Append(objs[0], objs[1]);
1637                                                                      Py_INCREF(objs[0]);
1638                                                                      return objs[0];
1639                                                                    }}},
1640   {SET_ADD,
1641    {ByteCodeSupported,
1642     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2c02() 1643                                                                    * {
1644                                                                      PySet_Add(objs[0], objs[1]);
1645                                                                      Py_INCREF(objs[0]);
1646                                                                      return objs[0];
1647                                                                    }}},
1648   {MAP_ADD,
1649    {ByteCodeSupported,
1650     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2d02() 1651                                                                    * {
1652                                                                      PyDict_SetItem(objs[0], objs[2], objs[1]);
1653                                                                      Py_INCREF(objs[0]);
1654                                                                      return objs[0];
1655                                                                    }}},
1656   {LOAD_CLASSDEREF, {ByteCodeSupported, nullptr}},
1657   {SETUP_ASYNC_WITH, {ByteCodeUnsupported, nullptr}},
1658   {FORMAT_VALUE, {ByteCodeUnsupported, nullptr}},
1659   {BUILD_CONST_KEY_MAP,
__anonaa04c86f2e02() 1660    {[](int opargs, const PyObjectArray &objs) -> bool {
1661       return PyTuple_CheckExact(objs[opargs]) && PyTuple_GET_SIZE(objs[opargs]) == (Py_ssize_t)opargs;
1662     },
1663     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1664       PyObject *keys = objs[opargs];
1665       PyObject *map = _PyDict_NewPresized((Py_ssize_t)opargs);
1666       for (Py_ssize_t i = opargs; i > 0; i--) {
1667         PyObject *key = PyTuple_GET_ITEM(keys, opargs - i);
1668         PyObject *value = objs[opargs - i];
1669         PyDict_SetItem(map, key, value);
1670       }
1671       return map;
1672     }}},
1673   {BUILD_STRING,
1674    {ByteCodeSupported,
1675     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3002() 1676                                                                    * {
1677                                                                      PyObject *empty = PyUnicode_New(0, 0);
1678                                                                      PyObject *str =
1679                                                                        _PyUnicode_JoinArray(empty, objs.data(), opargs);
1680                                                                      Py_DECREF(empty);
1681                                                                      return str;
1682                                                                    }}},
1683   {LOAD_METHOD, {ByteCodeUnsupported, nullptr}},
1684   {CALL_METHOD, {ByteCodeUnsupported, nullptr}},
1685   {ROT_FOUR, {ByteCodeUnsupported, nullptr}},
1686   {RERAISE, {ByteCodeUnsupported, nullptr}},
1687   {WITH_EXCEPT_START, {ByteCodeUnsupported, nullptr}},
1688   {END_ASYNC_FOR, {ByteCodeUnsupported, nullptr}},
1689   {LOAD_ASSERTION_ERROR, {ByteCodeUnsupported, nullptr}},
1690   {LIST_TO_TUPLE,
1691    {ByteCodeSupported,
__anonaa04c86f3102() 1692     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * { return PyList_AsTuple(objs[0]); }}},
1693   {IS_OP,
1694    {ByteCodeSupported,
1695     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3202() 1696                                                                    * {
1697                                                                      auto ret = (objs[0] == objs[1]) ^ opargs
1698                                                                                   ? Py_True
1699                                                                                   : Py_False;
1700                                                                      Py_INCREF(ret);
1701                                                                      return ret;
1702                                                                    }}},
1703   {CONTAINS_OP,
1704    {ByteCodeSupported,
1705     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3302() 1706                                                                    * {
1707                                                                      auto ret =
1708                                                                        (PySequence_Contains(objs[1], objs[0]) ^ opargs)
1709                                                                          ? Py_True
1710                                                                          : Py_False;
1711                                                                      Py_INCREF(ret);
1712                                                                      return ret;
1713                                                                    }}},
1714   {JUMP_IF_NOT_EXC_MATCH, {ByteCodeUnsupported, nullptr}},
1715   {LIST_EXTEND,
1716    {ByteCodeSupported,
1717     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3402() 1718                                                                    * {
1719                                                                      _PyList_Extend(
1720                                                                        reinterpret_cast<PyListObject *>(objs[0]),
1721                                                                        objs[1]);
1722                                                                      Py_INCREF(objs[0]);
1723                                                                      return objs[0];
1724                                                                    }}},
1725   {SET_UPDATE,
1726    {ByteCodeSupported,
1727     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3502() 1728                                                                    * {
1729                                                                      _PySet_Update(objs[0], objs[1]);
1730                                                                      Py_INCREF(objs[0]);
1731                                                                      return objs[0];
1732                                                                    }}},
1733   {DICT_MERGE,
1734    {ByteCodeSupported,
1735     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3602() 1736                                                                    * {
1737                                                                      _PyDict_MergeEx(objs[0], objs[1], 2);
1738                                                                      return objs[0];
1739                                                                    }}},
1740   {DICT_UPDATE,
1741    {ByteCodeSupported,
1742     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3702() 1743                                                                    * {
1744                                                                      PyDict_Update(objs[0], objs[1]);
1745                                                                      return objs[0];
1746                                                                    }}},
1747   {BREAK_LOOP, {ByteCodeUnsupported, nullptr}},
1748   {WITH_CLEANUP_START, {ByteCodeUnsupported, nullptr}},
1749   {WITH_CLEANUP_FINISH, {ByteCodeUnsupported, nullptr}},
1750   {END_FINALLY, {ByteCodeUnsupported, nullptr}},
1751   {CONTINUE_LOOP, {ByteCodeUnsupported, nullptr}},
1752   {SETUP_LOOP, {ByteCodeUnsupported, nullptr}},
1753   {SETUP_EXCEPT, {ByteCodeUnsupported, nullptr}},
1754   {BUILD_LIST_UNPACK,
1755    {ByteCodeSupported,
1756     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3802() 1757                                                                    * {
1758                                                                      PyObject *sum = PyList_New(0);
1759                                                                      for (int i = opargs; i > 0; i--) {
1760                                                                        auto none_val = _PyList_Extend(
1761                                                                          reinterpret_cast<PyListObject *>(sum),
1762                                                                          objs[opargs - i]);
1763                                                                        Py_DECREF(none_val);
1764                                                                      }
1765                                                                      return sum;
1766                                                                    }}},
1767   {BUILD_MAP_UNPACK,
1768    {ByteCodeSupported,
1769     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3902() 1770                                                                    * {
1771                                                                      PyObject *sum = PyDict_New();
1772                                                                      for (int i = opargs; i > 0; i--) {
1773                                                                        PyDict_Update(sum, objs[opargs - i]);
1774                                                                      }
1775                                                                      return sum;
1776                                                                    }}},
1777   {BUILD_MAP_UNPACK_WITH_CALL,
1778    {ByteCodeSupported,
1779     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3a02() 1780                                                                    * {
1781                                                                      PyObject *sum = PyDict_New();
1782                                                                      for (int i = opargs; i > 0; i--) {
1783                                                                        _PyDict_MergeEx(sum, objs[opargs - i], 2);
1784                                                                      }
1785                                                                      return sum;
1786                                                                    }}},
1787   {BUILD_TUPLE_UNPACK,
1788    {ByteCodeSupported,
1789     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3b02() 1790                                                                    * {
1791                                                                      PyObject *sum = PyList_New(0);
1792                                                                      for (int i = opargs; i > 0; i--) {
1793                                                                        auto none_val = _PyList_Extend(
1794                                                                          reinterpret_cast<PyListObject *>(sum),
1795                                                                          objs[opargs - i]);
1796                                                                        Py_DECREF(none_val);
1797                                                                      }
1798                                                                      auto ret = PyList_AsTuple(sum);
1799                                                                      Py_DECREF(sum);
1800                                                                      return ret;
1801                                                                    }}},
1802   {BUILD_SET_UNPACK,
1803    {ByteCodeSupported,
1804     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3c02() 1805                                                                    * {
1806                                                                      PyObject *sum = PySet_New(NULL);
1807                                                                      for (int i = opargs; i > 0; i--) {
1808                                                                        _PySet_Update(sum, objs[opargs - i]);
1809                                                                      }
1810                                                                      return sum;
1811                                                                    }}},
1812   {BUILD_TUPLE_UNPACK_WITH_CALL,
1813    {ByteCodeSupported,
1814     [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3d02() 1815                                                                    * {
1816                                                                      PyObject *sum = PyList_New(0);
1817                                                                      for (int i = opargs; i > 0; i--) {
1818                                                                        auto none_val = _PyList_Extend(
1819                                                                          reinterpret_cast<PyListObject *>(sum),
1820                                                                          objs[opargs - i]);
1821                                                                        Py_DECREF(none_val);
1822                                                                      }
1823                                                                      auto ret = PyList_AsTuple(sum);
1824                                                                      Py_DECREF(sum);
1825                                                                      return ret;
1826                                                                    }}},
1827 };
1828 
OpTrace(PyObject * obj,int opcode,int opargs,TraceVector params,std::string name)1829 OpTrace::OpTrace(PyObject *obj, int opcode, int opargs, TraceVector params, std::string name)
1830     : Trace(obj, nullptr), opcode_(opcode), opargs_(opargs), params_(params), name_(name) {
1831   curType_ = TraceType::Operation;
1832   if (!std::any_of(params.begin(), params.end(), [](const TracePtr &item) { return !item->IsConst(); })) {
1833     is_const_ = true;
1834   } else if (name.find(kIsSeqValUnknown) != std::string::npos || name.find(kIsSeqShapeUnknown) != std::string::npos) {
1835     is_const_ = true;
1836   } else if (kPIJitConfigDefault.getIntConfig(GraphJitConfig::kGuardRelaxCount) > 0 && opcode == LOAD_ATTR &&
1837              name_ == kFuncName) {
1838     is_const_ = true;
1839   }
1840   depth_ = std::accumulate(params.begin(), params.end(), 1, [](int depth, const TracePtr &i) {
1841     int d = i->GetDepth() + 1;
1842     if (d > depth) {
1843       return d;
1844     } else {
1845       return depth;
1846     }
1847   });
1848   CheckSpecialize();
1849 }
1850 
GetOpCode()1851 int OpTrace::GetOpCode() { return opcode_; }
1852 
GetOpArgs()1853 int OpTrace::GetOpArgs() { return opargs_; }
1854 
GetParam(size_t idx)1855 TracePtr OpTrace::GetParam(size_t idx) {
1856   if (params_.size() > idx) {
1857     return params_[idx];
1858   } else {
1859     return nullptr;
1860   }
1861 }
1862 
GetParamCount()1863 size_t OpTrace::GetParamCount() { return params_.size(); }
1864 
GetName()1865 std::string OpTrace::GetName() { return name_; }
1866 
Retrieve(PTraceContext context,bool perf)1867 PyObject *OpTrace::Retrieve(PTraceContext context, bool perf) {
1868   PyObject *ret = Trace::Retrieve(context, perf);
1869   if (ret != nullptr) {
1870     return ret;
1871   }
1872   std::vector<PyObject *> params;
1873   auto iter = std::find_if(params_.begin(), params_.end(), [&params, &context, perf](const TracePtr &p) {
1874     auto param = p->Retrieve(context, perf);
1875     if (param == nullptr) {
1876       return true;
1877     }
1878     if (py::isinstance<mindspore::tensor::Tensor>(param)) {
1879       mindspore::tensor::TensorPtr tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(param);
1880       if (OptStrategy::MakeCalcStrategyByShape(tensor_ptr->shape()) == OptStrategy::CalcKind::kCalcValue) {
1881         tensor_ptr->data_sync(true);
1882       }
1883     }
1884     params.push_back(param);
1885     return params.back() == nullptr;
1886   });
1887   if (iter != params_.end()) {
1888     MS_LOG(DEBUG) << "Guard Check Retrieve fail for " + (*iter)->ToString();
1889     std::for_each(params.begin(), params.end(), [](PyObject *p) { Py_XDECREF(p); });
1890     return nullptr;
1891   }
1892   TracePerf tp(this, perf, false);
1893   if (kBytecodeExecuter.find(opcode_) != kBytecodeExecuter.end() && kBytecodeExecuter[opcode_].first(opargs_, params) &&
1894       kBytecodeExecuter[opcode_].second != nullptr) {
1895     ret = kBytecodeExecuter[opcode_].second(opargs_, params, context);
1896   } else if (opcode_ == LOAD_ATTR) {
1897     MS_EXCEPTION_IF_CHECK_FAIL(name_.size(), "check trace");
1898     ret = PyObject_GetAttrString(params[0], name_.c_str());
1899   } else if (opcode_ == CALL_FUNCTION || opcode_ == CALL_FUNCTION_EX || opcode_ == CALL_FUNCTION_KW) {
1900     ret = DoCall(params, opcode_, name_);
1901   }
1902   for (auto p : params) {
1903     Py_DECREF(p);
1904   }
1905   if (PyErr_Occurred()) {
1906     PyErr_Clear();
1907   }
1908   Cache(context, ret);
1909   return ret;
1910 }
1911 
ToString(bool include_param)1912 std::string OpTrace::ToString(bool include_param) {
1913   if (strTrace_.size() > 0) {
1914     return strTrace_;
1915   }
1916   std::string ret = "operation ";
1917   ret += Opcode(opcode_).name();
1918   ret += "(arg:";
1919   ret += std::to_string(opargs_);
1920   if (name_.size() != 0 || params_.size() > 0) {
1921     ret += ",";
1922   }
1923   if (name_.size() != 0) {
1924     ret += std::string("name:") + name_;
1925     if (params_.size() > 0) {
1926       ret += ",";
1927     }
1928   }
1929   if (include_param && params_.size() > 0) {
1930     for (auto t : params_) {
1931       ret += t->ToString(include_param) + ",";
1932     }
1933     ret = ret.substr(0, ret.size() - 1);
1934   }
1935   ret = ret + ")";
1936   ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
1937   ret = std::regex_replace(ret, std::regex("(\n)"), "");
1938   strTrace_ = ret;
1939   return ret;
1940 }
1941 
FormatString(std::map<Trace *,size_t> * cache)1942 std::string OpTrace::FormatString(std::map<Trace *, size_t> *cache) {
1943   std::stringstream s;
1944   std::stringstream params_str;
1945   params_str << "(";
1946   for (auto i : params_) {
1947     if (cache->find(i.get()) == cache->end()) {
1948       s << i->FormatString(cache) << std::endl;
1949     }
1950     params_str << "%" << (cache->find(i.get())->second) << ", ";
1951   }
1952   params_str << ")";
1953 
1954   cache->insert(std::make_pair(this, cache->size()));
1955   s << "%" << cache->find(this)->second << " = operation " << Opcode(opcode_).name() << " " << opargs_;
1956   if (!name_.empty()) {
1957     s << ", name: " << name_;
1958   }
1959   s << ": " << params_str.str();
1960   return s.str();
1961 }
1962 
Info()1963 const InfoPack &OpTrace::Info() {
1964   if (info_ == nullptr) {
1965     InfoPack info;
1966     info << uint8_t(curType_);
1967     info.Begin();
1968     info << opcode_;
1969     info << opargs_;
1970     info << (name_.size() != 0);
1971     if (name_.size() != 0) {
1972       info << name_;
1973     }
1974     info << (params_.size() != 0);
1975     if (params_.size() > 0) {
1976       for (auto t : params_) {
1977         info << t->Info();
1978       }
1979     }
1980     info.End();
1981     info_ = std::make_shared<InfoPack>(info);
1982     info_->Update();
1983   }
1984   return *info_;
1985 }
1986 
RemoveCastDuplicatePatternPass()1987 TracePtr OpTrace::RemoveCastDuplicatePatternPass() {
1988   OpTracePtr cast_op;
1989   TracePtr next_op;
1990   TracePtr this_op;
1991   TracePtr ret_op;
1992   if (opcode_ != CALL_FUNCTION || !IsCastFunc(name_) ||
1993       (cast_op = CastTrace<OpTrace>(GetParam(kParamIndexTwo))) == nullptr || !IsCastFunc(cast_op->GetName()) ||
1994       (next_op = cast_op->GetParam(kParamIndexTwo)) == nullptr) {
1995     return nullptr;
1996   }
1997   // remove duplicate cast or contrary cast
1998   if (name_ == cast_op->GetName()) {
1999     this_op = cast_op;
2000     ret_op = cast_op->Optimize();
2001   } else {
2002     this_op = next_op;
2003     ret_op = next_op->Optimize();
2004   }
2005   if (ret_op != nullptr) {
2006     return ret_op;
2007   } else {
2008     return this_op;
2009   }
2010 }
2011 
RemovePrimOutIsTensorPass()2012 TracePtr OpTrace::RemovePrimOutIsTensorPass() {
2013   RootTracePtr global_op;
2014   OpTracePtr call_op;
2015   TracePtr param_op;
2016   if (opcode_ != CALL_FUNCTION || !(name_ == kIsInstance) ||
2017       (global_op = CastTrace<RootTrace>(GetParam(kParamIndexThree))) == nullptr ||
2018       global_op->GetTraceType() != TraceType::Global ||
2019       (call_op = CastOpTrace(GetParam(kParamIndexTwo), CALL_FUNCTION)) == nullptr ||
2020       (param_op = call_op->GetParam(kParamIndexOne)) == nullptr) {
2021     return nullptr;
2022   }
2023   int idx;
2024   std::string name;
2025   std::string module_name;
2026   global_op->GetParam(&idx, &name, &module_name);
2027   // isinstance(cast_to_mstensor(...) or Primitive) should be Tensor
2028   if ((name == kTensorName) && ((call_op->GetName() == kCastToMSTensor) ||
2029                                 (CastTrace<ConstTrace>(param_op) != nullptr &&
2030                                  py::isinstance<mindspore::PrimitivePyAdapter>(param_op->GetObject())))) {
2031     is_const_ = true;
2032     if (obj_ == nullptr) {
2033       obj_ = Py_True;
2034       Py_INCREF(obj_);
2035     }
2036     return shared_from_this();
2037   }
2038   return nullptr;
2039 }
2040 
RemoveCastPass()2041 TracePtr OpTrace::RemoveCastPass() {
2042   TracePtr next_op;
2043   if (opcode_ == CALL_FUNCTION && name_ == kCastToMSTensor && (next_op = GetParam(kParamIndexTwo)) != nullptr) {
2044     auto new_ret = next_op->Optimize();
2045     if (new_ret != nullptr) {
2046       return new_ret;
2047     } else {
2048       return next_op;
2049     }
2050   }
2051   return nullptr;
2052 }
2053 
RemoveEmptyTensorPass()2054 TracePtr OpTrace::RemoveEmptyTensorPass() {
2055   OpTracePtr subscr_op;
2056   OpTracePtr loadattr_op;
2057   ConstTracePtr const_op;
2058   ConstTracePtr const2_op;
2059   if (opcode_ != COMPARE_OP || params_.size() < kParamCountTwo) {
2060     return nullptr;
2061   }
2062   for (size_t idx = 0; idx < kParamCountTwo; ++idx) {
2063     TracePtr tmp = GetParam(idx);
2064     if (subscr_op == nullptr) {
2065       subscr_op = CastOpTrace(tmp, BINARY_SUBSCR);
2066     }
2067     if (const_op == nullptr) {
2068       const_op = CastConstTrace(tmp);
2069     }
2070   }
2071   if (subscr_op == nullptr || const_op == nullptr ||
2072       (const2_op = CastConstTrace(subscr_op->GetParam(kParamIndexTwo))) == nullptr ||
2073       (loadattr_op = CastOpTrace(subscr_op->GetParam(kParamIndexOne), kShapeName)) == nullptr ||
2074       CastOpTrace(loadattr_op->GetParam(kParamIndexOne), kCastToAdapterTensor) == nullptr) {
2075     return nullptr;
2076   }
2077   // make judgement shape[0] == 0 as const
2078   auto c1 = const_op->GetObject();
2079   auto c2 = const2_op->GetObject();
2080   if (!PyLong_CheckExact(c1) || !PyLong_CheckExact(c2)) {
2081     return nullptr;
2082   }
2083   auto v1 = _PyLong_AsInt(c1);
2084   auto v2 = _PyLong_AsInt(c2);
2085   if (v1 == 0 && v2 == 0) {
2086     is_const_ = true;
2087     return shared_from_this();
2088   }
2089   return nullptr;
2090 }
2091 
JudgeDTypeChangePass()2092 void OpTrace::JudgeDTypeChangePass() {
2093   if (opcode_ != COMPARE_OP) {
2094     return;
2095   }
2096   for (size_t i = 0; i < kParamCountTwo; ++i) {
2097     OpTracePtr trace = CastOpTrace(GetParam(i), CALL_FUNCTION);
2098     ConstTracePtr const_op = trace ? CastConstTrace(trace->GetParam(kParamIndexOne)) : nullptr;
2099     PyObject *const_param = const_op ? const_op->GetObject() : nullptr;
2100     if (trace != nullptr && const_op != nullptr && const_param != nullptr &&
2101         py::isinstance<mindspore::PrimitivePyAdapter>(const_param) &&
2102         py::cast<mindspore::PrimitivePyAdapterPtr>(const_param)->name() == kDTypePrimName) {
2103       // Compare for output of DType primitive
2104       continue;
2105     } else if ((trace = CastOpTrace(GetParam(i), LOAD_ATTR)) != nullptr && trace->GetName() == kDTypeAttrName) {
2106       // Compare for attribute dtype
2107       continue;
2108     }
2109     return;
2110   }
2111   // data type comparison should be kept as const
2112   EnableRelax();
2113 }
2114 
JudgeDTypeScopePass()2115 void OpTrace::JudgeDTypeScopePass() {
2116   if (opcode_ != CONTAINS_OP) {
2117     return;
2118   }
2119   OpTracePtr trace;
2120   if ((trace = CastOpTrace(GetParam(kParamIndexOne), LOAD_ATTR)) != nullptr && trace->GetName() == kDTypeAttrName) {
2121     // data type to check whether to be contained should be const
2122     EnableRelax();
2123   }
2124 }
2125 
JudgeDTypeTensorAttrPass()2126 void OpTrace::JudgeDTypeTensorAttrPass() {
2127   if (opcode_ != CALL_FUNCTION) {
2128     return;
2129   }
2130   RootTracePtr global_op;
2131   OpTracePtr call_op;
2132   if (params_.size() < kParamCountTwo || (global_op = CastTrace<RootTrace>(params_[kParamIndexOne])) == nullptr ||
2133       (call_op = CastOpTrace(params_[kParamIndexTwo], BINARY_SUBSCR)) == nullptr) {
2134     return;
2135   }
2136   int idx;
2137   std::string name;
2138   std::string module_name;
2139   global_op->GetParam(&idx, &name, &module_name);
2140   auto tsr = call_op->GetObject();
2141   if (tsr == nullptr) {
2142     return;
2143   }
2144   std::string type_name = std::string(py::str(reinterpret_cast<PyObject *>(Py_TYPE(tsr))));
2145   if (name == kDType_AttrName && module_name.find("mindspore") == 0 &&
2146       type_name.find(kTensorName) != std::string::npos) {
2147     EnableRelax();
2148   }
2149 }
2150 
JudgeCodeChangePass()2151 void OpTrace::JudgeCodeChangePass() {
2152   if (opcode_ != LOAD_ATTR || params_.size() < kParamCountOne || name_ != kCodeName) {
2153     return;
2154   }
2155   if (params_[kParamIndexOne]->IsConst()) {
2156     EnableRelax();
2157   } else {
2158     auto p1 = CastOpTrace(params_[kParamIndexOne], LOAD_ATTR);
2159     if (p1 == nullptr) {
2160       return;
2161     }
2162     auto p2 = p1->GetParam(kParamIndexOne)->GetObject();
2163     std::string type_name;
2164     if (p2 != nullptr) {
2165       type_name = std::string(py::str(reinterpret_cast<PyObject *>(Py_TYPE(p2))));
2166     }
2167     if (type_name.find(kMindTorchFlag) != std::string::npos) {
2168       EnableRelax();
2169     }
2170   }
2171 }
2172 
JudgeTrainFlagPass()2173 void OpTrace::JudgeTrainFlagPass() {
2174   if (opcode_ != LOAD_ATTR || params_.size() < kParamCountOne) {
2175     return;
2176   }
2177   if (name_ == kTrainingFlag) {
2178     // training flag shouldn't be changed frequently
2179     EnableRelax();
2180   }
2181 }
2182 
JudgeCompareConstPass()2183 void OpTrace::JudgeCompareConstPass() {
2184   if (RelaxEnabled()) {
2185     return;
2186   }
2187   if (opcode_ != COMPARE_OP || params_.size() < kParamCountTwo) {
2188     return;
2189   }
2190   if (params_[kParamIndexOne]->GetObject() == nullptr || params_[kParamIndexTwo]->GetObject() == nullptr) {
2191     return;
2192   }
2193   EnableRelax();
2194 }
2195 
JudgeContainsConstPass()2196 void OpTrace::JudgeContainsConstPass() {
2197   if (RelaxEnabled()) {
2198     return;
2199   }
2200   if (opcode_ != CONTAINS_OP || params_.size() < kParamCountTwo) {
2201     return;
2202   }
2203   if (params_[kParamIndexOne]->GetObject() == nullptr || params_[kParamIndexTwo]->GetObject() == nullptr) {
2204     return;
2205   }
2206   EnableRelax();
2207 }
2208 
JudgeInplaceAddConstPass()2209 void OpTrace::JudgeInplaceAddConstPass() {
2210   if (RelaxEnabled()) {
2211     return;
2212   }
2213   if (opcode_ != INPLACE_ADD || params_.size() < kParamCountTwo) {
2214     return;
2215   }
2216   if (params_[kParamIndexOne]->GetObject() == nullptr || params_[kParamIndexTwo]->GetObject() == nullptr) {
2217     return;
2218   }
2219   EnableRelax();
2220 }
2221 
JudgeIsConstPass()2222 void OpTrace::JudgeIsConstPass() {
2223   if (RelaxEnabled()) {
2224     return;
2225   }
2226   if (opcode_ != IS_OP || params_.size() < kParamCountTwo) {
2227     return;
2228   }
2229   if (params_[kParamIndexOne]->GetObject() == nullptr || params_[kParamIndexTwo]->GetObject() == nullptr) {
2230     return;
2231   }
2232   OpTracePtr subscr_op;
2233   if ((subscr_op = CastTrace<OpTrace>(params_[kParamIndexOne])) != nullptr &&
2234       (CastConstTrace(params_[kParamIndexTwo]) != nullptr || params_[kParamIndexTwo]->IsConst())) {
2235     if (subscr_op->params_.size() < kParamCountTwo) {
2236       return;
2237     }
2238     auto tsr = subscr_op->GetParam(kParamIndexOne)->GetObject();
2239     if (tsr == nullptr) {
2240       return;
2241     }
2242     std::string type_name = std::string(py::str(reinterpret_cast<PyObject *>(Py_TYPE(tsr))));
2243     if (type_name.find(kTensorName) == std::string::npos) {
2244       return;
2245     }
2246     if (subscr_op->GetParam(kParamIndexTwo)->IsConst() && params_[kParamIndexTwo]->IsConst()) {
2247       EnableRelax();
2248     }
2249   }
2250 }
2251 
JudgeBoundMethodPass()2252 void OpTrace::JudgeBoundMethodPass() {
2253   if (RelaxEnabled()) {
2254     return;
2255   }
2256   if (opcode_ != LOAD_ATTR || params_.size() < kParamCountOne) {
2257     return;
2258   }
2259   if (params_[kParamIndexOne]->GetObject() == nullptr) {
2260     return;
2261   }
2262   if (name_ == kFuncName) {
2263     EnableRelax();
2264   }
2265 }
2266 
JudgeSubScrRandPass()2267 void OpTrace::JudgeSubScrRandPass() {
2268   if (RelaxEnabled()) {
2269     return;
2270   }
2271   if (opcode_ != BINARY_SUBTRACT || params_.size() < kParamCountTwo) {
2272     return;
2273   }
2274   auto call_op = CastOpTrace(params_[kParamIndexOne], CALL_FUNCTION);
2275   ConstTracePtr prim;
2276   if (call_op != nullptr && call_op->params_.size() > kParamCountOne) {
2277     if ((prim = CastConstTrace(call_op->params_[kParamIndexOne])) != nullptr) {
2278       std::string prim_name = py::cast<mindspore::PrimitivePyAdapterPtr>(prim->GetObject())->name();
2279       if (prim_name == kRankPrimName) {
2280         EnableRelax();
2281       }
2282     }
2283   }
2284 }
2285 
GetGuardFuncKeyMap()2286 static const std::unordered_map<size_t, size_t> &GetGuardFuncKeyMap() {
2287   static std::unordered_map<size_t, size_t> map = {};
2288   static bool init = false;
2289   if (init) {
2290     return map;
2291   }
2292   init = true;
2293   py::object func_map = Utils::GetModuleAttr(kFuncWhiteListModuleName, kGuardFuncMapName, true, true);
2294   MS_EXCEPTION_IF_CHECK_FAIL(PyDict_CheckExact(func_map.ptr()), "white list func map must be 'dict[int, int]'");
2295   PyObject *key;
2296   PyObject *value;
2297   Py_ssize_t pos = 0;
2298   while (PyDict_Next(func_map.ptr(), &pos, &key, &value)) {
2299     MS_EXCEPTION_IF_CHECK_FAIL(PyLong_CheckExact(key), "white list func map key must be 'int'");
2300     MS_EXCEPTION_IF_CHECK_FAIL(PyLong_CheckExact(value), "white list func map value must be 'int'");
2301     map[PyLong_AsSize_t(key)] = PyLong_AsSize_t(value);
2302   }
2303   return map;
2304 }
2305 
CheckRelaxGuardFunc(const py::object & callable)2306 static bool CheckRelaxGuardFunc(const py::object &callable) {
2307   static size_t guard_key_relax_func = 0;
2308   if (guard_key_relax_func == 0) {
2309     py::object key_object = Utils::GetModuleAttr(kFuncWhiteListModuleName, "GUARD_KEY_RELAX_FUNC", true, true);
2310     guard_key_relax_func = py::cast<size_t>(key_object);
2311   }
2312 
2313   auto iter = GetGuardFuncKeyMap().find(FunctionId(callable));
2314   return iter != GetGuardFuncKeyMap().end() && iter->second == guard_key_relax_func;
2315 }
2316 
JudgeRelaxGuardFuncPass()2317 void OpTrace::JudgeRelaxGuardFuncPass() {
2318   if (opcode_ != CALL_FUNCTION || params_.size() < kParamCountOne) {
2319     return;
2320   }
2321   if (CheckRelaxGuardFunc(py::cast<py::object>(params_[0]->GetObject()))) {
2322     EnableRelax();
2323   }
2324 }
2325 
CheckSpecialize()2326 void OpTrace::CheckSpecialize() {
2327   bool any_params_specialized = false;
2328   for (auto param : params_) {
2329     if (!param->IsConst() && param->IsSpecialized()) {
2330       any_params_specialized = true;
2331       break;
2332     }
2333   }
2334   if (opcode_ == CALL_FUNCTION) {
2335     if ((name_ == kShape_Name || name_ == kCastToAdapterTensor || name_ == kCastToMSTensor) &&
2336         !any_params_specialized) {
2337       is_specialized_ = true;
2338     } else if (params_.size() > kParamCountOne &&
2339                py::isinstance<mindspore::PrimitivePyAdapter>(params_[kParamIndexOne]->GetObject())) {
2340       std::string prim_name = py::cast<mindspore::PrimitivePyAdapterPtr>(params_[kParamIndexOne]->GetObject())->name();
2341       if (prim_name == kCastPrimName) {
2342         is_specialized_ = params_[kParamIndexTwo]->IsSpecialized();
2343       } else if (prim_name == kLayerNormPrimName) {
2344         is_specialized_ = params_[kParamIndexTwo]->IsSpecialized();
2345       } else if (prim_name == kReshapePrimName) {
2346         is_specialized_ = any_params_specialized || !params_[kParamIndexThree]->IsConst();
2347       } else if (prim_name == kShapePrimName) {
2348         is_specialized_ = params_[kParamIndexTwo]->IsSpecialized();
2349       }
2350     }
2351   } else {
2352     is_specialized_ = true;
2353   }
2354 }
2355 
Optimize()2356 TracePtr OpTrace::Optimize() {
2357   if (is_const_ || RelaxEnabled()) {
2358     return nullptr;
2359   }
2360   TracePtr ret;
2361   if ((ret = RemoveCastDuplicatePatternPass()) != nullptr || (ret = RemoveEmptyTensorPass()) != nullptr ||
2362       (ret = RemovePrimOutIsTensorPass()) != nullptr || (ret = RemoveCastPass()) != nullptr) {
2363     return ret;
2364   }
2365   if (relax_limit_ > 0) {
2366     JudgeDTypeChangePass();
2367     JudgeDTypeScopePass();
2368     JudgeTrainFlagPass();
2369     JudgeCompareConstPass();
2370     JudgeContainsConstPass();
2371     JudgeInplaceAddConstPass();
2372     JudgeIsConstPass();
2373     JudgeCodeChangePass();
2374     JudgeBoundMethodPass();
2375     JudgeSubScrRandPass();
2376     JudgeDTypeTensorAttrPass();
2377     JudgeRelaxGuardFuncPass();
2378   }
2379   bool need_update = false;
2380   for (size_t i = 0; i < params_.size(); ++i) {
2381     params_[i] = OptimizeTrace(params_[i], &need_update);
2382   }
2383   if (need_update) {
2384     if (!std::any_of(params_.begin(), params_.end(), [](const TracePtr &item) { return !item->IsConst(); })) {
2385       is_const_ = true;
2386     }
2387     info_ = nullptr;
2388     strTrace_ = "";
2389     Info();
2390     return shared_from_this();
2391   } else {
2392     return nullptr;
2393   }
2394 }
2395 
SetRelaxCount(int cnt)2396 void OpTrace::SetRelaxCount(int cnt) {
2397   Trace::SetRelaxCount(cnt);
2398   for (auto param : params_) {
2399     param->SetRelaxCount(cnt);
2400   }
2401 }
2402 
operator ==(const Trace & trace)2403 bool OpTrace::operator==(const Trace &trace) {
2404   bool ret = false;
2405   if (Trace::operator==(trace)) {
2406     const OpTrace &t = (const OpTrace &)trace;
2407     ret = opcode_ == t.opcode_ && opargs_ == t.opargs_ && name_ == t.name_ && params_.size() == t.params_.size();
2408     if (ret) {
2409       for (size_t i = 0; i < params_.size(); i++) {
2410         if (*(params_[i]) == *(t.params_[i])) {
2411           continue;
2412         } else {
2413           ret = false;
2414           break;
2415         }
2416       }
2417     }
2418   }
2419   return ret;
2420 }
2421 
Replace(std::shared_ptr<Trace> dst,std::shared_ptr<Trace> src)2422 void OpTrace::Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src) {
2423   Trace::Replace(dst, src);
2424   for (size_t i = 0; i < params_.size(); ++i) {
2425     if (*params_[i] == *src) {
2426       params_[i] = dst;
2427     } else {
2428       params_[i]->Replace(dst, src);
2429     }
2430   }
2431 }
2432 
Detach()2433 void OpTrace::Detach() {
2434   Trace::Detach();
2435   for (auto t : params_) {
2436     t->Detach();
2437   }
2438 }
2439 
Support(TraceType tt)2440 bool OpTrace::Support(TraceType tt) { return tt == TraceType::Operation; }
2441 
2442 static std::map<int, TraceType> kMapBytecodeToTraceType = {
2443   {LOAD_CLOSURE, TraceType::Closure}, {LOAD_DEREF, TraceType::Deref},           {LOAD_GLOBAL, TraceType::Global},
2444   {LOAD_NAME, TraceType::Name},       {LOAD_CLASSDEREF, TraceType::ClassDeref},
2445 };
2446 
CreateOpTraceByBytecode(PyObject * obj,int opcode,int opargs,TraceVector params,std::string module_name,std::string name,bool strict)2447 TracePtr CreateOpTraceByBytecode(PyObject *obj, int opcode, int opargs, TraceVector params, std::string module_name,
2448                                  std::string name, bool strict) {
2449   static const std::set<int> root_op = {
2450     LOAD_CLOSURE, LOAD_DEREF, LOAD_GLOBAL, LOAD_NAME, LOAD_CLASSDEREF,
2451   };
2452   if (opcode == LOAD_DEREF && opargs < 0) {
2453     return nullptr;
2454   }
2455   if (root_op.find(opcode) != root_op.end()) {
2456     return std::make_shared<RootTrace>(obj, kMapBytecodeToTraceType[opcode], opargs, name, module_name);
2457   }
2458   if (opcode == LOAD_CONST) {
2459     return std::make_shared<ConstTrace>(obj, -1);
2460   }
2461   if (Opcode(opcode).IsCall()) {
2462     if (params.size() < 1 || !SupportCall(params[0]->GetObject(), name)) {
2463       if (strict) {
2464         return nullptr;
2465       } else {
2466         return std::make_shared<UnsupportedTrace>(obj, params, opcode, opargs);
2467       }
2468     }
2469   }
2470   return std::make_shared<OpTrace>(obj, opcode, opargs, params, name);
2471 }
2472 
CreateOpTrace(PyObject * obj,int opcode,int opargs,TraceVector params,const std::string & module_name,const std::string & name,bool strict,bool print)2473 TracePtr CreateOpTrace(PyObject *obj, int opcode, int opargs, TraceVector params, const std::string &module_name,
2474                        const std::string &name, bool strict, bool print) {
2475   std::vector<PyObject *> vparams;
2476   for (auto trace : params) {
2477     if (trace == nullptr) {
2478       return nullptr;
2479     } else if (trace->GetTraceType() == TraceType::Unsupported) {
2480       return std::make_shared<UnsupportedTrace>(obj, params, opcode, opargs);
2481     } else {
2482       vparams.push_back(trace->GetObject());
2483     }
2484   }
2485   if (kBytecodeExecuter.find(opcode) == kBytecodeExecuter.end() || !kBytecodeExecuter[opcode].first(opargs, vparams)) {
2486     if (print) {
2487       GRAPH_JIT_LOG_F("Unsupported bytecode %d args %d!\n", opcode, opargs);
2488     } else {
2489       MS_LOG(DEBUG) << "Unsupported bytecode " << opcode << " args " << opargs << "!";
2490     }
2491     if (strict) {
2492       return nullptr;
2493     } else {
2494       return std::make_shared<UnsupportedTrace>(obj, params, opcode, opargs);
2495     }
2496   }
2497   return CreateOpTraceByBytecode(obj, opcode, opargs, params, module_name, name, strict);
2498 }
2499 
CustomizedTrace(PyObject * obj,RetrieveFunc rfunc,ToStringFunc sfunc)2500 CustomizedTrace::CustomizedTrace(PyObject *obj, RetrieveFunc rfunc, ToStringFunc sfunc)
2501     : Trace(obj, nullptr), retrieve_(rfunc), tostring_(sfunc) {
2502   curType_ = TraceType::Customized;
2503   depth_ = 1;
2504 }
2505 
Retrieve(PTraceContext context,bool perf)2506 PyObject *CustomizedTrace::Retrieve(PTraceContext context, bool perf) {
2507   PyObject *ret = Trace::Retrieve(context, perf);
2508   if (ret != nullptr) {
2509     return ret;
2510   }
2511   TracePerf tp(this, perf, false);
2512   ret = retrieve_(context);
2513   Cache(context, ret);
2514   return ret;
2515 }
2516 
ToString(bool include_param)2517 std::string CustomizedTrace::ToString(bool include_param) {
2518   if (strTrace_.size() > 0) {
2519     return strTrace_;
2520   }
2521   std::string ret = tostring_(false);
2522   ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
2523   ret = std::regex_replace(ret, std::regex("(\n)"), "");
2524   strTrace_ = ret;
2525   return ret;
2526 }
2527 
Info()2528 const InfoPack &CustomizedTrace::Info() {
2529   if (info_ == nullptr) {
2530     InfoPack info;
2531     info << uint8_t(curType_);
2532     info.Begin();
2533     info << tostring_(true);
2534     info.End();
2535     info_ = std::make_shared<InfoPack>(info);
2536     info_->Update();
2537   }
2538   return *info_;
2539 }
2540 
Support(TraceType tt)2541 bool CustomizedTrace::Support(TraceType tt) { return tt == TraceType::Customized; }
2542 
UnsupportedTrace(PyObject * obj,TraceVector params,int op,int arg)2543 UnsupportedTrace::UnsupportedTrace(PyObject *obj, TraceVector params, int op, int arg)
2544     : Trace(obj, nullptr), params_(params), op_(op), arg_(arg) {
2545   curType_ = TraceType::Unsupported;
2546   if (!std::any_of(params.begin(), params.end(), [](const TracePtr &item) { return !item->IsConst(); })) {
2547     is_const_ = true;
2548   }
2549   depth_ = std::accumulate(params.begin(), params.end(), 1, [](int depth, const TracePtr &i) {
2550     int d = i->GetDepth() + 1;
2551     if (d > depth) {
2552       return d;
2553     } else {
2554       return depth;
2555     }
2556   });
2557 }
2558 
Retrieve(PTraceContext context,bool perf)2559 PyObject *UnsupportedTrace::Retrieve(PTraceContext context, bool perf) {
2560   PyObject *ret = Trace::Retrieve(context, perf);
2561   if (ret != nullptr) {
2562     return ret;
2563   }
2564   std::vector<PyObject *> params;
2565   bool fail = false;
2566   for (auto p : params_) {
2567     auto obj = p->Retrieve(context, perf);
2568     params.push_back(obj);
2569     if (p->GetTraceType() != TraceType::Unsupported) {
2570       // compare obj with original obj in trace for inputs of unsupported trace
2571       auto orig = p->GetObject();
2572       if (!IsPyObjectEqual(obj, orig)) {
2573         fail = true;
2574         break;
2575       }
2576     }
2577     if (params.back() == nullptr) {
2578       MS_LOG(DEBUG) << "Guard Check Retrieve fail for " + p->ToString();
2579       fail = true;
2580       break;
2581     }
2582   }
2583   TracePerf tp(this, perf, false);
2584   for (auto p : params) {
2585     Py_XDECREF(p);
2586   }
2587   if (fail) {
2588     return nullptr;
2589   } else {
2590     Cache(context, obj_);
2591     return obj_;
2592   }
2593 }
2594 
ToString(bool include_param)2595 std::string UnsupportedTrace::ToString(bool include_param) {
2596   if (strTrace_.size() > 0) {
2597     return strTrace_;
2598   }
2599   std::string ret = "unsupported ";
2600   ret += Opcode(op_).name();
2601   ret += "(arg:";
2602   ret += std::to_string(arg_);
2603   if (include_param && params_.size() > 0) {
2604     ret += ",";
2605     for (auto t : params_) {
2606       ret += t->ToString(include_param) + ",";
2607     }
2608     ret = ret.substr(0, ret.size() - 1);
2609   }
2610   ret = ret + ")";
2611   ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
2612   ret = std::regex_replace(ret, std::regex("(\n)"), "");
2613   strTrace_ = ret;
2614   return ret;
2615 }
2616 
FormatString(std::map<Trace *,size_t> * cache)2617 std::string UnsupportedTrace::FormatString(std::map<Trace *, size_t> *cache) {
2618   std::stringstream s;
2619   std::stringstream params_str;
2620   params_str << "(";
2621   for (auto i : params_) {
2622     if (cache->find(i.get()) == cache->end()) {
2623       s << i->FormatString(cache) << std::endl;
2624     }
2625     params_str << "%" << (cache->find(i.get())->second) << ", ";
2626     if (i->GetTraceType() == TraceType::Unsupported) {
2627       params_str << "...";
2628       break;
2629     }
2630   }
2631   params_str << ")";
2632 
2633   cache->insert(std::make_pair(this, cache->size()));
2634   s << "%" << cache->find(this)->second << " = unsupported " << Opcode(op_).name() << " " << arg_ << ": "
2635     << params_str.str();
2636   return s.str();
2637 }
2638 
Info()2639 const InfoPack &UnsupportedTrace::Info() {
2640   if (info_ == nullptr) {
2641     InfoPack info;
2642     info << uint8_t(curType_);
2643     info.Begin();
2644     info << op_;
2645     info << arg_;
2646     info << uint64_t(params_.size());
2647     for (auto i : params_) {
2648       info << i->Info();
2649     }
2650     info.End();
2651     info_ = std::make_shared<InfoPack>(info);
2652     info_->Update();
2653   }
2654   return *info_;
2655 }
2656 
SetRelaxCount(int cnt)2657 void UnsupportedTrace::SetRelaxCount(int cnt) {
2658   Trace::SetRelaxCount(cnt);
2659   for (auto param : params_) {
2660     param->SetRelaxCount(cnt);
2661   }
2662 }
2663 
GetParams()2664 TraceVector UnsupportedTrace::GetParams() { return params_; }
2665 
Detach()2666 void UnsupportedTrace::Detach() {
2667   Trace::Detach();
2668   for (auto t : params_) {
2669     t->Detach();
2670   }
2671 }
2672 
Support(TraceType tt)2673 bool UnsupportedTrace::Support(TraceType tt) { return tt == TraceType::Unsupported; }
2674 
GetObjectFromTrace(const PyFrameObject * frame,TracePtr trace,std::map<size_t,PyObject * > * cache,bool perf)2675 PyObject *GetObjectFromTrace(const PyFrameObject *frame, TracePtr trace, std::map<size_t, PyObject *> *cache,
2676                              bool perf) {
2677   TraceContext context = {frame->f_globals,    frame->f_builtins, frame->f_locals,
2678                           frame->f_localsplus, frame->f_code,     cache};
2679   if (trace != nullptr) {
2680     return trace->Retrieve(&context, perf);
2681   } else {
2682     return nullptr;
2683   }
2684 }
2685 }  // namespace pijit
2686 }  // namespace mindspore
2687