1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_BASE_H_ 18 #define MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_BASE_H_ 19 #include <set> 20 #include <mutex> 21 #include <memory> 22 #include <string> 23 24 #include "ir/anf.h" 25 26 namespace mindspore { 27 class FuncGraphBase; 28 using FuncGraphBasePtr = std::shared_ptr<FuncGraphBase>; 29 class MS_CORE_API FuncGraphLoopBreaker { 30 public: 31 ~FuncGraphLoopBreaker(); 32 33 static FuncGraphLoopBreaker &Inst(); 34 RegFuncGraphBase(FuncGraphBase * graph)35 void RegFuncGraphBase(FuncGraphBase *graph) { 36 std::lock_guard<std::mutex> lock_set(func_mutex_); 37 (void)func_set_.insert(graph); 38 } UnRegFuncGraphBase(FuncGraphBase * graph)39 void UnRegFuncGraphBase(FuncGraphBase *graph) { 40 std::lock_guard<std::mutex> lock_set(func_mutex_); 41 (void)func_set_.erase(graph); 42 } 43 44 void BreakLoop(); 45 46 void CleanMetaFuncGraphs(); 47 48 void CleanUnusedFuncGraphs(const std::string &phase); 49 50 void ClearCellGraphs(const std::string &phase); 51 52 void Dump() const; 53 54 private: 55 FuncGraphLoopBreaker() = default; 56 std::set<FuncGraphBase *> func_set_; 57 std::mutex func_mutex_; 58 }; 59 60 class FuncGraphChecker { 61 public: 62 FuncGraphChecker() = default; 63 template <typename... Ts> AddCheckFunc(const std::shared_ptr<std::function<bool (const Ts &...args)>> & func)64 void AddCheckFunc(const std::shared_ptr<std::function<bool(const Ts &... args)>> &func) { 65 func_ = func; 66 } 67 68 template <typename... Ts> Execute(const Ts &...args)69 bool Execute(const Ts &... args) const { 70 if (func_ == nullptr) { 71 return true; 72 } 73 auto func = reinterpret_cast<std::function<bool(const Ts &... args)> *>(func_.get()); 74 return (*func)(args...); 75 } 76 77 private: 78 std::shared_ptr<void> func_{nullptr}; 79 }; 80 81 class FuncGraphBase : public Value { 82 public: FuncGraphBase()83 FuncGraphBase() { 84 FuncGraphLoopBreaker::Inst().RegFuncGraphBase(this); 85 reg_flg_ = true; 86 } 87 ~FuncGraphBase()88 ~FuncGraphBase() override { 89 if (reg_flg_) { 90 FuncGraphLoopBreaker::Inst().UnRegFuncGraphBase(this); 91 } 92 } 93 MS_DECLARE_PARENT(FuncGraphBase, Value); 94 95 // Clear the member of FuncGraph to break loop 96 virtual void DoBreakLoop() = 0; 97 has_side_effect_node()98 bool has_side_effect_node() const { return has_side_effect_node_; } set_has_side_effect_node(bool has_side_effect_node)99 void set_has_side_effect_node(bool has_side_effect_node) { has_side_effect_node_ = has_side_effect_node; } 100 101 MS_CORE_API const FuncGraphChecker &GetChecker(const std::string &checker_name); 102 103 MS_CORE_API void AddChecker(const std::string &checker_name, const std::shared_ptr<FuncGraphChecker> &new_checker); 104 105 protected: 106 friend FuncGraphLoopBreaker; 107 bool reg_flg_{false}; 108 // If the subclass (such as FuncGraph) has started destructing. 109 bool subclass_destruct_flag_{false}; 110 111 private: 112 // If the nodes or their callee's nodes contain Depend CNode with isolated side-effect node. 113 bool has_side_effect_node_{false}; 114 HashMap<std::string, std::shared_ptr<FuncGraphChecker>> checkers_; 115 }; 116 } // namespace mindspore 117 #endif // MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_BASE_H_ 118