1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2022 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ 20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ 21 22 #include <memory> 23 #include <string> 24 #include <vector> 25 #include <stack> 26 #include <unordered_map> 27 28 #include "utils/ms_context.h" 29 #include "pipeline/jit/ps/static_analysis/static_analysis.h" 30 #include "pipeline/jit/ps/static_analysis/async_eval_result.h" 31 32 namespace mindspore { 33 namespace abstract { 34 using EvaluatorCacheMgrPtr = std::shared_ptr<EvaluatorCacheMgr>; 35 using EvaluatorAttrMap = 36 std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; 37 using EvaluatorAttrCache = MultiThreadCache<AbstractBasePtrList, AttrValueMapPtr, EvaluatorAttrMap>; 38 using EvaluatorAttrCachePtr = std::shared_ptr<EvaluatorAttrCache>; 39 40 class Evaluator : public Base { 41 public: Evaluator(const std::string & id)42 explicit Evaluator(const std::string &id) 43 : identifier_(id), 44 evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()), 45 attr_cache_(std::make_shared<EvaluatorAttrCache>()) {} 46 ~Evaluator() override = default; 47 MS_DECLARE_PARENT(Evaluator, Base); 48 49 // Difference between Run() and Eval(): 50 // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. 51 // Run() will modify cache_ member, so it cannot marked as const; 52 virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 53 const AnfNodeConfigPtr &out_conf); 54 55 virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list, 56 const AnfNodeConfigPtr &out_conf) = 0; 57 58 virtual EvalResultPtr SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 59 const AnfNodeConfigPtr &out_conf); 60 NormalizeArgs(const AbstractBasePtrList & args_abs_list)61 virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_abs_list) const { return args_abs_list; } 62 BroadenUndeterminedArgs(const AbstractBasePtrList & args_abs_list,const AnalysisEnginePtr &)63 virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_abs_list, 64 const AnalysisEnginePtr &) { 65 return args_abs_list; 66 } 67 68 virtual EvalResultPtr EvalUndeterminedArgs(const AbstractBasePtrList &args_abs_list); 69 ToString()70 std::string ToString() const override { return identifier_; } 71 bound_node()72 virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } 73 set_bound_node(const AnfNodePtr & node)74 virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } 75 evaluator_cache_mgr()76 EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; } attr_cache()77 EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; } 78 eval_lock()79 const std::recursive_timed_mutex &eval_lock() const { return eval_lock_; } 80 81 protected: 82 std::string identifier_; 83 AnfNodeWeakPtr bound_node_; 84 EvaluatorCacheMgrPtr evaluator_cache_mgr_; 85 std::recursive_timed_mutex eval_lock_; 86 87 private: 88 EvaluatorAttrCachePtr attr_cache_; 89 }; 90 91 class PrimEvaluator : public Evaluator { 92 public: PrimEvaluator(const std::string & id)93 explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} 94 ~PrimEvaluator() override = default; 95 MS_DECLARE_PARENT(PrimEvaluator, Evaluator); Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)96 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) final { 97 MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called"; 98 } 99 }; 100 101 class TrivialPrimEvaluator : public PrimEvaluator { 102 public: TrivialPrimEvaluator(const std::string & id)103 explicit TrivialPrimEvaluator(const std::string &id) 104 : PrimEvaluator(id), eval_cache_(AnalysisResultCacheMgr::GetInstance().prim_eval_cache()) {} 105 ~TrivialPrimEvaluator() override = default; 106 MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); 107 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) final; 108 virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list) = 0; 109 110 protected: 111 virtual bool inplace_prim() const = 0; 112 PrimitiveEvalCachePtr eval_cache_; 113 }; 114 115 class TransitionPrimEvaluator : public PrimEvaluator { 116 public: TransitionPrimEvaluator(const std::string & id)117 explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} 118 ~TransitionPrimEvaluator() override = default; 119 MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); 120 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 121 const AnfNodeConfigPtr &out_conf) final; 122 // Parameter in_conf0 : the first element in args_conf_list; 123 virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, 124 const ConfigPtr &in_conf, const AnfNodeConfigPtr &out_conf) = 0; 125 126 protected: inplace_prim()127 virtual bool inplace_prim() const { return false; } 128 }; 129 130 class SymbolicPrimEvaluator : public PrimEvaluator { 131 public: SymbolicPrimEvaluator(const std::string & id)132 explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} 133 ~SymbolicPrimEvaluator() override = default; 134 MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); 135 EvalResultPtr Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) final; 136 virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; 137 }; 138 139 // Evaluator will be stored in AnalysisEngine.evaluators_ 140 using EvaluatorPtrList = std::vector<EvaluatorPtr>; 141 142 class DummyEvaluator : public Evaluator { 143 public: DummyEvaluator()144 DummyEvaluator() : Evaluator("dummy") {} 145 ~DummyEvaluator() override = default; 146 MS_DECLARE_PARENT(DummyEvaluator, Evaluator); Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)147 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 148 return nullptr; 149 } 150 }; 151 152 // Wrap another evaluator to track a subset of uses. 153 // A TrackedEvaluator has its own cache that maps possible calls to 154 // their results, but is ultimately backed by a different evaluator. 155 // Multiple TrackedEvaluators can be backed by the same Evaluator. 156 class TrackedEvaluator : public Evaluator { 157 public: TrackedEvaluator(const EvaluatorPtr & subinf)158 explicit TrackedEvaluator(const EvaluatorPtr &subinf) : Evaluator("TrackedEvaluator"), sub_evaluator_(subinf) {} 159 ~TrackedEvaluator() override = default; 160 MS_DECLARE_PARENT(TrackedEvaluator, Evaluator); bound_node()161 AnfNodePtr bound_node() const override { 162 if (sub_evaluator_ != nullptr) { 163 return sub_evaluator_->bound_node(); 164 } 165 return bound_node_.lock(); 166 } 167 set_bound_node(const AnfNodePtr & node)168 void set_bound_node(const AnfNodePtr &node) override { 169 if (sub_evaluator_ != nullptr) { 170 sub_evaluator_->set_bound_node(node); 171 } 172 bound_node_ = AnfNodeWeakPtr(node); 173 } 174 Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)175 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 176 MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called"; 177 } 178 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 179 const AnfNodeConfigPtr &out_conf) override; ToString()180 std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } 181 182 private: 183 EvaluatorPtr sub_evaluator_; 184 }; 185 186 using FuncGraphCacheMap = 187 std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; 188 class StackFrame; 189 using StackFramePtr = std::shared_ptr<StackFrame>; 190 191 class BaseFuncGraphEvaluator : public Evaluator { 192 public: BaseFuncGraphEvaluator(const AnalysisContextPtr & context)193 explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context) 194 : Evaluator("basegraph"), parent_context_(context) {} 195 196 ~BaseFuncGraphEvaluator() override = default; 197 MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); 198 199 EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list, 200 const AnfNodeConfigPtr &out_conf) override; 201 202 virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) = 0; 203 parent_context()204 AnalysisContextPtr parent_context() const { return parent_context_; } set_parent_context(const AnalysisContextPtr & parent_context)205 void set_parent_context(const AnalysisContextPtr &parent_context) { parent_context_ = parent_context; } 206 PushAlwaysEvalFlag(bool flag)207 void PushAlwaysEvalFlag(bool flag) { always_eval_flags_.push_back(flag); } PopAlwaysEvalFlag()208 void PopAlwaysEvalFlag() { always_eval_flags_.pop_back(); } always_eval_flag()209 bool always_eval_flag() const { 210 if (always_eval_flags_.empty()) { 211 MS_LOG(INTERNAL_EXCEPTION) << "Always_eval_flag should not be empty"; 212 } 213 return always_eval_flags_.back(); 214 } 215 216 virtual void SyncFuncGraphSideEffectFlag(const FuncGraphPtr &func_graph) = 0; 217 218 protected: 219 AnalysisContextPtr parent_context_; 220 221 private: 222 // As evaluator can be recursively called, so use a vector to simulate a stack of flags. 223 std::vector<bool> always_eval_flags_; 224 AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg, 225 const AnalysisContextPtr &context) const; 226 // Add functions for stack frame routine. 227 AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg, 228 const AnalysisContextPtr &context); 229 static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame, 230 const StackFramePtr &new_stack_frame); 231 static void LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr ¤t_stack_frame); 232 }; 233 234 class FuncGraphEvaluator : public BaseFuncGraphEvaluator { 235 public: FuncGraphEvaluator(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context)236 FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) 237 : BaseFuncGraphEvaluator(context), func_graph_(func_graph) {} 238 239 ~FuncGraphEvaluator() override = default; 240 MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator); 241 242 FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) override; 243 func_graph()244 FuncGraphPtr func_graph() { return func_graph_; } 245 246 AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_abs_list) const override; 247 AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_abs_list, 248 const AnalysisEnginePtr &engine) override; ToString()249 std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } 250 SyncFuncGraphSideEffectFlag(const FuncGraphPtr & func_graph)251 void SyncFuncGraphSideEffectFlag(const FuncGraphPtr &func_graph) override { 252 if (func_graph->has_side_effect_node()) { 253 func_graph_->set_has_side_effect_node(true); 254 } 255 } 256 257 private: 258 FuncGraphPtr func_graph_; 259 FuncGraphCacheMap func_graph_cache_; 260 std::vector<AbstractBasePtrList> trace_; 261 }; 262 using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>; 263 264 class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator { 265 public: 266 // Note: context parameter is not used; MetaFuncGraphEvaluator(const MetaFuncGraphPtr & meta_func_graph,const ScopePtr & scope)267 MetaFuncGraphEvaluator(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope) 268 : BaseFuncGraphEvaluator(AnalysisContext::DummyContext()), meta_func_graph_(meta_func_graph), scope_(scope) {} 269 ~MetaFuncGraphEvaluator() override = default; 270 MS_DECLARE_PARENT(MetaFuncGraphEvaluator, BaseFuncGraphEvaluator); 271 272 FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) override; 273 274 // Return normalized versions of the arguments. NormalizeArgs(const AbstractBasePtrList & args_abs_list)275 AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_abs_list) const override { 276 return meta_func_graph_->NormalizeArgs(args_abs_list); 277 } ToString()278 std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); } 279 SyncFuncGraphSideEffectFlag(const FuncGraphPtr & func_graph)280 void SyncFuncGraphSideEffectFlag(const FuncGraphPtr &func_graph) override { 281 if (func_graph->has_side_effect_node()) { 282 meta_func_graph_->set_has_side_effect_node(true); 283 } 284 } 285 286 private: 287 MetaFuncGraphPtr meta_func_graph_; 288 FuncGraphCacheMap func_graph_cache_; 289 FuncGraphPtr generated_func_graph_{nullptr}; 290 ScopePtr scope_; 291 }; 292 293 class PartialAppEvaluator : public Evaluator { 294 public: PartialAppEvaluator(const EvaluatorPtr & evaluator,const AbstractBasePtrList & args)295 PartialAppEvaluator(const EvaluatorPtr &evaluator, const AbstractBasePtrList &args) 296 : Evaluator("PartialAppEvaluator"), evaluator_(evaluator), args_abs_list_(args) {} 297 ~PartialAppEvaluator() override = default; 298 MS_DECLARE_PARENT(PartialAppEvaluator, Evaluator); bound_node()299 AnfNodePtr bound_node() const override { 300 if (evaluator_ != nullptr) { 301 return evaluator_->bound_node(); 302 } 303 return bound_node_.lock(); 304 } 305 set_bound_node(const AnfNodePtr & node)306 void set_bound_node(const AnfNodePtr &node) override { 307 if (evaluator_ != nullptr) { 308 evaluator_->set_bound_node(node); 309 } 310 bound_node_ = AnfNodeWeakPtr(node); 311 } 312 Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)313 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 314 MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called"; 315 } 316 317 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 318 const AnfNodeConfigPtr &out_conf) override; ToString()319 std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } 320 321 private: 322 EvaluatorPtr evaluator_; 323 AbstractBasePtrList args_abs_list_; 324 }; 325 326 class VirtualEvaluator : public Evaluator { 327 public: VirtualEvaluator(const AbstractBasePtrList & args_abs_list,const AbstractBasePtr & output)328 VirtualEvaluator(const AbstractBasePtrList &args_abs_list, const AbstractBasePtr &output) 329 : Evaluator("virtual"), args_abs_list_(args_abs_list), output_(output) {} 330 ~VirtualEvaluator() override = default; 331 MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); 332 333 EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list, 334 const AnfNodeConfigPtr &out_conf) override; ToString()335 std::string ToString() const override { return identifier_; } 336 337 private: 338 AbstractBasePtrList args_abs_list_; 339 AbstractBasePtr output_; 340 }; 341 342 class JEvaluator : public Evaluator { 343 public: JEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func)344 JEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) 345 : Evaluator("JEvaluator"), evaluator_(evaluator), primal_func_(orig_func) {} 346 ~JEvaluator() override = default; 347 MS_DECLARE_PARENT(JEvaluator, Evaluator); bound_node()348 AnfNodePtr bound_node() const override { 349 if (evaluator_ != nullptr) { 350 return evaluator_->bound_node(); 351 } 352 return bound_node_.lock(); 353 } 354 set_bound_node(const AnfNodePtr & node)355 void set_bound_node(const AnfNodePtr &node) override { 356 if (evaluator_ != nullptr) { 357 evaluator_->set_bound_node(node); 358 } 359 bound_node_ = AnfNodeWeakPtr(node); 360 } 361 Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)362 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 363 MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called"; 364 } 365 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override; ToString()366 std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } 367 368 private: 369 EvaluatorPtr evaluator_; 370 AbstractFunctionPtr primal_func_; 371 }; 372 373 class TaylorEvaluator : public Evaluator { 374 public: TaylorEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func)375 TaylorEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) 376 : Evaluator("TaylorEvaluator"), evaluator_(evaluator), primal_func_(orig_func) {} 377 ~TaylorEvaluator() override = default; 378 MS_DECLARE_PARENT(TaylorEvaluator, Evaluator); bound_node()379 AnfNodePtr bound_node() const override { 380 if (evaluator_ != nullptr) { 381 return evaluator_->bound_node(); 382 } 383 return bound_node_.lock(); 384 } 385 set_bound_node(const AnfNodePtr & node)386 void set_bound_node(const AnfNodePtr &node) override { 387 if (evaluator_ != nullptr) { 388 evaluator_->set_bound_node(node); 389 } 390 bound_node_ = AnfNodeWeakPtr(node); 391 } 392 Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)393 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 394 MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called"; 395 } 396 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override; ToString()397 std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } 398 399 private: 400 EvaluatorPtr evaluator_; 401 AbstractFunctionPtr primal_func_; 402 }; 403 404 class ShardEvaluator : public Evaluator { 405 public: ShardEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func)406 ShardEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) 407 : Evaluator("ShardEvaluator"), evaluator_(evaluator), primal_func_(orig_func) {} 408 ~ShardEvaluator() override = default; 409 MS_DECLARE_PARENT(ShardEvaluator, Evaluator); 410 bound_node()411 AnfNodePtr bound_node() const override { 412 if (evaluator_ != nullptr) { 413 return evaluator_->bound_node(); 414 } 415 return bound_node_.lock(); 416 } 417 set_bound_node(const AnfNodePtr & node)418 void set_bound_node(const AnfNodePtr &node) override { 419 if (evaluator_ != nullptr) { 420 evaluator_->set_bound_node(node); 421 } 422 bound_node_ = AnfNodeWeakPtr(node); 423 } 424 Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)425 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 426 MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called"; 427 } 428 429 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override; 430 ToString()431 std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } 432 433 private: 434 EvaluatorPtr evaluator_; 435 AbstractFunctionPtr primal_func_; 436 }; 437 438 class VmapEvaluator : public Evaluator { 439 public: VmapEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func,const ValuePtr & in_axes,const ValuePtr & out_axes,size_t cell_size)440 VmapEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func, const ValuePtr &in_axes, 441 const ValuePtr &out_axes, size_t cell_size) 442 : Evaluator("VmapEvaluator"), 443 evaluator_(evaluator), 444 primal_func_(orig_func), 445 in_axes_(in_axes), 446 out_axes_(out_axes), 447 cell_size_(cell_size) {} 448 ~VmapEvaluator() override = default; 449 MS_DECLARE_PARENT(VmapEvaluator, Evaluator); bound_node()450 AnfNodePtr bound_node() const override { 451 if (evaluator_ != nullptr) { 452 return evaluator_->bound_node(); 453 } 454 return bound_node_.lock(); 455 } 456 set_bound_node(const AnfNodePtr & node)457 void set_bound_node(const AnfNodePtr &node) override { 458 if (evaluator_ != nullptr) { 459 evaluator_->set_bound_node(node); 460 } 461 bound_node_ = AnfNodeWeakPtr(node); 462 } 463 Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)464 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 465 MS_LOG(INTERNAL_EXCEPTION) << "Should not be called, Run() method should be called"; 466 } 467 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) override; ToString()468 std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } 469 470 private: 471 EvaluatorPtr evaluator_; 472 AbstractFunctionPtr primal_func_; 473 ValuePtr in_axes_; 474 ValuePtr out_axes_; 475 size_t cell_size_; 476 }; 477 478 AbstractBasePtrList EvaluateArguments(const ConfigPtrList &args_conf_list); 479 480 bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg); 481 482 bool ContainsAbstractAny(const AbstractBasePtrList &args_abs_list); 483 } // namespace abstract 484 } // namespace mindspore 485 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ 486