1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/operator/ops_front_infer_function.h"
17
18 #include <set>
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <algorithm>
23 #include <map>
24
25 #include "mindspore/core/ops/structure_ops.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "abstract/abstract_value.h"
31 #include "pipeline/jit/ps/parse/resolve.h"
32 #include "pipeline/jit/ps/static_analysis/prim.h"
33 #include "pipeline/jit/ps/fallback.h"
34 #include "abstract/param_validator.h"
35 #include "pybind_api/ir/tensor_py.h"
36 #include "frontend/operator/ops.h"
37 #include "abstract/ops/infer_functions.h"
38 #include "include/common/utils/convert_utils_py.h"
39 #include "include/common/utils/utils.h"
40 #include "ops/auto_generate/gen_ops_primitive.h"
41 #include "ops/ops_func_impl/greater_equal.h"
42 #include "ops/ops_func_impl/greater.h"
43 #include "ops/mod.h"
44 #include "ops/strided_slice_v2.h"
45 #include "ops/grad/strided_slice_v2_grad.h"
46 #include "abstract/abstract_function.h"
47 #include "utils/ms_context.h"
48 #include "ops/op_name.h"
49 #ifdef _MSC_VER
50 #include "include/common/pybind_api/api_register.h"
51 #endif
52
53 namespace mindspore {
54 namespace abstract {
55 enum class State {
56 SAME,
57 X_ONE,
58 Y_ONE,
59 };
60
ComputeReduceIndex(const std::vector<int64_t> & reverse_x,const std::vector<int64_t> & reverse_y,std::vector<int64_t> * grad_x_reduce_idx,std::vector<int64_t> * grad_y_reduce_idy)61 void ComputeReduceIndex(const std::vector<int64_t> &reverse_x, const std::vector<int64_t> &reverse_y,
62 std::vector<int64_t> *grad_x_reduce_idx, std::vector<int64_t> *grad_y_reduce_idy) {
63 MS_EXCEPTION_IF_NULL(grad_x_reduce_idx);
64 MS_EXCEPTION_IF_NULL(grad_y_reduce_idy);
65 const size_t n = reverse_x.size();
66 if (reverse_y.size() < n) {
67 MS_LOG(EXCEPTION) << "The size of reverse_y is less than the size of reverse_x.";
68 }
69 for (size_t i = 0; i < n; ++i) {
70 State curr;
71 const int64_t x_i = reverse_x[i];
72 const int64_t y_i = reverse_y[i];
73 const int64_t reduce_idx = SizeToLong(n - 1 - i);
74 if (x_i == y_i) {
75 curr = State::SAME;
76 } else if (x_i == 1) {
77 grad_x_reduce_idx->push_back(reduce_idx);
78 curr = State::X_ONE;
79 } else if (y_i == 1) {
80 grad_y_reduce_idy->push_back(reduce_idx);
81 curr = State::Y_ONE;
82 } else {
83 MS_LOG(EXCEPTION) << "Not compatible shape input for BroadcastGradientArgs.";
84 }
85 if (curr == State::SAME && x_i == 1) {
86 grad_x_reduce_idx->push_back(reduce_idx);
87 grad_y_reduce_idy->push_back(reduce_idx);
88 continue;
89 }
90 }
91
92 std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end());
93 std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end());
94 }
95
BroadcastGradientArgsDiff(const std::vector<ValuePtr> & x_shape,const std::vector<ValuePtr> & y_shape)96 AbstractBasePtr BroadcastGradientArgsDiff(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
97 std::vector<int64_t> reverse_x;
98 std::vector<int64_t> reverse_y;
99
100 (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x),
101 [](const ValuePtr &v) { return v->cast<Int64ImmPtr>()->value(); });
102 (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y),
103 [](const ValuePtr &v) { return v->cast<Int64ImmPtr>()->value(); });
104
105 if (reverse_x.size() > reverse_y.size()) {
106 reverse_y.resize(reverse_x.size(), 1);
107 } else {
108 reverse_x.resize(reverse_y.size(), 1);
109 }
110
111 std::vector<int64_t> grad_x_reduce_idx;
112 std::vector<int64_t> grad_y_reduce_idy;
113 ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy);
114
115 AbstractBasePtrList abs_list_x;
116 AbstractBasePtrList abs_list_y;
117 (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x),
118 [](int64_t v) { return abstract::FromValue(v); });
119 (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y),
120 [](int64_t v) { return abstract::FromValue(v); });
121 auto x_reduce_idx = std::make_shared<AbstractTuple>(abs_list_x);
122 auto y_reduce_idx = std::make_shared<AbstractTuple>(abs_list_y);
123 AbstractBasePtrList elem_list;
124 elem_list.push_back(x_reduce_idx);
125 elem_list.push_back(y_reduce_idx);
126
127 return std::make_shared<AbstractTuple>(elem_list);
128 }
129
InferImplTypeof(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)130 AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
131 const AbstractBasePtrList &args_abs_list) {
132 // Inputs: a pointer to an AbstractBase object
133 if (args_abs_list.size() != 1) {
134 MS_LOG(EXCEPTION) << "The Typeof operator must requires 1 argument, but the size of arguments is "
135 << args_abs_list.size() << ".";
136 }
137 AbstractBasePtr abs_base = args_abs_list[0];
138 MS_EXCEPTION_IF_NULL(abs_base);
139 TypePtr type = abs_base->BuildType();
140 return std::make_shared<AbstractType>(type);
141 }
142
InferImplTopTypeof(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_abs_list)143 AbstractBasePtr InferImplTopTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
144 const AbstractBasePtrList &args_abs_list) {
145 // Inputs: a pointer to an AbstractBase object
146 if (args_abs_list.size() != 1) {
147 MS_LOG(EXCEPTION) << "The Typeof operator must requires 1 argument, but the size of arguments is "
148 << args_abs_list.size() << ".";
149 }
150 AbstractBasePtr abs_base = args_abs_list[0];
151 MS_EXCEPTION_IF_NULL(abs_base);
152 TypeId type_id = abs_base->BuildType()->type_id();
153 return std::make_shared<AbstractType>(TypeIdToType(type_id));
154 }
155
InferImplStringUpper(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)156 AbstractBasePtr InferImplStringUpper(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
157 const AbstractBasePtrList &args_abs_list) {
158 MS_EXCEPTION_IF_NULL(primitive);
159 if (args_abs_list.size() != 1) {
160 MS_LOG(INTERNAL_EXCEPTION) << "StringUpper takes 1 argument, but got " << args_abs_list.size();
161 }
162 constexpr size_t index_str = 0;
163 auto abs_str = args_abs_list[index_str];
164 MS_EXCEPTION_IF_NULL(abs_str);
165 auto value_str = abs_str->BuildValue();
166 MS_EXCEPTION_IF_NULL(value_str);
167 if (!value_str->isa<StringImm>()) {
168 MS_INTERNAL_EXCEPTION(TypeError) << "StringUpper expected to get a string as input, but got:"
169 << value_str->ToString();
170 }
171 auto str = value_str->cast<StringImmPtr>()->value();
172 (void)std::transform(str.begin(), str.end(), str.begin(), ::toupper);
173 auto new_str = MakeValue(str);
174 return new_str->ToAbstract();
175 }
176
InferImplStringLower(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)177 AbstractBasePtr InferImplStringLower(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
178 const AbstractBasePtrList &args_abs_list) {
179 MS_EXCEPTION_IF_NULL(primitive);
180 if (args_abs_list.size() != 1) {
181 MS_LOG(EXCEPTION) << "StringLower takes 1 argument, but got " << args_abs_list.size();
182 }
183 constexpr size_t index_str = 0;
184 auto abs_str = args_abs_list[index_str];
185 MS_EXCEPTION_IF_NULL(abs_str);
186 auto value_str = abs_str->BuildValue();
187 MS_EXCEPTION_IF_NULL(value_str);
188 if (!value_str->isa<StringImm>()) {
189 MS_INTERNAL_EXCEPTION(TypeError) << "StringLower expected to get a string as input, but got:"
190 << value_str->ToString();
191 }
192 auto str = value_str->cast<StringImmPtr>()->value();
193 (void)std::transform(str.begin(), str.end(), str.begin(), ::tolower);
194 auto new_str = MakeValue(str);
195 return new_str->ToAbstract();
196 }
197
InferImplHasType(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)198 AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
199 const AbstractBasePtrList &args_abs_list) {
200 MS_EXCEPTION_IF_NULL(primitive);
201 // Inputs: a pointer to an AbstractBase object and a pointer to a Type
202 const std::string op_name = primitive->name();
203 const size_t args_num = 2;
204 CheckArgsSize(op_name, args_abs_list, args_num);
205 AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_abs_list, 1);
206 MS_EXCEPTION_IF_NULL(abs_type);
207 auto mode_v = abs_type->GetValueTrack();
208 MS_EXCEPTION_IF_NULL(mode_v);
209 if (!mode_v->isa<Type>()) {
210 MS_LOG(INTERNAL_EXCEPTION) << "Get the type from AbstractType value failed.";
211 }
212
213 auto tmpMode = mode_v->cast<TypePtr>();
214 MS_EXCEPTION_IF_NULL(args_abs_list[0]);
215 bool v = IsSubtype(args_abs_list[0], tmpMode);
216 return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
217 }
218
IsAdapterTensor(const AbstractBasePtr & x)219 bool IsAdapterTensor(const AbstractBasePtr &x) {
220 if (!x->isa<abstract::AbstractTensor>()) {
221 return false;
222 }
223 return x->cast<abstract::AbstractTensorPtr>()->is_adapter();
224 }
225
CheckIsInstanceForAdapter(const AbstractBasePtr & x,const AbstractBasePtr & cmp)226 bool CheckIsInstanceForAdapter(const AbstractBasePtr &x, const AbstractBasePtr &cmp) {
227 if (cmp->isa<abstract::AbstractTuple>()) {
228 const auto &elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
229 return std::any_of(elements.begin(), elements.end(),
230 [=](const AbstractBasePtr &element) { return CheckIsInstanceForAdapter(x, element); });
231 }
232 auto cmp_value = cmp->BuildValue();
233 MS_EXCEPTION_IF_NULL(cmp_value);
234 if (cmp_value->isa<parse::ClassType>()) {
235 auto class_obj = cmp_value->cast<parse::ClassTypePtr>()->obj();
236 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
237 // isinstance(tensor_x, Tensor) -> true, isinstance(tensor_x, Parameter) -> false.
238 // isinstance(parameter_x, Tensor) -> true, isinstance(parameter_x, Parameter) -> true.
239 bool is_cmp_tensor =
240 python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_IS_ADAPTER_TENSOR_CLASS, class_obj).cast<bool>();
241 if (is_cmp_tensor) {
242 return true;
243 }
244 bool is_x_parameter = x->isa<abstract::AbstractRefTensor>();
245 bool is_cmp_parameter =
246 python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_IS_ADAPTER_PARAMETER_CLASS, class_obj).cast<bool>();
247 return is_x_parameter && is_cmp_parameter;
248 }
249 return false;
250 }
251
CheckPythonIsInstance(const py::object & x,const AbstractBasePtr & cmp,const py::module & mod,bool is_const)252 bool CheckPythonIsInstance(const py::object &x, const AbstractBasePtr &cmp, const py::module &mod, bool is_const) {
253 if (cmp->isa<abstract::AbstractTuple>()) {
254 const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
255 return std::any_of(cmp_tuple_elements.begin(), cmp_tuple_elements.end(),
256 [&x, &mod, is_const](const AbstractBasePtr &element) {
257 return CheckPythonIsInstance(x, element, mod, is_const);
258 });
259 }
260 if (std::find(kSparsePrimStr.begin(), kSparsePrimStr.end(), cmp->ToString()) != kSparsePrimStr.end()) {
261 return false;
262 }
263
264 py::object cmp_type;
265 if (cmp->isa<abstract::PartialAbstractClosure>()) {
266 const auto &cmp_closure_args = cmp->cast<abstract::PartialAbstractClosurePtr>()->args();
267 // CheckCmpValid ensures size of cmp_closure_args to be 1.
268 auto cmp_closure_first_input = cmp_closure_args[0];
269 cmp_type = ValueToPyData(cmp_closure_first_input->BuildValue());
270 } else {
271 auto cmp_value = cmp->BuildValue();
272 if (cmp_value->ContainsValueAny()) {
273 return false;
274 }
275 cmp_type = ValueToPyData(cmp_value);
276 }
277
278 py::object result = is_const ? python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_PYTHON_ISINSTANCE, x, cmp_type)
279 : python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_MS_ISINSTANCE, x, cmp_type);
280 return result.cast<bool>();
281 }
282
CheckIsInstanceForFunc(const py::object & x_py_obj,const AbstractBasePtr & cmp,const py::module & mod)283 bool CheckIsInstanceForFunc(const py::object &x_py_obj, const AbstractBasePtr &cmp, const py::module &mod) {
284 MS_EXCEPTION_IF_NULL(cmp);
285 if (cmp->isa<abstract::AbstractTuple>()) {
286 const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
287 return std::any_of(
288 cmp_tuple_elements.begin(), cmp_tuple_elements.end(),
289 [&x_py_obj, &mod](const AbstractBasePtr &element) { return CheckIsInstanceForFunc(x_py_obj, element, mod); });
290 }
291
292 if (!cmp->isa<abstract::PartialAbstractClosure>()) {
293 return false;
294 }
295 const auto &cmp_closure_args = cmp->cast<abstract::PartialAbstractClosurePtr>()->args();
296 // CheckCmpValid ensures size of cmp_closure_args to be 1.
297 auto cmp_closure_first_input = cmp_closure_args[0];
298 auto cmp_py_obj = ValueToPyData(cmp_closure_first_input->BuildValue());
299 auto result = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_PYTHON_ISINSTANCE, x_py_obj, cmp_py_obj);
300 return result.cast<bool>();
301 }
302
CheckIsInstanceForSparse(const AbstractBasePtr & cmp,const std::string & target)303 bool CheckIsInstanceForSparse(const AbstractBasePtr &cmp, const std::string &target) {
304 MS_EXCEPTION_IF_NULL(cmp);
305 if (!cmp->isa<abstract::AbstractTuple>()) {
306 return cmp->ToString() == target;
307 }
308 const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
309 return std::any_of(cmp_tuple_elements.begin(), cmp_tuple_elements.end(),
310 [&target](const AbstractBasePtr &element) { return CheckIsInstanceForSparse(element, target); });
311 }
312
GetPrimitivePyObj(const abstract::PrimitiveAbstractClosurePtr & prim_abs)313 py::object GetPrimitivePyObj(const abstract::PrimitiveAbstractClosurePtr &prim_abs) {
314 MS_EXCEPTION_IF_NULL(prim_abs);
315 auto prim = prim_abs->prim();
316 MS_EXCEPTION_IF_NULL(prim);
317 auto prim_signature = prim->cast<prim::DoSignaturePrimitivePtr>();
318 MS_EXCEPTION_IF_NULL(prim_signature);
319 auto function = prim_signature->function();
320 MS_EXCEPTION_IF_NULL(function);
321 auto primitive_py_function = function->cast<PrimitivePyPtr>();
322 return primitive_py_function->GetPyObj();
323 }
324
GetMsClassPyObj(const abstract::PartialAbstractClosurePtr & ms_class_abs)325 py::object GetMsClassPyObj(const abstract::PartialAbstractClosurePtr &ms_class_abs) {
326 MS_EXCEPTION_IF_NULL(ms_class_abs);
327 const auto &ms_class_args = ms_class_abs->args();
328 if (ms_class_args.size() != 1) {
329 MS_LOG(INTERNAL_EXCEPTION)
330 << "When the first input to IsInstance is PartialAbstractClosure, its args size should be 1 but "
331 << "got: " << ms_class_args.size() << ".";
332 }
333 auto first_arg = ms_class_args[0];
334 auto class_value = first_arg->BuildValue();
335 MS_EXCEPTION_IF_NULL(class_value);
336 return ValueToPyData(class_value);
337 }
338
CheckCmpValid(const AbstractBasePtr & cmp)339 bool CheckCmpValid(const AbstractBasePtr &cmp) {
340 MS_EXCEPTION_IF_NULL(cmp);
341 if (cmp->isa<abstract::AbstractSequence>()) {
342 if (!cmp->isa<abstract::AbstractTuple>()) {
343 return false;
344 }
345 const auto &elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
346 return std::all_of(elements.begin(), elements.end(),
347 [](const AbstractBasePtr &element) { return CheckCmpValid(element); });
348 }
349 if (cmp->isa<abstract::AbstractScalar>()) {
350 auto cmp_type = cmp->BuildType();
351 MS_EXCEPTION_IF_NULL(cmp_type);
352 return cmp_type->type_id() == kMetaTypeTypeType;
353 } else if (cmp->isa<abstract::PartialAbstractClosure>()) {
354 auto cmp_closure = cmp->cast<abstract::PartialAbstractClosurePtr>();
355 const auto &cmp_closure_args = cmp_closure->args();
356 if (cmp_closure_args.size() != 1) {
357 return false;
358 }
359 auto cmp_closure_first_input = cmp_closure_args[0];
360 auto cmp_type = cmp_closure_first_input->BuildType();
361 MS_EXCEPTION_IF_NULL(cmp_type);
362 auto cmp_type_id = cmp_type->type_id();
363 if (cmp_type_id == kObjectTypeClass) {
364 // When cmp type is ms_class, fn should be create_instance.
365 auto cmp_closure_fn = cmp_closure->fn();
366 MS_EXCEPTION_IF_NULL(cmp_closure_fn);
367 const std::string ms_class_type_fn_name = "PrimitiveAbstractClosure: create_instance";
368 return cmp_closure_fn->ToString() == ms_class_type_fn_name;
369 }
370 return cmp_type_id == kMetaTypeTypeType;
371 } else if (cmp->isa<abstract::AbstractAny>()) {
372 return true;
373 }
374 return std::find(kSparsePrimStr.cbegin(), kSparsePrimStr.cend(), cmp->ToString()) != kSparsePrimStr.cend();
375 }
376
InferImplIsInstance(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)377 AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
378 const AbstractBasePtrList &args_abs_list) {
379 MS_EXCEPTION_IF_NULL(primitive);
380 constexpr size_t args_num = 2;
381 CheckArgsSize(primitive->name(), args_abs_list, args_num);
382 py::gil_scoped_acquire gil;
383 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
384 auto x = args_abs_list[0];
385 MS_EXCEPTION_IF_NULL(x);
386 auto cmp = args_abs_list[1];
387 MS_EXCEPTION_IF_NULL(cmp);
388
389 if (!CheckCmpValid(cmp)) {
390 auto cmp_type = cmp->BuildType();
391 MS_EXCEPTION_IF_NULL(cmp_type);
392 MS_LOG(ERROR) << "cmp: " << cmp->ToString() << ", cmp_type: " << cmp_type->ToString()
393 << ", cmp_type_id: " << TypeIdToType(cmp_type->type_id());
394 MS_EXCEPTION(TypeError) << "isinstance() arg 2 must be a type or tuple of types.";
395 }
396
397 // If x is AbstractAny the result of isinstance can not determined in frontend,
398 // isinstance should be converted to pyexecute later.
399 // So we set the abstract of instance to variable boolean scalar.
400 if (x->isa<abstract::AbstractAny>()) {
401 return std::make_shared<AbstractScalar>(kValueAny, kBool);
402 }
403
404 MS_EXCEPTION_IF_NULL(x);
405 bool result = false;
406 if (x->isa<abstract::FuncGraphAbstractClosure>()) {
407 // x is Cell object.
408 auto x_fg = x->cast<abstract::FuncGraphAbstractClosurePtr>()->func_graph();
409 MS_EXCEPTION_IF_NULL(x_fg);
410 auto wrapper_obj = x_fg->python_obj();
411 if (wrapper_obj != nullptr) {
412 if (!wrapper_obj->isa<parse::PyObjectWrapper>()) {
413 MS_LOG(INTERNAL_EXCEPTION) << "The wrapper_obj of FuncGraphAbstractClosure must be PyObjectWrapper but got: "
414 << wrapper_obj->ToString() << ".";
415 }
416 auto x_py_obj = wrapper_obj->cast<parse::PyObjectWrapperPtr>()->obj();
417 result = CheckIsInstanceForFunc(x_py_obj, cmp, mod);
418 }
419 } else if (x->isa<abstract::PrimitiveAbstractClosure>()) {
420 // x is Primitive.
421 auto x_py_obj = GetPrimitivePyObj(x->cast<abstract::PrimitiveAbstractClosurePtr>());
422 result = CheckIsInstanceForFunc(x_py_obj, cmp, mod);
423 } else if (x->isa<abstract::AbstractClass>()) {
424 // x is ms_class.
425 auto class_value = x->BuildValue();
426 MS_EXCEPTION_IF_NULL(class_value);
427 auto x_py = ValueToPyData(class_value);
428 result = CheckIsInstanceForFunc(x_py, cmp, mod);
429 } else if (x->isa<abstract::AbstractCSRTensor>()) {
430 // x is sparse tensor with type CSRTensor.
431 const size_t csr_index = 0;
432 result = CheckIsInstanceForSparse(cmp, kSparsePrimStr[csr_index]);
433 } else if (x->isa<abstract::AbstractCOOTensor>()) {
434 // x is sparse tensor with type COOTensor.
435 const size_t coo_index = 1;
436 result = CheckIsInstanceForSparse(cmp, kSparsePrimStr[coo_index]);
437 } else if (x->isa<abstract::AbstractRowTensor>()) {
438 // x is sparse tensor with type RowTensor.
439 const size_t row_index = 2;
440 result = CheckIsInstanceForSparse(cmp, kSparsePrimStr[row_index]);
441 } else if (IsAdapterTensor(x)) {
442 // x is adapter tensor.
443 result = CheckIsInstanceForAdapter(x, cmp);
444 return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(result), kBool);
445 } else if (x->BuildValue()->ContainsValueAny()) {
446 // x is variable built-in type.
447 auto x_abs_type = std::make_shared<AbstractType>(x->BuildType());
448 auto py_x_type = ValueToPyData(x_abs_type->BuildValue());
449 result = CheckPythonIsInstance(py_x_type, cmp, mod, false);
450 } else {
451 // x is python built-in constant type or external type.
452 py::object x_py_obj = ValueToPyData(x->BuildValue());
453 result = CheckPythonIsInstance(x_py_obj, cmp, mod, true);
454 }
455
456 // If no constant type in cmp match the type of x and cmp contains AbstractAny,
457 // the result of isinstance can not determined in frontend, should be converted to pyexecute later.
458 // So we set the abstract of instance to variable boolean scalar.
459 if (!result && fallback::ContainsSequenceAnyType(cmp)) {
460 return std::make_shared<AbstractScalar>(kValueAny, kBool);
461 }
462 return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(result), kBool);
463 }
464
CompareShape(const std::vector<ValuePtr> & x_shape,const std::vector<ValuePtr> & y_shape)465 bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
466 if (x_shape.size() != y_shape.size()) {
467 return false;
468 }
469
470 for (size_t i = 0; i < x_shape.size(); ++i) {
471 if (GetValue<int64_t>(x_shape[i]) != GetValue<int64_t>(y_shape[i])) {
472 return false;
473 }
474 }
475
476 return true;
477 }
478
DoInferReduceShape(const AbstractTuplePtr & x_shape,const ValuePtr & x_shp_value,const ValueSequencePtr & axis_value_ptr,const PrimitivePtr & primitive)479 AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
480 const ValueSequencePtr &axis_value_ptr, const PrimitivePtr &primitive) {
481 size_t x_rank = x_shape->size();
482 std::set<int64_t> axis_set;
483 auto axis_data = axis_value_ptr->value();
484 if (axis_data.empty()) {
485 int64_t size = 1;
486 AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
487 return std::make_shared<AbstractTuple>(values);
488 }
489
490 for (auto &elem : axis_data) {
491 auto x_rank_tmp = x_rank;
492 if (x_rank_tmp == 0) {
493 x_rank_tmp = 1;
494 }
495 int64_t e_value =
496 CheckAxis(primitive->name(), "axis", elem, -SizeToLong(x_rank_tmp), SizeToLong(x_rank_tmp), "input_x");
497 (void)axis_set.insert(e_value);
498 }
499 MS_EXCEPTION_IF_NULL(x_shp_value->cast<ValueTuplePtr>());
500 auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
501 if (x_shp_data.size() < x_rank) {
502 MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank << ".";
503 }
504 AbstractBasePtrList values;
505 for (size_t i = 0; i < x_rank; i++) {
506 if (axis_set.count(SizeToLong(i)) || axis_set.count(SizeToLong(i) - SizeToLong(x_rank))) {
507 auto axis_v = MakeValue(static_cast<int64_t>(1));
508 values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
509 } else {
510 int64_t dim_value = x_shp_data[i]->cast<Int64ImmPtr>()->value();
511 auto dim = MakeValue(dim_value);
512 values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
513 }
514 }
515
516 return std::make_shared<AbstractTuple>(values);
517 }
518
InferImplBroadcastGradientArgs(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)519 AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
520 const AbstractBasePtrList &args_abs_list) {
521 // this primitive get the index that need to reduce
522 // input: x's shape and y's shape, inputs should be tuple
523 // output: tuple of x and y 's reduce index, reduce index should be a tuple
524 MS_EXCEPTION_IF_NULL(primitive);
525 const std::string op_name = primitive->name();
526 const size_t inputs_size = 2;
527 CheckArgsSize(op_name, args_abs_list, inputs_size);
528 auto arg_x = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
529 auto arg_y = CheckArg<AbstractTuple>(op_name, args_abs_list, 1);
530
531 auto arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>();
532 MS_EXCEPTION_IF_NULL(arg_x_value);
533
534 auto arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>();
535 MS_EXCEPTION_IF_NULL(arg_y_value);
536
537 const std::vector<ValuePtr> x_shape = arg_x_value->value();
538 const std::vector<ValuePtr> y_shape = arg_y_value->value();
539 bool is_same_shape = CompareShape(x_shape, y_shape);
540 // if it is the same shape , do not need reduce , return empty tuple
541 if (is_same_shape) {
542 AbstractBasePtrList empty_list;
543 auto x_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
544 auto y_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
545
546 AbstractBasePtrList elem_list;
547 elem_list.push_back(x_reduce_idx);
548 elem_list.push_back(y_reduce_idx);
549
550 return std::make_shared<AbstractTuple>(elem_list);
551 }
552 return BroadcastGradientArgsDiff(x_shape, y_shape);
553 }
554
InferImplListReduce(const AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)555 AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
556 const AbstractBasePtrList &args_abs_list) {
557 // Inputs: a fn, a list and an object of a subclass of a AbstractBase.
558 MS_EXCEPTION_IF_NULL(engine);
559 MS_EXCEPTION_IF_NULL(primitive);
560 const std::string op_name = primitive->name();
561 const size_t inputs_size = 3;
562 CheckArgsSize(op_name, args_abs_list, inputs_size);
563 AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_abs_list, 0);
564 AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_abs_list, 1);
565 MS_EXCEPTION_IF_NULL(lst);
566 AbstractBasePtr dflt = args_abs_list[2];
567
568 AbstractBasePtr list_type = AbstractJoin(lst->elements());
569 auto result1 = engine->Execute(fn, lst->elements());
570 MS_EXCEPTION_IF_NULL(result1);
571 auto result2 = engine->Execute(fn, {dflt, list_type});
572 MS_EXCEPTION_IF_NULL(result2);
573 MS_EXCEPTION_IF_NULL(result1->abstract());
574 MS_EXCEPTION_IF_NULL(result2->abstract());
575 return result1->abstract()->Join(result2->abstract());
576 }
577
InferImplTupleReversed(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)578 AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
579 const AbstractBasePtrList &args_abs_list) {
580 // Inputs: a tuple
581 MS_EXCEPTION_IF_NULL(primitive);
582 const std::string op_name = primitive->name();
583 CheckArgsSize(op_name, args_abs_list, 1);
584 AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
585 MS_EXCEPTION_IF_NULL(input);
586 auto tuple_elements = input->elements();
587 AbstractBasePtrList elem_list;
588 (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
589 [](const AbstractBasePtr &elem) { return elem->Clone(); });
590 return std::make_shared<AbstractTuple>(elem_list);
591 }
592
InferImplReduceShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)593 AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
594 const AbstractBasePtrList &args_abs_list) {
595 // Inputs: x_shape, axis
596 MS_EXCEPTION_IF_NULL(primitive);
597 const std::string op_name = primitive->name();
598 constexpr size_t arg_size = 2;
599 CheckArgsSize(op_name, args_abs_list, arg_size);
600 AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
601 MS_EXCEPTION_IF_NULL(shape_x);
602 MS_EXCEPTION_IF_NULL(args_abs_list[1]);
603
604 auto x_shp_value = shape_x->BuildValue();
605 if (x_shp_value->ContainsValueAny()) {
606 MS_LOG(INTERNAL_EXCEPTION) << "The ReduceShape operator's data field can't be anything: "
607 << args_abs_list[1]->ToString() << ".";
608 }
609
610 // Axis can be scalar, tuple or list
611 AbstractSequencePtr axis = nullptr;
612 if (args_abs_list[1]->isa<AbstractScalar>()) {
613 MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar.";
614 AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_abs_list[1])};
615 axis = std::make_shared<AbstractTuple>(axis_list);
616 } else if (args_abs_list[1]->isa<AbstractSequence>()) {
617 MS_LOG(DEBUG) << "The type of second argument of ReduceShape operator is sequence.";
618 axis = args_abs_list[1]->cast<AbstractSequencePtr>();
619 } else {
620 MS_LOG(EXCEPTION) << "The second argument of ReduceShape operator should be a scalar or tuple or list, "
621 << "but got " << args_abs_list[1]->ToString() << ".";
622 }
623
624 auto axis_value = axis->BuildValue();
625 if (axis_value->ContainsValueAny()) {
626 MS_LOG(INTERNAL_EXCEPTION) << "The ReduceShape operator's data field can't be anything: "
627 << args_abs_list[1]->ToString() << ".";
628 }
629 auto axis_value_ptr = axis_value->cast<ValueSequencePtr>();
630 MS_EXCEPTION_IF_NULL(axis_value_ptr);
631 return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
632 }
633
InferImplTupleDiv(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)634 AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
635 const AbstractBasePtrList &args_abs_list) {
636 // Inputs: two tuples.
637 MS_EXCEPTION_IF_NULL(primitive);
638 const std::string op_name = primitive->name();
639 constexpr size_t arg_size = 2;
640 CheckArgsSize(op_name, args_abs_list, arg_size);
641 AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
642 AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_abs_list, 1);
643 MS_EXCEPTION_IF_NULL(shape_x);
644 MS_EXCEPTION_IF_NULL(div_shp);
645 MS_LOG(INFO) << "The shape of dividend:" << shape_x->ToString() << ", the shape of divisor:" << div_shp->ToString();
646
647 auto div_shp_value = div_shp->BuildValue();
648 MS_EXCEPTION_IF_NULL(div_shp_value);
649 if (div_shp_value->ContainsValueAny()) {
650 MS_LOG(INTERNAL_EXCEPTION) << "The 'tuple_div' operator shape's data field can't be anything, but got "
651 << args_abs_list[0]->ToString() << ".";
652 }
653
654 auto shape_x_value = shape_x->BuildValue();
655 MS_EXCEPTION_IF_NULL(shape_x_value);
656 if (shape_x_value->ContainsValueAny()) {
657 MS_LOG(INTERNAL_EXCEPTION) << "The 'tuple_div' operator shape's data field can't be anything, but got "
658 << args_abs_list[1]->ToString() << ".";
659 }
660
661 if (div_shp->size() != shape_x->size()) {
662 MS_LOG(INTERNAL_EXCEPTION)
663 << "The size of inputs of 'tuple_div' operator must be the same, but the size of divisor tuple is"
664 << " " << div_shp->size() << ", the size of dividend tuple is " << shape_x->size() << ".";
665 }
666 auto shape_x_tuple_value = shape_x_value->cast<ValueTuplePtr>();
667 auto div_shape_tuple_value = div_shp_value->cast<ValueTuplePtr>();
668 MS_EXCEPTION_IF_NULL(shape_x_tuple_value);
669 MS_EXCEPTION_IF_NULL(div_shape_tuple_value);
670 auto shape_x_data = shape_x_tuple_value->value();
671 auto div_shape_data = div_shape_tuple_value->value();
672 AbstractBasePtrList values;
673
674 for (size_t i = 0; i < div_shape_data.size(); i++) {
675 MS_EXCEPTION_IF_NULL(div_shape_data[i]);
676 if (div_shape_data[i]->cast<Int64ImmPtr>() == nullptr) {
677 auto value_type = div_shape_data[i]->type();
678 std::string str_type;
679 if (value_type) {
680 str_type = value_type->ToString();
681 } else {
682 str_type = "ValueAny";
683 }
684 MS_LOG(EXCEPTION) << "The data type of inputs of 'tuple_div' operator should be an int64 number, but got a "
685 << str_type << " number " << div_shape_data[i]->ToString() << ".";
686 }
687 auto shapex_value = GetValue<int64_t>(shape_x_data[i]);
688 auto div_value = GetValue<int64_t>(div_shape_data[i]);
689 MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
690 if (div_value == 0) {
691 MS_LOG(EXCEPTION) << "The divisor value should not be 0!";
692 }
693 if ((shapex_value % div_value) != 0) {
694 MS_LOG(EXCEPTION) << "The inputs of 'tuple_div' operator should be divisible, but they are not divisible now, "
695 << "the dividend is " << shapex_value << ", the divisor is " << div_value << ".";
696 }
697
698 int64_t result = shapex_value / div_value;
699 auto result_v = MakeValue(result);
700 values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
701 }
702 return std::make_shared<AbstractTuple>(values);
703 }
704
InferImplTuple2Array(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)705 AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
706 const AbstractBasePtrList &args_abs_list) {
707 // Inputs: a tuple
708 const std::string op_name = primitive->name();
709 CheckArgsSize(op_name, args_abs_list, 1);
710 AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_abs_list, 0);
711 MS_EXCEPTION_IF_NULL(input);
712 py::tuple data_tuple = ValueToPyData(input->BuildValue());
713 py::array data = py::array(data_tuple);
714 auto tensor = tensor::TensorPy::MakeTensor(data);
715 auto ret = tensor->ToAbstract();
716 ret->set_value(tensor);
717 MS_LOG(DEBUG) << "The infer result of Tuple2Array operator is tensor, the infer result is " << ret->ToString() << ".";
718 return ret;
719 }
720
InferImplSliceGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)721 AbstractBasePtr InferImplSliceGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
722 const AbstractBasePtrList &args_abs_list) {
723 auto op_name = primitive->name();
724 constexpr auto slice_getitem_input_size = 2;
725 CheckArgsSize(op_name, args_abs_list, slice_getitem_input_size);
726 AbstractSlicePtr slice_abs = CheckArg<AbstractSlice>(op_name, args_abs_list, 0);
727 const std::map<std::string, AbstractBasePtr> result_map = {
728 {kSliceStart, slice_abs->start()}, {kSliceStop, slice_abs->stop()}, {kSliceStep, slice_abs->step()}};
729 auto slice_attr = args_abs_list[1]->BuildValue();
730 MS_EXCEPTION_IF_NULL(slice_attr);
731 if (!slice_attr->isa<StringImm>()) {
732 MS_LOG(EXCEPTION) << "The second argument of SliceGetItem operator should be a string, but got "
733 << slice_attr->ToString() << ".";
734 }
735 auto slice_str = GetValue<std::string>(slice_attr);
736 auto iter = result_map.find(slice_str);
737 if (iter == result_map.end()) {
738 MS_INTERNAL_EXCEPTION(AttributeError) << "The 'slice' object has no attribute:" << slice_str << ".";
739 }
740 return iter->second;
741 }
742
InferImplMakeSlice(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)743 AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
744 const AbstractBasePtrList &args_abs_list) {
745 // Inputs: three scalars whose value is an int32 number.
746 constexpr auto make_slice_input_size = 3;
747 CheckArgsSize(primitive->name(), args_abs_list, make_slice_input_size);
748 size_t args_size = args_abs_list.size();
749 AbstractBasePtrList slice_args;
750 for (size_t index = 0; index < args_size; index++) {
751 MS_EXCEPTION_IF_NULL(args_abs_list[index]);
752 if (args_abs_list[index]->isa<AbstractNone>()) {
753 slice_args.push_back(args_abs_list[index]);
754 } else if (args_abs_list[index]->isa<AbstractScalar>()) {
755 ValuePtr scalar_value = args_abs_list[index]->cast<AbstractScalarPtr>()->BuildValue();
756 MS_EXCEPTION_IF_NULL(scalar_value);
757 if (scalar_value->isa<IntegerImm>() || scalar_value->ContainsValueAny()) {
758 slice_args.push_back(args_abs_list[index]);
759 } else if (scalar_value->isa<BoolImm>()) {
760 ValuePtr scalar_index = MakeValue(static_cast<int64_t>(scalar_value->cast<BoolImmPtr>()->value()));
761 slice_args.push_back(scalar_index->ToAbstract());
762 } else {
763 auto type = scalar_value->type();
764 MS_EXCEPTION_IF_NULL(type);
765 MS_EXCEPTION(TypeError) << "Slice indices must be integers or bool. But got a " << type->ToString()
766 << " number.";
767 }
768 } else if (args_abs_list[index]->isa<AbstractTensor>()) {
769 auto arg = args_abs_list[index]->cast<AbstractTensorPtr>();
770 TypePtr tensor_dtype = arg->element()->BuildType();
771 auto build_value = arg->BuildValue();
772 MS_EXCEPTION_IF_NULL(build_value);
773 auto value = build_value->cast<tensor::TensorPtr>();
774 if (value != nullptr) {
775 if (value->DataSize() != 1) {
776 MS_EXCEPTION(TypeError) << "The input tensor of the MakeSlice operator must contain only one element,"
777 << "but " << value->ToString() << " has " << value->DataSize() << " elements.";
778 }
779
780 if (tensor_dtype->isa<Bool>()) {
781 auto *bool_value = static_cast<bool *>(value->data_c());
782 slice_args.push_back(MakeValue((static_cast<int64_t>(*bool_value)))->ToAbstract());
783 } else if (tensor_dtype == kInt64) {
784 auto *int_value = static_cast<int64_t *>(value->data_c());
785 slice_args.push_back(MakeValue((*int_value))->ToAbstract());
786 } else if (tensor_dtype == kInt32) {
787 auto *int_value = static_cast<int32_t *>(value->data_c());
788 slice_args.push_back(MakeValue((*int_value))->ToAbstract());
789 } else {
790 MS_EXCEPTION(TypeError) << "The input tensor type of the MakeSlice operator must be int or bool, but got "
791 << tensor_dtype->ToString();
792 }
793 } else {
794 slice_args.push_back(args_abs_list[index]);
795 }
796 } else {
797 MS_EXCEPTION(TypeError) << "The " << index << "th input of MakeSlice operator should be scalar, none or tensor, "
798 << "but got " << args_abs_list[index]->ToString() << ".";
799 }
800 }
801 // Slice: start, end, step
802 constexpr size_t kMakeSliceInput0 = 0;
803 constexpr size_t kMakeSliceInput1 = 1;
804 constexpr size_t kMakeSliceInput2 = 2;
805 return std::make_shared<AbstractSlice>(slice_args[kMakeSliceInput0], slice_args[kMakeSliceInput1],
806 slice_args[kMakeSliceInput2]);
807 }
808
InferImplStopGradient(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)809 AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
810 const AbstractBasePtrList &args_abs_list) {
811 // Inputs: any value;
812 CheckArgsSize(primitive->name(), args_abs_list, 1);
813 return args_abs_list[0]->Clone();
814 }
815
InferImplDictLen(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)816 AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
817 const AbstractBasePtrList &args_abs_list) {
818 return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_abs_list);
819 }
820
InferImplJ(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)821 AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
822 const AbstractBasePtrList &args_abs_list) {
823 // args: An object of AbstractFunction.
824 CheckArgsSize(primitive->name(), args_abs_list, 1);
825 MS_LOG(DEBUG) << "evaluate J: " << args_abs_list[0]->ToString();
826
827 AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_abs_list[0]);
828 if (x == nullptr) {
829 return std::make_shared<AbstractJTagged>(args_abs_list[0]);
830 }
831
832 AbstractFuncAtomPtrList jv;
833 auto build_jv = [&jv](const AbstractFuncAtomPtr &func) {
834 auto j_closure = std::make_shared<JTransformedAbstractClosure>(func);
835 jv.push_back(j_closure);
836 };
837 x->Visit(build_jv);
838
839 return AbstractFunction::MakeAbstractFunction(jv);
840 }
841
InferImplTaylor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)842 AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
843 const AbstractBasePtrList &args_abs_list) {
844 // args: An object of AbstractFunction.
845 CheckArgsSize(primitive->name(), args_abs_list, 1);
846 MS_LOG(DEBUG) << "evaluate Taylor: " << args_abs_list[0]->ToString();
847
848 AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_abs_list[0]);
849 MS_EXCEPTION_IF_NULL(x);
850
851 AbstractFuncAtomPtrList taylor_v;
852 auto build_taylor_v = [&taylor_v](const AbstractFuncAtomPtr &func) {
853 auto taylor_closure = std::make_shared<TaylorTransformedAbstractClosure>(func);
854 taylor_v.push_back(taylor_closure);
855 };
856 x->Visit(build_taylor_v);
857
858 return AbstractFunction::MakeAbstractFunction(taylor_v);
859 }
860
InferImplReusing(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)861 AbstractBasePtr InferImplReusing(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
862 const AbstractBasePtrList &args_abs_list) {
863 // args: An object of AbstractFunction.
864 CheckArgsSize(primitive->name(), args_abs_list, 1);
865 MS_LOG(DEBUG) << "evaluate Reusing: " << args_abs_list[0]->ToString();
866 AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_abs_list[0]);
867 MS_EXCEPTION_IF_NULL(x);
868 auto set_graph_no_inline = [](const AbstractFuncAtomPtr &func) {
869 auto fg_closure = dyn_cast<FuncGraphAbstractClosure>(func);
870 if (fg_closure != nullptr) {
871 fg_closure->func_graph()->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
872 MS_LOG(DEBUG) << " Reusing: " << func->ToString()
873 << " no_inline: " << fg_closure->func_graph()->has_flag(FUNC_GRAPH_FLAG_NO_INLINE);
874 }
875 };
876 x->Visit(set_graph_no_inline);
877 return x;
878 }
879
InferImplShard(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)880 AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
881 const AbstractBasePtrList &args_abs_list) {
882 // Inputs: func, in_axes, out_axes, device, level.
883 constexpr size_t shard_input_size = 5;
884 CheckArgsSize(primitive->name(), args_abs_list, shard_input_size);
885 MS_LOG(DEBUG) << "Evaluate Shard: " << args_abs_list[0]->ToString();
886
887 AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_abs_list[0]);
888 MS_EXCEPTION_IF_NULL(x);
889
890 AbstractFuncAtomPtrList shard_v;
891 auto build_shard_v = [&shard_v](const AbstractFuncAtomPtr &func) {
892 auto shard_closure = std::make_shared<ShardTransformedAbstractClosure>(func);
893 shard_v.push_back(shard_closure);
894 };
895 x->Visit(build_shard_v);
896
897 return AbstractFunction::MakeAbstractFunction(shard_v);
898 }
899
InferImplVmap(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)900 AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
901 const AbstractBasePtrList &args_abs_list) {
902 // args: An object of AbstractFunction.
903 CheckArgsSize(primitive->name(), args_abs_list, 1);
904 auto fn_arg = args_abs_list[0];
905 MS_LOG(DEBUG) << "Evaluate Vmap: " << fn_arg->ToString() << ".";
906
907 AbstractFuncAtomPtrList vmap_v;
908 ValuePtr in_axes = primitive->GetAttr("in_axes");
909 ValuePtr out_axes = primitive->GetAttr("out_axes");
910 ValuePtr cell_size_value = primitive->GetAttr("cell_size");
911 MS_EXCEPTION_IF_NULL(cell_size_value);
912 auto cell_size = cell_size_value->isa<UInt64Imm>() ? dyn_cast<UInt64Imm>(cell_size_value)->value() : 0;
913
914 auto traverse_fn = [&vmap_v, &in_axes, &out_axes, &cell_size](const AbstractBasePtr &fn_arg) {
915 AbstractFunctionPtr x = dyn_cast<AbstractFunction>(fn_arg);
916 MS_EXCEPTION_IF_NULL(x);
917 auto build_vmap_v = [&vmap_v, &in_axes, &out_axes, &cell_size](const AbstractFuncAtomPtr &func) {
918 auto vmap_closure = std::make_shared<VmapTransformedAbstractClosure>(func, in_axes, out_axes, cell_size);
919 vmap_v.push_back(vmap_closure);
920 };
921 x->Visit(build_vmap_v);
922 };
923
924 AbstractTuplePtr cell_list = dyn_cast<AbstractTuple>(fn_arg);
925 if (cell_list != nullptr) {
926 const auto &cell_list_fns = cell_list->elements();
927 for (const auto &fn : cell_list_fns) {
928 traverse_fn(fn);
929 }
930 } else {
931 traverse_fn(fn_arg);
932 }
933
934 return AbstractFunction::MakeAbstractFunction(vmap_v);
935 }
936
InferImplFakeBprop(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)937 AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
938 const AbstractBasePtrList &args_abs_list) {
939 // Inputs: a tensor.
940 CheckArgsSize(primitive->name(), args_abs_list, 1);
941 return args_abs_list[0]->Broaden();
942 }
943
GetStringAndNumberFromAbstract(const std::string & op_name,const AbstractBasePtrList & args_abs_list,std::string * str,int64_t * num)944 void GetStringAndNumberFromAbstract(const std::string &op_name, const AbstractBasePtrList &args_abs_list,
945 std::string *str, int64_t *num) {
946 constexpr size_t args_num = 2;
947 CheckArgsSize(op_name, args_abs_list, args_num);
948 AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_abs_list, 0);
949 AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_abs_list, 1);
950 ValuePtr value_x = scalar_x->BuildValue();
951 ValuePtr value_y = scalar_y->BuildValue();
952
953 bool is_match = false;
954 if (value_x->isa<StringImm>()) {
955 *str = GetValue<std::string>(value_x);
956 if (value_y->isa<Int32Imm>()) {
957 *num = IntToLong(GetValue<int32_t>(value_y));
958 is_match = true;
959 } else if (value_y->isa<Int64Imm>()) {
960 *num = GetValue<int64_t>(value_y);
961 is_match = true;
962 }
963 } else if (value_y->isa<StringImm>()) {
964 *str = GetValue<std::string>(value_y);
965 if (value_x->isa<Int32Imm>()) {
966 *num = IntToLong(GetValue<int32_t>(value_x));
967 is_match = true;
968 } else if (value_x->isa<Int64Imm>()) {
969 *num = GetValue<int64_t>(value_x);
970 is_match = true;
971 }
972 }
973 if (!is_match) {
974 MS_LOG(EXCEPTION) << op_name << " requires the input to be a string and an integer, but got " << value_x->ToString()
975 << " and " << value_y->ToString() << ".";
976 }
977 }
978
InferImplStringMul(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)979 AbstractBasePtr InferImplStringMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
980 const AbstractBasePtrList &args_abs_list) {
981 // Inputs: a string and an integer.
982 std::string str;
983 int64_t num = 0;
984 const std::string op_name = primitive->name();
985 GetStringAndNumberFromAbstract(op_name, args_abs_list, &str, &num);
986 std::string res;
987 // If num is less than or equal to 0, return an empty string.
988 if (num > 0) {
989 for (auto i = 0; i < num; i++) {
990 res += str;
991 }
992 }
993 return std::make_shared<AbstractScalar>(res);
994 }
995
InferImplStringGetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)996 AbstractBasePtr InferImplStringGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
997 const AbstractBasePtrList &args_abs_list) {
998 // Inputs: a string and an integer.
999 std::string str;
1000 int64_t num = 0;
1001 const std::string op_name = primitive->name();
1002 GetStringAndNumberFromAbstract(op_name, args_abs_list, &str, &num);
1003 int64_t len = SizeToLong(str.length());
1004 if (num >= len || num < -len) {
1005 MS_LOG(EXCEPTION) << "String index out of range, expect:[" << -len << ", " << (len - 1) << "], but got " << num
1006 << ".";
1007 }
1008 if (num < 0) {
1009 num += len;
1010 }
1011 std::string res;
1012 (void)res.append(1, str.at(num));
1013 return std::make_shared<AbstractScalar>(res);
1014 }
1015
PrimNeedFrontendInferValue(const PrimitivePtr & primitive)1016 bool PrimNeedFrontendInferValue(const PrimitivePtr &primitive) {
1017 // The operators in this list are registered on the core/ops, which means operators are registered on both frontend
1018 // and backend, affects the infer value of the frontend. We use this list to skip the registration of the backend, so
1019 // that the optimization of the frontend like constant folding, can be carried out smoothly. We need to delete this
1020 // list when the infer value can be mapped to the CPU backend operator.
1021 static std::vector<PrimitivePtr> skip_frontend_registration_list{
1022 prim::kPrimAdd, prim::kPrimMod, prim::kPrimMul, prim::kPrimRealDiv,
1023 prim::kPrimSub, prim::kPrimStridedSlice, prim::kPrimStack, prim::kPrimTensorScatterUpdate,
1024 prim::kPrimTile};
1025 if (std::any_of(skip_frontend_registration_list.begin(), skip_frontend_registration_list.end(),
1026 [&primitive](const PrimitivePtr &item) {
1027 return IsPrimitiveEquals(primitive, item) && primitive->HasPyEvaluator();
1028 })) {
1029 return true;
1030 }
1031 return false;
1032 }
1033
1034 static PrimitiveEvalImplMap frontend_prim_infer_map{
1035 // frontend
1036 };
GetFrontendPrimitiveInferMapPtr()1037 PrimitiveEvalImplMap *GetFrontendPrimitiveInferMapPtr() { return &frontend_prim_infer_map; }
GetFrontendPrimitiveInferMap()1038 const PrimitiveEvalImplMap &GetFrontendPrimitiveInferMap() { return frontend_prim_infer_map; }
GetFrontendPrimitiveInferImpl(const PrimitivePtr & primitive)1039 std::optional<StandardPrimitiveImplReg> GetFrontendPrimitiveInferImpl(const PrimitivePtr &primitive) {
1040 auto iter = GetFrontendPrimitiveInferMap().find(primitive);
1041 if (iter != GetFrontendPrimitiveInferMap().end()) {
1042 return iter->second;
1043 }
1044
1045 // We need to delete this when the infer value can be mapped to the CPU backend operator.
1046 if (PrimNeedFrontendInferValue(primitive)) {
1047 return std::optional<StandardPrimitiveImplReg>();
1048 }
1049
1050 auto find = abstract::GetPrimitiveInferImpl(primitive);
1051 if (find.has_value()) {
1052 return find.value();
1053 }
1054 return std::optional<StandardPrimitiveImplReg>();
1055 }
1056
SetAdapterFlag(const std::string & op_name,const AbstractBasePtr & abs_input,bool adapter_flag)1057 AbstractBasePtr SetAdapterFlag(const std::string &op_name, const AbstractBasePtr &abs_input, bool adapter_flag) {
1058 MS_EXCEPTION_IF_NULL(abs_input);
1059 // Clone is needed here.
1060 if (abs_input->isa<AbstractRefTensor>()) {
1061 auto abs_ref = abs_input->Clone()->cast<AbstractRefPtr>();
1062 abs_ref->set_is_adapter(adapter_flag);
1063 return abs_ref;
1064 }
1065 if (abs_input->isa<AbstractTensor>()) {
1066 auto abs_tensor = abs_input->Clone()->cast<AbstractTensorPtr>();
1067 abs_tensor->set_is_adapter(adapter_flag);
1068 return abs_tensor;
1069 }
1070 MS_LOG(EXCEPTION) << op_name << " requires a tensor as the first argument, but got " << abs_input->ToString();
1071 }
1072
InferImplConvertToAdapterTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)1073 AbstractBasePtr InferImplConvertToAdapterTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1074 const AbstractBasePtrList &args_abs_list) {
1075 // Inputs: a tensor.
1076 constexpr size_t args_num = 1;
1077 constexpr size_t input_index = 0;
1078 const std::string op_name = primitive->name();
1079 CheckArgsSize(op_name, args_abs_list, args_num);
1080 return SetAdapterFlag(op_name, args_abs_list[input_index], true);
1081 }
1082
InferImplConvertToMsTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)1083 AbstractBasePtr InferImplConvertToMsTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1084 const AbstractBasePtrList &args_abs_list) {
1085 // Inputs: a tensor.
1086 constexpr size_t args_num = 1;
1087 constexpr size_t input_index = 0;
1088 const std::string op_name = primitive->name();
1089 CheckArgsSize(op_name, args_abs_list, args_num);
1090 return SetAdapterFlag(op_name, args_abs_list[input_index], false);
1091 }
1092
InferImplDtypeToEnum(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)1093 AbstractBasePtr InferImplDtypeToEnum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
1094 const AbstractBasePtrList &args_abs_list) {
1095 constexpr size_t args_num = 3;
1096 CheckArgsSize(primitive->name(), args_abs_list, args_num);
1097 auto abs_type = args_abs_list[ops::kInputIndex2]->cast<AbstractTypePtr>();
1098 if (abs_type == nullptr) {
1099 const auto &op_name = GetValue<std::string>(args_abs_list[ops::kInputIndex0]->GetValue());
1100 const auto &arg_name = GetValue<std::string>(args_abs_list[ops::kInputIndex1]->GetValue());
1101 MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input '" << arg_name << "' expect a type, but got "
1102 << args_abs_list[ops::kInputIndex2]->ToString();
1103 }
1104 auto val_type = abs_type->BuildValue();
1105 MS_EXCEPTION_IF_NULL(val_type);
1106 auto dtype = val_type->cast<TypePtr>();
1107 MS_EXCEPTION_IF_NULL(dtype);
1108 int64_t type_id = GetTypeId(dtype->type_id());
1109 return std::make_shared<AbstractScalar>(type_id);
1110 }
1111
1112 #ifndef _MSC_VER
1113 // String
1114 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringMul, prim::kPrimStringMul, InferImplStringMul, nullptr);
1115 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringGetItem, prim::kPrimStringGetItem, InferImplStringGetItem, nullptr);
1116 // Tuple
1117 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr);
1118 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr);
1119 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr);
1120 // List
1121 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr);
1122 // Dict
1123 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
1124 // Slice
1125 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeSlice, prim::kPrimMakeSlice, InferImplMakeSlice, nullptr);
1126 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(SliceGetItem, prim::kPrimSliceGetItem, InferImplSliceGetItem, nullptr);
1127 // Type
1128 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr);
1129 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TopTypeOf, prim::kPrimTopTypeOf, InferImplTopTypeof, nullptr);
1130 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr);
1131 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(IsInstance, prim::kPrimIsInstance, InferImplIsInstance, nullptr);
1132 // Shape
1133 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr);
1134 // Auto-Grad
1135 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr);
1136 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr);
1137 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
1138 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
1139 InferImplBroadcastGradientArgs, nullptr);
1140 // Other
1141 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Reusing, prim::kPrimReusing, InferImplReusing, nullptr);
1142 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Taylor, prim::kPrimTaylor, InferImplTaylor, nullptr);
1143 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
1144 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Vmap, prim::kPrimVmap, InferImplVmap, nullptr);
1145 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringUpper, prim::kPrimStringUpper, InferImplStringUpper, nullptr);
1146 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringLower, prim::kPrimStringLower, InferImplStringLower, nullptr);
1147 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ConvertToAdapterTensor, prim::kPrimConvertToAdapterTensor,
1148 InferImplConvertToAdapterTensor, nullptr);
1149 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ConvertToMsTensor, prim::kPrimConvertToMsTensor, InferImplConvertToMsTensor,
1150 nullptr);
1151 REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DtypeToEnum, prim::kPrimDtypeToEnum, InferImplDtypeToEnum, nullptr);
1152 #else
RegPrimitiveFrontEval()1153 void RegPrimitiveFrontEval() {
1154 // String
1155 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStringMul,
1156 InferImplStringMul, nullptr);
1157 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStringGetItem,
1158 InferImplStringGetItem, nullptr);
1159 // Tuple
1160 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleReversed,
1161 InferImplTupleReversed, nullptr);
1162 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleDiv,
1163 InferImplTupleDiv, nullptr);
1164 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleToArray,
1165 InferImplTuple2Array, nullptr);
1166 // List
1167 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimListReduce,
1168 InferImplListReduce, nullptr);
1169 // Dict
1170 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimDictLen,
1171 InferImplDictLen, nullptr);
1172 // Slice
1173 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimMakeSlice,
1174 InferImplMakeSlice, nullptr);
1175 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimSliceGetItem,
1176 InferImplSliceGetItem, nullptr);
1177 // Type
1178 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTypeOf,
1179 InferImplTypeof, nullptr);
1180 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTopTypeOf,
1181 InferImplTopTypeof, nullptr);
1182 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimHasType,
1183 InferImplHasType, nullptr);
1184 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimIsInstance,
1185 InferImplIsInstance, nullptr);
1186 // Shape
1187 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimReducedShape,
1188 InferImplReduceShape, nullptr);
1189 // Auto-Grad
1190 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStopGradient,
1191 InferImplStopGradient, nullptr);
1192 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimFakeBprop,
1193 InferImplFakeBprop, nullptr);
1194 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimJ, InferImplJ,
1195 nullptr);
1196 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
1197 prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs,
1198 nullptr);
1199 // Other
1200 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTaylor,
1201 InferImplTaylor, nullptr);
1202 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimShard,
1203 InferImplShard, nullptr);
1204 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimVmap,
1205 InferImplVmap, nullptr);
1206 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStringUpper,
1207 InferImplStringUpper, nullptr);
1208 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimStringLower,
1209 InferImplStringLower, nullptr);
1210 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
1211 prim::kPrimConvertToAdapterTensor, InferImplConvertToAdapterTensor,
1212 nullptr);
1213 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
1214 prim::kPrimConvertToMsTensor, InferImplConvertToMsTensor, nullptr);
1215 abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimDtypeToEnum,
1216 InferImplDtypeToEnum, nullptr);
1217 } // namespace abstract
1218 #endif
1219 } // namespace abstract
1220 } // namespace mindspore
1221