• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
21 
22 #include <list>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 #include <utility>
27 #include <map>
28 #include <set>
29 #include <unordered_map>
30 #include <unordered_set>
31 #include <mutex>
32 #include "utils/ms_utils.h"
33 #include "utils/hash_map.h"
34 #include "utils/hash_set.h"
35 #include "utils/log_adapter.h"
36 #include "utils/compile_config.h"
37 #include "utils/trace_base.h"
38 #include "ir/anf.h"
39 #include "pybind_api/ir/primitive_py.h"
40 #include "abstract/abstract_value.h"
41 #include "abstract/analysis_context.h"
42 #include "abstract/abstract_function.h"
43 #include "pipeline/jit/ps/parse/parse.h"
44 #include "include/common/debug/anf_ir_dump.h"
45 
46 namespace mindspore {
47 namespace abstract {
48 void ResetFunctionCallDepth();
49 void IncreaseFunctionCallDepth();
50 void DecreaseFunctionCallDepth();
51 size_t FunctionCallDepth();
52 
53 void ResetStackFrameDepth();
54 void IncreaseStackFrameDepth();
55 void DecreaseStackFrameDepth();
56 size_t StackFrameDepth();
57 
58 // Extend current context with values for another graph.
59 AnalysisContextPtr NewContext(const AnalysisContextPtr &current_context, const FuncGraphPtr &fg,
60                               const AbstractBasePtrList &args_abs_list);
61 
62 // Config to a certain node in a certain context.
63 class AnfNodeConfig final : public Config {
64  public:
AnfNodeConfig(const AnalysisEnginePtr & engine,const AnfNodePtr & node,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)65   AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context,
66                 const FuncGraphPtr &func_graph)
67       : Config(),
68         engine_(std::weak_ptr<AnalysisEngine>(engine)),
69         node_(node),
70         context_(nullptr),
71         func_graph_(func_graph) {
72     if (context == nullptr) {
73       return;
74     }
75     auto fg = GetValueNode<FuncGraphPtr>(node);
76     if (fg == nullptr && node != nullptr) {
77       fg = node->func_graph();
78     }
79     if (context->func_graph() == fg) {
80       // Usually `node` is CNode and not a FV, or top graph's ValueNodes.
81       context_ = context;
82     } else {
83       // If `node` is FV, FuncGraph, or other graph ValueNodes.
84       // Non-FuncGraph ValueNodes will always get a DummyContext since `fg` is null.
85       context_ = context->FindOwnOrParentContext(fg.get());
86       if (context_ == nullptr) {
87         FuncGraphPtr parent_graph = fg->parent();
88 #ifdef ENABLE_DUMP_IR
89         const auto no_parent = parent_graph == nullptr;
90         DumpIR(std::string("EXCEPTION_NEW_CONTEXT_CURRENT_") + (no_parent ? "0" : "1") + "_" + fg->ToString() + ".ir",
91                fg);
92         if (!no_parent) {
93           DumpIR("EXCEPTION_NEW_CONTEXT_PARENT_" + parent_graph->ToString() + ".ir", parent_graph);
94         }
95 #endif
96         // Context not found, it would be a bug in code so we raise exception.
97         std::ostringstream oss;
98         oss << "BUG: Failed to find context for: " << fg->ToString()
99             << ", parent: " << (parent_graph == nullptr ? "null" : parent_graph->ToString()) << " from contexts: ["
100             << context->ToString();
101         for (auto p = context->parent(); p != nullptr; p = p->parent()) {
102           oss << ", " << p->ToString();
103         }
104         oss << "] "
105             << ", node: " << node->DebugString() << ", " << trace::GetDebugInfoStr(fg->debug_info());
106         MS_LOG(INTERNAL_EXCEPTION) << oss.str();
107       }
108     }
109   }
110 
111   ~AnfNodeConfig() override = default;
112   MS_DECLARE_PARENT(AnfNodeConfig, Config);
113 
114   EvalResultPtr ObtainEvalResult() override;
115 
context()116   const AnalysisContextPtr &context() const { return context_; }
117 
node()118   const AnfNodePtr &node() const { return node_; }
119 
func_graph()120   const FuncGraphPtr &func_graph() const { return func_graph_; }
121 
engine()122   AnalysisEnginePtr engine() const { return engine_.lock(); }
123 
hash()124   size_t hash() const override {
125     std::size_t node_hash = PointerHash<AnfNodePtr>{}(node_);
126     return hash_combine(node_hash, PointerHash<AnalysisContextPtr>{}(context_));
127   }
128 
129   bool operator==(const AnfNodeConfig &other) const {
130     if (this == &other) {
131       return true;
132     }
133     // Compare node with pointer.
134     if (node_ != other.node_) {
135       return false;
136     }
137     // Compare context with pointer.
138     return context_ == other.context_;
139   }
140 
ToString()141   std::string ToString() const override {
142     std::ostringstream buffer;
143     constexpr int recursive_level = 2;
144     buffer << "Node: " << node_ << "/" << node_->DebugString(recursive_level) << "-uid(" << node_->UniqueId()
145            << "), Context: " << context_ << "/" << context_->ToString() << ", FuncGraph: " << func_graph_ << "/"
146            << func_graph_->ToString();
147     return buffer.str();
148   }
149 
150  private:
151   // AnalysisEngine is global.
152   // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use
153   // weak_ptr to break Config cycle.
154   std::weak_ptr<AnalysisEngine> engine_;
155   AnfNodePtr node_;
156   // Which context the node would be called, usually in owner func graph context.
157   AnalysisContextPtr context_;
158   // Where to call the node.
159   FuncGraphPtr func_graph_;
160 };
161 
162 using AnfNodeConfigPtr = std::shared_ptr<AnfNodeConfig>;
163 
164 struct AnfNodeConfigHasher {
operatorAnfNodeConfigHasher165   std::size_t operator()(const AnfNodeConfigPtr &conf) const {
166     MS_EXCEPTION_IF_NULL(conf);
167     return conf->hash();
168   }
169 };
170 
171 struct AnfNodeConfigEqual {
operatorAnfNodeConfigEqual172   bool operator()(const AnfNodeConfigPtr &lhs, const AnfNodeConfigPtr &rhs) const {
173     if (lhs == nullptr || rhs == nullptr) {
174       return false;
175     }
176     if (lhs == rhs) {
177       return true;
178     }
179     return (*lhs == *rhs);
180   }
181 };
182 
183 class VirtualConfig final : public Config {
184  public:
VirtualConfig(const AbstractBasePtr & abstract)185   explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {}
186 
187   ~VirtualConfig() override = default;
188   MS_DECLARE_PARENT(VirtualConfig, Config);
ObtainEvalResult()189   EvalResultPtr ObtainEvalResult() override {
190     return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>());
191   }
192 
193  private:
194   AbstractBasePtr abstract_;
195 };
196 
197 using PrimEvaluatorMap = mindspore::HashMap<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
198 using AnfNodeConfigMap =
199   mindspore::HashMap<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
200 
201 struct AnalysisResult {
202   EvalResultPtr eval_result;
203   AnalysisContextPtr context;
204 };
205 
206 struct PartialAppHasher {
operatorPartialAppHasher207   std::size_t operator()(const std::pair<AbstractFunctionPtr, AbstractBasePtrList> &p) const {
208     auto hash_value = PointerHash<AbstractFunctionPtr>{}(p.first);
209     for (const auto &abs : p.second) {
210       hash_value = hash_combine(hash_value, PointerHash<AbstractBasePtr>{}(abs));
211     }
212     return hash_value;
213   }
214 };
215 
216 // Should compare Args based on value other than pointer;
217 struct EvaluatorArgs {
EvaluatorArgsEvaluatorArgs218   EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {}
219   bool operator==(const EvaluatorArgs &other) const {
220     return (this == &other) || ((evaluator_ == other.evaluator_) && AbstractBasePtrListDeepEqual(args_, other.args_));
221   }
222   bool operator!=(const EvaluatorArgs &other) const { return !(*this == other); }
223 
224   EvaluatorPtr evaluator_;
225   AbstractBasePtrList args_;
226 };
227 using EvalTraceRevIter = std::list<EvaluatorArgs>::const_reverse_iterator;
228 struct EvaluatorArgsHasher {
operatorEvaluatorArgsHasher229   std::size_t operator()(const EvaluatorArgs &eval_args) const {
230     return hash_combine(PointerHash<EvaluatorPtr>{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_));
231   }
232 };
233 struct EvaluatorArgsEqual {
operatorEvaluatorArgsEqual234   bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; }
235 };
236 
237 struct PrimitiveEvalCacheKey {
238   AttrValueMap attrs;
239   AbstractBasePtrList args;
240 };
241 
242 struct PrimitiveEvalCacheHash {
operatorPrimitiveEvalCacheHash243   std::size_t operator()(const PrimitiveEvalCacheKey &key) const {
244     auto hash_value = key.attrs.size();
245     for (const auto &attr : key.attrs) {
246       hash_value = hash_combine(hash_value, std::hash<std::string>{}(attr.first));
247       if (attr.second != nullptr) {
248         hash_value = hash_combine(hash_value, attr.second->hash());
249       }
250     }
251     return hash_combine(hash_value, AbstractBasePtrListHash(key.args));
252   }
253 };
254 
255 struct PrimitiveEvalCacheEqual {
operatorPrimitiveEvalCacheEqual256   bool operator()(const PrimitiveEvalCacheKey &a, const PrimitiveEvalCacheKey &b) const {
257     if (!common::IsAttrsEqual(a.attrs, b.attrs)) {
258       return false;
259     }
260     return AbstractBasePtrListDeepEqual(a.args, b.args);
261   }
262 };
263 
264 class PrimitiveEvalCache {
265  public:
266   using EvalCache =
267     std::unordered_map<PrimitiveEvalCacheKey, EvalResultPtr, PrimitiveEvalCacheHash, PrimitiveEvalCacheEqual>;
268   using PrimToEvalCache = mindspore::HashMap<std::string, EvalCache>;
269   EvalResultPtr Get(const PrimitivePtr &prim, const AbstractBasePtrList &args) const;
270   void Put(const PrimitivePtr &prim, AttrValueMap &&attrs, const AbstractBasePtrList &args,
271            const EvalResultPtr &result);
272   void Clear();
273 
274  private:
275   mutable std::mutex mutex_;
276   PrimToEvalCache prim_cache_;
277 };
278 
279 using PrimitiveEvalCachePtr = std::shared_ptr<PrimitiveEvalCache>;
280 
281 class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
282  public:
AnalysisEngine(const PrimEvaluatorMap & prim_evaluator_map,const FuncGraphManagerPtr & func_graph_manager)283   AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
284       : prim_constructors_(prim_evaluator_map),
285         func_graph_manager_(func_graph_manager),
286         forward_count_(0),
287         enable_recursive_eval_(common::GetCompileConfig("RECURSIVE_EVAL") == "1"),
288         check_side_effect_(false) {}
289   virtual ~AnalysisEngine() = default;
290 
291   // func_graph: The func_graph to analyze.
292   // args_abs_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
293   AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_abs_list);
294   void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) const;
295   EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf);
296   // Evaluate a CNode without look up cache.
297   EvalResultPtr ObtainEvalResultWithoutCache(const AnfNodeConfigPtr &conf);
298   // Return the Evaluator for the given function.
299   EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &func);
300 
301   AnfNodeConfigPtr GetForwardConfig(const AnfNodeConfigPtr &conf) const;
302   EvalResultPtr InterpretedNodeCall(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
303   AbstractBasePtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
304                                            const FuncGraphPtr &func_graph);
305   AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) const;
306   EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
307   EvalResultPtr ConvertClassTypeToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf);
308   // Infer the result of fn(args).
309   EvalResultPtr Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_abs_list);
310   void Clear();
311   void ClearEvaluatorCache();
MakeConfig(const AnfNodePtr & node,const AnalysisContextPtr & context,const FuncGraphPtr & func_graph)312   AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context,
313                               const FuncGraphPtr &func_graph) {
314     return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context, func_graph);
315   }
316 
func_graph_manager()317   FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; }
anfnode_config_map()318   const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; }
319 
320   // Set the analysis result for orig to the result for new.
321   // This sets an entry in anfnode_config_map from orig to new.
322   EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf);
PrimConstructors()323   const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
324 
root_func_graph()325   FuncGraphPtr root_func_graph() const { return root_func_graph_.lock(); }
root_context()326   AnalysisContextPtr root_context() const { return root_context_; }
set_root_context(const AnalysisContextPtr & context)327   void set_root_context(const AnalysisContextPtr &context) { root_context_ = context; }
328 
329   mindspore::HashMap<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
330 
enable_recursive_eval()331   bool enable_recursive_eval() const { return enable_recursive_eval_; }
332   static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_abs_list, const AnfNodePtr &node);
333 
check_side_effect()334   bool check_side_effect() const { return check_side_effect_; }
set_check_side_effect(bool check_side_effect)335   void set_check_side_effect(bool check_side_effect) { check_side_effect_ = check_side_effect; }
336   void SetUndeterminedFlag(const std::string &thread_id, const FuncGraph &fg);
337   void SetIgnoreValueFlag(const std::string &thread_id, FuncGraph *fg);
338 
339  private:
340   // Overloaded function.
341   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func);
342   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &func);
343   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func);
344   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func);
345   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func);
346   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &func);
347   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &func);
348   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func);
349   EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> &func);
350 
351   EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
352                                      const AbstractBasePtrList &args_abs_list, const EvalTraceRevIter &it,
353                                      bool *continue_flag);
354 
355   const PrimEvaluatorMap &prim_constructors_;
356   FuncGraphManagerPtr func_graph_manager_;
357   std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> evaluators_;
358   // Record the func_graph which should be set as undetermined and the setting thread id.
359   mindspore::HashMap<const FuncGraph *, std::forward_list<std::string>> func_graph_undetermined_flags_;
360   std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
361     constructors_app_;
362 
363   AnfNodeConfigMap anfnode_config_map_;
364   // Use a list to trace multiple evaluators.
365   std::list<EvaluatorArgs> eval_trace_;
366   std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
367   std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> continued_evals_;
368   // Root or top func_graph for static analysis;
369   FuncGraphWeakPtr root_func_graph_;
370   AnalysisContextPtr root_context_{nullptr};
371 
372   AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
373                          const ConfigPtrList &args_conf_list);
374   EvalResultPtr Eval(const AnfNodeConfigPtr &conf);
375   EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
376                                   const ConfigPtrList &args_conf_list);
377   EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
378                                           const ConfigPtrList &args_conf_list);
379   EvalResultPtr ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
380                                                      const AnfNodeConfigPtr &out_conf,
381                                                      const ConfigPtrList &args_conf_list);
382 
383   std::atomic_long forward_count_;
384 
385   bool enable_recursive_eval_;
386 
387   bool check_side_effect_;
388 
389 #ifdef DEBUG
390   std::vector<AnfNodePtr> compute_conf_stack_;
391 #endif
392 };
393 
394 // Translate the value to an abstract value.
395 // Arguments:
396 // value:   The value to convert.
397 // context: The context in which the value was found, used if the value is a Graph.
398 // conf:     The Config to the valuenode we are converting, if there is one,
399 // so that we can generate a tracking_id.
400 AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr,
401                            const AnfNodeConfigPtr &conf = nullptr);
402 
403 // Convert a value to an abstract value.
404 // Arguments:
405 // v:       The value to convert.
406 // broaden: If True, concrete values will be made more abstract, so e.g.
407 // the value 1234 would become ANYTHING.
408 AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false);
409 
410 EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs);
411 
412 template <typename T>
413 AbstractBasePtr FromValue(const T &value, bool broaden = false) {
414   return FromValueInside(MakeValue(value), broaden);
415 }
416 EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine);
417 
418 EvalResultPtr ObtainEvalResultFromCache(const AnfNodeConfigPtr &conf);
419 // If the config of CNode(funcgraph/metafuncgraph) can be found in the cache, evaluation of the config of that CNode
420 // is not required, but the use flags of arguments should be synchronized as if the flags will be set when the
421 // evaluation is executed.
422 void SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
423                                                          const CNodePtr &cnode,
424                                                          const AbstractFunctionPtr &base_func_graph_func,
425                                                          const AnalysisContextPtr &fg_context);
426 }  // namespace abstract
427 }  // namespace mindspore
428 
429 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
430