1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include "ir/dtype.h"
19 #include "utils/log_adapter.h"
20 #include "abstract/param_validator.h"
21 #include "abstract/ops/infer_functions.h"
22 #include "abstract/utils.h"
23 #include "utils/anf_utils.h"
24 #include "utils/ms_context.h"
25 #include "utils/symbolic.h"
26 #include "utils/shape_utils.h"
27 #include "ops/ops_func_impl/real_div.h"
28 #include "ops/ops_func_impl/add.h"
29 #include "ops/ops_func_impl/mul.h"
30 #include "ops/ops_func_impl/square.h"
31 #include "utils/check_convert_utils.h"
32
33 namespace {
34 constexpr auto kRankSize = "rank_size";
35 } // namespace
36
37 namespace mindspore {
38 namespace ops {
39 // Apply ops will have a refractor and add_infer is just a temp modify
40 auto AddInfer = [](const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
__anon7743a9f50202(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &input_args) 41 const AbstractBasePtrList &input_args) {
42 auto add_op = AddFuncImpl();
43 return abstract::MakeAbstract(add_op.InferShape(primitive, input_args), add_op.InferType(primitive, input_args));
44 };
45 } // namespace ops
46
47 namespace abstract {
InferImplidentity(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)48 AbstractBasePtr InferImplidentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
49 const AbstractBasePtrList &args_abs_list) {
50 // An object of a subclass of AbstractBase
51 CheckArgsSize(primitive->name(), args_abs_list, 1);
52 return args_abs_list[0];
53 }
54
InferImplEnvironAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)55 AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
56 const AbstractBasePtrList &args_abs_list) {
57 // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
58 constexpr auto environ_add_input_size = 2;
59 CheckArgsSize(primitive->name(), args_abs_list, environ_add_input_size);
60 return std::make_shared<AbstractScalar>(kValueAny, std::make_shared<EnvType>());
61 }
62
InferImplStateSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)63 AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
64 const AbstractBasePtrList &args_abs_list) {
65 // args: Two objects of a subclass of AbstractBase, key and value.
66 constexpr auto state_setitem_input_size = 2;
67 CheckArgsSize(primitive->name(), args_abs_list, state_setitem_input_size);
68
69 TypePtr type = args_abs_list[0]->GetTypeTrack();
70 MS_EXCEPTION_IF_NULL(type);
71 if (type->type_id() != kObjectTypeRefKey && type->type_id() != kObjectTypeSymbolicKeyType) {
72 MS_LOG(EXCEPTION) << "First input of StateSetItem should be a RefKey or SymbolicKeyType but a " << type->ToString();
73 }
74 return std::make_shared<AbstractScalar>(kValueAny, kBool);
75 }
76
InferImplDepend(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)77 AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
78 const AbstractBasePtrList &args_abs_list) {
79 constexpr auto depend_input_size = 2;
80 CheckArgsSize(primitive->name(), args_abs_list, depend_input_size);
81
82 // If the dependent has a value, just return depended node.
83 // If depended node is not Any, the dependent maybe eliminated.
84 auto dependant_abstract = args_abs_list[1];
85 auto dependant_value = dependant_abstract->BuildValue();
86 MS_EXCEPTION_IF_NULL(dependant_value);
87 if (!dependant_value->ContainsValueAny()) {
88 return args_abs_list[0];
89 }
90 auto depends = args_abs_list[0];
91
92 if (depends->isa<AbstractRefTensor>()) {
93 auto abs_ref = depends->cast<AbstractRefPtr>();
94 auto tensor_abs = abs_ref->ref();
95 MS_EXCEPTION_IF_NULL(tensor_abs);
96 return std::make_shared<AbstractRefTensor>(tensor_abs->Broaden()->cast<AbstractTensorPtr>(),
97 abs_ref->ref_key_value());
98 }
99
100 auto depends_abs = depends->Broaden(); // Avoid eliminating the dependent node.
101 if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
102 // For scalar, need to set value to kValueAny, because broaden scalar will not change the value.
103 if (depends_abs->isa<AbstractScalar>()) {
104 depends_abs->set_value(kValueAny);
105 }
106 }
107 return depends_abs;
108 }
109
InferImplUpdateState(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)110 AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
111 const AbstractBasePtrList &args_abs_list) {
112 if (args_abs_list.empty()) {
113 MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0";
114 }
115 MS_EXCEPTION_IF_NULL(args_abs_list[0]);
116 return args_abs_list[0]->Broaden();
117 }
118
InferImplMakeRowTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)119 AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
120 const AbstractBasePtrList &args_abs_list) {
121 // Inputs: two tensors and a tuple.
122 const std::string op_name = primitive->name();
123 constexpr size_t size_expected = 3;
124 CheckArgsSize(op_name, args_abs_list, size_expected);
125 auto indices = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
126 auto values = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
127 auto dense_shape = CheckArg<AbstractTuple>(op_name, args_abs_list, 2);
128
129 auto indices_dtype = indices->element()->BuildType();
130 if (!indices_dtype->isa<Int>()) {
131 MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
132 }
133 auto indices_shp = indices->shape()->shape();
134 auto values_shp = values->shape()->shape();
135 auto is_values_dynamic = IsDynamic(values_shp);
136 if (!IsDynamic(indices_shp) && !is_values_dynamic) {
137 if (indices_shp.size() != 1) {
138 MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
139 << " dimension tensor";
140 }
141 if (indices_shp[0] != values_shp[0]) {
142 MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
143 << values_shp[0] << ", but got " << indices_shp[0];
144 }
145 }
146
147 for (const auto &elem_type : dense_shape->ElementsType()) {
148 if (!elem_type->isa<Int>()) {
149 MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
150 }
151 }
152 auto dense_shape_value = dense_shape->BuildValue();
153 MS_EXCEPTION_IF_NULL(dense_shape_value);
154 auto dense_shape_valuetuple = dense_shape_value->cast<ValueTuplePtr>();
155 MS_EXCEPTION_IF_NULL(dense_shape_valuetuple);
156 auto shp = dense_shape_valuetuple->value();
157 ShapeVector dense_shape_vec;
158 (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
159 [](const ValuePtr &e) -> int64_t {
160 auto elem = GetValue<int64_t>(e);
161 return elem;
162 });
163 if (dense_shape_vec.size() != values_shp.size() && !is_values_dynamic) {
164 MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values "
165 << values_shp.size() << ", but got " << dense_shape_valuetuple->size();
166 }
167 for (size_t i = 0; i < dense_shape_vec.size(); i++) {
168 if (dense_shape_vec[i] < 0) {
169 MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
170 << dense_shape_vec[i];
171 }
172 // The 0th mode might be less or exceed dense_shape[0] due to duplicated selection
173 if (!is_values_dynamic && i != 0 && dense_shape_vec[i] != values_shp[i]) {
174 MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
175 << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
176 }
177 }
178 auto ret = std::make_shared<AbstractRowTensor>(values->element()->BuildType(), dense_shape_vec);
179 ret->set_indices(indices);
180 ret->set_values(values);
181 ret->set_dense_shape(dense_shape);
182 return ret;
183 }
184
InferImplRowTensorGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)185 AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
186 const AbstractBasePtrList &args_abs_list) {
187 // Inputs: two tensors and a tuple.
188 const std::string op_name = primitive->name();
189 CheckArgsSize(op_name, args_abs_list, 1);
190 auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_abs_list, 0);
191 MS_EXCEPTION_IF_NULL(row_tensor->values());
192 return row_tensor->values();
193 }
194
InferImplRowTensorGetIndices(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)195 AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
196 const AbstractBasePtrList &args_abs_list) {
197 // Inputs: two tensors and a tuple.
198 const std::string op_name = primitive->name();
199 CheckArgsSize(op_name, args_abs_list, 1);
200 auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_abs_list, 0);
201 MS_EXCEPTION_IF_NULL(row_tensor->indices());
202 return row_tensor->indices();
203 }
204
InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)205 AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
206 const AbstractBasePtrList &args_abs_list) {
207 // Inputs: two tensors and a tuple.
208 const std::string op_name = primitive->name();
209 CheckArgsSize(op_name, args_abs_list, 1);
210 auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_abs_list, 0);
211 MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
212 return row_tensor->dense_shape();
213 }
214
InferImplRowTensorAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)215 AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
216 const AbstractBasePtrList &args_abs_list) {
217 // Inputs: row tensor and tensor.
218 const std::string op_name = primitive->name();
219 constexpr size_t args_size = 2;
220 CheckArgsSize(op_name, args_abs_list, args_size);
221 auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_abs_list, 0);
222 auto tensor = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
223 MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
224 MS_EXCEPTION_IF_NULL(tensor->shape());
225 return args_abs_list[0];
226 }
227
InferImplAllReduce(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)228 AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
229 const AbstractBasePtrList &args_abs_list) {
230 const std::string op_name = primitive->name();
231 CheckArgsSize(op_name, args_abs_list, 1);
232 auto x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
233 MS_EXCEPTION_IF_NULL(x);
234 MS_EXCEPTION_IF_NULL(x->shape());
235 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
236 }
237
InferImplReduceScatter(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)238 AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
239 const AbstractBasePtrList &args_abs_list) {
240 const std::string op_name = primitive->name();
241 CheckArgsSize(op_name, args_abs_list, 1);
242 auto x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
243 MS_EXCEPTION_IF_NULL(x);
244 MS_EXCEPTION_IF_NULL(x->shape());
245 auto tmp_shape = x->shape()->shape();
246 if (!primitive->HasAttr(kRankSize)) {
247 MS_LOG(EXCEPTION) << "Primitive don't have rank_size attr";
248 }
249 auto rank_size = GetValue<int64_t>(primitive->GetAttr(kRankSize));
250 if (tmp_shape.empty()) {
251 MS_LOG(EXCEPTION) << "shape size is 0";
252 }
253 tmp_shape[0] = LongMulWithOverflowCheck(tmp_shape[0], rank_size);
254 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape));
255 }
256
InferImplIsDimUnknown(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)257 AbstractBasePtr InferImplIsDimUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
258 const AbstractBasePtrList &args_abs_list) {
259 constexpr size_t input_size = 1;
260 const std::string &op_name = primitive->name();
261 CheckArgsSize(op_name, args_abs_list, input_size);
262 auto abs = args_abs_list[0];
263 if (abs->isa<AbstractAny>()) {
264 return std::make_shared<AbstractAny>();
265 }
266 if (!abs->isa<AbstractSequence>()) {
267 MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be tuple but got " << abs->ToString();
268 }
269 auto abs_seq = abs->cast<AbstractSequencePtr>();
270 return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(abs_seq->dynamic_len()), kBool);
271 }
272
InferImplIsTensorBoolCond(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)273 AbstractBasePtr InferImplIsTensorBoolCond(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
274 const AbstractBasePtrList &args_abs_list) {
275 constexpr size_t input_size = 1;
276 const std::string &op_name = primitive->name();
277 CheckArgsSize(op_name, args_abs_list, input_size);
278 auto abs = args_abs_list[0];
279 if (!abs->isa<AbstractTensor>()) {
280 MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be a tensor but got " << abs->ToString();
281 }
282
283 auto build_shape = abs->cast<AbstractTensorPtr>()->GetShape();
284 MS_EXCEPTION_IF_NULL(build_shape);
285 if (build_shape->IsDimUnknown()) {
286 return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
287 }
288 auto shape = build_shape->cast<abstract::ShapePtr>()->shape();
289 if (shape.size() == 0) {
290 return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
291 }
292 if (shape.size() == 1 && (shape[0] == abstract::Shape::kShapeDimAny || shape[0] == 1)) {
293 return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
294 }
295 MS_EXCEPTION(ValueError) << "Only tensor which shape is () or (1,) can be converted to bool, "
296 << "but got tensor shape is " << build_shape->ToString();
297 }
298
InferImplIsShapeUnknown(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)299 AbstractBasePtr InferImplIsShapeUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
300 const AbstractBasePtrList &args_abs_list) {
301 constexpr size_t input_size = 1;
302 const std::string &op_name = primitive->name();
303 CheckArgsSize(op_name, args_abs_list, input_size);
304 auto abs = args_abs_list[0];
305 if (!abs->isa<AbstractSequence>()) {
306 MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be tuple or list but got " << abs->ToString();
307 }
308 auto abs_seq = abs->cast<AbstractSequencePtr>();
309 bool is_shape_unknown = false;
310 if (abs_seq->dynamic_len()) {
311 is_shape_unknown = true;
312 } else {
313 auto &elements = abs_seq->elements();
314 for (size_t i = 0; i < elements.size(); ++i) {
315 auto cur = elements[i];
316 MS_EXCEPTION_IF_NULL(cur);
317 auto cur_val = cur->BuildValue();
318 MS_EXCEPTION_IF_NULL(cur_val);
319 if (cur_val->ContainsValueAny()) {
320 is_shape_unknown = true;
321 break;
322 }
323 }
324 }
325 return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(is_shape_unknown), kBool);
326 }
327
InferImplIsElementUnknown(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)328 AbstractBasePtr InferImplIsElementUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
329 const AbstractBasePtrList &args_abs_list) {
330 constexpr size_t input_size = 1;
331 const std::string &op_name = primitive->name();
332 CheckArgsSize(op_name, args_abs_list, input_size);
333 auto abs = args_abs_list[0];
334 if (!abs->isa<AbstractSequence>()) {
335 MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be tuple or list but got " << abs->ToString();
336 }
337 auto abs_seq = abs->cast<AbstractSequencePtr>();
338 if (!abs_seq->dynamic_len()) {
339 MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be variable length sequence.";
340 }
341 bool is_element_unknown = (abs_seq->dynamic_len_element_abs() == nullptr);
342 return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(is_element_unknown), kBool);
343 }
344
InferImplLoad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)345 AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
346 const AbstractBasePtrList &args_abs_list) {
347 // Inputs: Ref/Tensor, universal
348 constexpr auto load_input_size = 2;
349 CheckArgsSize(primitive->name(), args_abs_list, load_input_size);
350 auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(args_abs_list[0]);
351 if (ref_abs != nullptr) {
352 // Return tensor value if input is Ref.
353 return ref_abs->CloneAsTensor();
354 }
355 return args_abs_list[0]->Broaden();
356 }
357
InferImplTransData(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)358 AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
359 const AbstractBasePtrList &args_abs_list) {
360 // An object of a subclass of AbstractBase
361 CheckArgsSize(primitive->name(), args_abs_list, 1);
362 auto output = args_abs_list[0];
363 MS_EXCEPTION_IF_NULL(output);
364 return output;
365 }
366
InferImplTensorMove(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)367 AbstractBasePtr InferImplTensorMove(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
368 const AbstractBasePtrList &args_abs_list) {
369 // An object of a subclass of AbstractBase
370 CheckArgsSize(primitive->name(), args_abs_list, 1);
371 auto output = args_abs_list[0];
372 MS_EXCEPTION_IF_NULL(output);
373 return output;
374 }
375
376 // Infer for MapTensor.default_value.
InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)377 AbstractBasePtr InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
378 const AbstractBasePtrList &args_abs_list) {
379 CheckArgsSize(primitive->name(), args_abs_list, 1);
380 const auto &arg = args_abs_list[0];
381 MS_EXCEPTION_IF_NULL(arg);
382 auto abs_map_tensor = arg->cast_ptr<abstract::AbstractMapTensor>();
383 if (abs_map_tensor == nullptr) {
384 MS_EXCEPTION(TypeError) << "Expect MapTensor, but got " << arg->ToString();
385 }
386 return std::make_shared<AbstractScalar>(abs_map_tensor->default_value());
387 }
388 // Infer for MapTensor.permit_filter_value.
InferImplMapTensorGetPermitFilterValue(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)389 AbstractBasePtr InferImplMapTensorGetPermitFilterValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
390 const AbstractBasePtrList &args_abs_list) {
391 CheckArgsSize(primitive->name(), args_abs_list, 1);
392 const auto &arg = args_abs_list[0];
393 MS_EXCEPTION_IF_NULL(arg);
394 auto abs_map_tensor = arg->cast_ptr<abstract::AbstractMapTensor>();
395 if (abs_map_tensor == nullptr) {
396 MS_EXCEPTION(TypeError) << "Expect MapTensor, but got " << arg->ToString();
397 }
398 return std::make_shared<AbstractScalar>(abs_map_tensor->permit_filter_value());
399 }
400 // Infer for MapTensor.evict_filter_value.
InferImplMapTensorGetEvictFilterValue(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)401 AbstractBasePtr InferImplMapTensorGetEvictFilterValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
402 const AbstractBasePtrList &args_abs_list) {
403 CheckArgsSize(primitive->name(), args_abs_list, 1);
404 const auto &arg = args_abs_list[0];
405 MS_EXCEPTION_IF_NULL(arg);
406 auto abs_map_tensor = arg->cast_ptr<abstract::AbstractMapTensor>();
407 if (abs_map_tensor == nullptr) {
408 MS_EXCEPTION(TypeError) << "Expect MapTensor, but got " << arg->ToString();
409 }
410 return std::make_shared<AbstractScalar>(abs_map_tensor->evict_filter_value());
411 }
412 } // namespace abstract
413 } // namespace mindspore
414