1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ 20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ 21 22 #include <list> 23 #include <memory> 24 #include <string> 25 #include <unordered_map> 26 #include <vector> 27 #include <utility> 28 #include <map> 29 #include <set> 30 #include <unordered_set> 31 #include <mutex> 32 33 #ifdef DEBUG 34 #include <stack> 35 #endif 36 37 #include "utils/log_adapter.h" 38 #include "ir/anf.h" 39 #include "pybind_api/ir/primitive_py.h" 40 #include "abstract/analysis_context.h" 41 #include "abstract/abstract_function.h" 42 #include "pipeline/jit/parse/parse.h" 43 44 namespace mindspore { 45 namespace abstract { 46 void ResetFunctionCallDepth(); 47 void IncreaseFunctionCallDepth(); 48 void DecreaseFunctionCallDepth(); 49 size_t FunctionCallDepth(); 50 size_t FunctionCallMaxDepth(); 51 52 void ResetStackFrameDepth(); 53 void IncreaseStackFrameDepth(); 54 void DecreaseStackFrameDepth(); 55 size_t StackFrameDepth(); 56 size_t StackFrameMaxDepth(); 57 58 // define attribute value map 59 using AttrValueMap = std::unordered_map<std::string, ValuePtr>; 60 using AttrValueMapPtr = std::shared_ptr<AttrValueMap>; 61 62 // the class to save evaluated result: abstract value and modified attribute 63 class EvalResult : public Base { 64 public: EvalResult(AbstractBasePtr abs,AttrValueMapPtr attr)65 EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {} 66 ~EvalResult() override = default; 67 MS_DECLARE_PARENT(EvalResult, Base); abstract()68 AbstractBasePtr abstract() { return abstract_; } attribute()69 AttrValueMapPtr attribute() { return attribute_; } 70 71 private: 72 AbstractBasePtr abstract_; 73 // Attribute related to PrimEvaluator; 74 AttrValueMapPtr attribute_; 75 }; 76 using EvalResultPtr = std::shared_ptr<EvalResult>; 77 78 // Superclass for AnfNodeConfig and VirtualConfig. 79 class Config : public Base { 80 public: 81 Config() = default; 82 ~Config() override = default; 83 MS_DECLARE_PARENT(Config, Base); 84 virtual EvalResultPtr ObtainEvalResult() = 0; 85 }; 86 87 // Config will be stored in AnalysisCache 88 using ConfigPtr = std::shared_ptr<Config>; 89 using ConfigPtrList = std::vector<ConfigPtr>; 90 91 // Config to a certain node in a certain context. 92 class AnfNodeConfig : public Config { 93 public: AnfNodeConfig(const AnalysisEnginePtr & engine,const AnfNodePtr & node,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)94 AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context, 95 const FuncGraphPtr &func_graph) 96 : Config(), 97 engine_(std::weak_ptr<AnalysisEngine>(engine)), 98 node_(node), 99 context_(nullptr), 100 func_graph_(func_graph) { 101 FuncGraphPtr fg; 102 if (IsValueNode<FuncGraph>(node)) { 103 auto v = node->cast<ValueNodePtr>(); 104 fg = v->value()->cast<FuncGraphPtr>(); 105 } else { 106 fg = node->func_graph(); 107 } 108 109 if (context == nullptr) { 110 return; 111 } 112 if (context->func_graph() == fg) { 113 // Usually `node` is CNode and not a FV, or top graph's ValueNodes. 114 context_ = context; 115 } else { 116 // If `node` is FV, FuncGraph, or other graph ValueNodes. 117 // Non-FuncGraph ValueNodes will always get a DummyContext since `fg` is null. 118 context_ = context->FindOwnOrParentContext(fg); 119 } 120 } 121 122 ~AnfNodeConfig() override = default; 123 MS_DECLARE_PARENT(AnfNodeConfig, Config); 124 125 EvalResultPtr ObtainEvalResult() override; 126 context()127 AnalysisContextPtr context() const { return context_; } 128 node()129 AnfNodePtr node() const { return node_; } 130 func_graph()131 FuncGraphPtr func_graph() const { return func_graph_; } 132 engine()133 AnalysisEnginePtr engine() const { return engine_.lock(); } 134 135 // used by unordered_map; 136 bool operator==(const AnfNodeConfig &other) const { 137 // compare node with pointer, context with pointer except DummyContext as it's created by make_shared; 138 // context should not be nullptr; 139 if (context_->IsDummyContext() && other.context_->IsDummyContext()) { 140 return true; 141 } 142 // Don't check `func_graph_` equality. 143 return (node_ == other.node_) && (context_ == other.context_); 144 } 145 ToString()146 std::string ToString() const override { 147 std::ostringstream buffer; 148 buffer << "Node: " << node_ << "/" << node_->DebugString() << "-uid(" << node_->UniqueId() 149 << "), Context: " << context_ << "/" << context_->ToString() << ", FuncGraph: " << func_graph_->ToString(); 150 return buffer.str(); 151 } 152 153 private: 154 // AnalysisEngine is global. 155 // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use 156 // weak_ptr to break Config cycle. 157 std::weak_ptr<AnalysisEngine> engine_; 158 AnfNodePtr node_; 159 // Which context the node would be called, usually in owner func graph context. 160 AnalysisContextPtr context_; 161 // Where to call the node. 162 FuncGraphPtr func_graph_; 163 }; 164 165 using AnfNodeConfigPtr = std::shared_ptr<AnfNodeConfig>; 166 167 struct AnfNodeConfigHasher { 168 std::size_t operator()(const AnfNodeConfigPtr conf) const; 169 }; 170 171 struct AnfNodeConfigEqual { 172 bool operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const; 173 }; 174 175 class VirtualConfig : public Config { 176 public: VirtualConfig(const AbstractBasePtr & abstract)177 explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {} 178 179 ~VirtualConfig() override = default; 180 MS_DECLARE_PARENT(VirtualConfig, Config); ObtainEvalResult()181 EvalResultPtr ObtainEvalResult() override { 182 return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>()); 183 } 184 185 private: 186 AbstractBasePtr abstract_; 187 }; 188 189 using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>; 190 using AnfNodeConfigMap = 191 std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>; 192 193 struct AnalysisResult { 194 EvalResultPtr inferred; 195 AnalysisContextPtr context; 196 }; 197 198 struct PartialAppHasher { operatorPartialAppHasher199 std::size_t operator()(const std::pair<AbstractFunctionPtr, AbstractBasePtrList> &p) const { 200 auto h1 = std::hash<AbstractFunctionPtr>{}(p.first); 201 auto h2 = AbstractBasePtrListHash(p.second); 202 return h1 ^ h2; 203 } 204 }; 205 206 // Should compare Args based on value other than pointer; 207 struct EvaluatorArgs { EvaluatorArgsEvaluatorArgs208 EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {} 209 bool operator==(const EvaluatorArgs &other) const { 210 if (evaluator_ != other.evaluator_) { 211 return false; 212 } 213 if (AbstractBasePtrListDeepEqual(args_, other.args_)) { 214 return true; 215 } 216 return false; 217 } 218 bool operator!=(const EvaluatorArgs &other) { return !(*this == other); } 219 220 EvaluatorPtr evaluator_; 221 AbstractBasePtrList args_; 222 }; 223 using EvalTraceRevIter = std::list<EvaluatorArgs>::reverse_iterator; 224 struct EvaluatorArgsHasher { operatorEvaluatorArgsHasher225 std::size_t operator()(const EvaluatorArgs &eval_args) const { 226 return hash_combine(std::hash<EvaluatorPtr>{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_)); 227 } 228 }; 229 struct EvaluatorArgsEqual { operatorEvaluatorArgsEqual230 bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; } 231 }; 232 233 class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { 234 public: AnalysisEngine(const PrimEvaluatorMap & prim_evaluator_map,const FuncGraphManagerPtr & func_graph_manager)235 AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) 236 : prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) { 237 forward_count_ = 0; 238 enable_recursive_eval_ = (common::GetEnv("ENV_RECURSIVE_EVAL") == "1"); 239 } 240 ~AnalysisEngine() = default; 241 242 // func_graph: The func_graph to analyze. 243 // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. 244 AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); 245 void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result); 246 EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf); 247 // Return the Evaluator for the given function. 248 EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); 249 250 AbstractBasePtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context, 251 const FuncGraphPtr &func_graph); 252 AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); 253 EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); 254 // Infer the result of fn(args). 255 EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); 256 void Clear(); 257 void ClearEvaluatorCache(); MakeConfig(const AnfNodePtr & node,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)258 AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context, 259 const FuncGraphPtr &func_graph) { 260 return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context, func_graph); 261 } 262 // Overloaded function. 263 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &fn); 264 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &fn); 265 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &fn); 266 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &fn); 267 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn); 268 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &); 269 EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn); 270 func_graph_manager()271 FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; } anfnode_config_map()272 const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; } 273 274 // Set the analysis result for orig to the result for new. 275 // This sets an entry in anfnode_config_map from orig to new. 276 EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf); PrimConstructors()277 const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } 278 root_func_graph()279 FuncGraphPtr root_func_graph() const { return root_func_graph_; } root_context()280 AnalysisContextPtr root_context() const { return root_context_; } set_root_context(const AnalysisContextPtr & context)281 void set_root_context(const AnalysisContextPtr &context) { root_context_ = context; } 282 283 std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; 284 enable_recursive_eval()285 bool enable_recursive_eval() const { return enable_recursive_eval_; } 286 static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node); 287 288 private: 289 void SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg); 290 EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval, 291 const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, 292 bool *continue_flag); 293 294 const PrimEvaluatorMap &prim_constructors_; 295 FuncGraphManagerPtr func_graph_manager_; 296 std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> evaluators_; 297 std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher> 298 constructors_app_; 299 300 AnfNodeConfigMap anfnode_config_map_; 301 // Use a list to trace multiple evaluators. 302 std::list<EvaluatorArgs> eval_trace_; 303 std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_; 304 std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> continued_evals_; 305 // root or top func_graph for static analysis; 306 FuncGraphPtr root_func_graph_{nullptr}; 307 AnalysisContextPtr root_context_{nullptr}; 308 309 AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, 310 const ConfigPtrList &args_conf_list); 311 EvalResultPtr Eval(const AnfNodeConfigPtr &conf); 312 EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); 313 EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, 314 const ConfigPtrList &args_conf_list); 315 EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, 316 const ConfigPtrList &args_conf_list); 317 EvalResultPtr ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators, 318 const AnfNodeConfigPtr &out_conf, 319 const ConfigPtrList &args_conf_list); 320 321 std::atomic_long forward_count_; 322 323 bool enable_recursive_eval_; 324 325 #ifdef DEBUG 326 std::vector<AnfNodePtr> compute_conf_stack_; 327 #endif 328 }; 329 330 // Translate the value to an abstract value. 331 // Arguments: 332 // value: The value to convert. 333 // context: The context in which the value was found, used if the value is a Graph. 334 // conf: The Config to the valuenode we are converting, if there is one, 335 // so that we can generate a tracking_id. 336 AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr, 337 const AnfNodeConfigPtr &conf = nullptr); 338 339 // Convert a value to an abstract value. 340 // Arguments: 341 // v: The value to convert. 342 // broaden: If True, concrete values will be made more abstract, so e.g. 343 // the value 1234 would become ANYTHING. 344 AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false); 345 346 template <typename T> 347 AbstractBasePtr FromValue(const T &value, bool broaden = false) { 348 return FromValueInside(MakeValue(value), broaden); 349 } 350 EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine); 351 EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); 352 } // namespace abstract 353 } // namespace mindspore 354 355 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ 356