• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GUARD_UTILS_H
17 #define MINDSPORE_PI_JIT_GUARD_UTILS_H
18 
19 #include <memory>
20 #include <vector>
21 #include <map>
22 #include <string>
23 #include <utility>
24 #include <tuple>
25 #include "pipeline/jit/pi/pydef.h"
26 #include "include/common/utils/python_adapter.h"
27 #include "pipeline/jit/pi/graph_guard/trace.h"
28 #include "utils/convert_utils_base.h"
29 
30 namespace mindspore {
31 namespace pijit {
32 
33 typedef enum _GIType {
34   GTUnknown = 0,
35   GTEqual,
36   GTType,
37   GTId,
38   GTAttr,
39   GTRepr,
40 } GIType;
41 
42 class GuardItem : public std::enable_shared_from_this<GuardItem> {
43  public:
44   explicit GuardItem(TracePtr var);
45   virtual ~GuardItem() = default;
46   virtual bool Check(const PyFrameObject *frame, std::map<size_t, PyObject *> *cache = nullptr, bool perf = false) = 0;
47   virtual bool Check(PyObject *obj) = 0;
48   virtual std::string ToString() = 0;
49   virtual const InfoPack &Info() = 0;
50   virtual void Replace(TracePtr dst, TracePtr src);
51   virtual TracePtr GetTrace();
52   virtual bool operator==(const GuardItem &obj) const;
GetType()53   virtual GIType GetType() { return type_; }
MatchDynamicShape(std::shared_ptr<GuardItem> other)54   virtual bool MatchDynamicShape(std::shared_ptr<GuardItem> other) { return false; }
ApplyDynamicShape(PyObject * obj)55   virtual PyObject *ApplyDynamicShape(PyObject *obj) { return nullptr; }
56   virtual std::shared_ptr<GuardItem> Optimize();
This()57   virtual std::shared_ptr<GuardItem> This() { return shared_from_this(); }
58 
59  protected:
60   TracePtr var_;
61   GIType type_;
62   InfoPackPtr info_;
63   std::string strGuard_;
64 };
65 using GuardItemPtr = std::shared_ptr<GuardItem>;
66 
67 /// \brief check whether elements are equal
68 /// \param[in] obj
69 /// \param[in] needSpecialize to check the content of buffer
70 /// \param[in] recurseDepth to check the hierarchy element access like a.b.c by depth
71 GuardItemPtr GuardEqual(TracePtr obj, bool needSpecialize = true, int recurseDepth = INT_MAX);
72 GuardItemPtr GuardType(TracePtr obj);
73 GuardItemPtr GuardId(TracePtr obj);
74 GuardItemPtr GuardAttr(TracePtr obj);
75 GuardItemPtr GuardRepr(TracePtr obj);
76 bool IsPyObjectEqual(PyObject *src, PyObject *dst);
77 PyObject *GetMsModule();
78 PyObject *GetMsType();
79 PyObject *GetMsTensorType();
80 
81 }  // namespace pijit
82 }  // namespace mindspore
83 
84 #endif  // MINDSPORE_PI_JIT_GUARD_UTILS_H
85