1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2023 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_ 20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_ 21 22 #include <memory> 23 #include <string> 24 #include <stdexcept> 25 #include <utility> 26 #include <vector> 27 #include <unordered_map> 28 #include <stack> 29 30 #include "utils/hash_map.h" 31 #include "utils/hash_set.h" 32 #include "ir/anf.h" 33 #include "ir/func_graph_cloner.h" 34 #include "pipeline/jit/ps/static_analysis/evaluator.h" 35 36 namespace mindspore { 37 namespace abstract { 38 enum SpecializeStatusCode { 39 kSpecializeSuccess = 0, 40 kSpecializeDead = 1, // Dead Node 41 kSpecializePoly = 2, // Poly Node 42 kSpecializeFailure = 0xFF 43 }; 44 45 class FuncGraphSpecializer; 46 using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>; 47 48 // Specialize a func graph using analyzed abstract values. 49 class ProgramSpecializer { 50 public: 51 using DeferSpecializeNodesMap = 52 mindspore::HashMap<AnalysisContextPtr, std::vector<std::pair<FuncGraphSpecializer *, CNodePtr>>>; ProgramSpecializer(const std::shared_ptr<AnalysisEngine> & engine)53 explicit ProgramSpecializer(const std::shared_ptr<AnalysisEngine> &engine) : engine_(engine), top_context_(nullptr) { 54 manager_ = engine_->func_graph_manager(); 55 } 56 ~ProgramSpecializer() = default; 57 // Run the program specializer on the topmost graph in the given context. 58 FuncGraphPtr Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context); seen()59 const mindspore::HashSet<AnfNodePtr> &seen() const { return seen_; } AddSeen(const AnfNodePtr & node)60 void AddSeen(const AnfNodePtr &node) { (void)seen_.insert(node); } EraseSeen(const AnfNodePtr & node)61 void EraseSeen(const AnfNodePtr &node) { (void)seen_.erase(node); } 62 63 std::shared_ptr<FuncGraphSpecializer> GetFuncGraphSpecializer(const AnalysisContextPtr &context); 64 65 std::shared_ptr<FuncGraphSpecializer> NewFuncGraphSpecializer(const AnalysisContextPtr &context, 66 const FuncGraphPtr &fg); 67 engine()68 std::shared_ptr<AnalysisEngine> engine() { return engine_; } manager()69 FuncGraphManagerPtr manager() { return manager_; } 70 top_context()71 const AnalysisContextPtr &top_context() const { return top_context_; } 72 void SetSpecializedAbstract(const AbstractFunctionPtr &old_abs_func, const AbstractFunctionPtr &new_abs_func, 73 const CNodePtr &cnode, const AnfNodePtr &func); 74 AbstractFunctionPtr GetSpecializedAbstract(const AbstractFunctionPtr &old_abs_func); 75 void SpecializeCNodeInput0FuncGraph(); 76 sequence_abstract_list()77 std::vector<std::pair<AbstractSequencePtr, AnfNodePtr>> &sequence_abstract_list() { return sequence_abstract_list_; } dead_node_list()78 std::vector<std::pair<AnfNodePtr, size_t>> &dead_node_list() { return dead_node_list_; } defer_specialize_nodes()79 const DeferSpecializeNodesMap &defer_specialize_nodes() const { return defer_specialize_nodes_; } 80 AddDeferSpecializeNode(const AnalysisContextPtr & context,const CNodePtr & cnode,FuncGraphSpecializer * specializer)81 void AddDeferSpecializeNode(const AnalysisContextPtr &context, const CNodePtr &cnode, 82 FuncGraphSpecializer *specializer) { 83 (void)defer_specialize_nodes_[context].emplace_back(std::make_pair(specializer, cnode)); 84 } 85 86 template <typename Iter> RemoveDeferSpecializeNode(Iter && iter)87 void RemoveDeferSpecializeNode(Iter &&iter) { 88 (void)defer_specialize_nodes_.erase(std::forward<Iter>(iter)); 89 } 90 SetSpecializedFuncGraphToAbstract(const FuncGraphPtr & before_specialized_fg,const AbstractFunctionPtr & specialized_abs)91 void SetSpecializedFuncGraphToAbstract(const FuncGraphPtr &before_specialized_fg, 92 const AbstractFunctionPtr &specialized_abs) { 93 auto iter = func_graph_to_abstract_map_.find(before_specialized_fg); 94 if (iter != func_graph_to_abstract_map_.end()) { 95 iter->second.first = false; 96 return; 97 } 98 (void)func_graph_to_abstract_map_.emplace(before_specialized_fg, std::make_pair(true, specialized_abs)); 99 } 100 GetUniqueFuncGraphAbstract(const FuncGraphPtr & before_specialized_fg)101 AbstractFunctionPtr GetUniqueFuncGraphAbstract(const FuncGraphPtr &before_specialized_fg) { 102 auto iter = func_graph_to_abstract_map_.find(before_specialized_fg); 103 if (iter != func_graph_to_abstract_map_.end()) { 104 if (iter->second.first) { 105 return iter->second.second; 106 } 107 } 108 return nullptr; 109 } 110 PushFuncGraphTodoItem(const std::shared_ptr<FuncGraphSpecializer> & fg_spec)111 void PushFuncGraphTodoItem(const std::shared_ptr<FuncGraphSpecializer> &fg_spec) { 112 func_graph_todo_items_.push(fg_spec); 113 } 114 115 private: 116 std::shared_ptr<AnalysisEngine> engine_; 117 mindspore::HashSet<AnfNodePtr> seen_; 118 FuncGraphManagerPtr manager_; 119 mindspore::HashMap<AnalysisContextPtr, std::shared_ptr<FuncGraphSpecializer>> specializations_; 120 // If caller's input0 is a poly func, and the func's parent has not been specialized, then the caller specialization 121 // need to be deferred after parent specialized. 122 DeferSpecializeNodesMap defer_specialize_nodes_; 123 AnalysisContextPtr top_context_; 124 // The list to purify tuple/list elements. 125 std::vector<std::pair<AbstractSequencePtr, AnfNodePtr>> sequence_abstract_list_; 126 // The list to erase the DeadNode in tuple/list elements. 127 std::vector<std::pair<AnfNodePtr, size_t>> dead_node_list_; 128 // Map for unspecialized abstract function to specialized abstract; bool flag is used to indicate if it's unique. 129 std::unordered_map<AbstractFunctionPtr, std::pair<bool, AbstractFunctionPtr>, AbstractFunctionHasher, 130 AbstractFunctionEqual> 131 specialized_abs_map_; 132 // Map from unspecialized func_graph to <bool, specialized abstract>, bool flag is used to indicate if it's unique. 133 // This is the remedial action to find the specialized abstract function for func_graph if the abstract is unique. 134 mindspore::HashMap<FuncGraphPtr, std::pair<bool, AbstractFunctionPtr>> func_graph_to_abstract_map_; 135 136 AbstractFunctionPtr SpecializeAbstractFuncRecursively(const AbstractFunctionPtr &old_abs_func); 137 138 std::stack<std::shared_ptr<FuncGraphSpecializer>> func_graph_todo_items_; 139 }; 140 141 class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> { 142 public: 143 FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context); ~FuncGraphSpecializer()144 virtual ~FuncGraphSpecializer() { specializer_ = nullptr; } 145 void Run(); specialized_func_graph()146 FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; } 147 148 std::shared_ptr<FuncGraphSpecializer> GetTopSpecializer(const AnfNodePtr &node); 149 done()150 bool done() const { return done_; } 151 152 private: 153 ProgramSpecializer *specializer_; 154 FuncGraphPtr func_graph_; 155 FuncGraphPtr specialized_func_graph_; 156 AnalysisContextPtr context_; 157 std::shared_ptr<FuncGraphSpecializer> parent_; 158 std::shared_ptr<AnalysisEngine> engine_; 159 ClonerPtr cloner_; 160 std::vector<AnfNodePtr> todo_; 161 mindspore::HashSet<AnfNodePtr> marked_; 162 mindspore::HashMap<EvaluatorPtr, EvaluatorCacheMgrPtr> eval_cache_; 163 std::vector<CNodePtr> second_pass_todo_; 164 size_t second_pass_todo_index_{0}; 165 bool done_{false}; 166 167 void FirstPass(); 168 void SecondPass(); 169 void ProcessNode(const AnfNodePtr &node); 170 bool ProcessCNode(const CNodePtr &cnode); 171 void ProcessCNodeEnd(const CNodePtr &cnode, const AnfNodeWeakPtrList &new_weak_inputs); 172 bool ProcessSwitchAppCNode(const CNodePtr &cnode); 173 bool ParentNotSpecialized(const AnalysisContextPtr &context) const; 174 175 void EliminateUnusedSequenceItem(const CNodePtr &cnode) const; 176 cloned_nodes()177 const NodeToNodeMap &cloned_nodes() const { return cloner_->cloned_nodes(); } 178 MakeConfig(const AnfNodePtr & node)179 inline AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node) { 180 return engine_->MakeConfig(node, context_, func_graph_); // 'func_graph_' is dummy here. 181 } 182 183 inline AnalysisContextPtr GetAnalysisContext(const AnalysisEnginePtr &engine, 184 const BaseFuncGraphEvaluatorPtr &evaluator, 185 const AbstractBasePtrList &args_abs_list) const; 186 AddTodoItem(const AnfNodePtr & node)187 inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); } AddTodoItem(const std::vector<AnfNodePtr> & nodes)188 inline void AddTodoItem(const std::vector<AnfNodePtr> &nodes) { 189 (void)todo_.insert(todo_.cend(), nodes.cbegin(), nodes.cend()); 190 } 191 // Get node replicated by Cloner. 192 AnfNodePtr GetReplicatedNode(const AnfNodePtr &node); 193 // Replicated node which is not used directly by a func graph, so it's not searchable from it's return node 194 // (disconnected). 195 AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node); 196 197 // Build a value node from parameter if the function graph has special flag to hint it can be done. 198 AnfNodePtr BuildSpecializedParameterCNode(const CNodePtr &cnode); 199 // Build a value node if ival is a function. 200 AnfNodePtr BuildValueNodeForAbstractFunction(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, 201 const AttrValueMapPtr &attrs, const AnfNodePtr &cnode, 202 const AbstractFunctionPtr &abs); 203 // Build a value node if ival is constant and not any-value 204 AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, 205 const AttrValueMapPtr &attrs, const AnfNodePtr &cnode = nullptr); 206 // Build a replaceable node for iconf->node; it may be a replicated forwarded CNode in static analysis or just a 207 // replicated node. 208 AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); 209 // Build a specialized node from given argvals; 210 AnfNodePtr BuildSpecializedNode(const CNodePtr &cnode, const AnfNodePtr &func, const AbstractBasePtr &abs, 211 const AbstractBasePtrList &argvals); 212 AnfNodePtr BuildSpecializedNodeInner(const CNodePtr &cnode, const AnfNodePtr &func, const AbstractBasePtr &abs, 213 const AbstractFunctionPtr &func_abs, const AbstractBasePtrList &args, 214 SpecializeStatusCode *errcode); 215 216 // Find the unique argument values which can be used to specialize a primitive or graph function. 217 SpecializeStatusCode AcquireUniqueEvalResult(const AbstractFunctionPtr &func, const EvaluatorPtr &eval, 218 const AbstractBasePtrList &argvals, 219 std::pair<AbstractBasePtrList, AbstractBasePtr> *res); 220 // Get cache, it may be eval's cache or cache built from broaded argument values. 221 const EvaluatorCacheMgrPtr GetEvalCache(const EvaluatorPtr &eval); 222 // Try to build unique argvals from the broaded arg vals if it is unique. 223 std::pair<AbstractBasePtrList, AbstractBasePtr> BuildFromBroadedArgs(const EvaluatorPtr &eval); 224 void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node); 225 }; 226 using FuncGraphSpecializerPtr = std::shared_ptr<FuncGraphSpecializer>; 227 } // namespace abstract 228 } // namespace mindspore 229 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_ 230