• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_
21 
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include <stack>
26 #include <unordered_map>
27 
28 #include "utils/ms_context.h"
29 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
30 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
31 
32 namespace mindspore {
33 namespace abstract {
34 using EvaluatorCacheMgrPtr = std::shared_ptr<EvaluatorCacheMgr>;
35 using EvaluatorAttrMap =
36   std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
37 using EvaluatorAttrCache = MultiThreadCache<AbstractBasePtrList, AttrValueMapPtr, EvaluatorAttrMap>;
38 using EvaluatorAttrCachePtr = std::shared_ptr<EvaluatorAttrCache>;
39 
40 class Evaluator : public Base {
41  public:
Evaluator(const std::string & id)42   explicit Evaluator(const std::string &id)
43       : identifier_(id),
44         evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()),
45         attr_cache_(std::make_shared<EvaluatorAttrCache>()) {}
46   ~Evaluator() override = default;
47   MS_DECLARE_PARENT(Evaluator, Base);
48 
49   // Difference between Run() and Eval():
50   // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
51   // Run() will modify cache_ member, so it cannot marked as const;
52   virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
53                             const AnfNodeConfigPtr &out_conf);
54 
55   virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list,
56                              const AnfNodeConfigPtr &out_conf) = 0;
57 
58   virtual EvalResultPtr SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
59                                   const AnfNodeConfigPtr &out_conf);
60 
NormalizeArgs(const AbstractBasePtrList & args_abs_list)61   virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_abs_list) const { return args_abs_list; }
62 
BroadenUndeterminedArgs(const AbstractBasePtrList & args_abs_list,const AnalysisEnginePtr &)63   virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_abs_list,
64                                                       const AnalysisEnginePtr &) {
65     return args_abs_list;
66   }
67 
68   virtual EvalResultPtr EvalUndeterminedArgs(const AbstractBasePtrList &args_abs_list);
69 
ToString()70   std::string ToString() const override { return identifier_; }
71 
bound_node()72   virtual AnfNodePtr bound_node() const { return bound_node_.lock(); }
73 
set_bound_node(const AnfNodePtr & node)74   virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
75 
evaluator_cache_mgr()76   EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; }
attr_cache()77   EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; }
78 
eval_lock()79   const std::recursive_timed_mutex &eval_lock() const { return eval_lock_; }
80 
81  protected:
82   std::string identifier_;
83   AnfNodeWeakPtr bound_node_;
84   EvaluatorCacheMgrPtr evaluator_cache_mgr_;
85   std::recursive_timed_mutex eval_lock_;
86 
87  private:
88   EvaluatorAttrCachePtr attr_cache_;
89 };
90 
91 class PrimEvaluator : public Evaluator {
92  public:
PrimEvaluator(const std::string & id)93   explicit PrimEvaluator(const std::string &id) : Evaluator(id) {}
94   ~PrimEvaluator() override = default;
95   MS_DECLARE_PARENT(PrimEvaluator, Evaluator);
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)96   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) final {
97     MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called";
98   }
99 };
100 
101 class TrivialPrimEvaluator : public PrimEvaluator {
102  public:
TrivialPrimEvaluator(const std::string & id)103   explicit TrivialPrimEvaluator(const std::string &id)
104       : PrimEvaluator(id), eval_cache_(AnalysisResultCacheMgr::GetInstance().prim_eval_cache()) {}
105   ~TrivialPrimEvaluator() override = default;
106   MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
107   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) final;
108   virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list) = 0;
109 
110  protected:
111   virtual bool inplace_prim() const = 0;
112   PrimitiveEvalCachePtr eval_cache_;
113 };
114 
115 class TransitionPrimEvaluator : public PrimEvaluator {
116  public:
TransitionPrimEvaluator(const std::string & id)117   explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
118   ~TransitionPrimEvaluator() override = default;
119   MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
120   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
121                     const AnfNodeConfigPtr &out_conf) final;
122   // Parameter in_conf0 : the first element in args_conf_list;
123   virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
124                                  const ConfigPtr &in_conf, const AnfNodeConfigPtr &out_conf) = 0;
125 
126  protected:
inplace_prim()127   virtual bool inplace_prim() const { return false; }
128 };
129 
130 class SymbolicPrimEvaluator : public PrimEvaluator {
131  public:
SymbolicPrimEvaluator(const std::string & id)132   explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
133   ~SymbolicPrimEvaluator() override = default;
134   MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
135   EvalResultPtr Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) final;
136   virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
137 };
138 
139 // Evaluator will be stored in AnalysisEngine.evaluators_
140 using EvaluatorPtrList = std::vector<EvaluatorPtr>;
141 
142 class DummyEvaluator : public Evaluator {
143  public:
DummyEvaluator()144   DummyEvaluator() : Evaluator("dummy") {}
145   ~DummyEvaluator() override = default;
146   MS_DECLARE_PARENT(DummyEvaluator, Evaluator);
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)147   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
148     return nullptr;
149   }
150 };
151 
152 // Wrap another evaluator to track a subset of uses.
153 // A TrackedEvaluator has its own cache that maps possible calls to
154 // their results, but is ultimately backed by a different evaluator.
155 // Multiple TrackedEvaluators can be backed by the same Evaluator.
156 class TrackedEvaluator : public Evaluator {
157  public:
TrackedEvaluator(const EvaluatorPtr & subinf)158   explicit TrackedEvaluator(const EvaluatorPtr &subinf) : Evaluator("TrackedEvaluator"), sub_evaluator_(subinf) {}
159   ~TrackedEvaluator() override = default;
160   MS_DECLARE_PARENT(TrackedEvaluator, Evaluator);
bound_node()161   AnfNodePtr bound_node() const override {
162     if (sub_evaluator_ != nullptr) {
163       return sub_evaluator_->bound_node();
164     }
165     return bound_node_.lock();
166   }
167 
set_bound_node(const AnfNodePtr & node)168   void set_bound_node(const AnfNodePtr &node) override {
169     if (sub_evaluator_ != nullptr) {
170       sub_evaluator_->set_bound_node(node);
171     }
172     bound_node_ = AnfNodeWeakPtr(node);
173   }
174 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)175   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
176     MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called";
177   }
178   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
179                     const AnfNodeConfigPtr &out_conf) override;
ToString()180   std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
181 
182  private:
183   EvaluatorPtr sub_evaluator_;
184 };
185 
186 using FuncGraphCacheMap =
187   std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
188 class StackFrame;
189 using StackFramePtr = std::shared_ptr<StackFrame>;
190 
191 class BaseFuncGraphEvaluator : public Evaluator {
192  public:
BaseFuncGraphEvaluator(const AnalysisContextPtr & context)193   explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context)
194       : Evaluator("basegraph"), parent_context_(context) {}
195 
196   ~BaseFuncGraphEvaluator() override = default;
197   MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
198 
199   EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list,
200                      const AnfNodeConfigPtr &out_conf) override;
201 
202   virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) = 0;
203 
parent_context()204   AnalysisContextPtr parent_context() const { return parent_context_; }
set_parent_context(const AnalysisContextPtr & parent_context)205   void set_parent_context(const AnalysisContextPtr &parent_context) { parent_context_ = parent_context; }
206 
PushAlwaysEvalFlag(bool flag)207   void PushAlwaysEvalFlag(bool flag) { always_eval_flags_.push_back(flag); }
PopAlwaysEvalFlag()208   void PopAlwaysEvalFlag() { always_eval_flags_.pop_back(); }
always_eval_flag()209   bool always_eval_flag() const {
210     if (always_eval_flags_.empty()) {
211       MS_LOG(INTERNAL_EXCEPTION) << "Always_eval_flag should not be empty";
212     }
213     return always_eval_flags_.back();
214   }
215 
216   virtual void SyncFuncGraphSideEffectFlag(const FuncGraphPtr &func_graph) = 0;
217 
218  protected:
219   AnalysisContextPtr parent_context_;
220 
221  private:
222   // As evaluator can be recursively called, so use a vector to simulate a stack of flags.
223   std::vector<bool> always_eval_flags_;
224   AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
225                                       const AnalysisContextPtr &context) const;
226   // Add functions for stack frame routine.
227   AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
228                                    const AnalysisContextPtr &context);
229   static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
230                               const StackFramePtr &new_stack_frame);
231   static void LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr &current_stack_frame);
232 };
233 
234 class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
235  public:
FuncGraphEvaluator(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context)236   FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
237       : BaseFuncGraphEvaluator(context), func_graph_(func_graph) {}
238 
239   ~FuncGraphEvaluator() override = default;
240   MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator);
241 
242   FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) override;
243 
func_graph()244   FuncGraphPtr func_graph() { return func_graph_; }
245 
246   AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_abs_list) const override;
247   AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_abs_list,
248                                               const AnalysisEnginePtr &engine) override;
ToString()249   std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); }
250 
SyncFuncGraphSideEffectFlag(const FuncGraphPtr & func_graph)251   void SyncFuncGraphSideEffectFlag(const FuncGraphPtr &func_graph) override {
252     if (func_graph->has_side_effect_node()) {
253       func_graph_->set_has_side_effect_node(true);
254     }
255   }
256 
257  private:
258   FuncGraphPtr func_graph_;
259   FuncGraphCacheMap func_graph_cache_;
260   std::vector<AbstractBasePtrList> trace_;
261 };
262 using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>;
263 
264 class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator {
265  public:
266   // Note: context parameter is not used;
MetaFuncGraphEvaluator(const MetaFuncGraphPtr & meta_func_graph,const ScopePtr & scope)267   MetaFuncGraphEvaluator(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope)
268       : BaseFuncGraphEvaluator(AnalysisContext::DummyContext()), meta_func_graph_(meta_func_graph), scope_(scope) {}
269   ~MetaFuncGraphEvaluator() override = default;
270   MS_DECLARE_PARENT(MetaFuncGraphEvaluator, BaseFuncGraphEvaluator);
271 
272   FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) override;
273 
274   // Return normalized versions of the arguments.
NormalizeArgs(const AbstractBasePtrList & args_abs_list)275   AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_abs_list) const override {
276     return meta_func_graph_->NormalizeArgs(args_abs_list);
277   }
ToString()278   std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); }
279 
SyncFuncGraphSideEffectFlag(const FuncGraphPtr & func_graph)280   void SyncFuncGraphSideEffectFlag(const FuncGraphPtr &func_graph) override {
281     if (func_graph->has_side_effect_node()) {
282       meta_func_graph_->set_has_side_effect_node(true);
283     }
284   }
285 
286  private:
287   MetaFuncGraphPtr meta_func_graph_;
288   FuncGraphCacheMap func_graph_cache_;
289   FuncGraphPtr generated_func_graph_{nullptr};
290   ScopePtr scope_;
291 };
292 
293 class PartialAppEvaluator : public Evaluator {
294  public:
PartialAppEvaluator(const EvaluatorPtr & evaluator,const AbstractBasePtrList & args)295   PartialAppEvaluator(const EvaluatorPtr &evaluator, const AbstractBasePtrList &args)
296       : Evaluator("PartialAppEvaluator"), evaluator_(evaluator), args_abs_list_(args) {}
297   ~PartialAppEvaluator() override = default;
298   MS_DECLARE_PARENT(PartialAppEvaluator, Evaluator);
bound_node()299   AnfNodePtr bound_node() const override {
300     if (evaluator_ != nullptr) {
301       return evaluator_->bound_node();
302     }
303     return bound_node_.lock();
304   }
305 
set_bound_node(const AnfNodePtr & node)306   void set_bound_node(const AnfNodePtr &node) override {
307     if (evaluator_ != nullptr) {
308       evaluator_->set_bound_node(node);
309     }
310     bound_node_ = AnfNodeWeakPtr(node);
311   }
312 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)313   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
314     MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called";
315   }
316 
317   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
318                     const AnfNodeConfigPtr &out_conf) override;
ToString()319   std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
320 
321  private:
322   EvaluatorPtr evaluator_;
323   AbstractBasePtrList args_abs_list_;
324 };
325 
326 class VirtualEvaluator : public Evaluator {
327  public:
VirtualEvaluator(const AbstractBasePtrList & args_abs_list,const AbstractBasePtr & output)328   VirtualEvaluator(const AbstractBasePtrList &args_abs_list, const AbstractBasePtr &output)
329       : Evaluator("virtual"), args_abs_list_(args_abs_list), output_(output) {}
330   ~VirtualEvaluator() override = default;
331   MS_DECLARE_PARENT(VirtualEvaluator, Evaluator);
332 
333   EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list,
334                      const AnfNodeConfigPtr &out_conf) override;
ToString()335   std::string ToString() const override { return identifier_; }
336 
337  private:
338   AbstractBasePtrList args_abs_list_;
339   AbstractBasePtr output_;
340 };
341 
342 class JEvaluator : public Evaluator {
343  public:
JEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func)344   JEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
345       : Evaluator("JEvaluator"), evaluator_(evaluator), primal_func_(orig_func) {}
346   ~JEvaluator() override = default;
347   MS_DECLARE_PARENT(JEvaluator, Evaluator);
bound_node()348   AnfNodePtr bound_node() const override {
349     if (evaluator_ != nullptr) {
350       return evaluator_->bound_node();
351     }
352     return bound_node_.lock();
353   }
354 
set_bound_node(const AnfNodePtr & node)355   void set_bound_node(const AnfNodePtr &node) override {
356     if (evaluator_ != nullptr) {
357       evaluator_->set_bound_node(node);
358     }
359     bound_node_ = AnfNodeWeakPtr(node);
360   }
361 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)362   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
363     MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called";
364   }
365   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
ToString()366   std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
367 
368  private:
369   EvaluatorPtr evaluator_;
370   AbstractFunctionPtr primal_func_;
371 };
372 
373 class TaylorEvaluator : public Evaluator {
374  public:
TaylorEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func)375   TaylorEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
376       : Evaluator("TaylorEvaluator"), evaluator_(evaluator), primal_func_(orig_func) {}
377   ~TaylorEvaluator() override = default;
378   MS_DECLARE_PARENT(TaylorEvaluator, Evaluator);
bound_node()379   AnfNodePtr bound_node() const override {
380     if (evaluator_ != nullptr) {
381       return evaluator_->bound_node();
382     }
383     return bound_node_.lock();
384   }
385 
set_bound_node(const AnfNodePtr & node)386   void set_bound_node(const AnfNodePtr &node) override {
387     if (evaluator_ != nullptr) {
388       evaluator_->set_bound_node(node);
389     }
390     bound_node_ = AnfNodeWeakPtr(node);
391   }
392 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)393   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
394     MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called";
395   }
396   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
ToString()397   std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
398 
399  private:
400   EvaluatorPtr evaluator_;
401   AbstractFunctionPtr primal_func_;
402 };
403 
404 class ShardEvaluator : public Evaluator {
405  public:
ShardEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func)406   ShardEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
407       : Evaluator("ShardEvaluator"), evaluator_(evaluator), primal_func_(orig_func) {}
408   ~ShardEvaluator() override = default;
409   MS_DECLARE_PARENT(ShardEvaluator, Evaluator);
410 
bound_node()411   AnfNodePtr bound_node() const override {
412     if (evaluator_ != nullptr) {
413       return evaluator_->bound_node();
414     }
415     return bound_node_.lock();
416   }
417 
set_bound_node(const AnfNodePtr & node)418   void set_bound_node(const AnfNodePtr &node) override {
419     if (evaluator_ != nullptr) {
420       evaluator_->set_bound_node(node);
421     }
422     bound_node_ = AnfNodeWeakPtr(node);
423   }
424 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)425   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
426     MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called";
427   }
428 
429   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
430 
ToString()431   std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
432 
433  private:
434   EvaluatorPtr evaluator_;
435   AbstractFunctionPtr primal_func_;
436 };
437 
438 class VmapEvaluator : public Evaluator {
439  public:
VmapEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func,const ValuePtr & in_axes,const ValuePtr & out_axes,size_t cell_size)440   VmapEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func, const ValuePtr &in_axes,
441                 const ValuePtr &out_axes, size_t cell_size)
442       : Evaluator("VmapEvaluator"),
443         evaluator_(evaluator),
444         primal_func_(orig_func),
445         in_axes_(in_axes),
446         out_axes_(out_axes),
447         cell_size_(cell_size) {}
448   ~VmapEvaluator() override = default;
449   MS_DECLARE_PARENT(VmapEvaluator, Evaluator);
bound_node()450   AnfNodePtr bound_node() const override {
451     if (evaluator_ != nullptr) {
452       return evaluator_->bound_node();
453     }
454     return bound_node_.lock();
455   }
456 
set_bound_node(const AnfNodePtr & node)457   void set_bound_node(const AnfNodePtr &node) override {
458     if (evaluator_ != nullptr) {
459       evaluator_->set_bound_node(node);
460     }
461     bound_node_ = AnfNodeWeakPtr(node);
462   }
463 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)464   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
465     MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called";
466   }
467   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override;
ToString()468   std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
469 
470  private:
471   EvaluatorPtr evaluator_;
472   AbstractFunctionPtr primal_func_;
473   ValuePtr in_axes_;
474   ValuePtr out_axes_;
475   size_t cell_size_;
476 };
477 
478 AbstractBasePtrList EvaluateArguments(const ConfigPtrList &args_conf_list);
479 
480 bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg);
481 
482 bool ContainsAbstractAny(const AbstractBasePtrList &args_abs_list);
483 }  // namespace abstract
484 }  // namespace mindspore
485 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_
486