• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GUARD_H
17 #define MINDSPORE_PI_JIT_GUARD_H
18 
19 #include <memory>
20 #include <vector>
21 #include <map>
22 #include <stack>
23 #include <string>
24 #include <utility>
25 #include <tuple>
26 #include "pybind11/pybind11.h"
27 #include "include/common/utils/python_adapter.h"
28 #include "pipeline/jit/pi/graph_guard/trace.h"
29 #include "pipeline/jit/pi/graph_guard/guard_utils.h"
30 
31 namespace mindspore {
32 namespace pijit {
33 
34 typedef enum _GuardLevel {
35   GDeduce = 0,
36   GId,
37   GType,
38   GAttr,
39   GEqual,
40 } GuardLevel;
41 
42 using GuardItemVector = std::vector<GuardItemPtr>;
43 using GuardItemMap = std::map<size_t, GuardItemPtr>;
44 using GuardCheckPoint = std::pair<GuardItemVector, GuardItemMap>;
45 
46 class OptGuard : public std::enable_shared_from_this<OptGuard> {
47  public:
48   OptGuard();
49   virtual ~OptGuard() = default;
50   /// \brief check whether the variables guarded have been modified
51   /// \param[in] frame python frame
52   /// \param[in] print guard
53   /// \param[in] cache to reuse the guard result
54   /// \param[in] success to record the items to guard successfully
55   /// \param[in] fail to record the items which fail to guard
56   /// \param[in] perf to record the performance of guard
57   /// \param[out] the variables have been modified
58   virtual bool Check(const PyFrameObject *frame, bool print, std::map<size_t, PyObject *> *cache = nullptr,
59                      std::map<size_t, bool> *success = nullptr, std::map<size_t, bool> *fail = nullptr,
60                      bool perf = false);
61   /// \brief guard the variable which has trace to retrieve
62   /// \param[in] frame python frame
63   /// \param[in] var to trace the path to retrieve the object
64   /// \param[in] tp guard level
65   /// \param[in] needSpecialize to check the content of buffer
66   /// \param[in] recurseDepth to check the hierarchy element access like a.b.c by depth
67   /// \param[out] whether to guard successfully
68   virtual bool GuardOn(TracePtr var, GuardLevel tp = GuardLevel::GDeduce, bool needSpecialize = true,
69                        int recurseDepth = 0);
70   /// \brief add trace from guard, traces to replace in other guard
71   /// \param[in] traces to replace in other guard
72   /// \param[in] other guard with traces
73   virtual void AddTraceFromGuard(const std::vector<TracePtr> &traces, std::shared_ptr<OptGuard> other);
74   /// \brief return the description for the guard
75   virtual std::string GetDescript();
76   virtual void UpdateConfig(const std::map<std::string, bool> &bool_config,
77                             const std::map<std::string, int> &int_config);
78   virtual void Backup();
79   virtual void Rollback();
80   virtual void Pop();
IsEmpty()81   virtual bool IsEmpty() { return guardList_.size() == 0; }
82   virtual bool MatchShape(std::shared_ptr<OptGuard> other);
83   virtual std::vector<PyObject *> ApplyDynamicShape(PyFrameObject *frame);
84   virtual void RevertDynamicShape(PyFrameObject *frame, const std::vector<PyObject *> &backup);
85 
86   std::string ToString() const;
87   virtual const InfoPack &Info();
88   virtual std::shared_ptr<OptGuard> Optimize();
89 
90  protected:
91   void UpdateGuardList(GuardItemPtr item);
92   std::vector<GuardItemPtr> guardList_;
93   std::map<size_t, GuardItemPtr> guardMap_;
94   std::stack<GuardCheckPoint> guardStack_;
95   std::map<std::string, bool> bool_config_;
96   std::map<std::string, int> int_config_;
97   InfoPackPtr info_;
98 };
99 using OptGuardPtr = std::shared_ptr<OptGuard>;
100 
101 class OptGuardPerf {
102  public:
103   static OptGuardPerf *GetGuardPerf();
104   virtual void GetGuardPerfInfo(std::map<std::string, std::pair<size_t, size_t>> *guard_info,
105                                 std::map<std::string, std::pair<size_t, std::vector<size_t>>> *item_info,
106                                 std::map<std::string, std::pair<size_t, size_t>> *trace_info,
107                                 std::map<std::string, std::pair<size_t, size_t>> *guard_freq_info) const = 0;
108   virtual void LogTracePerfStart() = 0;
109   virtual void LogTracePerfEnd(Trace *trace, bool cache) = 0;
110   virtual void LogItemPerfStart(int total_stage) = 0;
111   virtual void LogItemPerfEnd(GuardItem *item, int stage) = 0;
112 
113  protected:
114   OptGuardPerf() = default;
115   virtual ~OptGuardPerf() = default;
116 };
117 
118 extern const char kSpecializeScalar[];
119 extern const char kSpecializeTensor[];
120 extern const char kSpecializeContainer[];
121 extern const char kGuardRelaxCnt[];
122 
123 }  // namespace pijit
124 }  // namespace mindspore
125 
126 #endif  // MINDSPORE_PI_JIT_GUARD_H
127