• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_PRIM_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_PRIM_H_
21 
22 #include <algorithm>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 
27 #include "utils/hash_map.h"
28 #include "pipeline/jit/ps/static_analysis/evaluator.h"
29 #include "abstract/ops/primitive_infer_map.h"
30 #include "ops/op_def.h"
31 #include "ops/ops_frontend_func_impl.h"
32 
33 namespace mindspore {
34 namespace abstract {
35 class PrimitiveFunctionEvaluator final : public TrivialPrimEvaluator {
36  public:
37   explicit PrimitiveFunctionEvaluator(const PrimitivePtr &primitive);
38   ~PrimitiveFunctionEvaluator() override = default;
39   MS_DECLARE_PARENT(PrimitiveFunctionEvaluator, TrivialPrimEvaluator);
40   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
ToString()41   std::string ToString() const override { return identifier_ + "_PrimitiveFunction_" + prim_func_->name(); }
42 
43  protected:
inplace_prim()44   bool inplace_prim() const override { return prim_func_->inplace_prim(); }
45 
46  private:
47   AbstractBasePtr CheckAndInfer(const AbstractBasePtrList &args);
48   void CheckArgsSizeAndType(const AbstractBasePtrList &args);
49   PrimitivePtr prim_func_;
50   mindspore::ops::OpDefPtr op_def_{nullptr};
51   mindspore::ops::OpFrontendFuncImplPtr frontend_func_impl_{nullptr};
52 };
53 
54 class StandardPrimEvaluator final : public TrivialPrimEvaluator {
55  public:
StandardPrimEvaluator(const PrimitivePtr & primitive,const StandardPrimitiveImplReg & eval_impl)56   StandardPrimEvaluator(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &eval_impl)
57       : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {}
StandardPrimEvaluator(const PrimitivePtr & primitive)58   explicit StandardPrimEvaluator(const PrimitivePtr &primitive)
59       : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive) {}
60   ~StandardPrimEvaluator() override = default;
61   MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator);
62   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
prim()63   PrimitivePtr prim() { return prim_; }
64 
ToString()65   std::string ToString() const override { return identifier_ + "_" + prim_->name(); }
66 
67  protected:
inplace_prim()68   bool inplace_prim() const override { return prim_->inplace_prim(); }
69 
70  private:
71   EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args);
72   EvalResultPtr RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base,
73                                 const AbstractBasePtrList &args);
74   PrimitivePtr prim_;
75   const StandardPrimitiveImplReg eval_impl_;
76 };
77 
78 using StandardPrimEvaluatorPtr = std::shared_ptr<StandardPrimEvaluator>;
79 
80 class PythonPrimEvaluator final : public TrivialPrimEvaluator {
81  public:
PythonPrimEvaluator(const PrimitivePyPtr primitive)82   explicit PythonPrimEvaluator(const PrimitivePyPtr primitive)
83       : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {}
84   ~PythonPrimEvaluator() override = default;
85   MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator);
86   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
prim()87   PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); }
88 
ToString()89   std::string ToString() const override { return identifier_ + "_" + prim_py_->name(); }
90 
91  protected:
inplace_prim()92   bool inplace_prim() const override { return dyn_cast<Primitive>(prim_py_)->inplace_prim(); }
93 
94  private:
95   PrimitivePyPtr prim_py_;
96 };
97 
98 using ValuePtrList = std::vector<ValuePtr>;
99 using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &);
100 
101 class UniformPrimEvaluator final : public TrivialPrimEvaluator {
102  public:
UniformPrimEvaluator(const FunctionPtr func_desc,PrimitiveImpl impl,bool eval_value,const TypePtr specify_out_type)103   UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type)
104       : TrivialPrimEvaluator("UniformPrimEvaluator"),
105         impl_(impl),
106         eval_value_(eval_value),
107         func_desc_(func_desc),
108         nargs_(func_desc_->args().size()),
109         return_value_type_(func_desc_->retval()),
110         specify_out_type_(specify_out_type) {
111     for (size_t i = 0; i < nargs_; ++i) {
112       const TypePtr &type = func_desc_->args()[i];
113       type_map_[type].push_back(i);
114     }
115   }
~UniformPrimEvaluator()116   ~UniformPrimEvaluator() override { impl_ = nullptr; };
117   MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);
118 
119   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) override;
120   ValuePtr RunImpl(const ValuePtrList &args) const;
121 
122   // If eval_value_ is False, return broadened arguments.
NormalizeArgs(const AbstractBasePtrList & args_abs_list)123   AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_abs_list) const override {
124     if (!eval_value_) {
125       AbstractBasePtrList broadened_args_abs_list;
126       (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(broadened_args_abs_list),
127                            [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
128       return broadened_args_abs_list;
129     }
130     return args_abs_list;
131   }
132 
133  protected:
inplace_prim()134   bool inplace_prim() const override { return false; }
135 
136  private:
137   PrimitiveImpl impl_;
138   bool eval_value_;
139   const FunctionPtr func_desc_;
140   const std::size_t nargs_;
141   const TypePtr return_value_type_;
142   const TypePtr specify_out_type_;
143   mindspore::HashMap<TypePtr, std::vector<size_t>, TypeHashById, TypeEqualById> type_map_;
144 };
145 
146 class DoSignatureEvaluator final : public Evaluator {
147  public:
DoSignatureEvaluator(const PrimitivePtr primitive)148   explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
149   ~DoSignatureEvaluator() override = default;
150   MS_DECLARE_PARENT(DoSignatureEvaluator, Evaluator);
151   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
152                     const AnfNodeConfigPtr &out_conf) override;
153 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)154   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
155     MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called";
156   }
157 
158  private:
159   PrimitivePtr prim_;
160   CNodePtr GenerateNewNodeBySignatures(const ValuePtr &func, const AbstractBasePtrList &args_abs_list,
161                                        const AnalysisEnginePtr &engine, const AnfNodeConfigPtr &out_conf);
162 };
163 
164 class UnpackGraphEvaluator final : public Evaluator {
165  public:
UnpackGraphEvaluator(const PrimitivePtr primitive)166   explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
167   ~UnpackGraphEvaluator() override = default;
168   MS_DECLARE_PARENT(UnpackGraphEvaluator, Evaluator);
169   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
170                     const AnfNodeConfigPtr &out_conf) override;
171 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)172   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
173     MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called";
174   }
175 
176  private:
177   PrimitivePtr prim_;
178 };
179 
180 class MixedPrecisionCastEvaluator final : public Evaluator {
181  public:
MixedPrecisionCastEvaluator(const PrimitivePtr primitive)182   explicit MixedPrecisionCastEvaluator(const PrimitivePtr primitive)
183       : Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {}
184   ~MixedPrecisionCastEvaluator() override = default;
185   MS_DECLARE_PARENT(MixedPrecisionCastEvaluator, Evaluator);
186   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
187                     const AnfNodeConfigPtr &out_conf) override;
188 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)189   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
190     MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called";
191   }
192 
193  private:
194   PrimitivePtr prim_;
195 };
196 
197 class SwitchEvaluator final : public Evaluator {
198  public:
SwitchEvaluator()199   SwitchEvaluator() : Evaluator("SwitchEvaluator") {}
200   ~SwitchEvaluator() override = default;
201   MS_DECLARE_PARENT(SwitchEvaluator, Evaluator);
202   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
203                     const AnfNodeConfigPtr &out_conf) override;
204 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)205   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
206     MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called";
207   }
208 };
209 
210 class SwitchLayerEvaluator final : public Evaluator {
211  public:
SwitchLayerEvaluator()212   SwitchLayerEvaluator() : Evaluator("SwitchLayerEvaluator") {}
213   ~SwitchLayerEvaluator() override = default;
214   MS_DECLARE_PARENT(SwitchLayerEvaluator, Evaluator);
215   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
216                     const AnfNodeConfigPtr &out_conf) override;
217 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)218   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
219     MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called";
220   }
221 };
222 
223 class PrimitiveArgsToInputsEvaluator : public TransitionPrimEvaluator {
224  public:
PrimitiveArgsToInputsEvaluator(const PrimitivePtr primitive)225   explicit PrimitiveArgsToInputsEvaluator(const PrimitivePtr primitive)
226       : TransitionPrimEvaluator("PrimitiveArgsToInputsEvaluator"), prim_(primitive) {}
227   ~PrimitiveArgsToInputsEvaluator() override = default;
228   MS_DECLARE_PARENT(PrimitiveArgsToInputsEvaluator, TransitionPrimEvaluator)
229   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
230                          const AnfNodeConfigPtr &out_conf) override;
231 
232  private:
233   PrimitivePtr prim_;
234 };
235 
236 class DoTransPrimitiveFunctionEvaluator : public TransitionPrimEvaluator {
237  public:
DoTransPrimitiveFunctionEvaluator(const PrimitivePtr primitive)238   explicit DoTransPrimitiveFunctionEvaluator(const PrimitivePtr primitive)
239       : TransitionPrimEvaluator("DoTransPrimitiveFunctionEvaluator"), prim_(primitive) {}
240   ~DoTransPrimitiveFunctionEvaluator() override = default;
241   MS_DECLARE_PARENT(DoTransPrimitiveFunctionEvaluator, TransitionPrimEvaluator)
242   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
243                          const AnfNodeConfigPtr &out_conf) override;
244 
245  private:
246   PrimitivePtr prim_;
247 };
248 
249 class PartialToEndEvaluator : public TransitionPrimEvaluator {
250  public:
PartialToEndEvaluator(const AbstractFunctionPtr & primal_func)251   explicit PartialToEndEvaluator(const AbstractFunctionPtr &primal_func)
252       : TransitionPrimEvaluator("PartialToEndEvaluator"), primal_func_(primal_func) {}
253   ~PartialToEndEvaluator() override = default;
254   MS_DECLARE_PARENT(PartialToEndEvaluator, TransitionPrimEvaluator);
255   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
256                          const AnfNodeConfigPtr &out_conf) override;
257 
258  private:
259   AbstractFunctionPtr primal_func_;
260 };
261 
262 class ConstexprEvaluator : public TransitionPrimEvaluator {
263  public:
ConstexprEvaluator(const PrimitivePyPtr primitive)264   explicit ConstexprEvaluator(const PrimitivePyPtr primitive)
265       : TransitionPrimEvaluator("ConstexprEvaluator"), prim_py_(primitive) {}
266   ~ConstexprEvaluator() override = default;
267   MS_DECLARE_PARENT(ConstexprEvaluator, TransitionPrimEvaluator)
268   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
269                          const AnfNodeConfigPtr &out_conf) override;
270 
271  private:
272   PrimitivePyPtr prim_py_;
273 };
274 
275 class MakeTupleEvaluator : public TransitionPrimEvaluator {
276  public:
MakeTupleEvaluator()277   MakeTupleEvaluator() : TransitionPrimEvaluator("MakeTupleEvaluator") {}
278   ~MakeTupleEvaluator() override = default;
279   MS_DECLARE_PARENT(MakeTupleEvaluator, TransitionPrimEvaluator);
280   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
281                          const AnfNodeConfigPtr &out_conf) override;
282 };
283 
284 class MakeListEvaluator : public TransitionPrimEvaluator {
285  public:
MakeListEvaluator()286   MakeListEvaluator() : TransitionPrimEvaluator("MakeListEvaluator") {}
287   ~MakeListEvaluator() override = default;
288   MS_DECLARE_PARENT(MakeListEvaluator, TransitionPrimEvaluator);
289   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
290                          const AnfNodeConfigPtr &out_conf) override;
291 };
292 
293 class PyExecuteEvaluator : public TransitionPrimEvaluator {
294  public:
PyExecuteEvaluator()295   PyExecuteEvaluator() : TransitionPrimEvaluator("PyExecuteEvaluator") {}
296   ~PyExecuteEvaluator() override = default;
297   MS_DECLARE_PARENT(PyExecuteEvaluator, TransitionPrimEvaluator);
298   EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
299                          const AnfNodeConfigPtr &out_conf) override;
300 };
301 
302 bool IsInWhiteList(const PrimitivePtr &primitive);
303 
304 PrimEvaluatorMap &GetPrimEvaluatorConstructors();
305 
306 // Check whether type x is a subtype of model.
307 bool IsSubtype(const AbstractBasePtr x, const TypePtr model);
308 
309 void ClearPrimEvaluatorMap();
310 
311 py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value = false);
312 py::tuple PreparePyInputs(const AbstractBasePtrList &args);
313 AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output);
314 
315 // Get the __init__() arguments of the PrimitivePy object.
316 AnfNodePtrList GetPrimitiveInitArgs(const PrimitivePyPtr &prim_py, const ops::OpDef *op_def);
317 
318 // Process the primitive's arguments (such as dtype auto-cast, add argument with default-value...),
319 // then generate the primitive CNode and add it to graph.
320 // (The returned CNode is without abstract, need to evaluate its abstract manually).
321 CNodePtr GeneratePrimitiveCNode(const PrimitivePtr &primitive, const ops::OpDef *op_def, const FuncGraphPtr &graph,
322                                 const AnfNodePtrList &init_args_nodes, const AnfNodePtrList &call_args_nodes,
323                                 const std::function<AbstractBasePtr(const AnfNodePtr &)> &eval_func);
324 }  // namespace abstract
325 }  // namespace mindspore
326 
327 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_PRIM_H_
328