• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
21 
22 #include <list>
23 #include <memory>
24 #include <string>
25 #include <unordered_map>
26 #include <vector>
27 #include <utility>
28 #include <map>
29 #include <set>
30 #include <unordered_set>
31 #include <mutex>
32 
33 #ifdef DEBUG
34 #include <stack>
35 #endif
36 
37 #include "utils/log_adapter.h"
38 #include "ir/anf.h"
39 #include "pybind_api/ir/primitive_py.h"
40 #include "abstract/analysis_context.h"
41 #include "abstract/abstract_function.h"
42 #include "pipeline/jit/parse/parse.h"
43 
44 namespace mindspore {
45 namespace abstract {
46 void ResetFunctionCallDepth();
47 void IncreaseFunctionCallDepth();
48 void DecreaseFunctionCallDepth();
49 size_t FunctionCallDepth();
50 size_t FunctionCallMaxDepth();
51 
52 void ResetStackFrameDepth();
53 void IncreaseStackFrameDepth();
54 void DecreaseStackFrameDepth();
55 size_t StackFrameDepth();
56 size_t StackFrameMaxDepth();
57 
58 // define attribute value map
59 using AttrValueMap = std::unordered_map<std::string, ValuePtr>;
60 using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
61 
62 // the class to save evaluated result: abstract value and modified attribute
63 class EvalResult : public Base {
64  public:
EvalResult(AbstractBasePtr abs,AttrValueMapPtr attr)65   EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {}
66   ~EvalResult() override = default;
67   MS_DECLARE_PARENT(EvalResult, Base);
abstract()68   AbstractBasePtr abstract() { return abstract_; }
attribute()69   AttrValueMapPtr attribute() { return attribute_; }
70 
71  private:
72   AbstractBasePtr abstract_;
73   // Attribute related to PrimEvaluator;
74   AttrValueMapPtr attribute_;
75 };
76 using EvalResultPtr = std::shared_ptr<EvalResult>;
77 
78 // Superclass for AnfNodeConfig and VirtualConfig.
79 class Config : public Base {
80  public:
81   Config() = default;
82   ~Config() override = default;
83   MS_DECLARE_PARENT(Config, Base);
84   virtual EvalResultPtr ObtainEvalResult() = 0;
85 };
86 
87 // Config will be stored in AnalysisCache
88 using ConfigPtr = std::shared_ptr<Config>;
89 using ConfigPtrList = std::vector<ConfigPtr>;
90 
91 // Config to a certain node in a certain context.
92 class AnfNodeConfig : public Config {
93  public:
AnfNodeConfig(const AnalysisEnginePtr & engine,const AnfNodePtr & node,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)94   AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context,
95                 const FuncGraphPtr &func_graph)
96       : Config(),
97         engine_(std::weak_ptr<AnalysisEngine>(engine)),
98         node_(node),
99         context_(nullptr),
100         func_graph_(func_graph) {
101     FuncGraphPtr fg;
102     if (IsValueNode<FuncGraph>(node)) {
103       auto v = node->cast<ValueNodePtr>();
104       fg = v->value()->cast<FuncGraphPtr>();
105     } else {
106       fg = node->func_graph();
107     }
108 
109     if (context == nullptr) {
110       return;
111     }
112     if (context->func_graph() == fg) {
113       // Usually `node` is CNode and not a FV, or top graph's ValueNodes.
114       context_ = context;
115     } else {
116       // If `node` is FV, FuncGraph, or other graph ValueNodes.
117       // Non-FuncGraph ValueNodes will always get a DummyContext since `fg` is null.
118       context_ = context->FindOwnOrParentContext(fg);
119     }
120   }
121 
122   ~AnfNodeConfig() override = default;
123   MS_DECLARE_PARENT(AnfNodeConfig, Config);
124 
125   EvalResultPtr ObtainEvalResult() override;
126 
context()127   AnalysisContextPtr context() const { return context_; }
128 
node()129   AnfNodePtr node() const { return node_; }
130 
func_graph()131   FuncGraphPtr func_graph() const { return func_graph_; }
132 
engine()133   AnalysisEnginePtr engine() const { return engine_.lock(); }
134 
135   // used by unordered_map;
136   bool operator==(const AnfNodeConfig &other) const {
137     // compare node with pointer, context with pointer except DummyContext as it's created by make_shared;
138     // context should not be nullptr;
139     if (context_->IsDummyContext() && other.context_->IsDummyContext()) {
140       return true;
141     }
142     // Don't check `func_graph_` equality.
143     return (node_ == other.node_) && (context_ == other.context_);
144   }
145 
ToString()146   std::string ToString() const override {
147     std::ostringstream buffer;
148     buffer << "Node: " << node_ << "/" << node_->DebugString() << "-uid(" << node_->UniqueId()
149            << "), Context: " << context_ << "/" << context_->ToString() << ", FuncGraph: " << func_graph_->ToString();
150     return buffer.str();
151   }
152 
153  private:
154   // AnalysisEngine is global.
155   // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use
156   // weak_ptr to break Config cycle.
157   std::weak_ptr<AnalysisEngine> engine_;
158   AnfNodePtr node_;
159   // Which context the node would be called, usually in owner func graph context.
160   AnalysisContextPtr context_;
161   // Where to call the node.
162   FuncGraphPtr func_graph_;
163 };
164 
165 using AnfNodeConfigPtr = std::shared_ptr<AnfNodeConfig>;
166 
167 struct AnfNodeConfigHasher {
168   std::size_t operator()(const AnfNodeConfigPtr conf) const;
169 };
170 
171 struct AnfNodeConfigEqual {
172   bool operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const;
173 };
174 
175 class VirtualConfig : public Config {
176  public:
VirtualConfig(const AbstractBasePtr & abstract)177   explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {}
178 
179   ~VirtualConfig() override = default;
180   MS_DECLARE_PARENT(VirtualConfig, Config);
ObtainEvalResult()181   EvalResultPtr ObtainEvalResult() override {
182     return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>());
183   }
184 
185  private:
186   AbstractBasePtr abstract_;
187 };
188 
189 using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
190 using AnfNodeConfigMap =
191   std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
192 
193 struct AnalysisResult {
194   EvalResultPtr inferred;
195   AnalysisContextPtr context;
196 };
197 
198 struct PartialAppHasher {
operatorPartialAppHasher199   std::size_t operator()(const std::pair<AbstractFunctionPtr, AbstractBasePtrList> &p) const {
200     auto h1 = std::hash<AbstractFunctionPtr>{}(p.first);
201     auto h2 = AbstractBasePtrListHash(p.second);
202     return h1 ^ h2;
203   }
204 };
205 
206 // Should compare Args based on value other than pointer;
207 struct EvaluatorArgs {
EvaluatorArgsEvaluatorArgs208   EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {}
209   bool operator==(const EvaluatorArgs &other) const {
210     if (evaluator_ != other.evaluator_) {
211       return false;
212     }
213     if (AbstractBasePtrListDeepEqual(args_, other.args_)) {
214       return true;
215     }
216     return false;
217   }
218   bool operator!=(const EvaluatorArgs &other) { return !(*this == other); }
219 
220   EvaluatorPtr evaluator_;
221   AbstractBasePtrList args_;
222 };
223 using EvalTraceRevIter = std::list<EvaluatorArgs>::reverse_iterator;
224 struct EvaluatorArgsHasher {
operatorEvaluatorArgsHasher225   std::size_t operator()(const EvaluatorArgs &eval_args) const {
226     return hash_combine(std::hash<EvaluatorPtr>{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_));
227   }
228 };
229 struct EvaluatorArgsEqual {
operatorEvaluatorArgsEqual230   bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; }
231 };
232 
233 class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
234  public:
AnalysisEngine(const PrimEvaluatorMap & prim_evaluator_map,const FuncGraphManagerPtr & func_graph_manager)235   AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
236       : prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {
237     forward_count_ = 0;
238     enable_recursive_eval_ = (common::GetEnv("ENV_RECURSIVE_EVAL") == "1");
239   }
240   ~AnalysisEngine() = default;
241 
242   // func_graph: The func_graph to analyze.
243   // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
244   AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
245   void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result);
246   EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf);
247   // Return the Evaluator for the given function.
248   EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
249 
250   AbstractBasePtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
251                                            const FuncGraphPtr &func_graph);
252   AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
253   EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
254   // Infer the result of fn(args).
255   EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
256   void Clear();
257   void ClearEvaluatorCache();
MakeConfig(const AnfNodePtr & node,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)258   AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context,
259                               const FuncGraphPtr &func_graph) {
260     return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context, func_graph);
261   }
262   // Overloaded function.
263   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &fn);
264   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &fn);
265   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &fn);
266   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &fn);
267   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn);
268   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &);
269   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn);
270 
func_graph_manager()271   FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; }
anfnode_config_map()272   const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; }
273 
274   // Set the analysis result for orig to the result for new.
275   // This sets an entry in anfnode_config_map from orig to new.
276   EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf);
PrimConstructors()277   const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
278 
root_func_graph()279   FuncGraphPtr root_func_graph() const { return root_func_graph_; }
root_context()280   AnalysisContextPtr root_context() const { return root_context_; }
set_root_context(const AnalysisContextPtr & context)281   void set_root_context(const AnalysisContextPtr &context) { root_context_ = context; }
282 
283   std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
284 
enable_recursive_eval()285   bool enable_recursive_eval() const { return enable_recursive_eval_; }
286   static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
287 
288  private:
289   void SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg);
290   EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
291                                      const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
292                                      bool *continue_flag);
293 
294   const PrimEvaluatorMap &prim_constructors_;
295   FuncGraphManagerPtr func_graph_manager_;
296   std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> evaluators_;
297   std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
298     constructors_app_;
299 
300   AnfNodeConfigMap anfnode_config_map_;
301   // Use a list to trace multiple evaluators.
302   std::list<EvaluatorArgs> eval_trace_;
303   std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
304   std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> continued_evals_;
305   // root or top func_graph for static analysis;
306   FuncGraphPtr root_func_graph_{nullptr};
307   AnalysisContextPtr root_context_{nullptr};
308 
309   AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
310                          const ConfigPtrList &args_conf_list);
311   EvalResultPtr Eval(const AnfNodeConfigPtr &conf);
312   EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn);
313   EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
314                                   const ConfigPtrList &args_conf_list);
315   EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
316                                           const ConfigPtrList &args_conf_list);
317   EvalResultPtr ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
318                                                      const AnfNodeConfigPtr &out_conf,
319                                                      const ConfigPtrList &args_conf_list);
320 
321   std::atomic_long forward_count_;
322 
323   bool enable_recursive_eval_;
324 
325 #ifdef DEBUG
326   std::vector<AnfNodePtr> compute_conf_stack_;
327 #endif
328 };
329 
330 // Translate the value to an abstract value.
331 // Arguments:
332 // value:   The value to convert.
333 // context: The context in which the value was found, used if the value is a Graph.
334 // conf:     The Config to the valuenode we are converting, if there is one,
335 // so that we can generate a tracking_id.
336 AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr,
337                            const AnfNodeConfigPtr &conf = nullptr);
338 
339 // Convert a value to an abstract value.
340 // Arguments:
341 // v:       The value to convert.
342 // broaden: If True, concrete values will be made more abstract, so e.g.
343 // the value 1234 would become ANYTHING.
344 AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false);
345 
346 template <typename T>
347 AbstractBasePtr FromValue(const T &value, bool broaden = false) {
348   return FromValueInside(MakeValue(value), broaden);
349 }
350 EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine);
351 EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
352 }  // namespace abstract
353 }  // namespace mindspore
354 
355 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
356