• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/operator/ops_front_infer_function.h"
17 
18 #include <set>
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <algorithm>
23 #include <map>
24 
25 #include "mindspore/core/ops/structure_ops.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "abstract/abstract_value.h"
31 #include "pipeline/jit/ps/parse/resolve.h"
32 #include "pipeline/jit/ps/static_analysis/prim.h"
33 #include "pipeline/jit/ps/fallback.h"
34 #include "abstract/param_validator.h"
35 #include "pybind_api/ir/tensor_py.h"
36 #include "frontend/operator/ops.h"
37 #include "abstract/ops/infer_functions.h"
38 #include "include/common/utils/convert_utils_py.h"
39 #include "include/common/utils/utils.h"
40 #include "ops/auto_generate/gen_ops_primitive.h"
41 #include "ops/ops_func_impl/greater_equal.h"
42 #include "ops/ops_func_impl/greater.h"
43 #include "ops/mod.h"
44 #include "ops/strided_slice_v2.h"
45 #include "ops/grad/strided_slice_v2_grad.h"
46 #include "abstract/abstract_function.h"
47 #include "utils/ms_context.h"
48 #include "ops/op_name.h"
49 #ifdef _MSC_VER
50 #include "include/common/pybind_api/api_register.h"
51 #endif
52 
53 namespace mindspore {
54 namespace abstract {
55 enum class State {
56   SAME,
57   X_ONE,
58   Y_ONE,
59 };
60 
ComputeReduceIndex(const std::vector<int64_t> & reverse_x,const std::vector<int64_t> & reverse_y,std::vector<int64_t> * grad_x_reduce_idx,std::vector<int64_t> * grad_y_reduce_idy)61 void ComputeReduceIndex(const std::vector<int64_t> &reverse_x, const std::vector<int64_t> &reverse_y,
62                         std::vector<int64_t> *grad_x_reduce_idx, std::vector<int64_t> *grad_y_reduce_idy) {
63   MS_EXCEPTION_IF_NULL(grad_x_reduce_idx);
64   MS_EXCEPTION_IF_NULL(grad_y_reduce_idy);
65   const size_t n = reverse_x.size();
66   if (reverse_y.size() < n) {
67     MS_LOG(EXCEPTION) << "The size of reverse_y is less than the size of reverse_x.";
68   }
69   for (size_t i = 0; i < n; ++i) {
70     State curr;
71     const int64_t x_i = reverse_x[i];
72     const int64_t y_i = reverse_y[i];
73     const int64_t reduce_idx = SizeToLong(n - 1 - i);
74     if (x_i == y_i) {
75       curr = State::SAME;
76     } else if (x_i == 1) {
77       grad_x_reduce_idx->push_back(reduce_idx);
78       curr = State::X_ONE;
79     } else if (y_i == 1) {
80       grad_y_reduce_idy->push_back(reduce_idx);
81       curr = State::Y_ONE;
82     } else {
83       MS_LOG(EXCEPTION) << "Not compatible shape input for BroadcastGradientArgs.";
84     }
85     if (curr == State::SAME && x_i == 1) {
86       grad_x_reduce_idx->push_back(reduce_idx);
87       grad_y_reduce_idy->push_back(reduce_idx);
88       continue;
89     }
90   }
91 
92   std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end());
93   std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end());
94 }
95 
BroadcastGradientArgsDiff(const std::vector<ValuePtr> & x_shape,const std::vector<ValuePtr> & y_shape)96 AbstractBasePtr BroadcastGradientArgsDiff(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
97   std::vector<int64_t> reverse_x;
98   std::vector<int64_t> reverse_y;
99 
100   (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x),
101                        [](const ValuePtr &v) { return v->cast<Int64ImmPtr>()->value(); });
102   (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y),
103                        [](const ValuePtr &v) { return v->cast<Int64ImmPtr>()->value(); });
104 
105   if (reverse_x.size() > reverse_y.size()) {
106     reverse_y.resize(reverse_x.size(), 1);
107   } else {
108     reverse_x.resize(reverse_y.size(), 1);
109   }
110 
111   std::vector<int64_t> grad_x_reduce_idx;
112   std::vector<int64_t> grad_y_reduce_idy;
113   ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy);
114 
115   AbstractBasePtrList abs_list_x;
116   AbstractBasePtrList abs_list_y;
117   (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x),
118                        [](int64_t v) { return abstract::FromValue(v); });
119   (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y),
120                        [](int64_t v) { return abstract::FromValue(v); });
121   auto x_reduce_idx = std::make_shared<AbstractTuple>(abs_list_x);
122   auto y_reduce_idx = std::make_shared<AbstractTuple>(abs_list_y);
123   AbstractBasePtrList elem_list;
124   elem_list.push_back(x_reduce_idx);
125   elem_list.push_back(y_reduce_idx);
126 
127   return std::make_shared<AbstractTuple>(elem_list);
128 }
129 
InferImplTypeof(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)130 AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
131                                 const AbstractBasePtrList &args_abs_list) {
132   // Inputs: a pointer to an AbstractBase object
133   if (args_abs_list.size() != 1) {
134     MS_LOG(EXCEPTION) << "The Typeof operator must requires 1 argument, but the size of arguments is "
135                       << args_abs_list.size() << ".";
136   }
137   AbstractBasePtr abs_base = args_abs_list[0];
138   MS_EXCEPTION_IF_NULL(abs_base);
139   TypePtr type = abs_base->BuildType();
140   return std::make_shared<AbstractType>(type);
141 }
142 
InferImplTopTypeof(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)143 AbstractBasePtr InferImplTopTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
144                                    const AbstractBasePtrList &args_abs_list) {
145   // Inputs: a pointer to an AbstractBase object
146   if (args_abs_list.size() != 1) {
147     MS_LOG(EXCEPTION) << "The Typeof operator must requires 1 argument, but the size of arguments is "
148                       << args_abs_list.size() << ".";
149   }
150   AbstractBasePtr abs_base = args_abs_list[0];
151   MS_EXCEPTION_IF_NULL(abs_base);
152   TypeId type_id = abs_base->BuildType()->type_id();
153   return std::make_shared<AbstractType>(TypeIdToType(type_id));
154 }
155 
InferImplStringUpper(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)156 AbstractBasePtr InferImplStringUpper(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
157                                      const AbstractBasePtrList &args_abs_list) {
158   MS_EXCEPTION_IF_NULL(primitive);
159   if (args_abs_list.size() != 1) {
160     MS_LOG(INTERNAL_EXCEPTION) << "StringUpper takes 1 argument, but got " << args_abs_list.size();
161   }
162   constexpr size_t index_str = 0;
163   auto abs_str = args_abs_list[index_str];
164   MS_EXCEPTION_IF_NULL(abs_str);
165   auto value_str = abs_str->BuildValue();
166   MS_EXCEPTION_IF_NULL(value_str);
167   if (!value_str->isa<StringImm>()) {
168     MS_INTERNAL_EXCEPTION(TypeError) << "StringUpper expected to get a string as input, but got:"
169                                      << value_str->ToString();
170   }
171   auto str = value_str->cast<StringImmPtr>()->value();
172   (void)std::transform(str.begin(), str.end(), str.begin(), ::toupper);
173   auto new_str = MakeValue(str);
174   return new_str->ToAbstract();
175 }
176 
InferImplStringLower(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)177 AbstractBasePtr InferImplStringLower(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
178                                      const AbstractBasePtrList &args_abs_list) {
179   MS_EXCEPTION_IF_NULL(primitive);
180   if (args_abs_list.size() != 1) {
181     MS_LOG(EXCEPTION) << "StringLower takes 1 argument, but got " << args_abs_list.size();
182   }
183   constexpr size_t index_str = 0;
184   auto abs_str = args_abs_list[index_str];
185   MS_EXCEPTION_IF_NULL(abs_str);
186   auto value_str = abs_str->BuildValue();
187   MS_EXCEPTION_IF_NULL(value_str);
188   if (!value_str->isa<StringImm>()) {
189     MS_INTERNAL_EXCEPTION(TypeError) << "StringLower expected to get a string as input, but got:"
190                                      << value_str->ToString();
191   }
192   auto str = value_str->cast<StringImmPtr>()->value();
193   (void)std::transform(str.begin(), str.end(), str.begin(), ::tolower);
194   auto new_str = MakeValue(str);
195   return new_str->ToAbstract();
196 }
197 
InferImplHasType(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)198 AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
199                                  const AbstractBasePtrList &args_abs_list) {
200   MS_EXCEPTION_IF_NULL(primitive);
201   // Inputs: a pointer to an AbstractBase object and a pointer to a Type
202   const std::string op_name = primitive->name();
203   const size_t args_num = 2;
204   CheckArgsSize(op_name, args_abs_list, args_num);
205   AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_abs_list, 1);
206   MS_EXCEPTION_IF_NULL(abs_type);
207   auto mode_v = abs_type->GetValueTrack();
208   MS_EXCEPTION_IF_NULL(mode_v);
209   if (!mode_v->isa<Type>()) {
210     MS_LOG(INTERNAL_EXCEPTION) << "Get the type from AbstractType value failed.";
211   }
212 
213   auto tmpMode = mode_v->cast<TypePtr>();
214   MS_EXCEPTION_IF_NULL(args_abs_list[0]);
215   bool v = IsSubtype(args_abs_list[0], tmpMode);
216   return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
217 }
218 
IsAdapterTensor(const AbstractBasePtr & x)219 bool IsAdapterTensor(const AbstractBasePtr &x) {
220   if (!x->isa<abstract::AbstractTensor>()) {
221     return false;
222   }
223   return x->cast<abstract::AbstractTensorPtr>()->is_adapter();
224 }
225 
CheckIsInstanceForAdapter(const AbstractBasePtr & x,const AbstractBasePtr & cmp)226 bool CheckIsInstanceForAdapter(const AbstractBasePtr &x, const AbstractBasePtr &cmp) {
227   if (cmp->isa<abstract::AbstractTuple>()) {
228     const auto &elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
229     return std::any_of(elements.begin(), elements.end(),
230                        [=](const AbstractBasePtr &element) { return CheckIsInstanceForAdapter(x, element); });
231   }
232   auto cmp_value = cmp->BuildValue();
233   MS_EXCEPTION_IF_NULL(cmp_value);
234   if (cmp_value->isa<parse::ClassType>()) {
235     auto class_obj = cmp_value->cast<parse::ClassTypePtr>()->obj();
236     py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
237     // isinstance(tensor_x, Tensor) -> true, isinstance(tensor_x, Parameter) -> false.
238     // isinstance(parameter_x, Tensor) -> true, isinstance(parameter_x, Parameter) -> true.
239     bool is_cmp_tensor =
240       python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_IS_ADAPTER_TENSOR_CLASS, class_obj).cast<bool>();
241     if (is_cmp_tensor) {
242       return true;
243     }
244     bool is_x_parameter = x->isa<abstract::AbstractRefTensor>();
245     bool is_cmp_parameter =
246       python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_IS_ADAPTER_PARAMETER_CLASS, class_obj).cast<bool>();
247     return is_x_parameter && is_cmp_parameter;
248   }
249   return false;
250 }
251 
CheckPythonIsInstance(const py::object & x,const AbstractBasePtr & cmp,const py::module & mod,bool is_const)252 bool CheckPythonIsInstance(const py::object &x, const AbstractBasePtr &cmp, const py::module &mod, bool is_const) {
253   if (cmp->isa<abstract::AbstractTuple>()) {
254     const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
255     return std::any_of(cmp_tuple_elements.begin(), cmp_tuple_elements.end(),
256                        [&x, &mod, is_const](const AbstractBasePtr &element) {
257                          return CheckPythonIsInstance(x, element, mod, is_const);
258                        });
259   }
260   if (std::find(kSparsePrimStr.begin(), kSparsePrimStr.end(), cmp->ToString()) != kSparsePrimStr.end()) {
261     return false;
262   }
263 
264   py::object cmp_type;
265   if (cmp->isa<abstract::PartialAbstractClosure>()) {
266     const auto &cmp_closure_args = cmp->cast<abstract::PartialAbstractClosurePtr>()->args();
267     // CheckCmpValid ensures size of cmp_closure_args to be 1.
268     auto cmp_closure_first_input = cmp_closure_args[0];
269     cmp_type = ValueToPyData(cmp_closure_first_input->BuildValue());
270   } else {
271     auto cmp_value = cmp->BuildValue();
272     if (cmp_value->ContainsValueAny()) {
273       return false;
274     }
275     cmp_type = ValueToPyData(cmp_value);
276   }
277 
278   py::object result = is_const ? python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_PYTHON_ISINSTANCE, x, cmp_type)
279                                : python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_MS_ISINSTANCE, x, cmp_type);
280   return result.cast<bool>();
281 }
282 
CheckIsInstanceForFunc(const py::object & x_py_obj,const AbstractBasePtr & cmp,const py::module & mod)283 bool CheckIsInstanceForFunc(const py::object &x_py_obj, const AbstractBasePtr &cmp, const py::module &mod) {
284   MS_EXCEPTION_IF_NULL(cmp);
285   if (cmp->isa<abstract::AbstractTuple>()) {
286     const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
287     return std::any_of(
288       cmp_tuple_elements.begin(), cmp_tuple_elements.end(),
289       [&x_py_obj, &mod](const AbstractBasePtr &element) { return CheckIsInstanceForFunc(x_py_obj, element, mod); });
290   }
291 
292   if (!cmp->isa<abstract::PartialAbstractClosure>()) {
293     return false;
294   }
295   const auto &cmp_closure_args = cmp->cast<abstract::PartialAbstractClosurePtr>()->args();
296   // CheckCmpValid ensures size of cmp_closure_args to be 1.
297   auto cmp_closure_first_input = cmp_closure_args[0];
298   auto cmp_py_obj = ValueToPyData(cmp_closure_first_input->BuildValue());
299   auto result = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_PYTHON_ISINSTANCE, x_py_obj, cmp_py_obj);
300   return result.cast<bool>();
301 }
302 
CheckIsInstanceForSparse(const AbstractBasePtr & cmp,const std::string & target)303 bool CheckIsInstanceForSparse(const AbstractBasePtr &cmp, const std::string &target) {
304   MS_EXCEPTION_IF_NULL(cmp);
305   if (!cmp->isa<abstract::AbstractTuple>()) {
306     return cmp->ToString() == target;
307   }
308   const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
309   return std::any_of(cmp_tuple_elements.begin(), cmp_tuple_elements.end(),
310                      [&target](const AbstractBasePtr &element) { return CheckIsInstanceForSparse(element, target); });
311 }
312 
GetPrimitivePyObj(const abstract::PrimitiveAbstractClosurePtr & prim_abs)313 py::object GetPrimitivePyObj(const abstract::PrimitiveAbstractClosurePtr &prim_abs) {
314   MS_EXCEPTION_IF_NULL(prim_abs);
315   auto prim = prim_abs->prim();
316   MS_EXCEPTION_IF_NULL(prim);
317   auto prim_signature = prim->cast<prim::DoSignaturePrimitivePtr>();
318   MS_EXCEPTION_IF_NULL(prim_signature);
319   auto function = prim_signature->function();
320   MS_EXCEPTION_IF_NULL(function);
321   auto primitive_py_function = function->cast<PrimitivePyPtr>();
322   return primitive_py_function->GetPyObj();
323 }
324 
GetMsClassPyObj(const abstract::PartialAbstractClosurePtr & ms_class_abs)325 py::object GetMsClassPyObj(const abstract::PartialAbstractClosurePtr &ms_class_abs) {
326   MS_EXCEPTION_IF_NULL(ms_class_abs);
327   const auto &ms_class_args = ms_class_abs->args();
328   if (ms_class_args.size() != 1) {
329     MS_LOG(INTERNAL_EXCEPTION)
330       << "When the first input to IsInstance is PartialAbstractClosure, its args size should be 1 but "
331       << "got: " << ms_class_args.size() << ".";
332   }
333   auto first_arg = ms_class_args[0];
334   auto class_value = first_arg->BuildValue();
335   MS_EXCEPTION_IF_NULL(class_value);
336   return ValueToPyData(class_value);
337 }
338 
CheckCmpValid(const AbstractBasePtr & cmp)339 bool CheckCmpValid(const AbstractBasePtr &cmp) {
340   MS_EXCEPTION_IF_NULL(cmp);
341   if (cmp->isa<abstract::AbstractSequence>()) {
342     if (!cmp->isa<abstract::AbstractTuple>()) {
343       return false;
344     }
345     const auto &elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
346     return std::all_of(elements.begin(), elements.end(),
347                        [](const AbstractBasePtr &element) { return CheckCmpValid(element); });
348   }
349   if (cmp->isa<abstract::AbstractScalar>()) {
350     auto cmp_type = cmp->BuildType();
351     MS_EXCEPTION_IF_NULL(cmp_type);
352     return cmp_type->type_id() == kMetaTypeTypeType;
353   } else if (cmp->isa<abstract::PartialAbstractClosure>()) {
354     auto cmp_closure = cmp->cast<abstract::PartialAbstractClosurePtr>();
355     const auto &cmp_closure_args = cmp_closure->args();
356     if (cmp_closure_args.size() != 1) {
357       return false;
358     }
359     auto cmp_closure_first_input = cmp_closure_args[0];
360     auto cmp_type = cmp_closure_first_input->BuildType();
361     MS_EXCEPTION_IF_NULL(cmp_type);
362     auto cmp_type_id = cmp_type->type_id();
363     if (cmp_type_id == kObjectTypeClass) {
364       // When cmp type is ms_class, fn should be create_instance.
365       auto cmp_closure_fn = cmp_closure->fn();
366       MS_EXCEPTION_IF_NULL(cmp_closure_fn);
367       const std::string ms_class_type_fn_name = "PrimitiveAbstractClosure: create_instance";
368       return cmp_closure_fn->ToString() == ms_class_type_fn_name;
369     }
370     return cmp_type_id == kMetaTypeTypeType;
371   } else if (cmp->isa<abstract::AbstractAny>()) {
372     return true;
373   }
374   return std::find(kSparsePrimStr.cbegin(), kSparsePrimStr.cend(), cmp->ToString()) != kSparsePrimStr.cend();
375 }
376 
InferImplIsInstance(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)377 AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
378                                     const AbstractBasePtrList &args_abs_list) {
379   MS_EXCEPTION_IF_NULL(primitive);
380   constexpr size_t args_num = 2;
381   CheckArgsSize(primitive->name(), args_abs_list, args_num);
382   py::gil_scoped_acquire gil;
383   py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
384   auto x = args_abs_list[0];
385   MS_EXCEPTION_IF_NULL(x);
386   auto cmp = args_abs_list[1];
387   MS_EXCEPTION_IF_NULL(cmp);
388 
389   if (!CheckCmpValid(cmp)) {
390     auto cmp_type = cmp->BuildType();
391     MS_EXCEPTION_IF_NULL(cmp_type);
392     MS_LOG(ERROR) << "cmp: " << cmp->ToString() << ", cmp_type: " << cmp_type->ToString()
393                   << ", cmp_type_id: " << TypeIdToType(cmp_type->type_id());
394     MS_EXCEPTION(TypeError) << "isinstance() arg 2 must be a type or tuple of types.";
395   }
396 
397   // If x is AbstractAny the result of isinstance can not determined in frontend,
398   // isinstance should be converted to pyexecute later.
399   // So we set the abstract of instance to variable boolean scalar.
400   if (x->isa<abstract::AbstractAny>()) {
401     return std::make_shared<AbstractScalar>(kValueAny, kBool);
402   }
403 
404   MS_EXCEPTION_IF_NULL(x);
405   bool result = false;
406   if (x->isa<abstract::FuncGraphAbstractClosure>()) {
407     // x is Cell object.
408     auto x_fg = x->cast<abstract::FuncGraphAbstractClosurePtr>()->func_graph();
409     MS_EXCEPTION_IF_NULL(x_fg);
410     auto wrapper_obj = x_fg->python_obj();
411     if (wrapper_obj != nullptr) {
412       if (!wrapper_obj->isa<parse::PyObjectWrapper>()) {
413         MS_LOG(INTERNAL_EXCEPTION) << "The wrapper_obj of FuncGraphAbstractClosure must be PyObjectWrapper but got: "
414                                    << wrapper_obj->ToString() << ".";
415       }
416       auto x_py_obj = wrapper_obj->cast<parse::PyObjectWrapperPtr>()->obj();
417       result = CheckIsInstanceForFunc(x_py_obj, cmp, mod);
418     }
419   } else if (x->isa<abstract::PrimitiveAbstractClosure>()) {
420     // x is Primitive.
421     auto x_py_obj = GetPrimitivePyObj(x->cast<abstract::PrimitiveAbstractClosurePtr>());
422     result = CheckIsInstanceForFunc(x_py_obj, cmp, mod);
423   } else if (x->isa<abstract::AbstractClass>()) {
424     // x is ms_class.
425     auto class_value = x->BuildValue();
426     MS_EXCEPTION_IF_NULL(class_value);
427     auto x_py = ValueToPyData(class_value);
428     result = CheckIsInstanceForFunc(x_py, cmp, mod);
429   } else if (x->isa<abstract::AbstractCSRTensor>()) {
430     // x is sparse tensor with type CSRTensor.
431     const size_t csr_index = 0;
432     result = CheckIsInstanceForSparse(cmp, kSparsePrimStr[csr_index]);
433   } else if (x->isa<abstract::AbstractCOOTensor>()) {
434     // x is sparse tensor with type COOTensor.
435     const size_t coo_index = 1;
436     result = CheckIsInstanceForSparse(cmp, kSparsePrimStr[coo_index]);
437   } else if (x->isa<abstract::AbstractRowTensor>()) {
438     // x is sparse tensor with type RowTensor.
439     const size_t row_index = 2;
440     result = CheckIsInstanceForSparse(cmp, kSparsePrimStr[row_index]);
441   } else if (IsAdapterTensor(x)) {
442     // x is adapter tensor.
443     result = CheckIsInstanceForAdapter(x, cmp);
444     return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(result), kBool);
445   } else if (x->BuildValue()->ContainsValueAny()) {
446     // x is variable built-in type.
447     auto x_abs_type = std::make_shared<AbstractType>(x->BuildType());
448     auto py_x_type = ValueToPyData(x_abs_type->BuildValue());
449     result = CheckPythonIsInstance(py_x_type, cmp, mod, false);
450   } else {
451     // x is python built-in constant type or external type.
452     py::object x_py_obj = ValueToPyData(x->BuildValue());
453     result = CheckPythonIsInstance(x_py_obj, cmp, mod, true);
454   }
455 
456   // If no constant type in cmp match the type of x and cmp contains AbstractAny,
457   // the result of isinstance can not determined in frontend, should be converted to pyexecute later.
458   // So we set the abstract of instance to variable boolean scalar.
459   if (!result && fallback::ContainsSequenceAnyType(cmp)) {
460     return std::make_shared<AbstractScalar>(kValueAny, kBool);
461   }
462   return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(result), kBool);
463 }
464 
CompareShape(const std::vector<ValuePtr> & x_shape,const std::vector<ValuePtr> & y_shape)465 bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
466   if (x_shape.size() != y_shape.size()) {
467     return false;
468   }
469 
470   for (size_t i = 0; i < x_shape.size(); ++i) {
471     if (GetValue<int64_t>(x_shape[i]) != GetValue<int64_t>(y_shape[i])) {
472       return false;
473     }
474   }
475 
476   return true;
477 }
478 
DoInferReduceShape(const AbstractTuplePtr & x_shape,const ValuePtr & x_shp_value,const ValueSequencePtr & axis_value_ptr,const PrimitivePtr & primitive)479 AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
480                                    const ValueSequencePtr &axis_value_ptr, const PrimitivePtr &primitive) {
481   size_t x_rank = x_shape->size();
482   std::set<int64_t> axis_set;
483   auto axis_data = axis_value_ptr->value();
484   if (axis_data.empty()) {
485     int64_t size = 1;
486     AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
487     return std::make_shared<AbstractTuple>(values);
488   }
489 
490   for (auto &elem : axis_data) {
491     auto x_rank_tmp = x_rank;
492     if (x_rank_tmp == 0) {
493       x_rank_tmp = 1;
494     }
495     int64_t e_value =
496       CheckAxis(primitive->name(), "axis", elem, -SizeToLong(x_rank_tmp), SizeToLong(x_rank_tmp), "input_x");
497     (void)axis_set.insert(e_value);
498   }
499   MS_EXCEPTION_IF_NULL(x_shp_value->cast<ValueTuplePtr>());
500   auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
501   if (x_shp_data.size() < x_rank) {
502     MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank << ".";
503   }
504   AbstractBasePtrList values;
505   for (size_t i = 0; i < x_rank; i++) {
506     if (axis_set.count(SizeToLong(i)) || axis_set.count(SizeToLong(i) - SizeToLong(x_rank))) {
507       auto axis_v = MakeValue(static_cast<int64_t>(1));
508       values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
509     } else {
510       int64_t dim_value = x_shp_data[i]->cast<Int64ImmPtr>()->value();
511       auto dim = MakeValue(dim_value);
512       values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
513     }
514   }
515 
516   return std::make_shared<AbstractTuple>(values);
517 }
518 
InferImplBroadcastGradientArgs(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)519 AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
520                                                const AbstractBasePtrList &args_abs_list) {
521   // this primitive get the index that need to reduce
522   // input: x's shape and y's shape, inputs should be tuple
523   // output: tuple of x and y 's reduce index, reduce index should be a tuple
524   MS_EXCEPTION_IF_NULL(primitive);
525   const std::string op_name = primitive->name();
526   const size_t inputs_size = 2;
527   CheckArgsSize(op_name, args_abs_list, inputs_size);
528   auto arg_x = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
529   auto arg_y = CheckArg<AbstractTuple>(op_name, args_abs_list, 1);
530 
531   auto arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>();
532   MS_EXCEPTION_IF_NULL(arg_x_value);
533 
534   auto arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>();
535   MS_EXCEPTION_IF_NULL(arg_y_value);
536 
537   const std::vector<ValuePtr> x_shape = arg_x_value->value();
538   const std::vector<ValuePtr> y_shape = arg_y_value->value();
539   bool is_same_shape = CompareShape(x_shape, y_shape);
540   // if it is the same shape , do not need reduce , return empty tuple
541   if (is_same_shape) {
542     AbstractBasePtrList empty_list;
543     auto x_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
544     auto y_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
545 
546     AbstractBasePtrList elem_list;
547     elem_list.push_back(x_reduce_idx);
548     elem_list.push_back(y_reduce_idx);
549 
550     return std::make_shared<AbstractTuple>(elem_list);
551   }
552   return BroadcastGradientArgsDiff(x_shape, y_shape);
553 }
554 
InferImplListReduce(const AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)555 AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
556                                     const AbstractBasePtrList &args_abs_list) {
557   // Inputs: a fn, a list and an object of a subclass of a AbstractBase.
558   MS_EXCEPTION_IF_NULL(engine);
559   MS_EXCEPTION_IF_NULL(primitive);
560   const std::string op_name = primitive->name();
561   const size_t inputs_size = 3;
562   CheckArgsSize(op_name, args_abs_list, inputs_size);
563   AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_abs_list, 0);
564   AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_abs_list, 1);
565   MS_EXCEPTION_IF_NULL(lst);
566   AbstractBasePtr dflt = args_abs_list[2];
567 
568   AbstractBasePtr list_type = AbstractJoin(lst->elements());
569   auto result1 = engine->Execute(fn, lst->elements());
570   MS_EXCEPTION_IF_NULL(result1);
571   auto result2 = engine->Execute(fn, {dflt, list_type});
572   MS_EXCEPTION_IF_NULL(result2);
573   MS_EXCEPTION_IF_NULL(result1->abstract());
574   MS_EXCEPTION_IF_NULL(result2->abstract());
575   return result1->abstract()->Join(result2->abstract());
576 }
577 
InferImplTupleReversed(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)578 AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
579                                        const AbstractBasePtrList &args_abs_list) {
580   // Inputs: a tuple
581   MS_EXCEPTION_IF_NULL(primitive);
582   const std::string op_name = primitive->name();
583   CheckArgsSize(op_name, args_abs_list, 1);
584   AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
585   MS_EXCEPTION_IF_NULL(input);
586   auto tuple_elements = input->elements();
587   AbstractBasePtrList elem_list;
588   (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
589                        [](const AbstractBasePtr &elem) { return elem->Clone(); });
590   return std::make_shared<AbstractTuple>(elem_list);
591 }
592 
InferImplReduceShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)593 AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
594                                      const AbstractBasePtrList &args_abs_list) {
595   // Inputs: x_shape, axis
596   MS_EXCEPTION_IF_NULL(primitive);
597   const std::string op_name = primitive->name();
598   constexpr size_t arg_size = 2;
599   CheckArgsSize(op_name, args_abs_list, arg_size);
600   AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
601   MS_EXCEPTION_IF_NULL(shape_x);
602   MS_EXCEPTION_IF_NULL(args_abs_list[1]);
603 
604   auto x_shp_value = shape_x->BuildValue();
605   if (x_shp_value->ContainsValueAny()) {
606     MS_LOG(INTERNAL_EXCEPTION) << "The ReduceShape operator's data field can't be anything: "
607                                << args_abs_list[1]->ToString() << ".";
608   }
609 
610   // Axis can be scalar, tuple or list
611   AbstractSequencePtr axis = nullptr;
612   if (args_abs_list[1]->isa<AbstractScalar>()) {
613     MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar.";
614     AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_abs_list[1])};
615     axis = std::make_shared<AbstractTuple>(axis_list);
616   } else if (args_abs_list[1]->isa<AbstractSequence>()) {
617     MS_LOG(DEBUG) << "The type of second argument of ReduceShape operator is sequence.";
618     axis = args_abs_list[1]->cast<AbstractSequencePtr>();
619   } else {
620     MS_LOG(EXCEPTION) << "The second argument of ReduceShape operator should be a scalar or tuple or list, "
621                       << "but got " << args_abs_list[1]->ToString() << ".";
622   }
623 
624   auto axis_value = axis->BuildValue();
625   if (axis_value->ContainsValueAny()) {
626     MS_LOG(INTERNAL_EXCEPTION) << "The ReduceShape operator's data field can't be anything: "
627                                << args_abs_list[1]->ToString() << ".";
628   }
629   auto axis_value_ptr = axis_value->cast<ValueSequencePtr>();
630   MS_EXCEPTION_IF_NULL(axis_value_ptr);
631   return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
632 }
633 
InferImplTupleDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)634 AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
635                                   const AbstractBasePtrList &args_abs_list) {
636   // Inputs: two tuples.
637   MS_EXCEPTION_IF_NULL(primitive);
638   const std::string op_name = primitive->name();
639   constexpr size_t arg_size = 2;
640   CheckArgsSize(op_name, args_abs_list, arg_size);
641   AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
642   AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_abs_list, 1);
643   MS_EXCEPTION_IF_NULL(shape_x);
644   MS_EXCEPTION_IF_NULL(div_shp);
645   MS_LOG(INFO) << "The shape of dividend:" << shape_x->ToString() << ", the shape of divisor:" << div_shp->ToString();
646 
647   auto div_shp_value = div_shp->BuildValue();
648   MS_EXCEPTION_IF_NULL(div_shp_value);
649   if (div_shp_value->ContainsValueAny()) {
650     MS_LOG(INTERNAL_EXCEPTION) << "The 'tuple_div' operator shape's data field can't be anything, but got "
651                                << args_abs_list[0]->ToString() << ".";
652   }
653 
654   auto shape_x_value = shape_x->BuildValue();
655   MS_EXCEPTION_IF_NULL(shape_x_value);
656   if (shape_x_value->ContainsValueAny()) {
657     MS_LOG(INTERNAL_EXCEPTION) << "The 'tuple_div' operator shape's data field can't be anything, but got "
658                                << args_abs_list[1]->ToString() << ".";
659   }
660 
661   if (div_shp->size() != shape_x->size()) {
662     MS_LOG(INTERNAL_EXCEPTION)
663       << "The size of inputs of 'tuple_div' operator must be the same, but the size of divisor tuple is"
664       << " " << div_shp->size() << ", the size of dividend tuple is " << shape_x->size() << ".";
665   }
666   auto shape_x_tuple_value = shape_x_value->cast<ValueTuplePtr>();
667   auto div_shape_tuple_value = div_shp_value->cast<ValueTuplePtr>();
668   MS_EXCEPTION_IF_NULL(shape_x_tuple_value);
669   MS_EXCEPTION_IF_NULL(div_shape_tuple_value);
670   auto shape_x_data = shape_x_tuple_value->value();
671   auto div_shape_data = div_shape_tuple_value->value();
672   AbstractBasePtrList values;
673 
674   for (size_t i = 0; i < div_shape_data.size(); i++) {
675     MS_EXCEPTION_IF_NULL(div_shape_data[i]);
676     if (div_shape_data[i]->cast<Int64ImmPtr>() == nullptr) {
677       auto value_type = div_shape_data[i]->type();
678       std::string str_type;
679       if (value_type) {
680         str_type = value_type->ToString();
681       } else {
682         str_type = "ValueAny";
683       }
684       MS_LOG(EXCEPTION) << "The data type of inputs of 'tuple_div' operator should be an int64 number, but got a "
685                         << str_type << " number " << div_shape_data[i]->ToString() << ".";
686     }
687     auto shapex_value = GetValue<int64_t>(shape_x_data[i]);
688     auto div_value = GetValue<int64_t>(div_shape_data[i]);
689     MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
690     if (div_value == 0) {
691       MS_LOG(EXCEPTION) << "The divisor value should not be 0!";
692     }
693     if ((shapex_value % div_value) != 0) {
694       MS_LOG(EXCEPTION) << "The inputs of 'tuple_div' operator should be divisible, but they are not divisible now, "
695                         << "the dividend is " << shapex_value << ", the divisor is " << div_value << ".";
696     }
697 
698     int64_t result = shapex_value / div_value;
699     auto result_v = MakeValue(result);
700     values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
701   }
702   return std::make_shared<AbstractTuple>(values);
703 }
704 
InferImplTuple2Array(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)705 AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
706                                      const AbstractBasePtrList &args_abs_list) {
707   // Inputs: a tuple
708   const std::string op_name = primitive->name();
709   CheckArgsSize(op_name, args_abs_list, 1);
710   AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
711   MS_EXCEPTION_IF_NULL(input);
712   py::tuple data_tuple = ValueToPyData(input->BuildValue());
713   py::array data = py::array(data_tuple);
714   auto tensor = tensor::TensorPy::MakeTensor(data);
715   auto ret = tensor->ToAbstract();
716   ret->set_value(tensor);
717   MS_LOG(DEBUG) << "The infer result of Tuple2Array operator is tensor, the infer result is " << ret->ToString() << ".";
718   return ret;
719 }
720 
InferImplSliceGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)721 AbstractBasePtr InferImplSliceGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
722                                       const AbstractBasePtrList &args_abs_list) {
723   auto op_name = primitive->name();
724   constexpr auto slice_getitem_input_size = 2;
725   CheckArgsSize(op_name, args_abs_list, slice_getitem_input_size);
726   AbstractSlicePtr slice_abs = CheckArg<AbstractSlice>(op_name, args_abs_list, 0);
727   const std::map<std::string, AbstractBasePtr> result_map = {
728     {kSliceStart, slice_abs->start()}, {kSliceStop, slice_abs->stop()}, {kSliceStep, slice_abs->step()}};
729   auto slice_attr = args_abs_list[1]->BuildValue();
730   MS_EXCEPTION_IF_NULL(slice_attr);
731   if (!slice_attr->isa<StringImm>()) {
732     MS_LOG(EXCEPTION) << "The second argument of SliceGetItem operator should be a string, but got "
733                       << slice_attr->ToString() << ".";
734   }
735   auto slice_str = GetValue<std::string>(slice_attr);
736   auto iter = result_map.find(slice_str);
737   if (iter == result_map.end()) {
738     MS_INTERNAL_EXCEPTION(AttributeError) << "The 'slice' object has no attribute:" << slice_str << ".";
739   }
740   return iter->second;
741 }
742 
InferImplMakeSlice(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)743 AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
744                                    const AbstractBasePtrList &args_abs_list) {
745   // Inputs: three scalars whose value is an int32 number.
746   constexpr auto make_slice_input_size = 3;
747   CheckArgsSize(primitive->name(), args_abs_list, make_slice_input_size);
748   size_t args_size = args_abs_list.size();
749   AbstractBasePtrList slice_args;
750   for (size_t index = 0; index < args_size; index++) {
751     MS_EXCEPTION_IF_NULL(args_abs_list[index]);
752     if (args_abs_list[index]->isa<AbstractNone>()) {
753       slice_args.push_back(args_abs_list[index]);
754     } else if (args_abs_list[index]->isa<AbstractScalar>()) {
755       ValuePtr scalar_value = args_abs_list[index]->cast<AbstractScalarPtr>()->BuildValue();
756       MS_EXCEPTION_IF_NULL(scalar_value);
757       if (scalar_value->isa<IntegerImm>() || scalar_value->ContainsValueAny()) {
758         slice_args.push_back(args_abs_list[index]);
759       } else if (scalar_value->isa<BoolImm>()) {
760         ValuePtr scalar_index = MakeValue(static_cast<int64_t>(scalar_value->cast<BoolImmPtr>()->value()));
761         slice_args.push_back(scalar_index->ToAbstract());
762       } else {
763         auto type = scalar_value->type();
764         MS_EXCEPTION_IF_NULL(type);
765         MS_EXCEPTION(TypeError) << "Slice indices must be integers or bool. But got a " << type->ToString()
766                                 << " number.";
767       }
768     } else if (args_abs_list[index]->isa<AbstractTensor>()) {
769       auto arg = args_abs_list[index]->cast<AbstractTensorPtr>();
770       TypePtr tensor_dtype = arg->element()->BuildType();
771       auto build_value = arg->BuildValue();
772       MS_EXCEPTION_IF_NULL(build_value);
773       auto value = build_value->cast<tensor::TensorPtr>();
774       if (value != nullptr) {
775         if (value->DataSize() != 1) {
776           MS_EXCEPTION(TypeError) << "The input tensor of the MakeSlice operator must contain only one element,"
777                                   << "but " << value->ToString() << " has " << value->DataSize() << " elements.";
778         }
779 
780         if (tensor_dtype->isa<Bool>()) {
781           auto *bool_value = static_cast<bool *>(value->data_c());
782           slice_args.push_back(MakeValue((static_cast<int64_t>(*bool_value)))->ToAbstract());
783         } else if (tensor_dtype == kInt64) {
784           auto *int_value = static_cast<int64_t *>(value->data_c());
785           slice_args.push_back(MakeValue((*int_value))->ToAbstract());
786         } else if (tensor_dtype == kInt32) {
787           auto *int_value = static_cast<int32_t *>(value->data_c());
788           slice_args.push_back(MakeValue((*int_value))->ToAbstract());
789         } else {
790           MS_EXCEPTION(TypeError) << "The input tensor type of the MakeSlice operator must be int or bool, but got "
791                                   << tensor_dtype->ToString();
792         }
793       } else {
794         slice_args.push_back(args_abs_list[index]);
795       }
796     } else {
797       MS_EXCEPTION(TypeError) << "The " << index << "th input of MakeSlice operator should be scalar, none or tensor, "
798                               << "but got " << args_abs_list[index]->ToString() << ".";
799     }
800   }
801   // Slice: start, end, step
802   constexpr size_t kMakeSliceInput0 = 0;
803   constexpr size_t kMakeSliceInput1 = 1;
804   constexpr size_t kMakeSliceInput2 = 2;
805   return std::make_shared<AbstractSlice>(slice_args[kMakeSliceInput0], slice_args[kMakeSliceInput1],
806                                          slice_args[kMakeSliceInput2]);
807 }
808 
InferImplStopGradient(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)809 AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
810                                       const AbstractBasePtrList &args_abs_list) {
811   // Inputs: any value;
812   CheckArgsSize(primitive->name(), args_abs_list, 1);
813   return args_abs_list[0]->Clone();
814 }
815 
InferImplDictLen(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)816 AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
817                                  const AbstractBasePtrList &args_abs_list) {
818   return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_abs_list);
819 }
820 
InferImplJ(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)821 AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
822                            const AbstractBasePtrList &args_abs_list) {
823   // args: An object of AbstractFunction.
824   CheckArgsSize(primitive->name(), args_abs_list, 1);
825   MS_LOG(DEBUG) << "evaluate J: " << args_abs_list[0]->ToString();
826 
827   AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_abs_list[0]);
828   if (x == nullptr) {
829     return std::make_shared<AbstractJTagged>(args_abs_list[0]);
830   }
831 
832   AbstractFuncAtomPtrList jv;
833   auto build_jv = [&jv](const AbstractFuncAtomPtr &func) {
834     auto j_closure = std::make_shared<JTransformedAbstractClosure>(func);
835     jv.push_back(j_closure);
836   };
837   x->Visit(build_jv);
838 
839   return AbstractFunction::MakeAbstractFunction(jv);
840 }
841 
InferImplTaylor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)842 AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
843                                 const AbstractBasePtrList &args_abs_list) {
844   // args: An object of AbstractFunction.
845   CheckArgsSize(primitive->name(), args_abs_list, 1);
846   MS_LOG(DEBUG) << "evaluate Taylor: " << args_abs_list[0]->ToString();
847 
848   AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_abs_list[0]);
849   MS_EXCEPTION_IF_NULL(x);
850 
851   AbstractFuncAtomPtrList taylor_v;
852   auto build_taylor_v = [&taylor_v](const AbstractFuncAtomPtr &func) {
853     auto taylor_closure = std::make_shared<TaylorTransformedAbstractClosure>(func);
854     taylor_v.push_back(taylor_closure);
855   };
856   x->Visit(build_taylor_v);
857 
858   return AbstractFunction::MakeAbstractFunction(taylor_v);
859 }
860 
InferImplReusing(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)861 AbstractBasePtr InferImplReusing(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
862                                  const AbstractBasePtrList &args_abs_list) {
863   // args: An object of AbstractFunction.
864   CheckArgsSize(primitive->name(), args_abs_list, 1);
865   MS_LOG(DEBUG) << "evaluate Reusing: " << args_abs_list[0]->ToString();
866   AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_abs_list[0]);
867   MS_EXCEPTION_IF_NULL(x);
868   auto set_graph_no_inline = [](const AbstractFuncAtomPtr &func) {
869     auto fg_closure = dyn_cast<FuncGraphAbstractClosure>(func);
870     if (fg_closure != nullptr) {
871       fg_closure->func_graph()->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
872       MS_LOG(DEBUG) << " Reusing: " << func->ToString()
873                     << " no_inline: " << fg_closure->func_graph()->has_flag(FUNC_GRAPH_FLAG_NO_INLINE);
874     }
875   };
876   x->Visit(set_graph_no_inline);
877   return x;
878 }
879 
InferImplShard(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)880 AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
881                                const AbstractBasePtrList &args_abs_list) {
882   // Inputs: func, in_axes, out_axes, device, level.
883   constexpr size_t shard_input_size = 5;
884   CheckArgsSize(primitive->name(), args_abs_list, shard_input_size);
885   MS_LOG(DEBUG) << "Evaluate Shard: " << args_abs_list[0]->ToString();
886 
887   AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_abs_list[0]);
888   MS_EXCEPTION_IF_NULL(x);
889 
890   AbstractFuncAtomPtrList shard_v;
891   auto build_shard_v = [&shard_v](const AbstractFuncAtomPtr &func) {
892     auto shard_closure = std::make_shared<ShardTransformedAbstractClosure>(func);
893     shard_v.push_back(shard_closure);
894   };
895   x->Visit(build_shard_v);
896 
897   return AbstractFunction::MakeAbstractFunction(shard_v);
898 }
899 
InferImplVmap(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)900 AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
901                               const AbstractBasePtrList &args_abs_list) {
902   // args: An object of AbstractFunction.
903   CheckArgsSize(primitive->name(), args_abs_list, 1);
904   auto fn_arg = args_abs_list[0];
905   MS_LOG(DEBUG) << "Evaluate Vmap: " << fn_arg->ToString() << ".";
906 
907   AbstractFuncAtomPtrList vmap_v;
908   ValuePtr in_axes = primitive->GetAttr("in_axes");
909   ValuePtr out_axes = primitive->GetAttr("out_axes");
910   ValuePtr cell_size_value = primitive->GetAttr("cell_size");
911   MS_EXCEPTION_IF_NULL(cell_size_value);
912   auto cell_size = cell_size_value->isa<UInt64Imm>() ? dyn_cast<UInt64Imm>(cell_size_value)->value() : 0;
913 
914   auto traverse_fn = [&vmap_v, &in_axes, &out_axes, &cell_size](const AbstractBasePtr &fn_arg) {
915     AbstractFunctionPtr x = dyn_cast<AbstractFunction>(fn_arg);
916     MS_EXCEPTION_IF_NULL(x);
917     auto build_vmap_v = [&vmap_v, &in_axes, &out_axes, &cell_size](const AbstractFuncAtomPtr &func) {
918       auto vmap_closure = std::make_shared<VmapTransformedAbstractClosure>(func, in_axes, out_axes, cell_size);
919       vmap_v.push_back(vmap_closure);
920     };
921     x->Visit(build_vmap_v);
922   };
923 
924   AbstractTuplePtr cell_list = dyn_cast<AbstractTuple>(fn_arg);
925   if (cell_list != nullptr) {
926     const auto &cell_list_fns = cell_list->elements();
927     for (const auto &fn : cell_list_fns) {
928       traverse_fn(fn);
929     }
930   } else {
931     traverse_fn(fn_arg);
932   }
933 
934   return AbstractFunction::MakeAbstractFunction(vmap_v);
935 }
936 
InferImplFakeBprop(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)937 AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
938                                    const AbstractBasePtrList &args_abs_list) {
939   // Inputs: a tensor.
940   CheckArgsSize(primitive->name(), args_abs_list, 1);
941   return args_abs_list[0]->Broaden();
942 }
943 
GetStringAndNumberFromAbstract(const std::string & op_name,const AbstractBasePtrList & args_abs_list,std::string * str,int64_t * num)944 void GetStringAndNumberFromAbstract(const std::string &op_name, const AbstractBasePtrList &args_abs_list,
945                                     std::string *str, int64_t *num) {
946   constexpr size_t args_num = 2;
947   CheckArgsSize(op_name, args_abs_list, args_num);
948   AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_abs_list, 0);
949   AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_abs_list, 1);
950   ValuePtr value_x = scalar_x->BuildValue();
951   ValuePtr value_y = scalar_y->BuildValue();
952 
953   bool is_match = false;
954   if (value_x->isa<StringImm>()) {
955     *str = GetValue<std::string>(value_x);
956     if (value_y->isa<Int32Imm>()) {
957       *num = IntToLong(GetValue<int32_t>(value_y));
958       is_match = true;
959     } else if (value_y->isa<Int64Imm>()) {
960       *num = GetValue<int64_t>(value_y);
961       is_match = true;
962     }
963   } else if (value_y->isa<StringImm>()) {
964     *str = GetValue<std::string>(value_y);
965     if (value_x->isa<Int32Imm>()) {
966       *num = IntToLong(GetValue<int32_t>(value_x));
967       is_match = true;
968     } else if (value_x->isa<Int64Imm>()) {
969       *num = GetValue<int64_t>(value_x);
970       is_match = true;
971     }
972   }
973   if (!is_match) {
974     MS_LOG(EXCEPTION) << op_name << " requires the input to be a string and an integer, but got " << value_x->ToString()
975                       << " and " << value_y->ToString() << ".";
976   }
977 }
978 
InferImplStringMul(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)979 AbstractBasePtr InferImplStringMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
980                                    const AbstractBasePtrList &args_abs_list) {
981   // Inputs: a string and an integer.
982   std::string str;
983   int64_t num = 0;
984   const std::string op_name = primitive->name();
985   GetStringAndNumberFromAbstract(op_name, args_abs_list, &str, &num);
986   std::string res;
987   // If num is less than or equal to 0, return an empty string.
988   if (num > 0) {
989     for (auto i = 0; i < num; i++) {
990       res += str;
991     }
992   }
993   return std::make_shared<AbstractScalar>(res);
994 }
995 
InferImplStringGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)996 AbstractBasePtr InferImplStringGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
997                                        const AbstractBasePtrList &args_abs_list) {
998   // Inputs: a string and an integer.
999   std::string str;
1000   int64_t num = 0;
1001   const std::string op_name = primitive->name();
1002   GetStringAndNumberFromAbstract(op_name, args_abs_list, &str, &num);
1003   int64_t len = SizeToLong(str.length());
1004   if (num >= len || num < -len) {
1005     MS_LOG(EXCEPTION) << "String index out of range, expect:[" << -len << ", " << (len - 1) << "], but got " << num
1006                       << ".";
1007   }
1008   if (num < 0) {
1009     num += len;
1010   }
1011   std::string res;
1012   (void)res.append(1, str.at(num));
1013   return std::make_shared<AbstractScalar>(res);
1014 }
1015 
PrimNeedFrontendInferValue(const PrimitivePtr & primitive)1016 bool PrimNeedFrontendInferValue(const PrimitivePtr &primitive) {
1017   // The operators in this list are registered on the core/ops, which means operators are registered on both frontend
1018   // and backend, affects the infer value of the frontend. We use this list to skip the registration of the backend, so
1019   // that the optimization of the frontend like constant folding, can be carried out smoothly. We need to delete this
1020   // list when the infer value can be mapped to the CPU backend operator.
1021   static std::vector<PrimitivePtr> skip_frontend_registration_list{
1022     prim::kPrimAdd, prim::kPrimMod,          prim::kPrimMul,   prim::kPrimRealDiv,
1023     prim::kPrimSub, prim::kPrimStridedSlice, prim::kPrimStack, prim::kPrimTensorScatterUpdate,
1024     prim::kPrimTile};
1025   if (std::any_of(skip_frontend_registration_list.begin(), skip_frontend_registration_list.end(),
1026                   [&primitive](const PrimitivePtr &item) {
1027                     return IsPrimitiveEquals(primitive, item) && primitive->HasPyEvaluator();
1028                   })) {
1029     return true;
1030   }
1031   return false;
1032 }
1033 
1034 static PrimitiveEvalImplMap frontend_prim_infer_map{
1035   // frontend
1036 };
GetFrontendPrimitiveInferMapPtr()1037 PrimitiveEvalImplMap *GetFrontendPrimitiveInferMapPtr() { return &frontend_prim_infer_map; }
GetFrontendPrimitiveInferMap()1038 const PrimitiveEvalImplMap &GetFrontendPrimitiveInferMap() { return frontend_prim_infer_map; }
GetFrontendPrimitiveInferImpl(const PrimitivePtr & primitive)1039 std::optional<StandardPrimitiveImplReg> GetFrontendPrimitiveInferImpl(const PrimitivePtr &primitive) {
1040   auto iter = GetFrontendPrimitiveInferMap().find(primitive);
1041   if (iter != GetFrontendPrimitiveInferMap().end()) {
1042     return iter->second;
1043   }
1044 
1045   // We need to delete this when the infer value can be mapped to the CPU backend operator.
1046   if (PrimNeedFrontendInferValue(primitive)) {
1047     return std::optional<StandardPrimitiveImplReg>();
1048   }
1049 
1050   auto find = abstract::GetPrimitiveInferImpl(primitive);
1051   if (find.has_value()) {
1052     return find.value();
1053   }
1054   return std::optional<StandardPrimitiveImplReg>();
1055 }
1056 
SetAdapterFlag(const std::string & op_name,const AbstractBasePtr & abs_input,bool adapter_flag)1057 AbstractBasePtr SetAdapterFlag(const std::string &op_name, const AbstractBasePtr &abs_input, bool adapter_flag) {
1058   MS_EXCEPTION_IF_NULL(abs_input);
1059   // Clone is needed here.
1060   if (abs_input->isa<AbstractRefTensor>()) {
1061     auto abs_ref = abs_input->Clone()->cast<AbstractRefPtr>();
1062     abs_ref->set_is_adapter(adapter_flag);
1063     return abs_ref;
1064   }
1065   if (abs_input->isa<AbstractTensor>()) {
1066     auto abs_tensor = abs_input->Clone()->cast<AbstractTensorPtr>();
1067     abs_tensor->set_is_adapter(adapter_flag);
1068     return abs_tensor;
1069   }
1070   MS_LOG(EXCEPTION) << op_name << " requires a tensor as the first argument, but got " << abs_input->ToString();
1071 }
1072 
InferImplConvertToAdapterTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)1073 AbstractBasePtr InferImplConvertToAdapterTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1074                                                 const AbstractBasePtrList &args_abs_list) {
1075   // Inputs: a tensor.
1076   constexpr size_t args_num = 1;
1077   constexpr size_t input_index = 0;
1078   const std::string op_name = primitive->name();
1079   CheckArgsSize(op_name, args_abs_list, args_num);
1080   return SetAdapterFlag(op_name, args_abs_list[input_index], true);
1081 }
1082 
InferImplConvertToMsTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)1083 AbstractBasePtr InferImplConvertToMsTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1084                                            const AbstractBasePtrList &args_abs_list) {
1085   // Inputs: a tensor.
1086   constexpr size_t args_num = 1;
1087   constexpr size_t input_index = 0;
1088   const std::string op_name = primitive->name();
1089   CheckArgsSize(op_name, args_abs_list, args_num);
1090   return SetAdapterFlag(op_name, args_abs_list[input_index], false);
1091 }
1092 
InferImplDtypeToEnum(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)1093 AbstractBasePtr InferImplDtypeToEnum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1094                                      const AbstractBasePtrList &args_abs_list) {
1095   constexpr size_t args_num = 3;
1096   CheckArgsSize(primitive->name(), args_abs_list, args_num);
1097   auto abs_type = args_abs_list[ops::kInputIndex2]->cast<AbstractTypePtr>();
1098   if (abs_type == nullptr) {
1099     const auto &op_name = GetValue<std::string>(args_abs_list[ops::kInputIndex0]->GetValue());
1100     const auto &arg_name = GetValue<std::string>(args_abs_list[ops::kInputIndex1]->GetValue());
1101     MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input '" << arg_name << "' expect a type, but got "
1102                             << args_abs_list[ops::kInputIndex2]->ToString();
1103   }
1104   auto val_type = abs_type->BuildValue();
1105   MS_EXCEPTION_IF_NULL(val_type);
1106   auto dtype = val_type->cast<TypePtr>();
1107   MS_EXCEPTION_IF_NULL(dtype);
1108   int64_t type_id = GetTypeId(dtype->type_id());
1109   return std::make_shared<AbstractScalar>(type_id);
1110 }
1111 
1112 #ifndef _MSC_VER
1113 // String
1114 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringMul, prim::kPrimStringMul, InferImplStringMul, nullptr);
1115 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringGetItem, prim::kPrimStringGetItem, InferImplStringGetItem, nullptr);
1116 // Tuple
1117 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr);
1118 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr);
1119 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr);
1120 // List
1121 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr);
1122 // Dict
1123 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
1124 // Slice
1125 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeSlice, prim::kPrimMakeSlice, InferImplMakeSlice, nullptr);
1126 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(SliceGetItem, prim::kPrimSliceGetItem, InferImplSliceGetItem, nullptr);
1127 // Type
1128 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr);
1129 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TopTypeOf, prim::kPrimTopTypeOf, InferImplTopTypeof, nullptr);
1130 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr);
1131 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(IsInstance, prim::kPrimIsInstance, InferImplIsInstance, nullptr);
1132 // Shape
1133 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr);
1134 // Auto-Grad
1135 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr);
1136 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr);
1137 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
1138 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
1139                                    InferImplBroadcastGradientArgs, nullptr);
1140 // Other
1141 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Reusing, prim::kPrimReusing, InferImplReusing, nullptr);
1142 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Taylor, prim::kPrimTaylor, InferImplTaylor, nullptr);
1143 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
1144 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Vmap, prim::kPrimVmap, InferImplVmap, nullptr);
1145 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringUpper, prim::kPrimStringUpper, InferImplStringUpper, nullptr);
1146 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringLower, prim::kPrimStringLower, InferImplStringLower, nullptr);
1147 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ConvertToAdapterTensor, prim::kPrimConvertToAdapterTensor,
1148                                    InferImplConvertToAdapterTensor, nullptr);
1149 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ConvertToMsTensor, prim::kPrimConvertToMsTensor, InferImplConvertToMsTensor,
1150                                    nullptr);
1151 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DtypeToEnum, prim::kPrimDtypeToEnum, InferImplDtypeToEnum, nullptr);
1152 #else
RegPrimitiveFrontEval()1153 void RegPrimitiveFrontEval() {
1154   // String
1155   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStringMul,
1156                                                 InferImplStringMul, nullptr);
1157   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStringGetItem,
1158                                                 InferImplStringGetItem, nullptr);
1159   // Tuple
1160   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleReversed,
1161                                                 InferImplTupleReversed, nullptr);
1162   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleDiv,
1163                                                 InferImplTupleDiv, nullptr);
1164   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleToArray,
1165                                                 InferImplTuple2Array, nullptr);
1166   // List
1167   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimListReduce,
1168                                                 InferImplListReduce, nullptr);
1169   // Dict
1170   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimDictLen,
1171                                                 InferImplDictLen, nullptr);
1172   // Slice
1173   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimMakeSlice,
1174                                                 InferImplMakeSlice, nullptr);
1175   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimSliceGetItem,
1176                                                 InferImplSliceGetItem, nullptr);
1177   // Type
1178   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTypeOf,
1179                                                 InferImplTypeof, nullptr);
1180   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTopTypeOf,
1181                                                 InferImplTopTypeof, nullptr);
1182   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimHasType,
1183                                                 InferImplHasType, nullptr);
1184   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimIsInstance,
1185                                                 InferImplIsInstance, nullptr);
1186   // Shape
1187   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimReducedShape,
1188                                                 InferImplReduceShape, nullptr);
1189   // Auto-Grad
1190   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStopGradient,
1191                                                 InferImplStopGradient, nullptr);
1192   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimFakeBprop,
1193                                                 InferImplFakeBprop, nullptr);
1194   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimJ, InferImplJ,
1195                                                 nullptr);
1196   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
1197                                                 prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs,
1198                                                 nullptr);
1199   // Other
1200   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTaylor,
1201                                                 InferImplTaylor, nullptr);
1202   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimShard,
1203                                                 InferImplShard, nullptr);
1204   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimVmap,
1205                                                 InferImplVmap, nullptr);
1206   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStringUpper,
1207                                                 InferImplStringUpper, nullptr);
1208   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStringLower,
1209                                                 InferImplStringLower, nullptr);
1210   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
1211                                                 prim::kPrimConvertToAdapterTensor, InferImplConvertToAdapterTensor,
1212                                                 nullptr);
1213   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
1214                                                 prim::kPrimConvertToMsTensor, InferImplConvertToMsTensor, nullptr);
1215   abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimDtypeToEnum,
1216                                                 InferImplDtypeToEnum, nullptr);
1217 }  // namespace abstract
1218 #endif
1219 }  // namespace abstract
1220 }  // namespace mindspore
1221