1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2023 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ 20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ 21 22 #include <list> 23 #include <memory> 24 #include <string> 25 #include <vector> 26 #include <utility> 27 #include <map> 28 #include <set> 29 #include <unordered_map> 30 #include <unordered_set> 31 #include <mutex> 32 #include "utils/ms_utils.h" 33 #include "utils/hash_map.h" 34 #include "utils/hash_set.h" 35 #include "utils/log_adapter.h" 36 #include "utils/compile_config.h" 37 #include "utils/trace_base.h" 38 #include "ir/anf.h" 39 #include "pybind_api/ir/primitive_py.h" 40 #include "abstract/abstract_value.h" 41 #include "abstract/analysis_context.h" 42 #include "abstract/abstract_function.h" 43 #include "pipeline/jit/ps/parse/parse.h" 44 #include "include/common/debug/anf_ir_dump.h" 45 46 namespace mindspore { 47 namespace abstract { 48 void ResetFunctionCallDepth(); 49 void IncreaseFunctionCallDepth(); 50 void DecreaseFunctionCallDepth(); 51 size_t FunctionCallDepth(); 52 53 void ResetStackFrameDepth(); 54 void IncreaseStackFrameDepth(); 55 void DecreaseStackFrameDepth(); 56 size_t StackFrameDepth(); 57 58 // Extend current context with values for another graph. 59 AnalysisContextPtr NewContext(const AnalysisContextPtr ¤t_context, const FuncGraphPtr &fg, 60 const AbstractBasePtrList &args_abs_list); 61 62 // Config to a certain node in a certain context. 63 class AnfNodeConfig final : public Config { 64 public: AnfNodeConfig(const AnalysisEnginePtr & engine,const AnfNodePtr & node,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)65 AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context, 66 const FuncGraphPtr &func_graph) 67 : Config(), 68 engine_(std::weak_ptr<AnalysisEngine>(engine)), 69 node_(node), 70 context_(nullptr), 71 func_graph_(func_graph) { 72 if (context == nullptr) { 73 return; 74 } 75 auto fg = GetValueNode<FuncGraphPtr>(node); 76 if (fg == nullptr && node != nullptr) { 77 fg = node->func_graph(); 78 } 79 if (context->func_graph() == fg) { 80 // Usually `node` is CNode and not a FV, or top graph's ValueNodes. 81 context_ = context; 82 } else { 83 // If `node` is FV, FuncGraph, or other graph ValueNodes. 84 // Non-FuncGraph ValueNodes will always get a DummyContext since `fg` is null. 85 context_ = context->FindOwnOrParentContext(fg.get()); 86 if (context_ == nullptr) { 87 FuncGraphPtr parent_graph = fg->parent(); 88 #ifdef ENABLE_DUMP_IR 89 const auto no_parent = parent_graph == nullptr; 90 DumpIR(std::string("EXCEPTION_NEW_CONTEXT_CURRENT_") + (no_parent ? "0" : "1") + "_" + fg->ToString() + ".ir", 91 fg); 92 if (!no_parent) { 93 DumpIR("EXCEPTION_NEW_CONTEXT_PARENT_" + parent_graph->ToString() + ".ir", parent_graph); 94 } 95 #endif 96 // Context not found, it would be a bug in code so we raise exception. 97 std::ostringstream oss; 98 oss << "BUG: Failed to find context for: " << fg->ToString() 99 << ", parent: " << (parent_graph == nullptr ? "null" : parent_graph->ToString()) << " from contexts: [" 100 << context->ToString(); 101 for (auto p = context->parent(); p != nullptr; p = p->parent()) { 102 oss << ", " << p->ToString(); 103 } 104 oss << "] " 105 << ", node: " << node->DebugString() << ", " << trace::GetDebugInfoStr(fg->debug_info()); 106 MS_LOG(INTERNAL_EXCEPTION) << oss.str(); 107 } 108 } 109 } 110 111 ~AnfNodeConfig() override = default; 112 MS_DECLARE_PARENT(AnfNodeConfig, Config); 113 114 EvalResultPtr ObtainEvalResult() override; 115 context()116 const AnalysisContextPtr &context() const { return context_; } 117 node()118 const AnfNodePtr &node() const { return node_; } 119 func_graph()120 const FuncGraphPtr &func_graph() const { return func_graph_; } 121 engine()122 AnalysisEnginePtr engine() const { return engine_.lock(); } 123 hash()124 size_t hash() const override { 125 std::size_t node_hash = PointerHash<AnfNodePtr>{}(node_); 126 return hash_combine(node_hash, PointerHash<AnalysisContextPtr>{}(context_)); 127 } 128 129 bool operator==(const AnfNodeConfig &other) const { 130 if (this == &other) { 131 return true; 132 } 133 // Compare node with pointer. 134 if (node_ != other.node_) { 135 return false; 136 } 137 // Compare context with pointer. 138 return context_ == other.context_; 139 } 140 ToString()141 std::string ToString() const override { 142 std::ostringstream buffer; 143 constexpr int recursive_level = 2; 144 buffer << "Node: " << node_ << "/" << node_->DebugString(recursive_level) << "-uid(" << node_->UniqueId() 145 << "), Context: " << context_ << "/" << context_->ToString() << ", FuncGraph: " << func_graph_ << "/" 146 << func_graph_->ToString(); 147 return buffer.str(); 148 } 149 150 private: 151 // AnalysisEngine is global. 152 // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use 153 // weak_ptr to break Config cycle. 154 std::weak_ptr<AnalysisEngine> engine_; 155 AnfNodePtr node_; 156 // Which context the node would be called, usually in owner func graph context. 157 AnalysisContextPtr context_; 158 // Where to call the node. 159 FuncGraphPtr func_graph_; 160 }; 161 162 using AnfNodeConfigPtr = std::shared_ptr<AnfNodeConfig>; 163 164 struct AnfNodeConfigHasher { operatorAnfNodeConfigHasher165 std::size_t operator()(const AnfNodeConfigPtr &conf) const { 166 MS_EXCEPTION_IF_NULL(conf); 167 return conf->hash(); 168 } 169 }; 170 171 struct AnfNodeConfigEqual { operatorAnfNodeConfigEqual172 bool operator()(const AnfNodeConfigPtr &lhs, const AnfNodeConfigPtr &rhs) const { 173 if (lhs == nullptr || rhs == nullptr) { 174 return false; 175 } 176 if (lhs == rhs) { 177 return true; 178 } 179 return (*lhs == *rhs); 180 } 181 }; 182 183 class VirtualConfig final : public Config { 184 public: VirtualConfig(const AbstractBasePtr & abstract)185 explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {} 186 187 ~VirtualConfig() override = default; 188 MS_DECLARE_PARENT(VirtualConfig, Config); ObtainEvalResult()189 EvalResultPtr ObtainEvalResult() override { 190 return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>()); 191 } 192 193 private: 194 AbstractBasePtr abstract_; 195 }; 196 197 using PrimEvaluatorMap = mindspore::HashMap<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>; 198 using AnfNodeConfigMap = 199 mindspore::HashMap<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>; 200 201 struct AnalysisResult { 202 EvalResultPtr eval_result; 203 AnalysisContextPtr context; 204 }; 205 206 struct PartialAppHasher { operatorPartialAppHasher207 std::size_t operator()(const std::pair<AbstractFunctionPtr, AbstractBasePtrList> &p) const { 208 auto hash_value = PointerHash<AbstractFunctionPtr>{}(p.first); 209 for (const auto &abs : p.second) { 210 hash_value = hash_combine(hash_value, PointerHash<AbstractBasePtr>{}(abs)); 211 } 212 return hash_value; 213 } 214 }; 215 216 // Should compare Args based on value other than pointer; 217 struct EvaluatorArgs { EvaluatorArgsEvaluatorArgs218 EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {} 219 bool operator==(const EvaluatorArgs &other) const { 220 return (this == &other) || ((evaluator_ == other.evaluator_) && AbstractBasePtrListDeepEqual(args_, other.args_)); 221 } 222 bool operator!=(const EvaluatorArgs &other) const { return !(*this == other); } 223 224 EvaluatorPtr evaluator_; 225 AbstractBasePtrList args_; 226 }; 227 using EvalTraceRevIter = std::list<EvaluatorArgs>::const_reverse_iterator; 228 struct EvaluatorArgsHasher { operatorEvaluatorArgsHasher229 std::size_t operator()(const EvaluatorArgs &eval_args) const { 230 return hash_combine(PointerHash<EvaluatorPtr>{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_)); 231 } 232 }; 233 struct EvaluatorArgsEqual { operatorEvaluatorArgsEqual234 bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; } 235 }; 236 237 struct PrimitiveEvalCacheKey { 238 AttrValueMap attrs; 239 AbstractBasePtrList args; 240 }; 241 242 struct PrimitiveEvalCacheHash { operatorPrimitiveEvalCacheHash243 std::size_t operator()(const PrimitiveEvalCacheKey &key) const { 244 auto hash_value = key.attrs.size(); 245 for (const auto &attr : key.attrs) { 246 hash_value = hash_combine(hash_value, std::hash<std::string>{}(attr.first)); 247 if (attr.second != nullptr) { 248 hash_value = hash_combine(hash_value, attr.second->hash()); 249 } 250 } 251 return hash_combine(hash_value, AbstractBasePtrListHash(key.args)); 252 } 253 }; 254 255 struct PrimitiveEvalCacheEqual { operatorPrimitiveEvalCacheEqual256 bool operator()(const PrimitiveEvalCacheKey &a, const PrimitiveEvalCacheKey &b) const { 257 if (!common::IsAttrsEqual(a.attrs, b.attrs)) { 258 return false; 259 } 260 return AbstractBasePtrListDeepEqual(a.args, b.args); 261 } 262 }; 263 264 class PrimitiveEvalCache { 265 public: 266 using EvalCache = 267 std::unordered_map<PrimitiveEvalCacheKey, EvalResultPtr, PrimitiveEvalCacheHash, PrimitiveEvalCacheEqual>; 268 using PrimToEvalCache = mindspore::HashMap<std::string, EvalCache>; 269 EvalResultPtr Get(const PrimitivePtr &prim, const AbstractBasePtrList &args) const; 270 void Put(const PrimitivePtr &prim, AttrValueMap &&attrs, const AbstractBasePtrList &args, 271 const EvalResultPtr &result); 272 void Clear(); 273 274 private: 275 mutable std::mutex mutex_; 276 PrimToEvalCache prim_cache_; 277 }; 278 279 using PrimitiveEvalCachePtr = std::shared_ptr<PrimitiveEvalCache>; 280 281 class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { 282 public: AnalysisEngine(const PrimEvaluatorMap & prim_evaluator_map,const FuncGraphManagerPtr & func_graph_manager)283 AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) 284 : prim_constructors_(prim_evaluator_map), 285 func_graph_manager_(func_graph_manager), 286 forward_count_(0), 287 enable_recursive_eval_(common::GetCompileConfig("RECURSIVE_EVAL") == "1"), 288 check_side_effect_(false) {} 289 virtual ~AnalysisEngine() = default; 290 291 // func_graph: The func_graph to analyze. 292 // args_abs_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. 293 AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_abs_list); 294 void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) const; 295 EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf); 296 // Evaluate a CNode without look up cache. 297 EvalResultPtr ObtainEvalResultWithoutCache(const AnfNodeConfigPtr &conf); 298 // Return the Evaluator for the given function. 299 EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &func); 300 301 AnfNodeConfigPtr GetForwardConfig(const AnfNodeConfigPtr &conf) const; 302 EvalResultPtr InterpretedNodeCall(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); 303 AbstractBasePtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context, 304 const FuncGraphPtr &func_graph); 305 AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) const; 306 EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); 307 EvalResultPtr ConvertClassTypeToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf); 308 // Infer the result of fn(args). 309 EvalResultPtr Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_abs_list); 310 void Clear(); 311 void ClearEvaluatorCache(); MakeConfig(const AnfNodePtr & node,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)312 AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context, 313 const FuncGraphPtr &func_graph) { 314 return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context, func_graph); 315 } 316 func_graph_manager()317 FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; } anfnode_config_map()318 const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; } 319 320 // Set the analysis result for orig to the result for new. 321 // This sets an entry in anfnode_config_map from orig to new. 322 EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf); PrimConstructors()323 const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } 324 root_func_graph()325 FuncGraphPtr root_func_graph() const { return root_func_graph_.lock(); } root_context()326 AnalysisContextPtr root_context() const { return root_context_; } set_root_context(const AnalysisContextPtr & context)327 void set_root_context(const AnalysisContextPtr &context) { root_context_ = context; } 328 329 mindspore::HashMap<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; 330 enable_recursive_eval()331 bool enable_recursive_eval() const { return enable_recursive_eval_; } 332 static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_abs_list, const AnfNodePtr &node); 333 check_side_effect()334 bool check_side_effect() const { return check_side_effect_; } set_check_side_effect(bool check_side_effect)335 void set_check_side_effect(bool check_side_effect) { check_side_effect_ = check_side_effect; } 336 void SetUndeterminedFlag(const std::string &thread_id, const FuncGraph &fg); 337 void SetIgnoreValueFlag(const std::string &thread_id, FuncGraph *fg); 338 339 private: 340 // Overloaded function. 341 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func); 342 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &func); 343 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func); 344 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func); 345 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func); 346 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &func); 347 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &func); 348 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func); 349 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> &func); 350 351 EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval, 352 const AbstractBasePtrList &args_abs_list, const EvalTraceRevIter &it, 353 bool *continue_flag); 354 355 const PrimEvaluatorMap &prim_constructors_; 356 FuncGraphManagerPtr func_graph_manager_; 357 std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> evaluators_; 358 // Record the func_graph which should be set as undetermined and the setting thread id. 359 mindspore::HashMap<const FuncGraph *, std::forward_list<std::string>> func_graph_undetermined_flags_; 360 std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher> 361 constructors_app_; 362 363 AnfNodeConfigMap anfnode_config_map_; 364 // Use a list to trace multiple evaluators. 365 std::list<EvaluatorArgs> eval_trace_; 366 std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_; 367 std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> continued_evals_; 368 // Root or top func_graph for static analysis; 369 FuncGraphWeakPtr root_func_graph_; 370 AnalysisContextPtr root_context_{nullptr}; 371 372 AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, 373 const ConfigPtrList &args_conf_list); 374 EvalResultPtr Eval(const AnfNodeConfigPtr &conf); 375 EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, 376 const ConfigPtrList &args_conf_list); 377 EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, 378 const ConfigPtrList &args_conf_list); 379 EvalResultPtr ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators, 380 const AnfNodeConfigPtr &out_conf, 381 const ConfigPtrList &args_conf_list); 382 383 std::atomic_long forward_count_; 384 385 bool enable_recursive_eval_; 386 387 bool check_side_effect_; 388 389 #ifdef DEBUG 390 std::vector<AnfNodePtr> compute_conf_stack_; 391 #endif 392 }; 393 394 // Translate the value to an abstract value. 395 // Arguments: 396 // value: The value to convert. 397 // context: The context in which the value was found, used if the value is a Graph. 398 // conf: The Config to the valuenode we are converting, if there is one, 399 // so that we can generate a tracking_id. 400 AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr, 401 const AnfNodeConfigPtr &conf = nullptr); 402 403 // Convert a value to an abstract value. 404 // Arguments: 405 // v: The value to convert. 406 // broaden: If True, concrete values will be made more abstract, so e.g. 407 // the value 1234 would become ANYTHING. 408 AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false); 409 410 EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs); 411 412 template <typename T> 413 AbstractBasePtr FromValue(const T &value, bool broaden = false) { 414 return FromValueInside(MakeValue(value), broaden); 415 } 416 EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine); 417 418 EvalResultPtr ObtainEvalResultFromCache(const AnfNodeConfigPtr &conf); 419 // If the config of CNode(funcgraph/metafuncgraph) can be found in the cache, evaluation of the config of that CNode 420 // is not required, but the use flags of arguments should be synchronized as if the flags will be set when the 421 // evaluation is executed. 422 void SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg, 423 const CNodePtr &cnode, 424 const AbstractFunctionPtr &base_func_graph_func, 425 const AnalysisContextPtr &fg_context); 426 } // namespace abstract 427 } // namespace mindspore 428 429 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ 430