1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_ 20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_ 21 22 #include <memory> 23 #include <string> 24 #include <stdexcept> 25 #include <unordered_set> 26 #include <unordered_map> 27 #include <utility> 28 #include <vector> 29 30 #include "ir/anf.h" 31 #include "ir/func_graph_cloner.h" 32 #include "pipeline/jit/static_analysis/evaluator.h" 33 34 namespace mindspore { 35 namespace abstract { 36 enum SpecializeStatusCode { 37 kSpecializeSuccess = 0, 38 kSpecializeFindUniqueArgvalDead = 1, // Dead Node 39 kSpecializeFindUniqueArgvalPoly = 2, // Poly Node 40 kSpecializeFailure = 0xFF 41 }; 42 43 class FuncGraphSpecializer; 44 using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>; 45 46 // Specialize a func graph using analyzed abstract values. 47 class ProgramSpecializer { 48 public: ProgramSpecializer(const std::shared_ptr<AnalysisEngine> & engine)49 explicit ProgramSpecializer(const std::shared_ptr<AnalysisEngine> &engine) : engine_(engine), top_context_(nullptr) { 50 mng_ = engine_->func_graph_manager(); 51 } 52 ~ProgramSpecializer() = default; 53 // Run the program specializer on the topmost graph in the given context. 54 FuncGraphPtr Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context); seen()55 const std::unordered_set<AnfNodePtr> &seen() const { return seen_; } AddSeen(const AnfNodePtr & node)56 void AddSeen(const AnfNodePtr &node) { (void)seen_.insert(node); } 57 58 std::shared_ptr<FuncGraphSpecializer> GetFuncGraphSpecializer(const AnalysisContextPtr &context); 59 // Specialze one FuncGraph in a given context. 60 FuncGraphPtr SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context); 61 engine()62 std::shared_ptr<AnalysisEngine> engine() { return engine_; } 63 top_context()64 AnalysisContextPtr top_context() { return top_context_; } 65 66 private: 67 std::shared_ptr<AnalysisEngine> engine_; 68 std::unordered_set<AnfNodePtr> seen_; 69 FuncGraphManagerPtr mng_; 70 std::unordered_map<AnalysisContextPtr, std::shared_ptr<FuncGraphSpecializer>, ContextHasher, ContextEqual> 71 specializations_; 72 AnalysisContextPtr top_context_; 73 }; 74 75 class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> { 76 public: 77 FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context); ~FuncGraphSpecializer()78 virtual ~FuncGraphSpecializer() { 79 specializer_ = nullptr; 80 repl_node_ = nullptr; 81 } 82 void Run(); specialized_func_graph()83 FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; } 84 85 std::shared_ptr<FuncGraphSpecializer> GetTopSpecializer(const AnfNodePtr &node); 86 87 private: 88 ProgramSpecializer *specializer_; 89 FuncGraphPtr func_graph_; 90 FuncGraphPtr specialized_func_graph_; 91 AnalysisContextPtr context_; 92 std::shared_ptr<FuncGraphSpecializer> parent_; 93 std::shared_ptr<AnalysisEngine> engine_; 94 ClonerPtr cloner_; 95 // ProcessNode-> [cloner_->CloneDisconnected] will clone AnfNode again. 96 // So, repl_node_ should pointer to GraphCloner->repl_node_ other than a copy of that. 97 std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_node_; 98 std::vector<AnfNodePtr> todo_; 99 std::unordered_set<AnfNodePtr> marked_; 100 std::unordered_map<EvaluatorPtr, EvaluatorCacheMgrPtr> evalcaches_; 101 102 void FirstPass(); 103 void SecondPass(); 104 void ProcessNode(const AnfNodePtr &node); 105 void ProcessCNode(const CNodePtr &new_node); 106 107 inline AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node); 108 inline AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const BaseFuncGraphEvaluatorPtr &evaluator, 109 const AbstractBasePtrList &args_spec_list); 110 AddTodoItem(const AnfNodePtr & node)111 inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); } 112 // Get node replicated by Cloner. 113 AnfNodePtr GetReplicatedNode(const AnfNodePtr &node); 114 // Replicated node which is not used directly by a func graph, so it's not searchable from it's return node 115 // (disconnected). 116 AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node); 117 118 // Build a value node from parameter if the function graph has special flag to hint it can be done. 119 AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); 120 121 // Build a value node if ival is constant and not any-value 122 AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, 123 const AttrValueMapPtr &attrs); 124 // Build a replaceable node for iconf->node; it may be a replicated forwarded CNode in static analysis or just a 125 // replicated node. 126 AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); 127 // Build a specialized node from given argvals; 128 AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, 129 const AbstractBasePtrList &argvals); 130 AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, 131 const AbstractFunctionPtr &func, const AbstractBasePtrList &args, 132 SpecializeStatusCode *errcode); 133 134 // Find the unique argument values which can be used to specialize a primitive or graph function. 135 SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval, 136 const AbstractBasePtrList &argvals, 137 std::pair<AbstractBasePtrList, AbstractBasePtr> *result); 138 // Get cache, it may be eval's cache or cache built from broaded argument values. 139 const EvaluatorCacheMgrPtr GetEvalCache(const EvaluatorPtr &eval); 140 // Try to build unique argvals from the broaded arg vals if it is unique. 141 std::pair<AbstractBasePtrList, AbstractBasePtr> BuildFromBroadedArgsVal(const EvaluatorPtr &eval); 142 void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node); 143 }; 144 } // namespace abstract 145 } // namespace mindspore 146 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_ 147