• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_
21 
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 #include <stack>
27 
28 #include "pipeline/jit/static_analysis/static_analysis.h"
29 #include "pipeline/jit/static_analysis/async_eval_result.h"
30 #include "utils/ms_context.h"
31 
32 namespace mindspore {
33 namespace abstract {
34 using EvaluatorCacheMgrPtr = std::shared_ptr<EvaluatorCacheMgr>;
35 using EvaluatorAttrMap =
36   std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
37 using EvaluatorAttrCache = MultiThreadCache<AbstractBasePtrList, AttrValueMapPtr, EvaluatorAttrMap>;
38 using EvaluatorAttrCachePtr = std::shared_ptr<EvaluatorAttrCache>;
39 
40 class Evaluator : public Base {
41  public:
Evaluator(const std::string & id)42   explicit Evaluator(const std::string &id)
43       : identifier_(id),
44         evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()),
45         attr_cache_(std::make_shared<EvaluatorAttrCache>()) {}
46   ~Evaluator() override = default;
47   MS_DECLARE_PARENT(Evaluator, Base);
48 
49   // difference between Run() and Eval():
50   // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
51   // Run() will modify cache_ member, so it cannot marked as const;
52   virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
53                             const AnfNodeConfigPtr &out_conf);
54 
55   virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list,
56                              const AnfNodeConfigPtr &out_conf) = 0;
57   virtual EvalResultPtr SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
58                                   const AnfNodeConfigPtr &out_conf);
59 
NormalizeArgs(const AbstractBasePtrList & args_spec_list)60   virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
61 
BroadenUndeterminedArgs(const AbstractBasePtrList & args_spec_list)62   virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) {
63     return args_spec_list;
64   }
65 
AbstractEval(const AbstractBasePtrList & args_spec_list)66   virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
67     auto context = MsContext::GetInstance();
68     MS_EXCEPTION_IF_NULL(context);
69     bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
70     if (!enable_sparse) {
71       return nullptr;
72     }
73 
74     auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) {
75       if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
76         return true;
77       }
78       return false;
79     });
80     if (is_abstract) {
81       MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result";
82       return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>());
83     }
84     return nullptr;
85   }
86 
ToString()87   std::string ToString() const override { return identifier_; }
88 
bound_node()89   virtual AnfNodePtr bound_node() const { return bound_node_.lock(); }
90 
set_bound_node(const AnfNodePtr & node)91   virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
92 
evaluator_cache_mgr()93   EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; }
attr_cache()94   EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; }
95 
eval_lock()96   std::recursive_timed_mutex &eval_lock() { return eval_lock_; }
97 
98  protected:
99   std::string identifier_;
100   AnfNodeWeakPtr bound_node_;
101   EvaluatorCacheMgrPtr evaluator_cache_mgr_;
102   std::recursive_timed_mutex eval_lock_;
103 
104  private:
105   EvaluatorAttrCachePtr attr_cache_;
106 };
107 
108 class PrimEvaluator : public Evaluator {
109  public:
PrimEvaluator(const std::string & id)110   explicit PrimEvaluator(const std::string &id) : Evaluator(id) {}
111   ~PrimEvaluator() override = default;
112   MS_DECLARE_PARENT(PrimEvaluator, Evaluator);
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)113   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) final {
114     MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
115   }
116 };
117 
118 class TrivialPrimEvaluator : public PrimEvaluator {
119  public:
TrivialPrimEvaluator(const std::string & id)120   explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
121   ~TrivialPrimEvaluator() override = default;
122   MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
123   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
124                     const AnfNodeConfigPtr &out_conf) final;
125   virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
126 };
127 
128 class TransitionPrimEvaluator : public PrimEvaluator {
129  public:
TransitionPrimEvaluator(const std::string & id)130   explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
131   ~TransitionPrimEvaluator() override = default;
132   MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
133   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
134                     const AnfNodeConfigPtr &out_conf) final;
135   // Parameter in_conf0 : the first element in args_conf_list;
136   virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
137                                  const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
138 };
139 
140 class SymbolicPrimEvaluator : public PrimEvaluator {
141  public:
SymbolicPrimEvaluator(const std::string & id)142   explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
143   ~SymbolicPrimEvaluator() override = default;
144   MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
145   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
146                     const AnfNodeConfigPtr &out_conf) final;
147   virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
148 };
149 
150 // Evaluator will be stored in AnalysisEngine.evaluators_
151 using EvaluatorPtrList = std::vector<EvaluatorPtr>;
152 
153 class DummyEvaluator : public Evaluator {
154  public:
DummyEvaluator()155   DummyEvaluator() : Evaluator("dummy") {}
156   ~DummyEvaluator() override = default;
157   MS_DECLARE_PARENT(DummyEvaluator, Evaluator);
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)158   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
159     return nullptr;
160   }
161 };
162 
163 // Wrap another evaluator to track a subset of uses.
164 // A TrackedEvaluator has its own cache that maps possible calls to
165 // their results, but is ultimately backed by a different evaluator.
166 // Multiple TrackedEvaluators can be backed by the same Evaluator.
167 class TrackedEvaluator : public Evaluator {
168  public:
TrackedEvaluator(const EvaluatorPtr & subinf)169   explicit TrackedEvaluator(const EvaluatorPtr &subinf) : Evaluator("TrackedEvaluator"), sub_evaluator_(subinf) {}
170   ~TrackedEvaluator() override = default;
171   MS_DECLARE_PARENT(TrackedEvaluator, Evaluator);
bound_node()172   AnfNodePtr bound_node() const override {
173     if (sub_evaluator_ != nullptr) {
174       return sub_evaluator_->bound_node();
175     }
176     return bound_node_.lock();
177   }
178 
set_bound_node(const AnfNodePtr & node)179   void set_bound_node(const AnfNodePtr &node) override {
180     if (sub_evaluator_ != nullptr) {
181       sub_evaluator_->set_bound_node(node);
182     }
183     bound_node_ = AnfNodeWeakPtr(node);
184   }
185 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)186   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
187     MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
188   }
189   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
190                     const AnfNodeConfigPtr &out_conf) override;
ToString()191   std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
192 
193  private:
194   EvaluatorPtr sub_evaluator_;
195 };
196 
197 using FuncGraphCacheMap =
198   std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
199 class StackFrame;
200 using StackFramePtr = std::shared_ptr<StackFrame>;
201 
202 class BaseFuncGraphEvaluator : public Evaluator {
203  public:
BaseFuncGraphEvaluator(const AnalysisContextPtr & context)204   explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context)
205       : Evaluator("basegraph"), parent_context_(context) {}
206 
207   ~BaseFuncGraphEvaluator() override = default;
208   MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
209 
210   EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list,
211                      const AnfNodeConfigPtr &out_conf) override;
212 
213   virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
214 
parent_context()215   AnalysisContextPtr parent_context() const { return parent_context_; }
set_parent_context(const AnalysisContextPtr & parent_context)216   void set_parent_context(const AnalysisContextPtr &parent_context) { parent_context_ = parent_context; }
217 
218  protected:
219   AnalysisContextPtr parent_context_;
220 
221  private:
222   AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
223                                       const AnalysisContextPtr &context);
224   // Add functions for stack frame routine.
225   AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
226                                    const AnalysisContextPtr &context);
227   static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
228                               const StackFramePtr &new_stack_frame);
229   static void LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr &current_stack_frame);
230 };
231 
232 class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
233  public:
FuncGraphEvaluator(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context)234   FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
235       : BaseFuncGraphEvaluator(context), func_graph_(func_graph) {}
236 
237   ~FuncGraphEvaluator() override = default;
238   MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator);
239 
240   FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
241 
func_graph()242   FuncGraphPtr func_graph() { return func_graph_; }
243 
244   AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override;
245   AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override;
ToString()246   std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); }
247 
248  private:
249   FuncGraphPtr func_graph_;
250   FuncGraphCacheMap func_graph_cache_;
251   std::vector<AbstractBasePtrList> trace_;
252 };
253 using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>;
254 
255 class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator {
256  public:
257   // Note: context parameter is not used;
MetaFuncGraphEvaluator(const MetaFuncGraphPtr & meta_func_graph,const ScopePtr & scope)258   MetaFuncGraphEvaluator(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope)
259       : BaseFuncGraphEvaluator(AnalysisContext::DummyContext()), meta_func_graph_(meta_func_graph), scope_(scope) {}
260   ~MetaFuncGraphEvaluator() override = default;
261   MS_DECLARE_PARENT(MetaFuncGraphEvaluator, BaseFuncGraphEvaluator);
262 
263   FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
264 
265   // Return normalized versions of the arguments.
NormalizeArgs(const AbstractBasePtrList & args_spec_list)266   AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override {
267     return meta_func_graph_->NormalizeArgs(args_spec_list);
268   }
ToString()269   std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); }
270 
271  private:
272   MetaFuncGraphPtr meta_func_graph_;
273   FuncGraphCacheMap func_graph_cache_;
274   ScopePtr scope_;
275 };
276 
277 class PartialAppEvaluator : public Evaluator {
278  public:
PartialAppEvaluator(const EvaluatorPtr & evaluator,const AbstractBasePtrList & args)279   PartialAppEvaluator(const EvaluatorPtr &evaluator, const AbstractBasePtrList &args)
280       : Evaluator("PartialAppEvaluator"), evaluator_(evaluator), args_spec_list_(args) {}
281   ~PartialAppEvaluator() override = default;
282   MS_DECLARE_PARENT(PartialAppEvaluator, Evaluator);
bound_node()283   AnfNodePtr bound_node() const override {
284     if (evaluator_ != nullptr) {
285       return evaluator_->bound_node();
286     }
287     return bound_node_.lock();
288   }
289 
set_bound_node(const AnfNodePtr & node)290   void set_bound_node(const AnfNodePtr &node) override {
291     if (evaluator_ != nullptr) {
292       evaluator_->set_bound_node(node);
293     }
294     bound_node_ = AnfNodeWeakPtr(node);
295   }
296 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)297   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
298     MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
299   }
300 
301   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
302                     const AnfNodeConfigPtr &out_conf) override;
ToString()303   std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
304 
305  private:
306   EvaluatorPtr evaluator_;
307   AbstractBasePtrList args_spec_list_;
308 };
309 
310 class VirtualEvaluator : public Evaluator {
311  public:
VirtualEvaluator(const AbstractBasePtrList & args_spec_list,const AbstractBasePtr & output)312   VirtualEvaluator(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output)
313       : Evaluator("virtual"), args_spec_list_(args_spec_list), output_(output) {}
314   ~VirtualEvaluator() override = default;
315   MS_DECLARE_PARENT(VirtualEvaluator, Evaluator);
316 
317   EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list,
318                      const AnfNodeConfigPtr &out_conf) override;
ToString()319   std::string ToString() const override { return identifier_; }
320 
321  private:
322   AbstractBasePtrList args_spec_list_;
323   AbstractBasePtr output_;
324 };
325 
326 class JEvaluator : public Evaluator {
327  public:
JEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func)328   JEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
329       : Evaluator("JEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {}
330   ~JEvaluator() override = default;
331   MS_DECLARE_PARENT(JEvaluator, Evaluator);
bound_node()332   AnfNodePtr bound_node() const override {
333     if (evaluator_ != nullptr) {
334       return evaluator_->bound_node();
335     }
336     return bound_node_.lock();
337   }
338 
set_bound_node(const AnfNodePtr & node)339   void set_bound_node(const AnfNodePtr &node) override {
340     if (evaluator_ != nullptr) {
341       evaluator_->set_bound_node(node);
342     }
343     bound_node_ = AnfNodeWeakPtr(node);
344   }
345 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)346   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
347     MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
348   }
349   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
350                     const AnfNodeConfigPtr &out_conf) override;
ToString()351   std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
352 
353  private:
354   EvaluatorPtr evaluator_;
355   AbstractFunctionPtr orig_func_;
356 };
357 
358 void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args);
359 }  // namespace abstract
360 }  // namespace mindspore
361 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_
362