• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_
21 
22 #include <memory>
23 #include <string>
24 #include <stdexcept>
25 #include <utility>
26 #include <vector>
27 #include <unordered_map>
28 #include <stack>
29 
30 #include "utils/hash_map.h"
31 #include "utils/hash_set.h"
32 #include "ir/anf.h"
33 #include "ir/func_graph_cloner.h"
34 #include "pipeline/jit/ps/static_analysis/evaluator.h"
35 
36 namespace mindspore {
37 namespace abstract {
38 enum SpecializeStatusCode {
39   kSpecializeSuccess = 0,
40   kSpecializeDead = 1,  // Dead Node
41   kSpecializePoly = 2,  // Poly Node
42   kSpecializeFailure = 0xFF
43 };
44 
45 class FuncGraphSpecializer;
46 using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>;
47 
48 // Specialize a func graph using analyzed abstract values.
49 class ProgramSpecializer {
50  public:
51   using DeferSpecializeNodesMap =
52     mindspore::HashMap<AnalysisContextPtr, std::vector<std::pair<FuncGraphSpecializer *, CNodePtr>>>;
ProgramSpecializer(const std::shared_ptr<AnalysisEngine> & engine)53   explicit ProgramSpecializer(const std::shared_ptr<AnalysisEngine> &engine) : engine_(engine), top_context_(nullptr) {
54     manager_ = engine_->func_graph_manager();
55   }
56   ~ProgramSpecializer() = default;
57   // Run the program specializer on the topmost graph in the given context.
58   FuncGraphPtr Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context);
seen()59   const mindspore::HashSet<AnfNodePtr> &seen() const { return seen_; }
AddSeen(const AnfNodePtr & node)60   void AddSeen(const AnfNodePtr &node) { (void)seen_.insert(node); }
EraseSeen(const AnfNodePtr & node)61   void EraseSeen(const AnfNodePtr &node) { (void)seen_.erase(node); }
62 
63   std::shared_ptr<FuncGraphSpecializer> GetFuncGraphSpecializer(const AnalysisContextPtr &context);
64 
65   std::shared_ptr<FuncGraphSpecializer> NewFuncGraphSpecializer(const AnalysisContextPtr &context,
66                                                                 const FuncGraphPtr &fg);
67 
engine()68   std::shared_ptr<AnalysisEngine> engine() { return engine_; }
manager()69   FuncGraphManagerPtr manager() { return manager_; }
70 
top_context()71   const AnalysisContextPtr &top_context() const { return top_context_; }
72   void SetSpecializedAbstract(const AbstractFunctionPtr &old_abs_func, const AbstractFunctionPtr &new_abs_func,
73                               const CNodePtr &cnode, const AnfNodePtr &func);
74   AbstractFunctionPtr GetSpecializedAbstract(const AbstractFunctionPtr &old_abs_func);
75   void SpecializeCNodeInput0FuncGraph();
76 
sequence_abstract_list()77   std::vector<std::pair<AbstractSequencePtr, AnfNodePtr>> &sequence_abstract_list() { return sequence_abstract_list_; }
dead_node_list()78   std::vector<std::pair<AnfNodePtr, size_t>> &dead_node_list() { return dead_node_list_; }
defer_specialize_nodes()79   const DeferSpecializeNodesMap &defer_specialize_nodes() const { return defer_specialize_nodes_; }
80 
AddDeferSpecializeNode(const AnalysisContextPtr & context,const CNodePtr & cnode,FuncGraphSpecializer * specializer)81   void AddDeferSpecializeNode(const AnalysisContextPtr &context, const CNodePtr &cnode,
82                               FuncGraphSpecializer *specializer) {
83     (void)defer_specialize_nodes_[context].emplace_back(std::make_pair(specializer, cnode));
84   }
85 
86   template <typename Iter>
RemoveDeferSpecializeNode(Iter && iter)87   void RemoveDeferSpecializeNode(Iter &&iter) {
88     (void)defer_specialize_nodes_.erase(std::forward<Iter>(iter));
89   }
90 
SetSpecializedFuncGraphToAbstract(const FuncGraphPtr & before_specialized_fg,const AbstractFunctionPtr & specialized_abs)91   void SetSpecializedFuncGraphToAbstract(const FuncGraphPtr &before_specialized_fg,
92                                          const AbstractFunctionPtr &specialized_abs) {
93     auto iter = func_graph_to_abstract_map_.find(before_specialized_fg);
94     if (iter != func_graph_to_abstract_map_.end()) {
95       iter->second.first = false;
96       return;
97     }
98     (void)func_graph_to_abstract_map_.emplace(before_specialized_fg, std::make_pair(true, specialized_abs));
99   }
100 
GetUniqueFuncGraphAbstract(const FuncGraphPtr & before_specialized_fg)101   AbstractFunctionPtr GetUniqueFuncGraphAbstract(const FuncGraphPtr &before_specialized_fg) {
102     auto iter = func_graph_to_abstract_map_.find(before_specialized_fg);
103     if (iter != func_graph_to_abstract_map_.end()) {
104       if (iter->second.first) {
105         return iter->second.second;
106       }
107     }
108     return nullptr;
109   }
110 
PushFuncGraphTodoItem(const std::shared_ptr<FuncGraphSpecializer> & fg_spec)111   void PushFuncGraphTodoItem(const std::shared_ptr<FuncGraphSpecializer> &fg_spec) {
112     func_graph_todo_items_.push(fg_spec);
113   }
114 
115  private:
116   std::shared_ptr<AnalysisEngine> engine_;
117   mindspore::HashSet<AnfNodePtr> seen_;
118   FuncGraphManagerPtr manager_;
119   mindspore::HashMap<AnalysisContextPtr, std::shared_ptr<FuncGraphSpecializer>> specializations_;
120   // If caller's input0 is a poly func, and the func's parent has not been specialized, then the caller specialization
121   // need to be deferred after parent specialized.
122   DeferSpecializeNodesMap defer_specialize_nodes_;
123   AnalysisContextPtr top_context_;
124   // The list to purify tuple/list elements.
125   std::vector<std::pair<AbstractSequencePtr, AnfNodePtr>> sequence_abstract_list_;
126   // The list to erase the DeadNode in tuple/list elements.
127   std::vector<std::pair<AnfNodePtr, size_t>> dead_node_list_;
128   // Map for unspecialized abstract function to specialized abstract; bool flag is used to indicate if it's unique.
129   std::unordered_map<AbstractFunctionPtr, std::pair<bool, AbstractFunctionPtr>, AbstractFunctionHasher,
130                      AbstractFunctionEqual>
131     specialized_abs_map_;
132   // Map from unspecialized func_graph to <bool, specialized abstract>, bool flag is used to indicate if it's unique.
133   // This is the remedial action to find the specialized abstract function for func_graph if the abstract is unique.
134   mindspore::HashMap<FuncGraphPtr, std::pair<bool, AbstractFunctionPtr>> func_graph_to_abstract_map_;
135 
136   AbstractFunctionPtr SpecializeAbstractFuncRecursively(const AbstractFunctionPtr &old_abs_func);
137 
138   std::stack<std::shared_ptr<FuncGraphSpecializer>> func_graph_todo_items_;
139 };
140 
141 class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> {
142  public:
143   FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context);
~FuncGraphSpecializer()144   virtual ~FuncGraphSpecializer() { specializer_ = nullptr; }
145   void Run();
specialized_func_graph()146   FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; }
147 
148   std::shared_ptr<FuncGraphSpecializer> GetTopSpecializer(const AnfNodePtr &node);
149 
done()150   bool done() const { return done_; }
151 
152  private:
153   ProgramSpecializer *specializer_;
154   FuncGraphPtr func_graph_;
155   FuncGraphPtr specialized_func_graph_;
156   AnalysisContextPtr context_;
157   std::shared_ptr<FuncGraphSpecializer> parent_;
158   std::shared_ptr<AnalysisEngine> engine_;
159   ClonerPtr cloner_;
160   std::vector<AnfNodePtr> todo_;
161   mindspore::HashSet<AnfNodePtr> marked_;
162   mindspore::HashMap<EvaluatorPtr, EvaluatorCacheMgrPtr> eval_cache_;
163   std::vector<CNodePtr> second_pass_todo_;
164   size_t second_pass_todo_index_{0};
165   bool done_{false};
166 
167   void FirstPass();
168   void SecondPass();
169   void ProcessNode(const AnfNodePtr &node);
170   bool ProcessCNode(const CNodePtr &cnode);
171   void ProcessCNodeEnd(const CNodePtr &cnode, const AnfNodeWeakPtrList &new_weak_inputs);
172   bool ProcessSwitchAppCNode(const CNodePtr &cnode);
173   bool ParentNotSpecialized(const AnalysisContextPtr &context) const;
174 
175   void EliminateUnusedSequenceItem(const CNodePtr &cnode) const;
176 
cloned_nodes()177   const NodeToNodeMap &cloned_nodes() const { return cloner_->cloned_nodes(); }
178 
MakeConfig(const AnfNodePtr & node)179   inline AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node) {
180     return engine_->MakeConfig(node, context_, func_graph_);  // 'func_graph_' is dummy here.
181   }
182 
183   inline AnalysisContextPtr GetAnalysisContext(const AnalysisEnginePtr &engine,
184                                                const BaseFuncGraphEvaluatorPtr &evaluator,
185                                                const AbstractBasePtrList &args_abs_list) const;
186 
AddTodoItem(const AnfNodePtr & node)187   inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); }
AddTodoItem(const std::vector<AnfNodePtr> & nodes)188   inline void AddTodoItem(const std::vector<AnfNodePtr> &nodes) {
189     (void)todo_.insert(todo_.cend(), nodes.cbegin(), nodes.cend());
190   }
191   // Get node replicated by Cloner.
192   AnfNodePtr GetReplicatedNode(const AnfNodePtr &node);
193   // Replicated node which is not used directly by a func graph, so it's not searchable from it's return node
194   // (disconnected).
195   AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node);
196 
197   // Build a value node from parameter if the function graph has special flag to hint it can be done.
198   AnfNodePtr BuildSpecializedParameterCNode(const CNodePtr &cnode);
199   // Build a value node if ival is a function.
200   AnfNodePtr BuildValueNodeForAbstractFunction(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
201                                                const AttrValueMapPtr &attrs, const AnfNodePtr &cnode,
202                                                const AbstractFunctionPtr &abs);
203   // Build a value node if ival is constant and not any-value
204   AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
205                                     const AttrValueMapPtr &attrs, const AnfNodePtr &cnode = nullptr);
206   // Build a replaceable node for iconf->node; it may be a replicated forwarded CNode in static analysis or just a
207   // replicated node.
208   AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);
209   // Build a specialized node from given argvals;
210   AnfNodePtr BuildSpecializedNode(const CNodePtr &cnode, const AnfNodePtr &func, const AbstractBasePtr &abs,
211                                   const AbstractBasePtrList &argvals);
212   AnfNodePtr BuildSpecializedNodeInner(const CNodePtr &cnode, const AnfNodePtr &func, const AbstractBasePtr &abs,
213                                        const AbstractFunctionPtr &func_abs, const AbstractBasePtrList &args,
214                                        SpecializeStatusCode *errcode);
215 
216   // Find the unique argument values which can be used to specialize a primitive or graph function.
217   SpecializeStatusCode AcquireUniqueEvalResult(const AbstractFunctionPtr &func, const EvaluatorPtr &eval,
218                                                const AbstractBasePtrList &argvals,
219                                                std::pair<AbstractBasePtrList, AbstractBasePtr> *res);
220   // Get cache, it may be eval's cache or cache built from broaded argument values.
221   const EvaluatorCacheMgrPtr GetEvalCache(const EvaluatorPtr &eval);
222   // Try to build unique argvals from the broaded arg vals if it is unique.
223   std::pair<AbstractBasePtrList, AbstractBasePtr> BuildFromBroadedArgs(const EvaluatorPtr &eval);
224   void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node);
225 };
226 using FuncGraphSpecializerPtr = std::shared_ptr<FuncGraphSpecializer>;
227 }  // namespace abstract
228 }  // namespace mindspore
229 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_
230