1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_GUARD_H 17 #define MINDSPORE_PI_JIT_GUARD_H 18 19 #include <memory> 20 #include <vector> 21 #include <map> 22 #include <stack> 23 #include <string> 24 #include <utility> 25 #include <tuple> 26 #include "pybind11/pybind11.h" 27 #include "include/common/utils/python_adapter.h" 28 #include "pipeline/jit/pi/graph_guard/trace.h" 29 #include "pipeline/jit/pi/graph_guard/guard_utils.h" 30 31 namespace mindspore { 32 namespace pijit { 33 34 typedef enum _GuardLevel { 35 GDeduce = 0, 36 GId, 37 GType, 38 GAttr, 39 GEqual, 40 } GuardLevel; 41 42 using GuardItemVector = std::vector<GuardItemPtr>; 43 using GuardItemMap = std::map<size_t, GuardItemPtr>; 44 using GuardCheckPoint = std::pair<GuardItemVector, GuardItemMap>; 45 46 class OptGuard : public std::enable_shared_from_this<OptGuard> { 47 public: 48 OptGuard(); 49 virtual ~OptGuard() = default; 50 /// \brief check whether the variables guarded have been modified 51 /// \param[in] frame python frame 52 /// \param[in] print guard 53 /// \param[in] cache to reuse the guard result 54 /// \param[in] success to record the items to guard successfully 55 /// \param[in] fail to record the items which fail to guard 56 /// \param[in] perf to record the performance of guard 57 /// \param[out] the variables have been modified 58 virtual bool Check(const PyFrameObject *frame, bool print, std::map<size_t, PyObject *> *cache = nullptr, 59 std::map<size_t, bool> *success = nullptr, std::map<size_t, bool> *fail = nullptr, 60 bool perf = false); 61 /// \brief guard the variable which has trace to retrieve 62 /// \param[in] frame python frame 63 /// \param[in] var to trace the path to retrieve the object 64 /// \param[in] tp guard level 65 /// \param[in] needSpecialize to check the content of buffer 66 /// \param[in] recurseDepth to check the hierarchy element access like a.b.c by depth 67 /// \param[out] whether to guard successfully 68 virtual bool GuardOn(TracePtr var, GuardLevel tp = GuardLevel::GDeduce, bool needSpecialize = true, 69 int recurseDepth = 0); 70 /// \brief add trace from guard, traces to replace in other guard 71 /// \param[in] traces to replace in other guard 72 /// \param[in] other guard with traces 73 virtual void AddTraceFromGuard(const std::vector<TracePtr> &traces, std::shared_ptr<OptGuard> other); 74 /// \brief return the description for the guard 75 virtual std::string GetDescript(); 76 virtual void UpdateConfig(const std::map<std::string, bool> &bool_config, 77 const std::map<std::string, int> &int_config); 78 virtual void Backup(); 79 virtual void Rollback(); 80 virtual void Pop(); IsEmpty()81 virtual bool IsEmpty() { return guardList_.size() == 0; } 82 virtual bool MatchShape(std::shared_ptr<OptGuard> other); 83 virtual std::vector<PyObject *> ApplyDynamicShape(PyFrameObject *frame); 84 virtual void RevertDynamicShape(PyFrameObject *frame, const std::vector<PyObject *> &backup); 85 86 std::string ToString() const; 87 virtual const InfoPack &Info(); 88 virtual std::shared_ptr<OptGuard> Optimize(); 89 90 protected: 91 void UpdateGuardList(GuardItemPtr item); 92 std::vector<GuardItemPtr> guardList_; 93 std::map<size_t, GuardItemPtr> guardMap_; 94 std::stack<GuardCheckPoint> guardStack_; 95 std::map<std::string, bool> bool_config_; 96 std::map<std::string, int> int_config_; 97 InfoPackPtr info_; 98 }; 99 using OptGuardPtr = std::shared_ptr<OptGuard>; 100 101 class OptGuardPerf { 102 public: 103 static OptGuardPerf *GetGuardPerf(); 104 virtual void GetGuardPerfInfo(std::map<std::string, std::pair<size_t, size_t>> *guard_info, 105 std::map<std::string, std::pair<size_t, std::vector<size_t>>> *item_info, 106 std::map<std::string, std::pair<size_t, size_t>> *trace_info, 107 std::map<std::string, std::pair<size_t, size_t>> *guard_freq_info) const = 0; 108 virtual void LogTracePerfStart() = 0; 109 virtual void LogTracePerfEnd(Trace *trace, bool cache) = 0; 110 virtual void LogItemPerfStart(int total_stage) = 0; 111 virtual void LogItemPerfEnd(GuardItem *item, int stage) = 0; 112 113 protected: 114 OptGuardPerf() = default; 115 virtual ~OptGuardPerf() = default; 116 }; 117 118 extern const char kSpecializeScalar[]; 119 extern const char kSpecializeTensor[]; 120 extern const char kSpecializeContainer[]; 121 extern const char kGuardRelaxCnt[]; 122 123 } // namespace pijit 124 } // namespace mindspore 125 126 #endif // MINDSPORE_PI_JIT_GUARD_H 127