1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_guard/trace.h"
17 #include <map>
18 #include <vector>
19 #include <set>
20 #include <unordered_map>
21 #include <functional>
22 #include <utility>
23 #include <regex>
24 #include "pipeline/jit/pi/graph_guard/guard.h"
25 #include "pipeline/jit/pi/graph_guard/guard_utils.h"
26 #include "pybind11/pybind11.h"
27 #include "pybind_api/ir/primitive_py.h"
28 #include "include/common/utils/convert_utils_py.h"
29 #include "pipeline/jit/pi/graph_guard/infer.h"
30 #include "pipeline/jit/pi/graph_guard/strategy.h"
31 #include "pipeline/jit/pi/utils/utils.h"
32 #include "include/common/utils/python_adapter.h"
33 #include "pipeline/jit/pi/graph_capture/abstract_object.h"
34 #include "pipeline/jit/pi/pi_jit_config.h"
35 #include "pipeline/jit/pi/external.h"
36
37 namespace mindspore {
38 namespace pijit {
39
40 static constexpr size_t kParamCountOne = 1;
41 static constexpr size_t kParamCountTwo = 2;
42 static constexpr size_t kParamCountThree = 3;
43 static constexpr size_t kParamIndexOne = 0;
44 static constexpr size_t kParamIndexTwo = 1;
45 static constexpr size_t kParamIndexThree = 2;
46 static const char kCastPrimName[] = "Cast";
47 static const char kLayerNormPrimName[] = "LayerNorm";
48 static const char kReshapePrimName[] = "Reshape";
49 static const char kShapePrimName[] = "Shape";
50 static const char kShape_Name[] = "shape_";
51 static const char kShapeName[] = "shape";
52 static const char kRankPrimName[] = "Rank";
53 static const char kCastToMSTensor[] = "cast_to_ms_tensor";
54 static const char kCastToAdapterTensor[] = "cast_to_adapter_tensor";
55 static const std::vector<std::string> kCastFunc = {kCastToMSTensor, kCastToAdapterTensor};
56 static const char kIsInstance[] = "isinstance";
57 static const char kTensorName[] = "Tensor";
58 static const char kDTypeAttrName[] = "dtype";
59 static const char kDType_AttrName[] = "dtype_";
60 static const char kDTypePrimName[] = "DType";
61 static const char kCodeName[] = "__code__";
62 static const char kFuncName[] = "__func__";
63 static const char kIsSeqValUnknown[] = "is_sequence_value_unknown";
64 static const char kIsSeqShapeUnknown[] = "is_sequence_shape_unknown";
65 static const char kMindTorchFlag[] = "mindtorch";
66 static const char kTrainingFlag[] = "training";
67 static const char kMindSporePackPrefix[] = "mindspore.";
68 static const char kMindtorchPackPrefix[] = "mindtorch.";
69
70 constexpr const char *kFuncWhiteListModuleName = "mindspore._extends.pijit.pijit_func_white_list";
71 constexpr const char *kGuardFuncMapName = "_guard_func_map";
72
73 static PyObject *RichCompare(PyObject *left, PyObject *right, int oparg);
74
IsCastFunc(std::string name)75 static bool IsCastFunc(std::string name) {
76 return std::find(kCastFunc.begin(), kCastFunc.end(), name) != kCastFunc.end();
77 }
78
OptimizeTrace(TracePtr trace,bool * update)79 static TracePtr OptimizeTrace(TracePtr trace, bool *update) {
80 if (trace != nullptr) {
81 auto new_trace = trace->Optimize();
82 if (new_trace != nullptr) {
83 if (update != nullptr) {
84 *update = true;
85 }
86 return new_trace;
87 }
88 }
89 return trace;
90 }
91
92 template <typename T>
CastTrace(TracePtr trace)93 std::shared_ptr<T> CastTrace(TracePtr trace) {
94 if (trace != nullptr && T::Support(trace->GetTraceType())) {
95 return std::static_pointer_cast<T>(trace);
96 }
97 return nullptr;
98 }
99
CastConstTrace(TracePtr trace)100 static ConstTracePtr CastConstTrace(TracePtr trace) {
101 ConstTracePtr ret = CastTrace<ConstTrace>(trace);
102 if (ret != nullptr && ret->GetIndex() == -1) {
103 return ret;
104 }
105 return nullptr;
106 }
107
CastOpTrace(TracePtr trace,int opcode)108 static OpTracePtr CastOpTrace(TracePtr trace, int opcode) {
109 OpTracePtr ret = CastTrace<OpTrace>(trace);
110 if (ret != nullptr && ret->GetOpCode() == opcode) {
111 return ret;
112 }
113 return nullptr;
114 }
115
CastOpTrace(TracePtr trace,const std::string & name)116 static OpTracePtr CastOpTrace(TracePtr trace, const std::string &name) {
117 OpTracePtr ret = CastTrace<OpTrace>(trace);
118 if (ret != nullptr && ret->GetName() == name) {
119 return ret;
120 }
121 return nullptr;
122 }
123
124 class TracePerf {
125 public:
TracePerf(Trace * trace,bool enable,bool cache)126 TracePerf(Trace *trace, bool enable, bool cache)
127 : trace_(trace), enable_(enable), cache_(cache), perf_(OptGuardPerf::GetGuardPerf()) {
128 if (enable_) {
129 perf_->LogTracePerfStart();
130 }
131 }
~TracePerf()132 ~TracePerf() {
133 if (enable_) {
134 perf_->LogTracePerfEnd(trace_, cache_);
135 }
136 }
137
138 protected:
139 Trace *trace_;
140 bool enable_;
141 bool cache_;
142 OptGuardPerf *perf_;
143 };
144
Trace(PyObject * pObj,std::shared_ptr<Trace> pOrigin)145 Trace::Trace(PyObject *pObj, std::shared_ptr<Trace> pOrigin)
146 : obj_(pObj),
147 origin_(pOrigin),
148 info_(nullptr),
149 is_const_(false),
150 relax_count_(-1),
151 relax_limit_(0),
152 is_specialized_(false),
153 depth_(0) {
154 if (pOrigin != nullptr) {
155 originType_ = pOrigin->GetOriginType();
156 curType_ = pOrigin->GetTraceType();
157 } else {
158 originType_ = Unknown;
159 curType_ = Unknown;
160 }
161 if (obj_ != Py_None && obj_ != NULL) {
162 Py_INCREF(obj_);
163 }
164 }
165
~Trace()166 Trace::~Trace() {
167 if (obj_ != Py_None && obj_ != NULL) {
168 Py_DECREF(obj_);
169 }
170 }
171
GetOrigin()172 TracePtr Trace::GetOrigin() {
173 if (origin_ != nullptr) {
174 return origin_;
175 } else {
176 return nullptr;
177 }
178 }
179
GetObject()180 PyObject *Trace::GetObject() { return obj_; }
181
GetTraceType()182 TraceType Trace::GetTraceType() { return curType_; }
183
GetOriginType()184 TraceType Trace::GetOriginType() { return originType_; }
185
Replace(std::shared_ptr<Trace> dst,std::shared_ptr<Trace> src)186 void Trace::Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src) {
187 if (origin_ != nullptr) {
188 if (*origin_ == *src) {
189 origin_ = dst;
190 } else {
191 origin_->Replace(dst, src);
192 }
193 }
194 }
195
operator ==(const Trace & trace)196 bool Trace::operator==(const Trace &trace) {
197 if (curType_ == trace.curType_ && obj_ == trace.obj_) {
198 return true;
199 } else {
200 return false;
201 }
202 }
203
Detach()204 void Trace::Detach() {
205 if (obj_ != Py_None && obj_ != nullptr && !is_const_ && !PyLong_Check(obj_) && !PyType_Check(obj_)) {
206 Py_DECREF(obj_);
207 obj_ = nullptr;
208 }
209 if (origin_ != nullptr) {
210 origin_->Detach();
211 }
212 }
213
Retrieve(PTraceContext context,bool perf)214 PyObject *Trace::Retrieve(PTraceContext context, bool perf) {
215 TracePerf tp(this, perf, true);
216 if (is_const_) {
217 Py_XINCREF(obj_);
218 return obj_;
219 }
220 if (context->cache != nullptr) {
221 size_t szTrace = this->Info().Id();
222 auto cache = context->cache;
223 auto iter = cache->find(szTrace);
224 if (iter != cache->end()) {
225 auto item = iter->second;
226 Py_XINCREF(item);
227 return item;
228 }
229 }
230 return nullptr;
231 }
232
Cache(PTraceContext context,PyObject * obj)233 void Trace::Cache(PTraceContext context, PyObject *obj) {
234 if (context->cache != nullptr && obj != nullptr) {
235 size_t szTrace = this->Info().Id();
236 Py_XINCREF(obj);
237 auto iter = context->cache->find(szTrace);
238 if (iter != context->cache->end()) {
239 Py_XDECREF(iter->second);
240 iter->second = obj;
241 } else {
242 (*(context->cache))[szTrace] = obj;
243 }
244 }
245 if (RelaxEnabled() && obj != nullptr) {
246 if (obj_ != nullptr) {
247 auto cmp = RichCompare(obj_, obj, Py_EQ);
248 if (cmp != Py_True) {
249 relax_count_ = -1;
250 } else if (relax_count_ < relax_limit_) {
251 relax_count_++;
252 } else {
253 is_const_ = true;
254 }
255 Py_XDECREF(cmp);
256 }
257 if (obj_ == nullptr) {
258 Py_XINCREF(obj);
259 obj_ = obj;
260 }
261 }
262 }
263
IsConst() const264 bool Trace::IsConst() const { return is_const_; }
265
This()266 TracePtr Trace::This() { return shared_from_this(); }
267
SetRelaxCount(int cnt)268 void Trace::SetRelaxCount(int cnt) {
269 relax_count_ = -1;
270 relax_limit_ = cnt;
271 }
272
GetRelaxCount() const273 int Trace::GetRelaxCount() const { return relax_limit_; }
274
EnableRelax()275 void Trace::EnableRelax() { relax_count_ = 0; }
276
RelaxEnabled() const277 bool Trace::RelaxEnabled() const { return relax_count_ >= 0; }
278
IsSpecialized() const279 bool Trace::IsSpecialized() const { return is_specialized_; }
280
GetDepth() const281 int Trace::GetDepth() const { return depth_; }
282
Optimize()283 TracePtr Trace::Optimize() { return nullptr; }
284
FormatString(std::map<Trace *,size_t> * cache)285 std::string Trace::FormatString(std::map<Trace *, size_t> *cache) {
286 cache->insert(std::make_pair(this, cache->size()));
287 return "%" + std::to_string(cache->find(this)->second) + " = " + this->ToString();
288 }
289
RootTrace(PyObject * pObj,TraceType tt,int index,std::string name,std::string module_name)290 RootTrace::RootTrace(PyObject *pObj, TraceType tt, int index, std::string name, std::string module_name)
291 : Trace(pObj, nullptr), idx_(index), name_(name), module_name_(module_name) {
292 depth_ = 1;
293 originType_ = tt;
294 curType_ = tt;
295 for (auto n : kPIJitConfigDefault.allowed_inline_modules()) {
296 if (module_name.find(n) == 0) {
297 is_const_ = true;
298 break;
299 }
300 }
301 if (!is_const_ && (module_name.find(kMindSporePackPrefix) == 0 || module_name.find(kMindtorchPackPrefix) == 0 ||
302 name.find(kCastToAdapterTensor) == 0 || name.find(kCastToMSTensor) == 0)) {
303 is_const_ = true;
304 }
305 if (curType_ == TraceType::Deref) {
306 is_const_ = false;
307 }
308 if (pObj == nullptr) {
309 return;
310 }
311 if (py::isinstance<mindspore::tensor::MetaTensor>(pObj) || py::isinstance<mindspore::tensor::Tensor>(pObj) ||
312 IsStubTensor(py::cast<py::object>(pObj))) {
313 is_specialized_ = false;
314 }
315 }
316
GetParam(int * index,std::string * name,std::string * module_name)317 void RootTrace::GetParam(int *index, std::string *name, std::string *module_name) {
318 *index = idx_;
319 *name = name_;
320 *module_name = module_name_;
321 }
322
Retrieve(PTraceContext context,bool perf)323 PyObject *RootTrace::Retrieve(PTraceContext context, bool perf) {
324 PyObject *ret = Trace::Retrieve(context, perf);
325 if (ret != nullptr) {
326 return ret;
327 }
328 TracePerf tp(this, perf, false);
329 switch (curType_) {
330 case TraceType::Global: {
331 ret = RetrieveGlobal(context);
332 break;
333 }
334 case TraceType::Deref: {
335 ret = RetrieveDeref(context);
336 break;
337 }
338 case TraceType::Closure: {
339 ret = RetrieveClosure(context);
340 break;
341 }
342 case TraceType::BuiltIn: {
343 ret = RetrieveBuiltin(context);
344 break;
345 }
346 case TraceType::Local:
347 ret = RetrieveLocal(context);
348 Py_XINCREF(ret);
349 break;
350 case TraceType::Param:
351 ret = RetrieveParam(context);
352 Py_XINCREF(ret);
353 break;
354 case TraceType::Name: {
355 ret = RetrieveName(context);
356 break;
357 }
358 case TraceType::ClassDeref: {
359 ret = RetrieveClassDeref(context);
360 break;
361 }
362 default:
363 break;
364 }
365 if (ret != Py_None && ret != NULL) {
366 Cache(context, ret);
367 }
368 return ret;
369 }
370
RetrieveGlobal(PTraceContext context)371 PyObject *RootTrace::RetrieveGlobal(PTraceContext context) {
372 MS_EXCEPTION_IF_CHECK_FAIL(name_.size() > 0, "check trace");
373 PyObject *globals = context->f_globals;
374 if (!module_name_.empty()) {
375 PyObject *mn = PyUnicode_FromString(module_name_.c_str());
376 PyObject *mm = PyImport_GetModule(mn); // ensure module is initialized
377 if (mn != nullptr && mm != nullptr) {
378 globals = PyModule_GetDict(mm);
379 }
380 PyErr_Clear();
381 Py_XDECREF(mn);
382 Py_XDECREF(mm);
383 }
384 PyObject *key = PyUnicode_FromString(name_.c_str());
385 PyObject *ret = PyObject_GetItem(globals, key);
386 if (ret == nullptr) {
387 PyErr_Clear();
388 ret = PyObject_GetItem(context->f_builtins, key);
389 if (ret == nullptr) {
390 PyErr_Clear();
391 }
392 }
393 Py_DECREF(key);
394 return ret;
395 }
396
RetrieveDeref(PTraceContext context)397 PyObject *RootTrace::RetrieveDeref(PTraceContext context) {
398 PyObject *ret = nullptr;
399 PyObject *cell = context->f_localsplus[context->f_code->co_nlocals + idx_];
400 if (cell != nullptr && cell != Py_None) {
401 ret = reinterpret_cast<PyObject *>(PyCell_GET(cell));
402 Py_XINCREF(ret);
403 }
404 return ret;
405 }
406
RetrieveClosure(PTraceContext context)407 PyObject *RootTrace::RetrieveClosure(PTraceContext context) {
408 PyObject *ret = context->f_localsplus[context->f_code->co_nlocals + idx_];
409 Py_XINCREF(ret);
410 return ret;
411 }
412
RetrieveBuiltin(PTraceContext context)413 PyObject *RootTrace::RetrieveBuiltin(PTraceContext context) {
414 MS_EXCEPTION_IF_CHECK_FAIL(name_.size() > 0, "check trace");
415 PyObject *key = PyUnicode_FromString(name_.c_str());
416 PyObject *ret = PyObject_GetItem(context->f_builtins, key);
417 if (ret == nullptr) {
418 PyErr_Clear();
419 ret = PyObject_GetItem(context->f_globals, key);
420 if (ret == nullptr) {
421 PyErr_Clear();
422 }
423 }
424 Py_DECREF(key);
425 return ret;
426 }
427
RetrieveLocal(PTraceContext context)428 PyObject *RootTrace::RetrieveLocal(PTraceContext context) { return context->f_locals; }
429
RetrieveParam(PTraceContext context)430 PyObject *RootTrace::RetrieveParam(PTraceContext context) { return context->f_localsplus[idx_]; }
431
RetrieveName(PTraceContext context)432 PyObject *RootTrace::RetrieveName(PTraceContext context) {
433 PyObject *ret = nullptr;
434 PyObject *name = PyTuple_GetItem(context->f_code->co_names, idx_);
435 PyObject *locals = context->f_locals;
436 if (PyDict_CheckExact(locals)) {
437 ret = PyDict_GetItem(locals, name);
438 Py_XINCREF(ret);
439 } else {
440 ret = PyObject_GetItem(locals, name);
441 }
442 if (ret == nullptr) {
443 ret = PyDict_GetItem(context->f_globals, name);
444 Py_XINCREF(ret);
445 }
446 if (ret == nullptr) {
447 if (PyDict_CheckExact(context->f_builtins)) {
448 ret = PyDict_GetItem(context->f_builtins, name);
449 Py_XINCREF(ret);
450 } else {
451 ret = PyObject_GetItem(context->f_builtins, name);
452 }
453 }
454 return ret;
455 }
456
RetrieveClassDeref(PTraceContext context)457 PyObject *RootTrace::RetrieveClassDeref(PTraceContext context) {
458 PyObject *ret = nullptr;
459 Py_ssize_t idx = idx_ - PyTuple_GET_SIZE(context->f_code->co_cellvars);
460 if (idx >= 0 && idx < PyTuple_GET_SIZE(context->f_code->co_freevars)) {
461 PyObject *name = PyTuple_GET_ITEM(context->f_code->co_freevars, idx);
462 if (PyDict_CheckExact(context->f_locals)) {
463 ret = PyDict_GetItem(context->f_locals, name);
464 Py_XINCREF(ret);
465 } else {
466 ret = PyObject_GetItem(context->f_locals, name);
467 }
468 if (!ret) {
469 PyObject *cell = context->f_localsplus[context->f_code->co_nlocals + idx_];
470 ret = PyCell_GET(cell);
471 Py_XINCREF(ret);
472 }
473 }
474 return ret;
475 }
476
ToString(bool include_param)477 std::string RootTrace::ToString(bool include_param) {
478 if (strTrace_.size() > 0) {
479 return strTrace_;
480 }
481 std::string ret;
482 switch (curType_) {
483 case TraceType::Global:
484 if (!module_name_.empty()) {
485 ret = "(global " + module_name_ + ".__dict__[" + name_ + "])";
486 } else {
487 ret = "f_globals[" + name_ + "]";
488 }
489 break;
490 case TraceType::Deref:
491 ret = "f_freevars[" + std::to_string(idx_) + "]";
492 break;
493 case TraceType::Closure:
494 ret = "f_closure[" + std::to_string(idx_) + "]";
495 break;
496 case TraceType::BuiltIn:
497 ret = "f_builtins[" + name_ + "]";
498 break;
499 case TraceType::Local:
500 ret = "f_locals";
501 break;
502 case TraceType::Param:
503 ret = "f_localsplus[";
504 ret += std::to_string(idx_);
505 ret += "]";
506 break;
507 case TraceType::Name:
508 ret = "f->f_code->co_names[";
509 ret += std::to_string(idx_);
510 ret += "]";
511 break;
512 case TraceType::ClassDeref:
513 ret = "f->f_classdef[";
514 ret += std::to_string(idx_);
515 ret += "]";
516 break;
517 default:
518 ret = "unknown_root";
519 break;
520 }
521 ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
522 ret = std::regex_replace(ret, std::regex("(\n)"), "");
523 strTrace_ = ret;
524 return ret;
525 }
526
Info()527 const InfoPack &RootTrace::Info() {
528 if (info_ == nullptr) {
529 InfoPack info;
530 info << uint8_t(curType_);
531 info.Begin();
532 switch (curType_) {
533 case TraceType::Global:
534 info << (!module_name_.empty());
535 if (!module_name_.empty()) {
536 info << module_name_ << name_;
537 } else {
538 info << name_;
539 }
540 break;
541 case TraceType::Deref:
542 case TraceType::Closure:
543 case TraceType::Param:
544 case TraceType::Name:
545 case TraceType::ClassDeref:
546 info << idx_;
547 break;
548 case TraceType::BuiltIn:
549 info << name_;
550 break;
551 case TraceType::Local:
552 default:
553 break;
554 }
555 info.End();
556 info_ = std::make_shared<InfoPack>(info);
557 info_->Update();
558 }
559 return *info_;
560 }
561
operator ==(const Trace & trace)562 bool RootTrace::operator==(const Trace &trace) {
563 bool ret = false;
564 if (Trace::operator==(trace)) {
565 const RootTrace &t = (const RootTrace &)trace;
566 ret = idx_ == t.idx_;
567 if (ret && idx_ == -1) {
568 ret = name_ == t.name_ && module_name_ == t.module_name_;
569 }
570 }
571 return ret;
572 }
573
Support(TraceType tt)574 bool RootTrace::Support(TraceType tt) {
575 switch (tt) {
576 case TraceType::Global:
577 case TraceType::Deref:
578 case TraceType::Closure:
579 case TraceType::BuiltIn:
580 case TraceType::Local:
581 case TraceType::Param:
582 case TraceType::Name:
583 case TraceType::ClassDeref:
584 return true;
585 default:
586 return false;
587 }
588 }
589
ItemTrace(PyObject * pObj,TracePtr pOrigin,TracePtr pItem)590 ItemTrace::ItemTrace(PyObject *pObj, TracePtr pOrigin, TracePtr pItem) : Trace(pObj, pOrigin), item_(pItem) {
591 curType_ = TraceType::Item;
592 if (origin_ != nullptr && item_ != nullptr) {
593 if (origin_->IsConst() && item_->IsConst()) {
594 is_const_ = true;
595 }
596 if (!origin_->IsSpecialized() && !item_->IsSpecialized()) {
597 is_specialized_ = false;
598 }
599 }
600 if (origin_ != nullptr) {
601 depth_ = origin_->GetDepth() + 1;
602 } else {
603 depth_ = 1;
604 }
605 if (item_ != nullptr) {
606 auto d = item_->GetDepth() + 1;
607 if (d > depth_) {
608 depth_ = d;
609 }
610 }
611 }
612
GetItem()613 TracePtr ItemTrace::GetItem() { return item_; }
614
Replace(std::shared_ptr<Trace> dst,std::shared_ptr<Trace> src)615 void ItemTrace::Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src) {
616 Trace::Replace(dst, src);
617 if (item_ != nullptr) {
618 if (*item_ == *src) {
619 item_ = dst;
620 } else {
621 item_->Replace(dst, src);
622 }
623 }
624 }
625
Retrieve(PTraceContext context,bool perf)626 PyObject *ItemTrace::Retrieve(PTraceContext context, bool perf) {
627 PyObject *ret = Trace::Retrieve(context, perf);
628 if (ret != nullptr) {
629 return ret;
630 }
631 if (origin_ != nullptr && item_ != nullptr) {
632 PyObject *pSet = origin_->Retrieve(context, perf);
633 PyObject *pItem = item_->Retrieve(context, perf);
634 if (pSet != NULL && pItem != NULL) {
635 TracePerf tp(this, perf, false);
636 if (PyDict_CheckExact(pSet)) {
637 ret = PyDict_GetItem(pSet, pItem);
638 Py_INCREF(ret);
639 } else {
640 ret = PyObject_GetItem(pSet, pItem);
641 }
642 Cache(context, ret);
643 }
644 Py_XDECREF(pSet);
645 Py_XDECREF(pItem);
646 }
647 return ret;
648 }
649
ToString(bool include_param)650 std::string ItemTrace::ToString(bool include_param) {
651 if (strTrace_.size() > 0) {
652 return strTrace_;
653 }
654 std::string ret;
655 if (origin_ != nullptr && item_ != nullptr) {
656 std::string ori = origin_->ToString(include_param);
657 std::string itm = item_->ToString(include_param);
658 ret = ori + "[" + itm + "]";
659 }
660 ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
661 ret = std::regex_replace(ret, std::regex("(\n)"), "");
662 strTrace_ = ret;
663 return ret;
664 }
665
Info()666 const InfoPack &ItemTrace::Info() {
667 if (info_ == nullptr) {
668 InfoPack info;
669 info << uint8_t(curType_);
670 info.Begin();
671 info << (origin_ != nullptr && item_ != nullptr);
672 if (origin_ != nullptr && item_ != nullptr) {
673 auto ori = origin_->Info();
674 auto itm = item_->Info();
675 info << ori << itm;
676 }
677 info.End();
678 info_ = std::make_shared<InfoPack>(info);
679 info_->Update();
680 }
681 return *info_;
682 }
683
Optimize()684 TracePtr ItemTrace::Optimize() {
685 bool need_update = false;
686 origin_ = OptimizeTrace(origin_, &need_update);
687 item_ = OptimizeTrace(item_, &need_update);
688 if (need_update) {
689 if (origin_ != nullptr && item_ != nullptr && origin_->IsConst() && item_->IsConst()) {
690 is_const_ = true;
691 }
692 info_ = nullptr;
693 strTrace_ = "";
694 Info();
695 return shared_from_this();
696 } else {
697 return nullptr;
698 }
699 }
700
SetRelaxCount(int cnt)701 void ItemTrace::SetRelaxCount(int cnt) {
702 Trace::SetRelaxCount(cnt);
703 if (origin_ != nullptr) {
704 origin_->SetRelaxCount(cnt);
705 }
706 if (item_ != nullptr) {
707 item_->SetRelaxCount(cnt);
708 }
709 }
710
operator ==(const Trace & trace)711 bool ItemTrace::operator==(const Trace &trace) {
712 if (Trace::operator==(trace)) {
713 const ItemTrace &t = (const ItemTrace &)trace;
714 if (!item_ && !(t.item_)) {
715 return true;
716 } else if (item_ != nullptr && t.item_ != nullptr) {
717 return *item_ == *(t.item_);
718 }
719 }
720 return false;
721 }
722
Detach()723 void ItemTrace::Detach() {
724 Trace::Detach();
725 if (item_ != nullptr) {
726 item_->Detach();
727 }
728 }
729
Support(TraceType tt)730 bool ItemTrace::Support(TraceType tt) { return tt == TraceType::Item; }
731
AttrTrace(PyObject * pObj,TracePtr pOrigin,std::string strAttr)732 AttrTrace::AttrTrace(PyObject *pObj, TracePtr pOrigin, std::string strAttr) : Trace(pObj, pOrigin), attr_(strAttr) {
733 curType_ = TraceType::Attr;
734 if (origin_ != nullptr && origin_->IsConst()) {
735 is_const_ = true;
736 }
737 if (origin_ != nullptr) {
738 depth_ = origin_->GetDepth() + 1;
739 } else {
740 depth_ = 1;
741 }
742 }
743
GetAttribute()744 std::string AttrTrace::GetAttribute() { return attr_; }
745
Retrieve(PTraceContext context,bool perf)746 PyObject *AttrTrace::Retrieve(PTraceContext context, bool perf) {
747 PyObject *ret = Trace::Retrieve(context, perf);
748 if (ret != nullptr) {
749 return ret;
750 }
751 if (origin_ != nullptr) {
752 PyObject *pOrigin = origin_->Retrieve(context, perf);
753 if (pOrigin != NULL) {
754 TracePerf tp(this, perf, false);
755 PyObject *itemName = PyUnicode_FromString(attr_.c_str());
756 if (PyDict_CheckExact(pOrigin)) {
757 ret = PyDict_GetItem(pOrigin, itemName);
758 Py_INCREF(ret);
759 } else {
760 ret = PyObject_GetItem(pOrigin, itemName);
761 }
762 Py_DECREF(itemName);
763 Py_DECREF(pOrigin);
764 Cache(context, ret);
765 }
766 }
767 return ret;
768 }
769
ToString(bool include_param)770 std::string AttrTrace::ToString(bool include_param) {
771 if (strTrace_.size() > 0) {
772 return strTrace_;
773 }
774 std::string ret;
775 if (origin_ != nullptr) {
776 std::string ori = origin_->ToString(include_param);
777 ret = ori + "." + attr_;
778 }
779 ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
780 ret = std::regex_replace(ret, std::regex("(\n)"), "");
781 strTrace_ = ret;
782 return ret;
783 }
784
Info()785 const InfoPack &AttrTrace::Info() {
786 if (info_ == nullptr) {
787 InfoPack info;
788 info << uint8_t(curType_);
789 info.Begin();
790 info << (origin_ != nullptr);
791 if (origin_ != nullptr) {
792 auto ori = origin_->Info();
793 info << ori;
794 }
795 info << attr_;
796 info.End();
797 info_ = std::make_shared<InfoPack>(info);
798 info_->Update();
799 }
800 return *info_;
801 }
802
Optimize()803 TracePtr AttrTrace::Optimize() {
804 bool need_update = false;
805 origin_ = OptimizeTrace(origin_, &need_update);
806 if (need_update) {
807 if (origin_ != nullptr && origin_->IsConst()) {
808 is_const_ = true;
809 }
810 info_ = nullptr;
811 strTrace_ = "";
812 Info();
813 return shared_from_this();
814 } else {
815 return nullptr;
816 }
817 }
818
SetRelaxCount(int cnt)819 void AttrTrace::SetRelaxCount(int cnt) {
820 Trace::SetRelaxCount(cnt);
821 if (origin_ != nullptr) {
822 origin_->SetRelaxCount(cnt);
823 }
824 }
825
operator ==(const Trace & trace)826 bool AttrTrace::operator==(const Trace &trace) {
827 if (Trace::operator==(trace)) {
828 const AttrTrace &t = (const AttrTrace &)trace;
829 return attr_ == t.attr_;
830 }
831 return false;
832 }
833
Support(TraceType tt)834 bool AttrTrace::Support(TraceType tt) { return tt == TraceType::Attr; }
835
ConstTrace(PyObject * pObj,int iIndex)836 ConstTrace::ConstTrace(PyObject *pObj, int iIndex) : Trace(pObj, nullptr), index_(iIndex) {
837 curType_ = TraceType::Const;
838 originType_ = TraceType::Const;
839 if (index_ == -1) {
840 is_const_ = true;
841 }
842 depth_ = 1;
843 }
844
GetIndex()845 int ConstTrace::GetIndex() { return index_; }
846
Retrieve(PTraceContext context,bool perf)847 PyObject *ConstTrace::Retrieve(PTraceContext context, bool perf) {
848 PyObject *ret = Trace::Retrieve(context, perf);
849 if (ret != nullptr) {
850 return ret;
851 }
852 if (obj_ != NULL) {
853 Py_INCREF(obj_);
854 return obj_;
855 }
856 if (index_ >= 0 && index_ < PyTuple_GET_SIZE(context->f_code->co_consts)) {
857 TracePerf tp(this, perf, false);
858 ret = PyTuple_GET_ITEM(context->f_code->co_consts, index_);
859 Py_INCREF(ret);
860 Cache(context, ret);
861 } else {
862 ret = obj_;
863 Py_INCREF(ret);
864 }
865 return ret;
866 }
867
ToString(bool include_param)868 std::string ConstTrace::ToString(bool include_param) {
869 if (strTrace_.size() > 0) {
870 return strTrace_;
871 }
872 std::string ret = "co_consts";
873 if (index_ != -1) {
874 ret = ret + "[" + std::to_string(index_) + "]";
875 } else {
876 ret = ret + "[-1](" + std::string(py::str(obj_)) + ")";
877 }
878 ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
879 ret = std::regex_replace(ret, std::regex("(\n)"), "");
880 strTrace_ = ret;
881 return ret;
882 }
883
Info()884 const InfoPack &ConstTrace::Info() {
885 if (info_ == nullptr) {
886 InfoPack info;
887 info << uint8_t(curType_);
888 info.Begin();
889 if (index_ != -1) {
890 info << index_;
891 } else {
892 info << index_ << obj_;
893 }
894 info.End();
895 info_ = std::make_shared<InfoPack>(info);
896 info_->Update();
897 }
898 return *info_;
899 }
900
operator ==(const Trace & trace)901 bool ConstTrace::operator==(const Trace &trace) {
902 if (Trace::operator==(trace)) {
903 const ConstTrace &t = (const ConstTrace &)trace;
904 return index_ == t.index_;
905 }
906 return false;
907 }
908
Detach()909 void ConstTrace::Detach() {}
910
Support(TraceType tt)911 bool ConstTrace::Support(TraceType tt) { return tt == TraceType::Const; }
912
TypeTrace(PyObject * pObj,TracePtr pOrigin)913 TypeTrace::TypeTrace(PyObject *pObj, TracePtr pOrigin) : Trace(pObj, pOrigin) {
914 pType_ = Py_TYPE(pObj);
915 curType_ = TraceType::Type;
916 if (origin_ != nullptr && origin_->IsConst()) {
917 is_const_ = true;
918 }
919 if (origin_ != nullptr) {
920 depth_ = origin_->GetDepth() + 1;
921 } else {
922 depth_ = 1;
923 }
924 }
925
GetType()926 PyTypeObject *TypeTrace::GetType() { return pType_; }
927
Retrieve(PTraceContext context,bool perf)928 PyObject *TypeTrace::Retrieve(PTraceContext context, bool perf) {
929 if (is_const_) {
930 auto rt = reinterpret_cast<PyObject *>(pType_);
931 Py_INCREF(rt);
932 return rt;
933 }
934 PyObject *ret = Trace::Retrieve(context, perf);
935 if (ret != nullptr) {
936 return ret;
937 }
938 if (origin_ != NULL) {
939 PyObject *pOrigin = origin_->Retrieve(context, perf);
940 if (pOrigin != NULL) {
941 TracePerf tp(this, perf, false);
942 ret = reinterpret_cast<PyObject *>(Py_TYPE(pOrigin));
943 Py_INCREF(ret);
944 Py_DECREF(pOrigin);
945 Cache(context, ret);
946 return ret;
947 }
948 }
949 return ret;
950 }
951
ToString(bool include_param)952 std::string TypeTrace::ToString(bool include_param) {
953 if (strTrace_.size() > 0) {
954 return strTrace_;
955 }
956 std::string ret = "type(type:";
957 ret += std::string(py::str(reinterpret_cast<PyObject *>(pType_)));
958 if (origin_ != NULL) {
959 ret += ", origin:" + origin_->ToString(include_param);
960 }
961 ret += ")";
962 ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
963 ret += std::regex_replace(ret, std::regex("(\n)"), "");
964 strTrace_ = ret;
965 return ret;
966 }
967
Info()968 const InfoPack &TypeTrace::Info() {
969 if (info_ == nullptr) {
970 InfoPack info;
971 info << uint8_t(curType_);
972 info.Begin();
973 info << reinterpret_cast<PyObject *>(pType_);
974 info << (origin_ != nullptr);
975 if (origin_ != nullptr) {
976 info << origin_->Info();
977 }
978 info.End();
979 info_ = std::make_shared<InfoPack>(info);
980 info_->Update();
981 }
982 return *info_;
983 }
984
Optimize()985 TracePtr TypeTrace::Optimize() {
986 bool need_update = false;
987 origin_ = OptimizeTrace(origin_, &need_update);
988 if (need_update) {
989 if (origin_ != nullptr && origin_->IsConst()) {
990 is_const_ = true;
991 }
992 info_ = nullptr;
993 strTrace_ = "";
994 Info();
995 return shared_from_this();
996 } else {
997 return nullptr;
998 }
999 }
1000
SetRelaxCount(int cnt)1001 void TypeTrace::SetRelaxCount(int cnt) {
1002 Trace::SetRelaxCount(cnt);
1003 if (origin_ != nullptr) {
1004 origin_->SetRelaxCount(cnt);
1005 }
1006 }
1007
operator ==(const Trace & trace)1008 bool TypeTrace::operator==(const Trace &trace) {
1009 if (Trace::operator==(trace)) {
1010 const TypeTrace &t = (const TypeTrace &)trace;
1011 return pType_ == t.pType_;
1012 }
1013 return false;
1014 }
1015
Detach()1016 void TypeTrace::Detach() {
1017 if (is_const_) {
1018 is_const_ = false;
1019 Trace::Detach();
1020 is_const_ = true;
1021 } else {
1022 Trace::Detach();
1023 }
1024 }
1025
Support(TraceType tt)1026 bool TypeTrace::Support(TraceType tt) { return tt == TraceType::Type; }
1027
RichCompare(PyObject * left,PyObject * right,int oparg)1028 static PyObject *RichCompare(PyObject *left, PyObject *right, int oparg) {
1029 bool invert;
1030 if (oparg >= Py_LT && oparg <= Py_GE) {
1031 return PyObject_RichCompare(left, right, oparg);
1032 } else if (Opcode(COMPARE_OP).CheckIsOp(oparg, &invert)) {
1033 auto ret = ((left == right) ^ invert) ? Py_True : Py_False;
1034 Py_INCREF(ret);
1035 return ret;
1036 } else if (Opcode(COMPARE_OP).CheckContainsOp(oparg, &invert)) {
1037 auto stat = PySequence_Contains(right, left);
1038 if (stat < 0) {
1039 return nullptr;
1040 }
1041 auto ret = (stat ^ invert) ? Py_True : Py_False;
1042 Py_INCREF(ret);
1043 return ret;
1044 }
1045 return nullptr;
1046 }
1047
support_infer_primitive(PyObject * obj)1048 static bool support_infer_primitive(PyObject *obj) {
1049 if (py::isinstance<mindspore::PrimitivePyAdapter>(obj) || py::isinstance<mindspore::PrimitiveFunctionAdapter>(obj)) {
1050 auto inst = mindspore::pijit::InferEngine::GetInstance();
1051 return inst->SupportInfer(obj);
1052 } else {
1053 return false;
1054 }
1055 }
1056
support_create_primitive(PyObject * obj)1057 static bool support_create_primitive(PyObject *obj) {
1058 if (!obj || !PyType_Check(obj)) {
1059 return false;
1060 }
1061 py::object m = py::reinterpret_steal<py::object>(PyImport_GetModule(py::str("mindspore.ops").ptr()));
1062 if (!m.ptr()) {
1063 PyErr_Clear();
1064 return false;
1065 }
1066 py::object t = py::reinterpret_steal<py::object>(PyObject_GetAttrString(m.ptr(), "Primitive"));
1067 if (PyType_IsSubtype(reinterpret_cast<PyTypeObject *>(obj), reinterpret_cast<PyTypeObject *>((t.ptr())))) {
1068 return true;
1069 } else {
1070 return false;
1071 }
1072 }
1073
1074 extern bool CheckJitConstexpr(const py::object &func);
1075 extern bool CheckMSConstexpr(const py::object &func);
1076 extern bool CheckBuiltinFuncOrMethod(const py::object &func);
SupportCall(PyObject * func,const std::string & name)1077 static bool SupportCall(PyObject *func, const std::string &name) {
1078 /**
1079 * NOTE: exclude method type, it shouldn't be guard
1080 */
1081 static const std::set<PyTypeObject *> support_create_instance_type = {
1082 &PyComplex_Type, &PyMap_Type, &PyBaseObject_Type, &PyRange_Type, &PyZip_Type, &PySlice_Type,
1083 &PyBool_Type, &PyFloat_Type, &PyLong_Type, &PyType_Type, &PyList_Type, &PyTuple_Type,
1084 &PySet_Type, &PyFrozenSet_Type, &PyDict_Type, &PyUnicode_Type, &PyEnum_Type, &PyMethod_Type,
1085 };
1086 if (PyType_CheckExact(func)) {
1087 if (IsMsClass(func)) {
1088 return true;
1089 }
1090 return support_create_instance_type.find(reinterpret_cast<PyTypeObject *>(func)) !=
1091 support_create_instance_type.end();
1092 }
1093
1094 py::object handle = py::cast<py::object>(func);
1095 if (CheckJitConstexpr(handle)) {
1096 return true;
1097 }
1098 if (CheckMSConstexpr(handle)) {
1099 return true;
1100 }
1101 if (CheckBuiltinFuncOrMethod(handle)) {
1102 return true;
1103 }
1104 return support_infer_primitive(func) || support_create_primitive(func) || IsMsClass(func) ||
1105 (name.size() != 0 && PyDict_GetItemString(PyEval_GetBuiltins(), name.c_str()) == func);
1106 }
1107
DoCall(const std::vector<PyObject * > & params,int op,const std::string & name)1108 static PyObject *DoCall(const std::vector<PyObject *> ¶ms, int op, const std::string &name) {
1109 if (!Opcode(op).IsCall() || params.size() < 1) {
1110 return nullptr;
1111 }
1112 if (support_infer_primitive(params[0])) {
1113 std::vector<PyObject *> list;
1114 auto inst = mindspore::pijit::InferEngine::GetInstance();
1115 list.insert(list.begin(), params.begin() + 1, params.end());
1116 bool is_abstract = false;
1117 try {
1118 return inst->InferPrimitive(params[0], list, &is_abstract);
1119 } catch (py::error_already_set &e) {
1120 MS_LOG(ERROR) << "InferPrimitive failed " << std::endl << e.what();
1121 } catch (py::builtin_exception &e) {
1122 MS_LOG(ERROR) << "InferPrimitive failed " << std::endl << e.what();
1123 }
1124 return nullptr;
1125 }
1126
1127 size_t nargs = (params.size() - 1);
1128 size_t kw_cnt;
1129 if (op == CALL_FUNCTION) {
1130 return PyObject_Vectorcall(params[0], params.data() + 1, nargs, NULL);
1131 } else if (op == CALL_FUNCTION_KW) {
1132 kw_cnt = PyTuple_GET_SIZE(params.back());
1133 return PyObject_Vectorcall(params[0], params.data() + 1, nargs - 1 - kw_cnt, params.back());
1134 } else if (op == CALL_FUNCTION_EX) {
1135 return PyObject_Call(params[0], params[1], params.size() > 2 ? params[2] : nullptr);
1136 }
1137 return nullptr;
1138 }
1139
1140 using PyObjectArray = std::vector<PyObject *>;
1141
CheckAndDoBinary(int op,const PyObjectArray & objs,binaryfunc pyfunc)1142 static PyObject *CheckAndDoBinary(int op, const PyObjectArray &objs, binaryfunc pyfunc) {
1143 if (py::isinstance<mindspore::tensor::Tensor>(objs[0])) {
1144 auto arg0 = py::reinterpret_borrow<py::object>(objs[0]);
1145 auto arg1 = py::reinterpret_borrow<py::object>(objs[1]);
1146 auto res = pijit::AbstractTensor::Binary(op, arg0, arg1);
1147 return res.inc_ref().ptr();
1148 } else {
1149 return pyfunc(objs[0], objs[1]);
1150 }
1151 }
1152
1153 using PythonBytecodeSupportCheckFunc = std::function<bool(int opargs, const PyObjectArray &objs)>;
1154 using PythonBytecodeExecuteFunc = std::function<PyObject *(int opargs, const PyObjectArray &objs, PTraceContext ctx)>;
1155 using PythonBytecodeFuncSet = std::pair<PythonBytecodeSupportCheckFunc, PythonBytecodeExecuteFunc>;
ByteCodeUnsupported(int opargs,const PyObjectArray & objs)1156 static bool ByteCodeUnsupported(int opargs, const PyObjectArray &objs) { return false; }
ByteCodeSupported(int opargs,const PyObjectArray & objs)1157 static bool ByteCodeSupported(int opargs, const PyObjectArray &objs) { return true; }
1158 #define ByteCodeTest(bytecode) \
1159 [](int opargs, const PyObjectArray &objs) { \
1160 return OptStrategy::MakeCalcStrategyByInputs(bytecode, opargs, objs) != OptStrategy::CalcKind::kCalcUnsupported; \
1161 }
1162 #define ByteCodeCheck(bytecode, opargs, objs) \
1163 (OptStrategy::MakeCalcStrategyByInputs(bytecode, opargs, objs) == OptStrategy::CalcKind::kCalcValue)
1164 static std::unordered_map<int, PythonBytecodeFuncSet> kBytecodeExecuter = {
1165 {POP_TOP, {ByteCodeUnsupported, nullptr}},
1166 {ROT_TWO, {ByteCodeUnsupported, nullptr}},
1167 {ROT_THREE, {ByteCodeUnsupported, nullptr}},
1168 {DUP_TOP, {ByteCodeUnsupported, nullptr}},
1169 {DUP_TOP_TWO, {ByteCodeUnsupported, nullptr}},
1170 {NOP, {ByteCodeUnsupported, nullptr}},
1171 {UNARY_POSITIVE,
1172 {ByteCodeTest(UNARY_POSITIVE),
__anonaa04c86f0102() 1173 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1174 if (ByteCodeCheck(UNARY_POSITIVE, opargs, objs)) {
1175 return PyNumber_Positive(objs[0]);
1176 } else {
1177 Py_XINCREF(objs[0]);
1178 return objs[0];
1179 }
1180 }}},
1181 {UNARY_NEGATIVE,
1182 {ByteCodeTest(UNARY_NEGATIVE),
__anonaa04c86f0202() 1183 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1184 if (ByteCodeCheck(UNARY_NEGATIVE, opargs, objs)) {
1185 return PyNumber_Negative(objs[0]);
1186 } else {
1187 Py_XINCREF(objs[0]);
1188 return objs[0];
1189 }
1190 }}},
1191 {UNARY_NOT,
1192 {ByteCodeTest(UNARY_NOT),
__anonaa04c86f0302() 1193 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1194 if (ByteCodeCheck(UNARY_NOT, opargs, objs)) {
1195 auto ret = PyObject_IsTrue(objs[0]) ? Py_False : Py_True;
1196 Py_INCREF(ret);
1197 return ret;
1198 } else {
1199 Py_INCREF(Py_True);
1200 return Py_True;
1201 }
1202 }}},
1203 {UNARY_INVERT,
1204 {ByteCodeTest(UNARY_INVERT),
__anonaa04c86f0402() 1205 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1206 if (ByteCodeCheck(UNARY_INVERT, opargs, objs)) {
1207 return PyNumber_Invert(objs[0]);
1208 } else {
1209 Py_INCREF(objs[0]);
1210 return objs[0];
1211 }
1212 }}},
1213 {BINARY_MATRIX_MULTIPLY,
1214 {ByteCodeTest(BINARY_MATRIX_MULTIPLY),
__anonaa04c86f0502() 1215 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1216 if (ByteCodeCheck(BINARY_MATRIX_MULTIPLY, opargs, objs)) {
1217 return PyNumber_MatrixMultiply(objs[0], objs[1]);
1218 } else {
1219 Py_INCREF(objs[0]);
1220 return objs[0];
1221 }
1222 }}},
1223 {INPLACE_MATRIX_MULTIPLY,
1224 {ByteCodeTest(INPLACE_MATRIX_MULTIPLY),
__anonaa04c86f0602() 1225 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1226 if (ByteCodeCheck(INPLACE_MATRIX_MULTIPLY, opargs, objs)) {
1227 return PyNumber_InPlaceMatrixMultiply(objs[0], objs[1]);
1228 } else {
1229 Py_INCREF(objs[0]);
1230 return objs[0];
1231 }
1232 }}},
1233 {BINARY_POWER,
1234 {ByteCodeTest(BINARY_POWER),
__anonaa04c86f0702() 1235 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1236 if (ByteCodeCheck(BINARY_POWER, opargs, objs)) {
1237 return PyNumber_Power(objs[0], objs[1], Py_None);
1238 } else {
1239 Py_INCREF(objs[0]);
1240 return objs[0];
1241 }
1242 }}},
1243 {BINARY_MULTIPLY,
1244 {ByteCodeTest(BINARY_MULTIPLY),
__anonaa04c86f0802() 1245 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1246 if (ByteCodeCheck(BINARY_MULTIPLY, opargs, objs)) {
1247 return CheckAndDoBinary(BINARY_MULTIPLY, objs, PyNumber_Multiply);
1248 } else {
1249 Py_INCREF(objs[0]);
1250 return objs[0];
1251 }
1252 }}},
1253 {BINARY_MODULO,
1254 {ByteCodeTest(BINARY_MODULO),
__anonaa04c86f0902() 1255 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1256 if (ByteCodeCheck(BINARY_MODULO, opargs, objs)) {
1257 if (PyUnicode_CheckExact(objs[0]) && (!PyUnicode_Check(objs[1]) || PyUnicode_CheckExact(objs[1]))) {
1258 return PyUnicode_Format(objs[0], objs[1]);
1259 } else {
1260 return PyNumber_Remainder(objs[0], objs[1]);
1261 }
1262 } else {
1263 Py_INCREF(objs[0]);
1264 return objs[0];
1265 }
1266 }}},
1267 {BINARY_ADD,
__anonaa04c86f0a02() 1268 {[](int opargs, const PyObjectArray &objs) -> bool {
1269 return (!PyUnicode_CheckExact(objs[0]) || !PyUnicode_CheckExact(objs[1])) &&
1270 OptStrategy::MakeCalcStrategyByInputs(BINARY_ADD, opargs, objs) != OptStrategy::CalcKind::kCalcUnsupported;
1271 },
1272 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1273 if (ByteCodeCheck(BINARY_ADD, opargs, objs)) {
1274 return CheckAndDoBinary(BINARY_ADD, objs, PyNumber_Add);
1275 } else {
1276 Py_INCREF(objs[0]);
1277 return objs[0];
1278 }
1279 }}},
1280 {BINARY_SUBTRACT,
1281 {ByteCodeTest(BINARY_SUBTRACT),
__anonaa04c86f0c02() 1282 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1283 if (ByteCodeCheck(BINARY_SUBTRACT, opargs, objs)) {
1284 return CheckAndDoBinary(BINARY_SUBTRACT, objs, PyNumber_Subtract);
1285 } else {
1286 Py_INCREF(objs[0]);
1287 return objs[0];
1288 }
1289 }}},
1290 {BINARY_SUBSCR,
1291 {ByteCodeTest(BINARY_SUBSCR),
__anonaa04c86f0d02() 1292 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1293 return PyObject_GetItem(objs[0], objs[1]);
1294 }}},
1295 {BINARY_FLOOR_DIVIDE,
1296 {ByteCodeTest(BINARY_FLOOR_DIVIDE),
__anonaa04c86f0e02() 1297 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1298 if (ByteCodeCheck(BINARY_FLOOR_DIVIDE, opargs, objs)) {
1299 return CheckAndDoBinary(BINARY_FLOOR_DIVIDE, objs, PyNumber_FloorDivide);
1300 } else {
1301 Py_INCREF(objs[0]);
1302 return objs[0];
1303 }
1304 }}},
1305 {BINARY_TRUE_DIVIDE,
1306 {ByteCodeTest(BINARY_TRUE_DIVIDE),
__anonaa04c86f0f02() 1307 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1308 if (ByteCodeCheck(BINARY_TRUE_DIVIDE, opargs, objs)) {
1309 return CheckAndDoBinary(BINARY_TRUE_DIVIDE, objs, PyNumber_TrueDivide);
1310 } else {
1311 Py_INCREF(objs[0]);
1312 return objs[0];
1313 }
1314 }}},
1315 {INPLACE_FLOOR_DIVIDE,
1316 {ByteCodeTest(INPLACE_FLOOR_DIVIDE),
__anonaa04c86f1002() 1317 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1318 if (ByteCodeCheck(INPLACE_FLOOR_DIVIDE, opargs, objs)) {
1319 return PyNumber_InPlaceFloorDivide(objs[0], objs[1]);
1320 } else {
1321 Py_INCREF(objs[0]);
1322 return objs[0];
1323 }
1324 }}},
1325 {INPLACE_TRUE_DIVIDE,
1326 {ByteCodeTest(INPLACE_TRUE_DIVIDE),
__anonaa04c86f1102() 1327 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1328 if (ByteCodeCheck(INPLACE_TRUE_DIVIDE, opargs, objs)) {
1329 return PyNumber_InPlaceTrueDivide(objs[0], objs[1]);
1330 } else {
1331 Py_INCREF(objs[0]);
1332 return objs[0];
1333 }
1334 }}},
1335 {GET_AITER, {ByteCodeUnsupported, nullptr}},
1336 {GET_ANEXT, {ByteCodeUnsupported, nullptr}},
1337 {BEFORE_ASYNC_WITH, {ByteCodeUnsupported, nullptr}},
1338 {INPLACE_ADD,
__anonaa04c86f1202() 1339 {[](int opargs, const PyObjectArray &objs) -> bool {
1340 return (!PyUnicode_CheckExact(objs[0]) || !PyUnicode_CheckExact(objs[1])) &&
1341 OptStrategy::MakeCalcStrategyByInputs(INPLACE_ADD, opargs, objs) !=
1342 OptStrategy::CalcKind::kCalcUnsupported;
1343 },
1344 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1345 if (ByteCodeCheck(INPLACE_ADD, opargs, objs)) {
1346 return PyNumber_InPlaceAdd(objs[0], objs[1]);
1347 } else {
1348 Py_INCREF(objs[0]);
1349 return objs[0];
1350 }
1351 }}},
1352 {INPLACE_SUBTRACT,
1353 {ByteCodeTest(INPLACE_SUBTRACT),
__anonaa04c86f1402() 1354 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1355 if (ByteCodeCheck(INPLACE_SUBTRACT, opargs, objs)) {
1356 return PyNumber_InPlaceSubtract(objs[0], objs[1]);
1357 } else {
1358 Py_INCREF(objs[0]);
1359 return objs[0];
1360 }
1361 }}},
1362 {INPLACE_MULTIPLY,
1363 {ByteCodeTest(INPLACE_MULTIPLY),
__anonaa04c86f1502() 1364 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1365 if (ByteCodeCheck(INPLACE_MULTIPLY, opargs, objs)) {
1366 return PyNumber_InPlaceMultiply(objs[0], objs[1]);
1367 } else {
1368 Py_INCREF(objs[0]);
1369 return objs[0];
1370 }
1371 }}},
1372 {INPLACE_MODULO,
1373 {ByteCodeTest(INPLACE_MODULO),
__anonaa04c86f1602() 1374 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1375 if (ByteCodeCheck(INPLACE_MODULO, opargs, objs)) {
1376 return PyNumber_InPlaceRemainder(objs[0], objs[1]);
1377 } else {
1378 Py_INCREF(objs[0]);
1379 return objs[0];
1380 }
1381 }}},
1382 {STORE_SUBSCR, {ByteCodeUnsupported, nullptr}},
1383 {DELETE_SUBSCR, {ByteCodeUnsupported, nullptr}},
1384 {BINARY_LSHIFT,
1385 {ByteCodeTest(BINARY_LSHIFT),
__anonaa04c86f1702() 1386 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1387 if (ByteCodeCheck(BINARY_LSHIFT, opargs, objs)) {
1388 return PyNumber_Lshift(objs[0], objs[1]);
1389 } else {
1390 Py_INCREF(objs[0]);
1391 return objs[0];
1392 }
1393 }}},
1394 {BINARY_RSHIFT,
1395 {ByteCodeTest(BINARY_RSHIFT),
__anonaa04c86f1802() 1396 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1397 if (ByteCodeCheck(BINARY_RSHIFT, opargs, objs)) {
1398 return PyNumber_Rshift(objs[0], objs[1]);
1399 } else {
1400 Py_INCREF(objs[0]);
1401 return objs[0];
1402 }
1403 }}},
1404 {BINARY_AND,
1405 {ByteCodeTest(BINARY_AND),
__anonaa04c86f1902() 1406 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1407 if (ByteCodeCheck(BINARY_AND, opargs, objs)) {
1408 return PyNumber_And(objs[0], objs[1]);
1409 } else {
1410 Py_INCREF(objs[0]);
1411 return objs[0];
1412 }
1413 }}},
1414 {BINARY_XOR,
1415 {ByteCodeTest(BINARY_XOR),
__anonaa04c86f1a02() 1416 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1417 if (ByteCodeCheck(BINARY_XOR, opargs, objs)) {
1418 return PyNumber_Xor(objs[0], objs[1]);
1419 } else {
1420 Py_INCREF(objs[0]);
1421 return objs[0];
1422 }
1423 }}},
1424 {BINARY_OR,
1425 {ByteCodeTest(BINARY_OR),
__anonaa04c86f1b02() 1426 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1427 if (ByteCodeCheck(BINARY_OR, opargs, objs)) {
1428 return PyNumber_Or(objs[0], objs[1]);
1429 } else {
1430 Py_INCREF(objs[0]);
1431 return objs[0];
1432 }
1433 }}},
1434 {INPLACE_POWER,
1435 {ByteCodeTest(INPLACE_POWER),
__anonaa04c86f1c02() 1436 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1437 if (ByteCodeCheck(INPLACE_POWER, opargs, objs)) {
1438 return PyNumber_InPlacePower(objs[0], objs[1], Py_None);
1439 } else {
1440 Py_INCREF(objs[0]);
1441 return objs[0];
1442 }
1443 }}},
1444 {GET_ITER,
1445 {ByteCodeSupported,
__anonaa04c86f1d02() 1446 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * { return PyObject_GetIter(objs[0]); }}},
1447 {GET_YIELD_FROM_ITER,
1448 {ByteCodeSupported,
__anonaa04c86f1e02() 1449 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1450 PyObject *iterable = objs[0];
1451 if (PyCoro_CheckExact(iterable)) {
1452 if (!(ctx->f_code->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE))) {
1453 return nullptr;
1454 }
1455 } else if (!PyGen_CheckExact(iterable)) {
1456 return PyObject_GetIter(iterable);
1457 }
1458 Py_INCREF(iterable);
1459 return iterable;
1460 }}},
1461 {PRINT_EXPR, {ByteCodeUnsupported, nullptr}},
1462 {LOAD_BUILD_CLASS, {ByteCodeUnsupported, nullptr}},
1463 {YIELD_FROM, {ByteCodeUnsupported, nullptr}},
1464 {GET_AWAITABLE, {ByteCodeUnsupported, nullptr}},
1465 {INPLACE_LSHIFT,
1466 {ByteCodeTest(INPLACE_LSHIFT),
__anonaa04c86f1f02() 1467 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1468 if (ByteCodeCheck(INPLACE_LSHIFT, opargs, objs)) {
1469 return PyNumber_InPlaceLshift(objs[0], objs[1]);
1470 } else {
1471 Py_INCREF(objs[0]);
1472 return objs[0];
1473 }
1474 }}},
1475 {INPLACE_RSHIFT,
1476 {ByteCodeTest(INPLACE_RSHIFT),
__anonaa04c86f2002() 1477 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1478 if (ByteCodeCheck(INPLACE_RSHIFT, opargs, objs)) {
1479 return PyNumber_InPlaceRshift(objs[0], objs[1]);
1480 } else {
1481 Py_INCREF(objs[0]);
1482 return objs[0];
1483 }
1484 }}},
1485 {INPLACE_AND,
1486 {ByteCodeTest(INPLACE_AND),
1487 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2102() 1488 * {
1489 if (ByteCodeCheck(INPLACE_AND, opargs, objs)) {
1490 return PyNumber_InPlaceAnd(objs[0], objs[1]);
1491 } else {
1492 Py_INCREF(objs[0]);
1493 return objs[0];
1494 }
1495 }}},
1496 {INPLACE_XOR,
1497 {ByteCodeTest(INPLACE_XOR),
1498 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2202() 1499 * {
1500 if (ByteCodeCheck(INPLACE_XOR, opargs, objs)) {
1501 return PyNumber_InPlaceXor(objs[0], objs[1]);
1502 } else {
1503 Py_INCREF(objs[0]);
1504 return objs[0];
1505 }
1506 }}},
1507 {INPLACE_OR,
1508 {ByteCodeTest(INPLACE_OR),
1509 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2302() 1510 * {
1511 if (ByteCodeCheck(INPLACE_OR, opargs, objs)) {
1512 return PyNumber_InPlaceOr(objs[0], objs[1]);
1513 } else {
1514 Py_INCREF(objs[0]);
1515 return objs[0];
1516 }
1517 }}},
1518 {RETURN_VALUE, {ByteCodeUnsupported, nullptr}},
1519 {IMPORT_STAR, {ByteCodeUnsupported, nullptr}},
1520 {SETUP_ANNOTATIONS, {ByteCodeUnsupported, nullptr}},
1521 {YIELD_VALUE, {ByteCodeUnsupported, nullptr}},
1522 {POP_BLOCK, {ByteCodeUnsupported, nullptr}},
1523 {POP_EXCEPT, {ByteCodeUnsupported, nullptr}},
1524 {STORE_NAME, {ByteCodeUnsupported, nullptr}},
1525 {DELETE_NAME, {ByteCodeUnsupported, nullptr}},
1526 {UNPACK_SEQUENCE, {ByteCodeUnsupported, nullptr}},
1527 {FOR_ITER, {ByteCodeUnsupported, nullptr}},
1528 {UNPACK_EX, {ByteCodeUnsupported, nullptr}},
1529 {STORE_ATTR, {ByteCodeUnsupported, nullptr}},
1530 {DELETE_ATTR, {ByteCodeUnsupported, nullptr}},
1531 {STORE_GLOBAL, {ByteCodeUnsupported, nullptr}},
1532 {DELETE_GLOBAL, {ByteCodeUnsupported, nullptr}},
1533 {LOAD_CONST, {ByteCodeSupported, nullptr}},
1534 {LOAD_NAME, {ByteCodeSupported, nullptr}},
1535 {BUILD_TUPLE,
1536 {ByteCodeSupported,
1537 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2402() 1538 * {
1539 PyObject *tup = PyTuple_New(opargs);
1540 while (--opargs >= 0) {
1541 Py_INCREF(objs[opargs]);
1542 PyTuple_SET_ITEM(tup, opargs, objs[opargs]);
1543 }
1544 return tup;
1545 }}},
1546 {BUILD_LIST,
1547 {ByteCodeSupported,
1548 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2502() 1549 * {
1550 PyObject *list = PyList_New(opargs);
1551 while (--opargs >= 0) {
1552 Py_INCREF(objs[opargs]);
1553 PyList_SET_ITEM(list, opargs, objs[opargs]);
1554 }
1555 return list;
1556 }}},
1557 {BUILD_SET,
1558 {ByteCodeSupported,
1559 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2602() 1560 * {
1561 PyObject *set = PySet_New(NULL);
1562 for (int i = opargs; i > 0; i--) {
1563 PySet_Add(set, objs[opargs - i]);
1564 }
1565 return set;
1566 }}},
1567 {BUILD_MAP,
1568 {ByteCodeSupported,
1569 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2702() 1570 * {
1571 PyObject *map =
1572 _PyDict_NewPresized((Py_ssize_t)opargs);
1573 for (Py_ssize_t i = opargs; i > 0; i--) {
1574 PyObject *key = objs[2 * (opargs - i)];
1575 PyObject *value = objs[2 * (opargs - i) + 1];
1576 PyDict_SetItem(map, key, value);
1577 }
1578 return map;
1579 }}},
1580 {LOAD_ATTR, {ByteCodeSupported, nullptr}},
1581 {COMPARE_OP,
__anonaa04c86f2802() 1582 {[](int opargs, const PyObjectArray &objs) {
1583 return OptStrategy::MakeCalcStrategyByInputs(COMPARE_OP, opargs, objs) != OptStrategy::CalcKind::kCalcUnsupported;
1584 },
1585 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1586 if (ByteCodeCheck(COMPARE_OP, opargs, objs)) {
1587 return RichCompare(objs[0], objs[1], opargs);
1588 } else {
1589 Py_INCREF(Py_True);
1590 return Py_True;
1591 }
1592 }}},
1593 {IMPORT_NAME, {ByteCodeUnsupported, nullptr}},
1594 {IMPORT_FROM, {ByteCodeUnsupported, nullptr}},
1595 {JUMP_FORWARD, {ByteCodeUnsupported, nullptr}},
1596 {JUMP_IF_FALSE_OR_POP, {ByteCodeUnsupported, nullptr}},
1597 {JUMP_IF_TRUE_OR_POP, {ByteCodeUnsupported, nullptr}},
1598 {JUMP_ABSOLUTE, {ByteCodeUnsupported, nullptr}},
1599 {POP_JUMP_IF_FALSE, {ByteCodeUnsupported, nullptr}},
1600 {POP_JUMP_IF_TRUE, {ByteCodeUnsupported, nullptr}},
1601 {LOAD_GLOBAL, {ByteCodeSupported, nullptr}},
1602 {SETUP_FINALLY, {ByteCodeUnsupported, nullptr}},
1603 {LOAD_FAST, {ByteCodeUnsupported, nullptr}},
1604 {STORE_FAST, {ByteCodeUnsupported, nullptr}},
1605 {DELETE_FAST, {ByteCodeUnsupported, nullptr}},
1606 {RAISE_VARARGS, {ByteCodeUnsupported, nullptr}},
1607 {CALL_FUNCTION, {ByteCodeSupported, nullptr}},
1608 {MAKE_FUNCTION, {ByteCodeUnsupported, nullptr}},
1609 {BUILD_SLICE,
1610 {ByteCodeSupported,
1611 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2a02() 1612 * {
1613 PyObject *start;
1614 PyObject *stop;
1615 PyObject *step;
1616 if (opargs == 3)
1617 step = objs[2];
1618 else
1619 step = nullptr;
1620 stop = objs[1];
1621 start = objs[0];
1622 return PySlice_New(start, stop, step);
1623 }}},
1624 {LOAD_CLOSURE, {ByteCodeSupported, nullptr}},
1625 {LOAD_DEREF, {ByteCodeSupported, nullptr}},
1626 {STORE_DEREF, {ByteCodeUnsupported, nullptr}},
1627 {DELETE_DEREF, {ByteCodeUnsupported, nullptr}},
1628 {CALL_FUNCTION_KW, {ByteCodeSupported, nullptr}},
1629 {CALL_FUNCTION_EX, {ByteCodeSupported, nullptr}},
1630 {SETUP_WITH, {ByteCodeUnsupported, nullptr}},
1631 {EXTENDED_ARG, {ByteCodeUnsupported, nullptr}},
1632 {LIST_APPEND,
1633 {ByteCodeSupported,
1634 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2b02() 1635 * {
1636 PyList_Append(objs[0], objs[1]);
1637 Py_INCREF(objs[0]);
1638 return objs[0];
1639 }}},
1640 {SET_ADD,
1641 {ByteCodeSupported,
1642 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2c02() 1643 * {
1644 PySet_Add(objs[0], objs[1]);
1645 Py_INCREF(objs[0]);
1646 return objs[0];
1647 }}},
1648 {MAP_ADD,
1649 {ByteCodeSupported,
1650 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f2d02() 1651 * {
1652 PyDict_SetItem(objs[0], objs[2], objs[1]);
1653 Py_INCREF(objs[0]);
1654 return objs[0];
1655 }}},
1656 {LOAD_CLASSDEREF, {ByteCodeSupported, nullptr}},
1657 {SETUP_ASYNC_WITH, {ByteCodeUnsupported, nullptr}},
1658 {FORMAT_VALUE, {ByteCodeUnsupported, nullptr}},
1659 {BUILD_CONST_KEY_MAP,
__anonaa04c86f2e02() 1660 {[](int opargs, const PyObjectArray &objs) -> bool {
1661 return PyTuple_CheckExact(objs[opargs]) && PyTuple_GET_SIZE(objs[opargs]) == (Py_ssize_t)opargs;
1662 },
1663 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * {
1664 PyObject *keys = objs[opargs];
1665 PyObject *map = _PyDict_NewPresized((Py_ssize_t)opargs);
1666 for (Py_ssize_t i = opargs; i > 0; i--) {
1667 PyObject *key = PyTuple_GET_ITEM(keys, opargs - i);
1668 PyObject *value = objs[opargs - i];
1669 PyDict_SetItem(map, key, value);
1670 }
1671 return map;
1672 }}},
1673 {BUILD_STRING,
1674 {ByteCodeSupported,
1675 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3002() 1676 * {
1677 PyObject *empty = PyUnicode_New(0, 0);
1678 PyObject *str =
1679 _PyUnicode_JoinArray(empty, objs.data(), opargs);
1680 Py_DECREF(empty);
1681 return str;
1682 }}},
1683 {LOAD_METHOD, {ByteCodeUnsupported, nullptr}},
1684 {CALL_METHOD, {ByteCodeUnsupported, nullptr}},
1685 {ROT_FOUR, {ByteCodeUnsupported, nullptr}},
1686 {RERAISE, {ByteCodeUnsupported, nullptr}},
1687 {WITH_EXCEPT_START, {ByteCodeUnsupported, nullptr}},
1688 {END_ASYNC_FOR, {ByteCodeUnsupported, nullptr}},
1689 {LOAD_ASSERTION_ERROR, {ByteCodeUnsupported, nullptr}},
1690 {LIST_TO_TUPLE,
1691 {ByteCodeSupported,
__anonaa04c86f3102() 1692 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject * { return PyList_AsTuple(objs[0]); }}},
1693 {IS_OP,
1694 {ByteCodeSupported,
1695 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3202() 1696 * {
1697 auto ret = (objs[0] == objs[1]) ^ opargs
1698 ? Py_True
1699 : Py_False;
1700 Py_INCREF(ret);
1701 return ret;
1702 }}},
1703 {CONTAINS_OP,
1704 {ByteCodeSupported,
1705 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3302() 1706 * {
1707 auto ret =
1708 (PySequence_Contains(objs[1], objs[0]) ^ opargs)
1709 ? Py_True
1710 : Py_False;
1711 Py_INCREF(ret);
1712 return ret;
1713 }}},
1714 {JUMP_IF_NOT_EXC_MATCH, {ByteCodeUnsupported, nullptr}},
1715 {LIST_EXTEND,
1716 {ByteCodeSupported,
1717 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3402() 1718 * {
1719 _PyList_Extend(
1720 reinterpret_cast<PyListObject *>(objs[0]),
1721 objs[1]);
1722 Py_INCREF(objs[0]);
1723 return objs[0];
1724 }}},
1725 {SET_UPDATE,
1726 {ByteCodeSupported,
1727 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3502() 1728 * {
1729 _PySet_Update(objs[0], objs[1]);
1730 Py_INCREF(objs[0]);
1731 return objs[0];
1732 }}},
1733 {DICT_MERGE,
1734 {ByteCodeSupported,
1735 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3602() 1736 * {
1737 _PyDict_MergeEx(objs[0], objs[1], 2);
1738 return objs[0];
1739 }}},
1740 {DICT_UPDATE,
1741 {ByteCodeSupported,
1742 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3702() 1743 * {
1744 PyDict_Update(objs[0], objs[1]);
1745 return objs[0];
1746 }}},
1747 {BREAK_LOOP, {ByteCodeUnsupported, nullptr}},
1748 {WITH_CLEANUP_START, {ByteCodeUnsupported, nullptr}},
1749 {WITH_CLEANUP_FINISH, {ByteCodeUnsupported, nullptr}},
1750 {END_FINALLY, {ByteCodeUnsupported, nullptr}},
1751 {CONTINUE_LOOP, {ByteCodeUnsupported, nullptr}},
1752 {SETUP_LOOP, {ByteCodeUnsupported, nullptr}},
1753 {SETUP_EXCEPT, {ByteCodeUnsupported, nullptr}},
1754 {BUILD_LIST_UNPACK,
1755 {ByteCodeSupported,
1756 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3802() 1757 * {
1758 PyObject *sum = PyList_New(0);
1759 for (int i = opargs; i > 0; i--) {
1760 auto none_val = _PyList_Extend(
1761 reinterpret_cast<PyListObject *>(sum),
1762 objs[opargs - i]);
1763 Py_DECREF(none_val);
1764 }
1765 return sum;
1766 }}},
1767 {BUILD_MAP_UNPACK,
1768 {ByteCodeSupported,
1769 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3902() 1770 * {
1771 PyObject *sum = PyDict_New();
1772 for (int i = opargs; i > 0; i--) {
1773 PyDict_Update(sum, objs[opargs - i]);
1774 }
1775 return sum;
1776 }}},
1777 {BUILD_MAP_UNPACK_WITH_CALL,
1778 {ByteCodeSupported,
1779 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3a02() 1780 * {
1781 PyObject *sum = PyDict_New();
1782 for (int i = opargs; i > 0; i--) {
1783 _PyDict_MergeEx(sum, objs[opargs - i], 2);
1784 }
1785 return sum;
1786 }}},
1787 {BUILD_TUPLE_UNPACK,
1788 {ByteCodeSupported,
1789 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3b02() 1790 * {
1791 PyObject *sum = PyList_New(0);
1792 for (int i = opargs; i > 0; i--) {
1793 auto none_val = _PyList_Extend(
1794 reinterpret_cast<PyListObject *>(sum),
1795 objs[opargs - i]);
1796 Py_DECREF(none_val);
1797 }
1798 auto ret = PyList_AsTuple(sum);
1799 Py_DECREF(sum);
1800 return ret;
1801 }}},
1802 {BUILD_SET_UNPACK,
1803 {ByteCodeSupported,
1804 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3c02() 1805 * {
1806 PyObject *sum = PySet_New(NULL);
1807 for (int i = opargs; i > 0; i--) {
1808 _PySet_Update(sum, objs[opargs - i]);
1809 }
1810 return sum;
1811 }}},
1812 {BUILD_TUPLE_UNPACK_WITH_CALL,
1813 {ByteCodeSupported,
1814 [](int opargs, const PyObjectArray &objs, PTraceContext ctx) -> PyObject
__anonaa04c86f3d02() 1815 * {
1816 PyObject *sum = PyList_New(0);
1817 for (int i = opargs; i > 0; i--) {
1818 auto none_val = _PyList_Extend(
1819 reinterpret_cast<PyListObject *>(sum),
1820 objs[opargs - i]);
1821 Py_DECREF(none_val);
1822 }
1823 auto ret = PyList_AsTuple(sum);
1824 Py_DECREF(sum);
1825 return ret;
1826 }}},
1827 };
1828
OpTrace(PyObject * obj,int opcode,int opargs,TraceVector params,std::string name)1829 OpTrace::OpTrace(PyObject *obj, int opcode, int opargs, TraceVector params, std::string name)
1830 : Trace(obj, nullptr), opcode_(opcode), opargs_(opargs), params_(params), name_(name) {
1831 curType_ = TraceType::Operation;
1832 if (!std::any_of(params.begin(), params.end(), [](const TracePtr &item) { return !item->IsConst(); })) {
1833 is_const_ = true;
1834 } else if (name.find(kIsSeqValUnknown) != std::string::npos || name.find(kIsSeqShapeUnknown) != std::string::npos) {
1835 is_const_ = true;
1836 } else if (kPIJitConfigDefault.getIntConfig(GraphJitConfig::kGuardRelaxCount) > 0 && opcode == LOAD_ATTR &&
1837 name_ == kFuncName) {
1838 is_const_ = true;
1839 }
1840 depth_ = std::accumulate(params.begin(), params.end(), 1, [](int depth, const TracePtr &i) {
1841 int d = i->GetDepth() + 1;
1842 if (d > depth) {
1843 return d;
1844 } else {
1845 return depth;
1846 }
1847 });
1848 CheckSpecialize();
1849 }
1850
GetOpCode()1851 int OpTrace::GetOpCode() { return opcode_; }
1852
GetOpArgs()1853 int OpTrace::GetOpArgs() { return opargs_; }
1854
GetParam(size_t idx)1855 TracePtr OpTrace::GetParam(size_t idx) {
1856 if (params_.size() > idx) {
1857 return params_[idx];
1858 } else {
1859 return nullptr;
1860 }
1861 }
1862
GetParamCount()1863 size_t OpTrace::GetParamCount() { return params_.size(); }
1864
GetName()1865 std::string OpTrace::GetName() { return name_; }
1866
Retrieve(PTraceContext context,bool perf)1867 PyObject *OpTrace::Retrieve(PTraceContext context, bool perf) {
1868 PyObject *ret = Trace::Retrieve(context, perf);
1869 if (ret != nullptr) {
1870 return ret;
1871 }
1872 std::vector<PyObject *> params;
1873 auto iter = std::find_if(params_.begin(), params_.end(), [¶ms, &context, perf](const TracePtr &p) {
1874 auto param = p->Retrieve(context, perf);
1875 if (param == nullptr) {
1876 return true;
1877 }
1878 if (py::isinstance<mindspore::tensor::Tensor>(param)) {
1879 mindspore::tensor::TensorPtr tensor_ptr = py::cast<mindspore::tensor::TensorPtr>(param);
1880 if (OptStrategy::MakeCalcStrategyByShape(tensor_ptr->shape()) == OptStrategy::CalcKind::kCalcValue) {
1881 tensor_ptr->data_sync(true);
1882 }
1883 }
1884 params.push_back(param);
1885 return params.back() == nullptr;
1886 });
1887 if (iter != params_.end()) {
1888 MS_LOG(DEBUG) << "Guard Check Retrieve fail for " + (*iter)->ToString();
1889 std::for_each(params.begin(), params.end(), [](PyObject *p) { Py_XDECREF(p); });
1890 return nullptr;
1891 }
1892 TracePerf tp(this, perf, false);
1893 if (kBytecodeExecuter.find(opcode_) != kBytecodeExecuter.end() && kBytecodeExecuter[opcode_].first(opargs_, params) &&
1894 kBytecodeExecuter[opcode_].second != nullptr) {
1895 ret = kBytecodeExecuter[opcode_].second(opargs_, params, context);
1896 } else if (opcode_ == LOAD_ATTR) {
1897 MS_EXCEPTION_IF_CHECK_FAIL(name_.size(), "check trace");
1898 ret = PyObject_GetAttrString(params[0], name_.c_str());
1899 } else if (opcode_ == CALL_FUNCTION || opcode_ == CALL_FUNCTION_EX || opcode_ == CALL_FUNCTION_KW) {
1900 ret = DoCall(params, opcode_, name_);
1901 }
1902 for (auto p : params) {
1903 Py_DECREF(p);
1904 }
1905 if (PyErr_Occurred()) {
1906 PyErr_Clear();
1907 }
1908 Cache(context, ret);
1909 return ret;
1910 }
1911
ToString(bool include_param)1912 std::string OpTrace::ToString(bool include_param) {
1913 if (strTrace_.size() > 0) {
1914 return strTrace_;
1915 }
1916 std::string ret = "operation ";
1917 ret += Opcode(opcode_).name();
1918 ret += "(arg:";
1919 ret += std::to_string(opargs_);
1920 if (name_.size() != 0 || params_.size() > 0) {
1921 ret += ",";
1922 }
1923 if (name_.size() != 0) {
1924 ret += std::string("name:") + name_;
1925 if (params_.size() > 0) {
1926 ret += ",";
1927 }
1928 }
1929 if (include_param && params_.size() > 0) {
1930 for (auto t : params_) {
1931 ret += t->ToString(include_param) + ",";
1932 }
1933 ret = ret.substr(0, ret.size() - 1);
1934 }
1935 ret = ret + ")";
1936 ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
1937 ret = std::regex_replace(ret, std::regex("(\n)"), "");
1938 strTrace_ = ret;
1939 return ret;
1940 }
1941
FormatString(std::map<Trace *,size_t> * cache)1942 std::string OpTrace::FormatString(std::map<Trace *, size_t> *cache) {
1943 std::stringstream s;
1944 std::stringstream params_str;
1945 params_str << "(";
1946 for (auto i : params_) {
1947 if (cache->find(i.get()) == cache->end()) {
1948 s << i->FormatString(cache) << std::endl;
1949 }
1950 params_str << "%" << (cache->find(i.get())->second) << ", ";
1951 }
1952 params_str << ")";
1953
1954 cache->insert(std::make_pair(this, cache->size()));
1955 s << "%" << cache->find(this)->second << " = operation " << Opcode(opcode_).name() << " " << opargs_;
1956 if (!name_.empty()) {
1957 s << ", name: " << name_;
1958 }
1959 s << ": " << params_str.str();
1960 return s.str();
1961 }
1962
Info()1963 const InfoPack &OpTrace::Info() {
1964 if (info_ == nullptr) {
1965 InfoPack info;
1966 info << uint8_t(curType_);
1967 info.Begin();
1968 info << opcode_;
1969 info << opargs_;
1970 info << (name_.size() != 0);
1971 if (name_.size() != 0) {
1972 info << name_;
1973 }
1974 info << (params_.size() != 0);
1975 if (params_.size() > 0) {
1976 for (auto t : params_) {
1977 info << t->Info();
1978 }
1979 }
1980 info.End();
1981 info_ = std::make_shared<InfoPack>(info);
1982 info_->Update();
1983 }
1984 return *info_;
1985 }
1986
RemoveCastDuplicatePatternPass()1987 TracePtr OpTrace::RemoveCastDuplicatePatternPass() {
1988 OpTracePtr cast_op;
1989 TracePtr next_op;
1990 TracePtr this_op;
1991 TracePtr ret_op;
1992 if (opcode_ != CALL_FUNCTION || !IsCastFunc(name_) ||
1993 (cast_op = CastTrace<OpTrace>(GetParam(kParamIndexTwo))) == nullptr || !IsCastFunc(cast_op->GetName()) ||
1994 (next_op = cast_op->GetParam(kParamIndexTwo)) == nullptr) {
1995 return nullptr;
1996 }
1997 // remove duplicate cast or contrary cast
1998 if (name_ == cast_op->GetName()) {
1999 this_op = cast_op;
2000 ret_op = cast_op->Optimize();
2001 } else {
2002 this_op = next_op;
2003 ret_op = next_op->Optimize();
2004 }
2005 if (ret_op != nullptr) {
2006 return ret_op;
2007 } else {
2008 return this_op;
2009 }
2010 }
2011
RemovePrimOutIsTensorPass()2012 TracePtr OpTrace::RemovePrimOutIsTensorPass() {
2013 RootTracePtr global_op;
2014 OpTracePtr call_op;
2015 TracePtr param_op;
2016 if (opcode_ != CALL_FUNCTION || !(name_ == kIsInstance) ||
2017 (global_op = CastTrace<RootTrace>(GetParam(kParamIndexThree))) == nullptr ||
2018 global_op->GetTraceType() != TraceType::Global ||
2019 (call_op = CastOpTrace(GetParam(kParamIndexTwo), CALL_FUNCTION)) == nullptr ||
2020 (param_op = call_op->GetParam(kParamIndexOne)) == nullptr) {
2021 return nullptr;
2022 }
2023 int idx;
2024 std::string name;
2025 std::string module_name;
2026 global_op->GetParam(&idx, &name, &module_name);
2027 // isinstance(cast_to_mstensor(...) or Primitive) should be Tensor
2028 if ((name == kTensorName) && ((call_op->GetName() == kCastToMSTensor) ||
2029 (CastTrace<ConstTrace>(param_op) != nullptr &&
2030 py::isinstance<mindspore::PrimitivePyAdapter>(param_op->GetObject())))) {
2031 is_const_ = true;
2032 if (obj_ == nullptr) {
2033 obj_ = Py_True;
2034 Py_INCREF(obj_);
2035 }
2036 return shared_from_this();
2037 }
2038 return nullptr;
2039 }
2040
RemoveCastPass()2041 TracePtr OpTrace::RemoveCastPass() {
2042 TracePtr next_op;
2043 if (opcode_ == CALL_FUNCTION && name_ == kCastToMSTensor && (next_op = GetParam(kParamIndexTwo)) != nullptr) {
2044 auto new_ret = next_op->Optimize();
2045 if (new_ret != nullptr) {
2046 return new_ret;
2047 } else {
2048 return next_op;
2049 }
2050 }
2051 return nullptr;
2052 }
2053
RemoveEmptyTensorPass()2054 TracePtr OpTrace::RemoveEmptyTensorPass() {
2055 OpTracePtr subscr_op;
2056 OpTracePtr loadattr_op;
2057 ConstTracePtr const_op;
2058 ConstTracePtr const2_op;
2059 if (opcode_ != COMPARE_OP || params_.size() < kParamCountTwo) {
2060 return nullptr;
2061 }
2062 for (size_t idx = 0; idx < kParamCountTwo; ++idx) {
2063 TracePtr tmp = GetParam(idx);
2064 if (subscr_op == nullptr) {
2065 subscr_op = CastOpTrace(tmp, BINARY_SUBSCR);
2066 }
2067 if (const_op == nullptr) {
2068 const_op = CastConstTrace(tmp);
2069 }
2070 }
2071 if (subscr_op == nullptr || const_op == nullptr ||
2072 (const2_op = CastConstTrace(subscr_op->GetParam(kParamIndexTwo))) == nullptr ||
2073 (loadattr_op = CastOpTrace(subscr_op->GetParam(kParamIndexOne), kShapeName)) == nullptr ||
2074 CastOpTrace(loadattr_op->GetParam(kParamIndexOne), kCastToAdapterTensor) == nullptr) {
2075 return nullptr;
2076 }
2077 // make judgement shape[0] == 0 as const
2078 auto c1 = const_op->GetObject();
2079 auto c2 = const2_op->GetObject();
2080 if (!PyLong_CheckExact(c1) || !PyLong_CheckExact(c2)) {
2081 return nullptr;
2082 }
2083 auto v1 = _PyLong_AsInt(c1);
2084 auto v2 = _PyLong_AsInt(c2);
2085 if (v1 == 0 && v2 == 0) {
2086 is_const_ = true;
2087 return shared_from_this();
2088 }
2089 return nullptr;
2090 }
2091
JudgeDTypeChangePass()2092 void OpTrace::JudgeDTypeChangePass() {
2093 if (opcode_ != COMPARE_OP) {
2094 return;
2095 }
2096 for (size_t i = 0; i < kParamCountTwo; ++i) {
2097 OpTracePtr trace = CastOpTrace(GetParam(i), CALL_FUNCTION);
2098 ConstTracePtr const_op = trace ? CastConstTrace(trace->GetParam(kParamIndexOne)) : nullptr;
2099 PyObject *const_param = const_op ? const_op->GetObject() : nullptr;
2100 if (trace != nullptr && const_op != nullptr && const_param != nullptr &&
2101 py::isinstance<mindspore::PrimitivePyAdapter>(const_param) &&
2102 py::cast<mindspore::PrimitivePyAdapterPtr>(const_param)->name() == kDTypePrimName) {
2103 // Compare for output of DType primitive
2104 continue;
2105 } else if ((trace = CastOpTrace(GetParam(i), LOAD_ATTR)) != nullptr && trace->GetName() == kDTypeAttrName) {
2106 // Compare for attribute dtype
2107 continue;
2108 }
2109 return;
2110 }
2111 // data type comparison should be kept as const
2112 EnableRelax();
2113 }
2114
JudgeDTypeScopePass()2115 void OpTrace::JudgeDTypeScopePass() {
2116 if (opcode_ != CONTAINS_OP) {
2117 return;
2118 }
2119 OpTracePtr trace;
2120 if ((trace = CastOpTrace(GetParam(kParamIndexOne), LOAD_ATTR)) != nullptr && trace->GetName() == kDTypeAttrName) {
2121 // data type to check whether to be contained should be const
2122 EnableRelax();
2123 }
2124 }
2125
JudgeDTypeTensorAttrPass()2126 void OpTrace::JudgeDTypeTensorAttrPass() {
2127 if (opcode_ != CALL_FUNCTION) {
2128 return;
2129 }
2130 RootTracePtr global_op;
2131 OpTracePtr call_op;
2132 if (params_.size() < kParamCountTwo || (global_op = CastTrace<RootTrace>(params_[kParamIndexOne])) == nullptr ||
2133 (call_op = CastOpTrace(params_[kParamIndexTwo], BINARY_SUBSCR)) == nullptr) {
2134 return;
2135 }
2136 int idx;
2137 std::string name;
2138 std::string module_name;
2139 global_op->GetParam(&idx, &name, &module_name);
2140 auto tsr = call_op->GetObject();
2141 if (tsr == nullptr) {
2142 return;
2143 }
2144 std::string type_name = std::string(py::str(reinterpret_cast<PyObject *>(Py_TYPE(tsr))));
2145 if (name == kDType_AttrName && module_name.find("mindspore") == 0 &&
2146 type_name.find(kTensorName) != std::string::npos) {
2147 EnableRelax();
2148 }
2149 }
2150
JudgeCodeChangePass()2151 void OpTrace::JudgeCodeChangePass() {
2152 if (opcode_ != LOAD_ATTR || params_.size() < kParamCountOne || name_ != kCodeName) {
2153 return;
2154 }
2155 if (params_[kParamIndexOne]->IsConst()) {
2156 EnableRelax();
2157 } else {
2158 auto p1 = CastOpTrace(params_[kParamIndexOne], LOAD_ATTR);
2159 if (p1 == nullptr) {
2160 return;
2161 }
2162 auto p2 = p1->GetParam(kParamIndexOne)->GetObject();
2163 std::string type_name;
2164 if (p2 != nullptr) {
2165 type_name = std::string(py::str(reinterpret_cast<PyObject *>(Py_TYPE(p2))));
2166 }
2167 if (type_name.find(kMindTorchFlag) != std::string::npos) {
2168 EnableRelax();
2169 }
2170 }
2171 }
2172
JudgeTrainFlagPass()2173 void OpTrace::JudgeTrainFlagPass() {
2174 if (opcode_ != LOAD_ATTR || params_.size() < kParamCountOne) {
2175 return;
2176 }
2177 if (name_ == kTrainingFlag) {
2178 // training flag shouldn't be changed frequently
2179 EnableRelax();
2180 }
2181 }
2182
JudgeCompareConstPass()2183 void OpTrace::JudgeCompareConstPass() {
2184 if (RelaxEnabled()) {
2185 return;
2186 }
2187 if (opcode_ != COMPARE_OP || params_.size() < kParamCountTwo) {
2188 return;
2189 }
2190 if (params_[kParamIndexOne]->GetObject() == nullptr || params_[kParamIndexTwo]->GetObject() == nullptr) {
2191 return;
2192 }
2193 EnableRelax();
2194 }
2195
JudgeContainsConstPass()2196 void OpTrace::JudgeContainsConstPass() {
2197 if (RelaxEnabled()) {
2198 return;
2199 }
2200 if (opcode_ != CONTAINS_OP || params_.size() < kParamCountTwo) {
2201 return;
2202 }
2203 if (params_[kParamIndexOne]->GetObject() == nullptr || params_[kParamIndexTwo]->GetObject() == nullptr) {
2204 return;
2205 }
2206 EnableRelax();
2207 }
2208
JudgeInplaceAddConstPass()2209 void OpTrace::JudgeInplaceAddConstPass() {
2210 if (RelaxEnabled()) {
2211 return;
2212 }
2213 if (opcode_ != INPLACE_ADD || params_.size() < kParamCountTwo) {
2214 return;
2215 }
2216 if (params_[kParamIndexOne]->GetObject() == nullptr || params_[kParamIndexTwo]->GetObject() == nullptr) {
2217 return;
2218 }
2219 EnableRelax();
2220 }
2221
JudgeIsConstPass()2222 void OpTrace::JudgeIsConstPass() {
2223 if (RelaxEnabled()) {
2224 return;
2225 }
2226 if (opcode_ != IS_OP || params_.size() < kParamCountTwo) {
2227 return;
2228 }
2229 if (params_[kParamIndexOne]->GetObject() == nullptr || params_[kParamIndexTwo]->GetObject() == nullptr) {
2230 return;
2231 }
2232 OpTracePtr subscr_op;
2233 if ((subscr_op = CastTrace<OpTrace>(params_[kParamIndexOne])) != nullptr &&
2234 (CastConstTrace(params_[kParamIndexTwo]) != nullptr || params_[kParamIndexTwo]->IsConst())) {
2235 if (subscr_op->params_.size() < kParamCountTwo) {
2236 return;
2237 }
2238 auto tsr = subscr_op->GetParam(kParamIndexOne)->GetObject();
2239 if (tsr == nullptr) {
2240 return;
2241 }
2242 std::string type_name = std::string(py::str(reinterpret_cast<PyObject *>(Py_TYPE(tsr))));
2243 if (type_name.find(kTensorName) == std::string::npos) {
2244 return;
2245 }
2246 if (subscr_op->GetParam(kParamIndexTwo)->IsConst() && params_[kParamIndexTwo]->IsConst()) {
2247 EnableRelax();
2248 }
2249 }
2250 }
2251
JudgeBoundMethodPass()2252 void OpTrace::JudgeBoundMethodPass() {
2253 if (RelaxEnabled()) {
2254 return;
2255 }
2256 if (opcode_ != LOAD_ATTR || params_.size() < kParamCountOne) {
2257 return;
2258 }
2259 if (params_[kParamIndexOne]->GetObject() == nullptr) {
2260 return;
2261 }
2262 if (name_ == kFuncName) {
2263 EnableRelax();
2264 }
2265 }
2266
JudgeSubScrRandPass()2267 void OpTrace::JudgeSubScrRandPass() {
2268 if (RelaxEnabled()) {
2269 return;
2270 }
2271 if (opcode_ != BINARY_SUBTRACT || params_.size() < kParamCountTwo) {
2272 return;
2273 }
2274 auto call_op = CastOpTrace(params_[kParamIndexOne], CALL_FUNCTION);
2275 ConstTracePtr prim;
2276 if (call_op != nullptr && call_op->params_.size() > kParamCountOne) {
2277 if ((prim = CastConstTrace(call_op->params_[kParamIndexOne])) != nullptr) {
2278 std::string prim_name = py::cast<mindspore::PrimitivePyAdapterPtr>(prim->GetObject())->name();
2279 if (prim_name == kRankPrimName) {
2280 EnableRelax();
2281 }
2282 }
2283 }
2284 }
2285
GetGuardFuncKeyMap()2286 static const std::unordered_map<size_t, size_t> &GetGuardFuncKeyMap() {
2287 static std::unordered_map<size_t, size_t> map = {};
2288 static bool init = false;
2289 if (init) {
2290 return map;
2291 }
2292 init = true;
2293 py::object func_map = Utils::GetModuleAttr(kFuncWhiteListModuleName, kGuardFuncMapName, true, true);
2294 MS_EXCEPTION_IF_CHECK_FAIL(PyDict_CheckExact(func_map.ptr()), "white list func map must be 'dict[int, int]'");
2295 PyObject *key;
2296 PyObject *value;
2297 Py_ssize_t pos = 0;
2298 while (PyDict_Next(func_map.ptr(), &pos, &key, &value)) {
2299 MS_EXCEPTION_IF_CHECK_FAIL(PyLong_CheckExact(key), "white list func map key must be 'int'");
2300 MS_EXCEPTION_IF_CHECK_FAIL(PyLong_CheckExact(value), "white list func map value must be 'int'");
2301 map[PyLong_AsSize_t(key)] = PyLong_AsSize_t(value);
2302 }
2303 return map;
2304 }
2305
CheckRelaxGuardFunc(const py::object & callable)2306 static bool CheckRelaxGuardFunc(const py::object &callable) {
2307 static size_t guard_key_relax_func = 0;
2308 if (guard_key_relax_func == 0) {
2309 py::object key_object = Utils::GetModuleAttr(kFuncWhiteListModuleName, "GUARD_KEY_RELAX_FUNC", true, true);
2310 guard_key_relax_func = py::cast<size_t>(key_object);
2311 }
2312
2313 auto iter = GetGuardFuncKeyMap().find(FunctionId(callable));
2314 return iter != GetGuardFuncKeyMap().end() && iter->second == guard_key_relax_func;
2315 }
2316
JudgeRelaxGuardFuncPass()2317 void OpTrace::JudgeRelaxGuardFuncPass() {
2318 if (opcode_ != CALL_FUNCTION || params_.size() < kParamCountOne) {
2319 return;
2320 }
2321 if (CheckRelaxGuardFunc(py::cast<py::object>(params_[0]->GetObject()))) {
2322 EnableRelax();
2323 }
2324 }
2325
CheckSpecialize()2326 void OpTrace::CheckSpecialize() {
2327 bool any_params_specialized = false;
2328 for (auto param : params_) {
2329 if (!param->IsConst() && param->IsSpecialized()) {
2330 any_params_specialized = true;
2331 break;
2332 }
2333 }
2334 if (opcode_ == CALL_FUNCTION) {
2335 if ((name_ == kShape_Name || name_ == kCastToAdapterTensor || name_ == kCastToMSTensor) &&
2336 !any_params_specialized) {
2337 is_specialized_ = true;
2338 } else if (params_.size() > kParamCountOne &&
2339 py::isinstance<mindspore::PrimitivePyAdapter>(params_[kParamIndexOne]->GetObject())) {
2340 std::string prim_name = py::cast<mindspore::PrimitivePyAdapterPtr>(params_[kParamIndexOne]->GetObject())->name();
2341 if (prim_name == kCastPrimName) {
2342 is_specialized_ = params_[kParamIndexTwo]->IsSpecialized();
2343 } else if (prim_name == kLayerNormPrimName) {
2344 is_specialized_ = params_[kParamIndexTwo]->IsSpecialized();
2345 } else if (prim_name == kReshapePrimName) {
2346 is_specialized_ = any_params_specialized || !params_[kParamIndexThree]->IsConst();
2347 } else if (prim_name == kShapePrimName) {
2348 is_specialized_ = params_[kParamIndexTwo]->IsSpecialized();
2349 }
2350 }
2351 } else {
2352 is_specialized_ = true;
2353 }
2354 }
2355
Optimize()2356 TracePtr OpTrace::Optimize() {
2357 if (is_const_ || RelaxEnabled()) {
2358 return nullptr;
2359 }
2360 TracePtr ret;
2361 if ((ret = RemoveCastDuplicatePatternPass()) != nullptr || (ret = RemoveEmptyTensorPass()) != nullptr ||
2362 (ret = RemovePrimOutIsTensorPass()) != nullptr || (ret = RemoveCastPass()) != nullptr) {
2363 return ret;
2364 }
2365 if (relax_limit_ > 0) {
2366 JudgeDTypeChangePass();
2367 JudgeDTypeScopePass();
2368 JudgeTrainFlagPass();
2369 JudgeCompareConstPass();
2370 JudgeContainsConstPass();
2371 JudgeInplaceAddConstPass();
2372 JudgeIsConstPass();
2373 JudgeCodeChangePass();
2374 JudgeBoundMethodPass();
2375 JudgeSubScrRandPass();
2376 JudgeDTypeTensorAttrPass();
2377 JudgeRelaxGuardFuncPass();
2378 }
2379 bool need_update = false;
2380 for (size_t i = 0; i < params_.size(); ++i) {
2381 params_[i] = OptimizeTrace(params_[i], &need_update);
2382 }
2383 if (need_update) {
2384 if (!std::any_of(params_.begin(), params_.end(), [](const TracePtr &item) { return !item->IsConst(); })) {
2385 is_const_ = true;
2386 }
2387 info_ = nullptr;
2388 strTrace_ = "";
2389 Info();
2390 return shared_from_this();
2391 } else {
2392 return nullptr;
2393 }
2394 }
2395
SetRelaxCount(int cnt)2396 void OpTrace::SetRelaxCount(int cnt) {
2397 Trace::SetRelaxCount(cnt);
2398 for (auto param : params_) {
2399 param->SetRelaxCount(cnt);
2400 }
2401 }
2402
operator ==(const Trace & trace)2403 bool OpTrace::operator==(const Trace &trace) {
2404 bool ret = false;
2405 if (Trace::operator==(trace)) {
2406 const OpTrace &t = (const OpTrace &)trace;
2407 ret = opcode_ == t.opcode_ && opargs_ == t.opargs_ && name_ == t.name_ && params_.size() == t.params_.size();
2408 if (ret) {
2409 for (size_t i = 0; i < params_.size(); i++) {
2410 if (*(params_[i]) == *(t.params_[i])) {
2411 continue;
2412 } else {
2413 ret = false;
2414 break;
2415 }
2416 }
2417 }
2418 }
2419 return ret;
2420 }
2421
Replace(std::shared_ptr<Trace> dst,std::shared_ptr<Trace> src)2422 void OpTrace::Replace(std::shared_ptr<Trace> dst, std::shared_ptr<Trace> src) {
2423 Trace::Replace(dst, src);
2424 for (size_t i = 0; i < params_.size(); ++i) {
2425 if (*params_[i] == *src) {
2426 params_[i] = dst;
2427 } else {
2428 params_[i]->Replace(dst, src);
2429 }
2430 }
2431 }
2432
Detach()2433 void OpTrace::Detach() {
2434 Trace::Detach();
2435 for (auto t : params_) {
2436 t->Detach();
2437 }
2438 }
2439
Support(TraceType tt)2440 bool OpTrace::Support(TraceType tt) { return tt == TraceType::Operation; }
2441
2442 static std::map<int, TraceType> kMapBytecodeToTraceType = {
2443 {LOAD_CLOSURE, TraceType::Closure}, {LOAD_DEREF, TraceType::Deref}, {LOAD_GLOBAL, TraceType::Global},
2444 {LOAD_NAME, TraceType::Name}, {LOAD_CLASSDEREF, TraceType::ClassDeref},
2445 };
2446
CreateOpTraceByBytecode(PyObject * obj,int opcode,int opargs,TraceVector params,std::string module_name,std::string name,bool strict)2447 TracePtr CreateOpTraceByBytecode(PyObject *obj, int opcode, int opargs, TraceVector params, std::string module_name,
2448 std::string name, bool strict) {
2449 static const std::set<int> root_op = {
2450 LOAD_CLOSURE, LOAD_DEREF, LOAD_GLOBAL, LOAD_NAME, LOAD_CLASSDEREF,
2451 };
2452 if (opcode == LOAD_DEREF && opargs < 0) {
2453 return nullptr;
2454 }
2455 if (root_op.find(opcode) != root_op.end()) {
2456 return std::make_shared<RootTrace>(obj, kMapBytecodeToTraceType[opcode], opargs, name, module_name);
2457 }
2458 if (opcode == LOAD_CONST) {
2459 return std::make_shared<ConstTrace>(obj, -1);
2460 }
2461 if (Opcode(opcode).IsCall()) {
2462 if (params.size() < 1 || !SupportCall(params[0]->GetObject(), name)) {
2463 if (strict) {
2464 return nullptr;
2465 } else {
2466 return std::make_shared<UnsupportedTrace>(obj, params, opcode, opargs);
2467 }
2468 }
2469 }
2470 return std::make_shared<OpTrace>(obj, opcode, opargs, params, name);
2471 }
2472
CreateOpTrace(PyObject * obj,int opcode,int opargs,TraceVector params,const std::string & module_name,const std::string & name,bool strict,bool print)2473 TracePtr CreateOpTrace(PyObject *obj, int opcode, int opargs, TraceVector params, const std::string &module_name,
2474 const std::string &name, bool strict, bool print) {
2475 std::vector<PyObject *> vparams;
2476 for (auto trace : params) {
2477 if (trace == nullptr) {
2478 return nullptr;
2479 } else if (trace->GetTraceType() == TraceType::Unsupported) {
2480 return std::make_shared<UnsupportedTrace>(obj, params, opcode, opargs);
2481 } else {
2482 vparams.push_back(trace->GetObject());
2483 }
2484 }
2485 if (kBytecodeExecuter.find(opcode) == kBytecodeExecuter.end() || !kBytecodeExecuter[opcode].first(opargs, vparams)) {
2486 if (print) {
2487 GRAPH_JIT_LOG_F("Unsupported bytecode %d args %d!\n", opcode, opargs);
2488 } else {
2489 MS_LOG(DEBUG) << "Unsupported bytecode " << opcode << " args " << opargs << "!";
2490 }
2491 if (strict) {
2492 return nullptr;
2493 } else {
2494 return std::make_shared<UnsupportedTrace>(obj, params, opcode, opargs);
2495 }
2496 }
2497 return CreateOpTraceByBytecode(obj, opcode, opargs, params, module_name, name, strict);
2498 }
2499
CustomizedTrace(PyObject * obj,RetrieveFunc rfunc,ToStringFunc sfunc)2500 CustomizedTrace::CustomizedTrace(PyObject *obj, RetrieveFunc rfunc, ToStringFunc sfunc)
2501 : Trace(obj, nullptr), retrieve_(rfunc), tostring_(sfunc) {
2502 curType_ = TraceType::Customized;
2503 depth_ = 1;
2504 }
2505
Retrieve(PTraceContext context,bool perf)2506 PyObject *CustomizedTrace::Retrieve(PTraceContext context, bool perf) {
2507 PyObject *ret = Trace::Retrieve(context, perf);
2508 if (ret != nullptr) {
2509 return ret;
2510 }
2511 TracePerf tp(this, perf, false);
2512 ret = retrieve_(context);
2513 Cache(context, ret);
2514 return ret;
2515 }
2516
ToString(bool include_param)2517 std::string CustomizedTrace::ToString(bool include_param) {
2518 if (strTrace_.size() > 0) {
2519 return strTrace_;
2520 }
2521 std::string ret = tostring_(false);
2522 ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
2523 ret = std::regex_replace(ret, std::regex("(\n)"), "");
2524 strTrace_ = ret;
2525 return ret;
2526 }
2527
Info()2528 const InfoPack &CustomizedTrace::Info() {
2529 if (info_ == nullptr) {
2530 InfoPack info;
2531 info << uint8_t(curType_);
2532 info.Begin();
2533 info << tostring_(true);
2534 info.End();
2535 info_ = std::make_shared<InfoPack>(info);
2536 info_->Update();
2537 }
2538 return *info_;
2539 }
2540
Support(TraceType tt)2541 bool CustomizedTrace::Support(TraceType tt) { return tt == TraceType::Customized; }
2542
UnsupportedTrace(PyObject * obj,TraceVector params,int op,int arg)2543 UnsupportedTrace::UnsupportedTrace(PyObject *obj, TraceVector params, int op, int arg)
2544 : Trace(obj, nullptr), params_(params), op_(op), arg_(arg) {
2545 curType_ = TraceType::Unsupported;
2546 if (!std::any_of(params.begin(), params.end(), [](const TracePtr &item) { return !item->IsConst(); })) {
2547 is_const_ = true;
2548 }
2549 depth_ = std::accumulate(params.begin(), params.end(), 1, [](int depth, const TracePtr &i) {
2550 int d = i->GetDepth() + 1;
2551 if (d > depth) {
2552 return d;
2553 } else {
2554 return depth;
2555 }
2556 });
2557 }
2558
Retrieve(PTraceContext context,bool perf)2559 PyObject *UnsupportedTrace::Retrieve(PTraceContext context, bool perf) {
2560 PyObject *ret = Trace::Retrieve(context, perf);
2561 if (ret != nullptr) {
2562 return ret;
2563 }
2564 std::vector<PyObject *> params;
2565 bool fail = false;
2566 for (auto p : params_) {
2567 auto obj = p->Retrieve(context, perf);
2568 params.push_back(obj);
2569 if (p->GetTraceType() != TraceType::Unsupported) {
2570 // compare obj with original obj in trace for inputs of unsupported trace
2571 auto orig = p->GetObject();
2572 if (!IsPyObjectEqual(obj, orig)) {
2573 fail = true;
2574 break;
2575 }
2576 }
2577 if (params.back() == nullptr) {
2578 MS_LOG(DEBUG) << "Guard Check Retrieve fail for " + p->ToString();
2579 fail = true;
2580 break;
2581 }
2582 }
2583 TracePerf tp(this, perf, false);
2584 for (auto p : params) {
2585 Py_XDECREF(p);
2586 }
2587 if (fail) {
2588 return nullptr;
2589 } else {
2590 Cache(context, obj_);
2591 return obj_;
2592 }
2593 }
2594
ToString(bool include_param)2595 std::string UnsupportedTrace::ToString(bool include_param) {
2596 if (strTrace_.size() > 0) {
2597 return strTrace_;
2598 }
2599 std::string ret = "unsupported ";
2600 ret += Opcode(op_).name();
2601 ret += "(arg:";
2602 ret += std::to_string(arg_);
2603 if (include_param && params_.size() > 0) {
2604 ret += ",";
2605 for (auto t : params_) {
2606 ret += t->ToString(include_param) + ",";
2607 }
2608 ret = ret.substr(0, ret.size() - 1);
2609 }
2610 ret = ret + ")";
2611 ret = (is_const_ ? std::string("const:") : std::string("var:")) + ret;
2612 ret = std::regex_replace(ret, std::regex("(\n)"), "");
2613 strTrace_ = ret;
2614 return ret;
2615 }
2616
FormatString(std::map<Trace *,size_t> * cache)2617 std::string UnsupportedTrace::FormatString(std::map<Trace *, size_t> *cache) {
2618 std::stringstream s;
2619 std::stringstream params_str;
2620 params_str << "(";
2621 for (auto i : params_) {
2622 if (cache->find(i.get()) == cache->end()) {
2623 s << i->FormatString(cache) << std::endl;
2624 }
2625 params_str << "%" << (cache->find(i.get())->second) << ", ";
2626 if (i->GetTraceType() == TraceType::Unsupported) {
2627 params_str << "...";
2628 break;
2629 }
2630 }
2631 params_str << ")";
2632
2633 cache->insert(std::make_pair(this, cache->size()));
2634 s << "%" << cache->find(this)->second << " = unsupported " << Opcode(op_).name() << " " << arg_ << ": "
2635 << params_str.str();
2636 return s.str();
2637 }
2638
Info()2639 const InfoPack &UnsupportedTrace::Info() {
2640 if (info_ == nullptr) {
2641 InfoPack info;
2642 info << uint8_t(curType_);
2643 info.Begin();
2644 info << op_;
2645 info << arg_;
2646 info << uint64_t(params_.size());
2647 for (auto i : params_) {
2648 info << i->Info();
2649 }
2650 info.End();
2651 info_ = std::make_shared<InfoPack>(info);
2652 info_->Update();
2653 }
2654 return *info_;
2655 }
2656
SetRelaxCount(int cnt)2657 void UnsupportedTrace::SetRelaxCount(int cnt) {
2658 Trace::SetRelaxCount(cnt);
2659 for (auto param : params_) {
2660 param->SetRelaxCount(cnt);
2661 }
2662 }
2663
GetParams()2664 TraceVector UnsupportedTrace::GetParams() { return params_; }
2665
Detach()2666 void UnsupportedTrace::Detach() {
2667 Trace::Detach();
2668 for (auto t : params_) {
2669 t->Detach();
2670 }
2671 }
2672
Support(TraceType tt)2673 bool UnsupportedTrace::Support(TraceType tt) { return tt == TraceType::Unsupported; }
2674
GetObjectFromTrace(const PyFrameObject * frame,TracePtr trace,std::map<size_t,PyObject * > * cache,bool perf)2675 PyObject *GetObjectFromTrace(const PyFrameObject *frame, TracePtr trace, std::map<size_t, PyObject *> *cache,
2676 bool perf) {
2677 TraceContext context = {frame->f_globals, frame->f_builtins, frame->f_locals,
2678 frame->f_localsplus, frame->f_code, cache};
2679 if (trace != nullptr) {
2680 return trace->Retrieve(&context, perf);
2681 } else {
2682 return nullptr;
2683 }
2684 }
2685 } // namespace pijit
2686 } // namespace mindspore
2687