• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_
21 
22 #include <memory>
23 #include <string>
24 #include <stdexcept>
25 #include <unordered_set>
26 #include <unordered_map>
27 #include <utility>
28 #include <vector>
29 
30 #include "ir/anf.h"
31 #include "ir/func_graph_cloner.h"
32 #include "pipeline/jit/static_analysis/evaluator.h"
33 
34 namespace mindspore {
35 namespace abstract {
36 enum SpecializeStatusCode {
37   kSpecializeSuccess = 0,
38   kSpecializeFindUniqueArgvalDead = 1,  // Dead Node
39   kSpecializeFindUniqueArgvalPoly = 2,  // Poly Node
40   kSpecializeFailure = 0xFF
41 };
42 
43 class FuncGraphSpecializer;
44 using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>;
45 
46 // Specialize a func graph using analyzed abstract values.
47 class ProgramSpecializer {
48  public:
ProgramSpecializer(const std::shared_ptr<AnalysisEngine> & engine)49   explicit ProgramSpecializer(const std::shared_ptr<AnalysisEngine> &engine) : engine_(engine), top_context_(nullptr) {
50     mng_ = engine_->func_graph_manager();
51   }
52   ~ProgramSpecializer() = default;
53   // Run the program specializer on the topmost graph in the given context.
54   FuncGraphPtr Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context);
seen()55   const std::unordered_set<AnfNodePtr> &seen() const { return seen_; }
AddSeen(const AnfNodePtr & node)56   void AddSeen(const AnfNodePtr &node) { (void)seen_.insert(node); }
57 
58   std::shared_ptr<FuncGraphSpecializer> GetFuncGraphSpecializer(const AnalysisContextPtr &context);
59   // Specialze one FuncGraph in a given context.
60   FuncGraphPtr SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context);
61 
engine()62   std::shared_ptr<AnalysisEngine> engine() { return engine_; }
63 
top_context()64   AnalysisContextPtr top_context() { return top_context_; }
65 
66  private:
67   std::shared_ptr<AnalysisEngine> engine_;
68   std::unordered_set<AnfNodePtr> seen_;
69   FuncGraphManagerPtr mng_;
70   std::unordered_map<AnalysisContextPtr, std::shared_ptr<FuncGraphSpecializer>, ContextHasher, ContextEqual>
71     specializations_;
72   AnalysisContextPtr top_context_;
73 };
74 
75 class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> {
76  public:
77   FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context);
~FuncGraphSpecializer()78   virtual ~FuncGraphSpecializer() {
79     specializer_ = nullptr;
80     repl_node_ = nullptr;
81   }
82   void Run();
specialized_func_graph()83   FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; }
84 
85   std::shared_ptr<FuncGraphSpecializer> GetTopSpecializer(const AnfNodePtr &node);
86 
87  private:
88   ProgramSpecializer *specializer_;
89   FuncGraphPtr func_graph_;
90   FuncGraphPtr specialized_func_graph_;
91   AnalysisContextPtr context_;
92   std::shared_ptr<FuncGraphSpecializer> parent_;
93   std::shared_ptr<AnalysisEngine> engine_;
94   ClonerPtr cloner_;
95   // ProcessNode-> [cloner_->CloneDisconnected] will clone AnfNode again.
96   // So, repl_node_ should pointer to GraphCloner->repl_node_ other than a copy of that.
97   std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_node_;
98   std::vector<AnfNodePtr> todo_;
99   std::unordered_set<AnfNodePtr> marked_;
100   std::unordered_map<EvaluatorPtr, EvaluatorCacheMgrPtr> evalcaches_;
101 
102   void FirstPass();
103   void SecondPass();
104   void ProcessNode(const AnfNodePtr &node);
105   void ProcessCNode(const CNodePtr &new_node);
106 
107   inline AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node);
108   inline AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const BaseFuncGraphEvaluatorPtr &evaluator,
109                                         const AbstractBasePtrList &args_spec_list);
110 
AddTodoItem(const AnfNodePtr & node)111   inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); }
112   // Get node replicated by Cloner.
113   AnfNodePtr GetReplicatedNode(const AnfNodePtr &node);
114   // Replicated node which is not used directly by a func graph, so it's not searchable from it's return node
115   // (disconnected).
116   AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node);
117 
118   // Build a value node from parameter if the function graph has special flag to hint it can be done.
119   AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node);
120 
121   // Build a value node if ival is constant and not any-value
122   AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
123                                     const AttrValueMapPtr &attrs);
124   // Build a replaceable node for iconf->node; it may be a replicated forwarded CNode in static analysis or just a
125   // replicated node.
126   AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);
127   // Build a specialized node from given argvals;
128   AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
129                                   const AbstractBasePtrList &argvals);
130   AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
131                                        const AbstractFunctionPtr &func, const AbstractBasePtrList &args,
132                                        SpecializeStatusCode *errcode);
133 
134   // Find the unique argument values which can be used to specialize a primitive or graph function.
135   SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval,
136                                          const AbstractBasePtrList &argvals,
137                                          std::pair<AbstractBasePtrList, AbstractBasePtr> *result);
138   // Get cache, it may be eval's cache or cache built from broaded argument values.
139   const EvaluatorCacheMgrPtr GetEvalCache(const EvaluatorPtr &eval);
140   // Try to build unique argvals from the broaded arg vals if it is unique.
141   std::pair<AbstractBasePtrList, AbstractBasePtr> BuildFromBroadedArgsVal(const EvaluatorPtr &eval);
142   void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node);
143 };
144 }  // namespace abstract
145 }  // namespace mindspore
146 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_
147