1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ 20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ 21 22 #include <memory> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 #include <stack> 27 28 #include "pipeline/jit/static_analysis/static_analysis.h" 29 #include "pipeline/jit/static_analysis/async_eval_result.h" 30 #include "utils/ms_context.h" 31 32 namespace mindspore { 33 namespace abstract { 34 using EvaluatorCacheMgrPtr = std::shared_ptr<EvaluatorCacheMgr>; 35 using EvaluatorAttrMap = 36 std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; 37 using EvaluatorAttrCache = MultiThreadCache<AbstractBasePtrList, AttrValueMapPtr, EvaluatorAttrMap>; 38 using EvaluatorAttrCachePtr = std::shared_ptr<EvaluatorAttrCache>; 39 40 class Evaluator : public Base { 41 public: Evaluator(const std::string & id)42 explicit Evaluator(const std::string &id) 43 : identifier_(id), 44 evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()), 45 attr_cache_(std::make_shared<EvaluatorAttrCache>()) {} 46 ~Evaluator() override = default; 47 MS_DECLARE_PARENT(Evaluator, Base); 48 49 // difference between Run() and Eval(): 50 // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. 51 // Run() will modify cache_ member, so it cannot marked as const; 52 virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 53 const AnfNodeConfigPtr &out_conf); 54 55 virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list, 56 const AnfNodeConfigPtr &out_conf) = 0; 57 virtual EvalResultPtr SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 58 const AnfNodeConfigPtr &out_conf); 59 NormalizeArgs(const AbstractBasePtrList & args_spec_list)60 virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } 61 BroadenUndeterminedArgs(const AbstractBasePtrList & args_spec_list)62 virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { 63 return args_spec_list; 64 } 65 AbstractEval(const AbstractBasePtrList & args_spec_list)66 virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { 67 auto context = MsContext::GetInstance(); 68 MS_EXCEPTION_IF_NULL(context); 69 bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE); 70 if (!enable_sparse) { 71 return nullptr; 72 } 73 74 auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { 75 if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { 76 return true; 77 } 78 return false; 79 }); 80 if (is_abstract) { 81 MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result"; 82 return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>()); 83 } 84 return nullptr; 85 } 86 ToString()87 std::string ToString() const override { return identifier_; } 88 bound_node()89 virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } 90 set_bound_node(const AnfNodePtr & node)91 virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } 92 evaluator_cache_mgr()93 EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; } attr_cache()94 EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; } 95 eval_lock()96 std::recursive_timed_mutex &eval_lock() { return eval_lock_; } 97 98 protected: 99 std::string identifier_; 100 AnfNodeWeakPtr bound_node_; 101 EvaluatorCacheMgrPtr evaluator_cache_mgr_; 102 std::recursive_timed_mutex eval_lock_; 103 104 private: 105 EvaluatorAttrCachePtr attr_cache_; 106 }; 107 108 class PrimEvaluator : public Evaluator { 109 public: PrimEvaluator(const std::string & id)110 explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} 111 ~PrimEvaluator() override = default; 112 MS_DECLARE_PARENT(PrimEvaluator, Evaluator); Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)113 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) final { 114 MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; 115 } 116 }; 117 118 class TrivialPrimEvaluator : public PrimEvaluator { 119 public: TrivialPrimEvaluator(const std::string & id)120 explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} 121 ~TrivialPrimEvaluator() override = default; 122 MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); 123 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 124 const AnfNodeConfigPtr &out_conf) final; 125 virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; 126 }; 127 128 class TransitionPrimEvaluator : public PrimEvaluator { 129 public: TransitionPrimEvaluator(const std::string & id)130 explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} 131 ~TransitionPrimEvaluator() override = default; 132 MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); 133 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 134 const AnfNodeConfigPtr &out_conf) final; 135 // Parameter in_conf0 : the first element in args_conf_list; 136 virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, 137 const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; 138 }; 139 140 class SymbolicPrimEvaluator : public PrimEvaluator { 141 public: SymbolicPrimEvaluator(const std::string & id)142 explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} 143 ~SymbolicPrimEvaluator() override = default; 144 MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); 145 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 146 const AnfNodeConfigPtr &out_conf) final; 147 virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; 148 }; 149 150 // Evaluator will be stored in AnalysisEngine.evaluators_ 151 using EvaluatorPtrList = std::vector<EvaluatorPtr>; 152 153 class DummyEvaluator : public Evaluator { 154 public: DummyEvaluator()155 DummyEvaluator() : Evaluator("dummy") {} 156 ~DummyEvaluator() override = default; 157 MS_DECLARE_PARENT(DummyEvaluator, Evaluator); Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)158 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 159 return nullptr; 160 } 161 }; 162 163 // Wrap another evaluator to track a subset of uses. 164 // A TrackedEvaluator has its own cache that maps possible calls to 165 // their results, but is ultimately backed by a different evaluator. 166 // Multiple TrackedEvaluators can be backed by the same Evaluator. 167 class TrackedEvaluator : public Evaluator { 168 public: TrackedEvaluator(const EvaluatorPtr & subinf)169 explicit TrackedEvaluator(const EvaluatorPtr &subinf) : Evaluator("TrackedEvaluator"), sub_evaluator_(subinf) {} 170 ~TrackedEvaluator() override = default; 171 MS_DECLARE_PARENT(TrackedEvaluator, Evaluator); bound_node()172 AnfNodePtr bound_node() const override { 173 if (sub_evaluator_ != nullptr) { 174 return sub_evaluator_->bound_node(); 175 } 176 return bound_node_.lock(); 177 } 178 set_bound_node(const AnfNodePtr & node)179 void set_bound_node(const AnfNodePtr &node) override { 180 if (sub_evaluator_ != nullptr) { 181 sub_evaluator_->set_bound_node(node); 182 } 183 bound_node_ = AnfNodeWeakPtr(node); 184 } 185 Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)186 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 187 MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; 188 } 189 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 190 const AnfNodeConfigPtr &out_conf) override; ToString()191 std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } 192 193 private: 194 EvaluatorPtr sub_evaluator_; 195 }; 196 197 using FuncGraphCacheMap = 198 std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; 199 class StackFrame; 200 using StackFramePtr = std::shared_ptr<StackFrame>; 201 202 class BaseFuncGraphEvaluator : public Evaluator { 203 public: BaseFuncGraphEvaluator(const AnalysisContextPtr & context)204 explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context) 205 : Evaluator("basegraph"), parent_context_(context) {} 206 207 ~BaseFuncGraphEvaluator() override = default; 208 MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); 209 210 EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list, 211 const AnfNodeConfigPtr &out_conf) override; 212 213 virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; 214 parent_context()215 AnalysisContextPtr parent_context() const { return parent_context_; } set_parent_context(const AnalysisContextPtr & parent_context)216 void set_parent_context(const AnalysisContextPtr &parent_context) { parent_context_ = parent_context; } 217 218 protected: 219 AnalysisContextPtr parent_context_; 220 221 private: 222 AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg, 223 const AnalysisContextPtr &context); 224 // Add functions for stack frame routine. 225 AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg, 226 const AnalysisContextPtr &context); 227 static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame, 228 const StackFramePtr &new_stack_frame); 229 static void LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr ¤t_stack_frame); 230 }; 231 232 class FuncGraphEvaluator : public BaseFuncGraphEvaluator { 233 public: FuncGraphEvaluator(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context)234 FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) 235 : BaseFuncGraphEvaluator(context), func_graph_(func_graph) {} 236 237 ~FuncGraphEvaluator() override = default; 238 MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator); 239 240 FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; 241 func_graph()242 FuncGraphPtr func_graph() { return func_graph_; } 243 244 AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override; 245 AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override; ToString()246 std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } 247 248 private: 249 FuncGraphPtr func_graph_; 250 FuncGraphCacheMap func_graph_cache_; 251 std::vector<AbstractBasePtrList> trace_; 252 }; 253 using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>; 254 255 class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator { 256 public: 257 // Note: context parameter is not used; MetaFuncGraphEvaluator(const MetaFuncGraphPtr & meta_func_graph,const ScopePtr & scope)258 MetaFuncGraphEvaluator(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope) 259 : BaseFuncGraphEvaluator(AnalysisContext::DummyContext()), meta_func_graph_(meta_func_graph), scope_(scope) {} 260 ~MetaFuncGraphEvaluator() override = default; 261 MS_DECLARE_PARENT(MetaFuncGraphEvaluator, BaseFuncGraphEvaluator); 262 263 FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; 264 265 // Return normalized versions of the arguments. NormalizeArgs(const AbstractBasePtrList & args_spec_list)266 AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { 267 return meta_func_graph_->NormalizeArgs(args_spec_list); 268 } ToString()269 std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); } 270 271 private: 272 MetaFuncGraphPtr meta_func_graph_; 273 FuncGraphCacheMap func_graph_cache_; 274 ScopePtr scope_; 275 }; 276 277 class PartialAppEvaluator : public Evaluator { 278 public: PartialAppEvaluator(const EvaluatorPtr & evaluator,const AbstractBasePtrList & args)279 PartialAppEvaluator(const EvaluatorPtr &evaluator, const AbstractBasePtrList &args) 280 : Evaluator("PartialAppEvaluator"), evaluator_(evaluator), args_spec_list_(args) {} 281 ~PartialAppEvaluator() override = default; 282 MS_DECLARE_PARENT(PartialAppEvaluator, Evaluator); bound_node()283 AnfNodePtr bound_node() const override { 284 if (evaluator_ != nullptr) { 285 return evaluator_->bound_node(); 286 } 287 return bound_node_.lock(); 288 } 289 set_bound_node(const AnfNodePtr & node)290 void set_bound_node(const AnfNodePtr &node) override { 291 if (evaluator_ != nullptr) { 292 evaluator_->set_bound_node(node); 293 } 294 bound_node_ = AnfNodeWeakPtr(node); 295 } 296 Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)297 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 298 MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; 299 } 300 301 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 302 const AnfNodeConfigPtr &out_conf) override; ToString()303 std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } 304 305 private: 306 EvaluatorPtr evaluator_; 307 AbstractBasePtrList args_spec_list_; 308 }; 309 310 class VirtualEvaluator : public Evaluator { 311 public: VirtualEvaluator(const AbstractBasePtrList & args_spec_list,const AbstractBasePtr & output)312 VirtualEvaluator(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output) 313 : Evaluator("virtual"), args_spec_list_(args_spec_list), output_(output) {} 314 ~VirtualEvaluator() override = default; 315 MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); 316 317 EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list, 318 const AnfNodeConfigPtr &out_conf) override; ToString()319 std::string ToString() const override { return identifier_; } 320 321 private: 322 AbstractBasePtrList args_spec_list_; 323 AbstractBasePtr output_; 324 }; 325 326 class JEvaluator : public Evaluator { 327 public: JEvaluator(const EvaluatorPtr & evaluator,const AbstractFunctionPtr & orig_func)328 JEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) 329 : Evaluator("JEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {} 330 ~JEvaluator() override = default; 331 MS_DECLARE_PARENT(JEvaluator, Evaluator); bound_node()332 AnfNodePtr bound_node() const override { 333 if (evaluator_ != nullptr) { 334 return evaluator_->bound_node(); 335 } 336 return bound_node_.lock(); 337 } 338 set_bound_node(const AnfNodePtr & node)339 void set_bound_node(const AnfNodePtr &node) override { 340 if (evaluator_ != nullptr) { 341 evaluator_->set_bound_node(node); 342 } 343 bound_node_ = AnfNodeWeakPtr(node); 344 } 345 Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)346 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { 347 MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; 348 } 349 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, 350 const AnfNodeConfigPtr &out_conf) override; ToString()351 std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } 352 353 private: 354 EvaluatorPtr evaluator_; 355 AbstractFunctionPtr orig_func_; 356 }; 357 358 void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args); 359 } // namespace abstract 360 } // namespace mindspore 361 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ 362