• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include "ir/dtype.h"
19 #include "utils/log_adapter.h"
20 #include "abstract/param_validator.h"
21 #include "abstract/ops/infer_functions.h"
22 #include "abstract/utils.h"
23 #include "utils/anf_utils.h"
24 #include "utils/ms_context.h"
25 #include "utils/symbolic.h"
26 #include "utils/shape_utils.h"
27 #include "ops/ops_func_impl/real_div.h"
28 #include "ops/ops_func_impl/add.h"
29 #include "ops/ops_func_impl/mul.h"
30 #include "ops/ops_func_impl/square.h"
31 #include "utils/check_convert_utils.h"
32 
33 namespace {
34 constexpr auto kRankSize = "rank_size";
35 }  // namespace
36 
37 namespace mindspore {
38 namespace ops {
39 // Apply ops will have a refractor and add_infer is just a temp modify
40 auto AddInfer = [](const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
__anon7743a9f50202(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &input_args) 41                    const AbstractBasePtrList &input_args) {
42   auto add_op = AddFuncImpl();
43   return abstract::MakeAbstract(add_op.InferShape(primitive, input_args), add_op.InferType(primitive, input_args));
44 };
45 }  // namespace ops
46 
47 namespace abstract {
InferImplidentity(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)48 AbstractBasePtr InferImplidentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
49                                   const AbstractBasePtrList &args_abs_list) {
50   // An object of a subclass of AbstractBase
51   CheckArgsSize(primitive->name(), args_abs_list, 1);
52   return args_abs_list[0];
53 }
54 
InferImplEnvironAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)55 AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
56                                     const AbstractBasePtrList &args_abs_list) {
57   // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
58   constexpr auto environ_add_input_size = 2;
59   CheckArgsSize(primitive->name(), args_abs_list, environ_add_input_size);
60   return std::make_shared<AbstractScalar>(kValueAny, std::make_shared<EnvType>());
61 }
62 
InferImplStateSetItem(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)63 AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
64                                       const AbstractBasePtrList &args_abs_list) {
65   // args: Two objects of a subclass of AbstractBase, key and value.
66   constexpr auto state_setitem_input_size = 2;
67   CheckArgsSize(primitive->name(), args_abs_list, state_setitem_input_size);
68 
69   TypePtr type = args_abs_list[0]->GetTypeTrack();
70   MS_EXCEPTION_IF_NULL(type);
71   if (type->type_id() != kObjectTypeRefKey && type->type_id() != kObjectTypeSymbolicKeyType) {
72     MS_LOG(EXCEPTION) << "First input of StateSetItem should be a RefKey or SymbolicKeyType but a " << type->ToString();
73   }
74   return std::make_shared<AbstractScalar>(kValueAny, kBool);
75 }
76 
InferImplDepend(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)77 AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
78                                 const AbstractBasePtrList &args_abs_list) {
79   constexpr auto depend_input_size = 2;
80   CheckArgsSize(primitive->name(), args_abs_list, depend_input_size);
81 
82   // If the dependent has a value, just return depended node.
83   // If depended node is not Any, the dependent maybe eliminated.
84   auto dependant_abstract = args_abs_list[1];
85   auto dependant_value = dependant_abstract->BuildValue();
86   MS_EXCEPTION_IF_NULL(dependant_value);
87   if (!dependant_value->ContainsValueAny()) {
88     return args_abs_list[0];
89   }
90   auto depends = args_abs_list[0];
91 
92   if (depends->isa<AbstractRefTensor>()) {
93     auto abs_ref = depends->cast<AbstractRefPtr>();
94     auto tensor_abs = abs_ref->ref();
95     MS_EXCEPTION_IF_NULL(tensor_abs);
96     return std::make_shared<AbstractRefTensor>(tensor_abs->Broaden()->cast<AbstractTensorPtr>(),
97                                                abs_ref->ref_key_value());
98   }
99 
100   auto depends_abs = depends->Broaden();  // Avoid eliminating the dependent node.
101   if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
102     // For scalar, need to set value to kValueAny, because broaden scalar will not change the value.
103     if (depends_abs->isa<AbstractScalar>()) {
104       depends_abs->set_value(kValueAny);
105     }
106   }
107   return depends_abs;
108 }
109 
InferImplUpdateState(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)110 AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
111                                      const AbstractBasePtrList &args_abs_list) {
112   if (args_abs_list.empty()) {
113     MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0";
114   }
115   MS_EXCEPTION_IF_NULL(args_abs_list[0]);
116   return args_abs_list[0]->Broaden();
117 }
118 
InferImplMakeRowTensor(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)119 AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
120                                        const AbstractBasePtrList &args_abs_list) {
121   // Inputs: two tensors and a tuple.
122   const std::string op_name = primitive->name();
123   constexpr size_t size_expected = 3;
124   CheckArgsSize(op_name, args_abs_list, size_expected);
125   auto indices = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
126   auto values = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
127   auto dense_shape = CheckArg<AbstractTuple>(op_name, args_abs_list, 2);
128 
129   auto indices_dtype = indices->element()->BuildType();
130   if (!indices_dtype->isa<Int>()) {
131     MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
132   }
133   auto indices_shp = indices->shape()->shape();
134   auto values_shp = values->shape()->shape();
135   auto is_values_dynamic = IsDynamic(values_shp);
136   if (!IsDynamic(indices_shp) && !is_values_dynamic) {
137     if (indices_shp.size() != 1) {
138       MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
139                               << " dimension tensor";
140     }
141     if (indices_shp[0] != values_shp[0]) {
142       MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
143                               << values_shp[0] << ", but got " << indices_shp[0];
144     }
145   }
146 
147   for (const auto &elem_type : dense_shape->ElementsType()) {
148     if (!elem_type->isa<Int>()) {
149       MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
150     }
151   }
152   auto dense_shape_value = dense_shape->BuildValue();
153   MS_EXCEPTION_IF_NULL(dense_shape_value);
154   auto dense_shape_valuetuple = dense_shape_value->cast<ValueTuplePtr>();
155   MS_EXCEPTION_IF_NULL(dense_shape_valuetuple);
156   auto shp = dense_shape_valuetuple->value();
157   ShapeVector dense_shape_vec;
158   (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
159                        [](const ValuePtr &e) -> int64_t {
160                          auto elem = GetValue<int64_t>(e);
161                          return elem;
162                        });
163   if (dense_shape_vec.size() != values_shp.size() && !is_values_dynamic) {
164     MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values "
165                             << values_shp.size() << ", but got " << dense_shape_valuetuple->size();
166   }
167   for (size_t i = 0; i < dense_shape_vec.size(); i++) {
168     if (dense_shape_vec[i] < 0) {
169       MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
170                               << dense_shape_vec[i];
171     }
172     // The 0th mode might be less or exceed dense_shape[0] due to duplicated selection
173     if (!is_values_dynamic && i != 0 && dense_shape_vec[i] != values_shp[i]) {
174       MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
175                               << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
176     }
177   }
178   auto ret = std::make_shared<AbstractRowTensor>(values->element()->BuildType(), dense_shape_vec);
179   ret->set_indices(indices);
180   ret->set_values(values);
181   ret->set_dense_shape(dense_shape);
182   return ret;
183 }
184 
InferImplRowTensorGetValues(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)185 AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
186                                             const AbstractBasePtrList &args_abs_list) {
187   // Inputs: two tensors and a tuple.
188   const std::string op_name = primitive->name();
189   CheckArgsSize(op_name, args_abs_list, 1);
190   auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_abs_list, 0);
191   MS_EXCEPTION_IF_NULL(row_tensor->values());
192   return row_tensor->values();
193 }
194 
InferImplRowTensorGetIndices(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)195 AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
196                                              const AbstractBasePtrList &args_abs_list) {
197   // Inputs: two tensors and a tuple.
198   const std::string op_name = primitive->name();
199   CheckArgsSize(op_name, args_abs_list, 1);
200   auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_abs_list, 0);
201   MS_EXCEPTION_IF_NULL(row_tensor->indices());
202   return row_tensor->indices();
203 }
204 
InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)205 AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
206                                                 const AbstractBasePtrList &args_abs_list) {
207   // Inputs: two tensors and a tuple.
208   const std::string op_name = primitive->name();
209   CheckArgsSize(op_name, args_abs_list, 1);
210   auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_abs_list, 0);
211   MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
212   return row_tensor->dense_shape();
213 }
214 
InferImplRowTensorAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)215 AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
216                                       const AbstractBasePtrList &args_abs_list) {
217   // Inputs: row tensor and tensor.
218   const std::string op_name = primitive->name();
219   constexpr size_t args_size = 2;
220   CheckArgsSize(op_name, args_abs_list, args_size);
221   auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_abs_list, 0);
222   auto tensor = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
223   MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
224   MS_EXCEPTION_IF_NULL(tensor->shape());
225   return args_abs_list[0];
226 }
227 
InferImplAllReduce(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)228 AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
229                                    const AbstractBasePtrList &args_abs_list) {
230   const std::string op_name = primitive->name();
231   CheckArgsSize(op_name, args_abs_list, 1);
232   auto x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
233   MS_EXCEPTION_IF_NULL(x);
234   MS_EXCEPTION_IF_NULL(x->shape());
235   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
236 }
237 
InferImplReduceScatter(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)238 AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
239                                        const AbstractBasePtrList &args_abs_list) {
240   const std::string op_name = primitive->name();
241   CheckArgsSize(op_name, args_abs_list, 1);
242   auto x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
243   MS_EXCEPTION_IF_NULL(x);
244   MS_EXCEPTION_IF_NULL(x->shape());
245   auto tmp_shape = x->shape()->shape();
246   if (!primitive->HasAttr(kRankSize)) {
247     MS_LOG(EXCEPTION) << "Primitive don't have rank_size attr";
248   }
249   auto rank_size = GetValue<int64_t>(primitive->GetAttr(kRankSize));
250   if (tmp_shape.empty()) {
251     MS_LOG(EXCEPTION) << "shape size is 0";
252   }
253   tmp_shape[0] = LongMulWithOverflowCheck(tmp_shape[0], rank_size);
254   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape));
255 }
256 
InferImplIsDimUnknown(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)257 AbstractBasePtr InferImplIsDimUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
258                                       const AbstractBasePtrList &args_abs_list) {
259   constexpr size_t input_size = 1;
260   const std::string &op_name = primitive->name();
261   CheckArgsSize(op_name, args_abs_list, input_size);
262   auto abs = args_abs_list[0];
263   if (abs->isa<AbstractAny>()) {
264     return std::make_shared<AbstractAny>();
265   }
266   if (!abs->isa<AbstractSequence>()) {
267     MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be tuple but got " << abs->ToString();
268   }
269   auto abs_seq = abs->cast<AbstractSequencePtr>();
270   return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(abs_seq->dynamic_len()), kBool);
271 }
272 
InferImplIsTensorBoolCond(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)273 AbstractBasePtr InferImplIsTensorBoolCond(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
274                                           const AbstractBasePtrList &args_abs_list) {
275   constexpr size_t input_size = 1;
276   const std::string &op_name = primitive->name();
277   CheckArgsSize(op_name, args_abs_list, input_size);
278   auto abs = args_abs_list[0];
279   if (!abs->isa<AbstractTensor>()) {
280     MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be a tensor but got " << abs->ToString();
281   }
282 
283   auto build_shape = abs->cast<AbstractTensorPtr>()->GetShape();
284   MS_EXCEPTION_IF_NULL(build_shape);
285   if (build_shape->IsDimUnknown()) {
286     return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
287   }
288   auto shape = build_shape->cast<abstract::ShapePtr>()->shape();
289   if (shape.size() == 0) {
290     return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
291   }
292   if (shape.size() == 1 && (shape[0] == abstract::Shape::kShapeDimAny || shape[0] == 1)) {
293     return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
294   }
295   MS_EXCEPTION(ValueError) << "Only tensor which shape is () or (1,) can be converted to bool, "
296                            << "but got tensor shape is " << build_shape->ToString();
297 }
298 
InferImplIsShapeUnknown(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)299 AbstractBasePtr InferImplIsShapeUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
300                                         const AbstractBasePtrList &args_abs_list) {
301   constexpr size_t input_size = 1;
302   const std::string &op_name = primitive->name();
303   CheckArgsSize(op_name, args_abs_list, input_size);
304   auto abs = args_abs_list[0];
305   if (!abs->isa<AbstractSequence>()) {
306     MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be tuple or list but got " << abs->ToString();
307   }
308   auto abs_seq = abs->cast<AbstractSequencePtr>();
309   bool is_shape_unknown = false;
310   if (abs_seq->dynamic_len()) {
311     is_shape_unknown = true;
312   } else {
313     auto &elements = abs_seq->elements();
314     for (size_t i = 0; i < elements.size(); ++i) {
315       auto cur = elements[i];
316       MS_EXCEPTION_IF_NULL(cur);
317       auto cur_val = cur->BuildValue();
318       MS_EXCEPTION_IF_NULL(cur_val);
319       if (cur_val->ContainsValueAny()) {
320         is_shape_unknown = true;
321         break;
322       }
323     }
324   }
325   return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(is_shape_unknown), kBool);
326 }
327 
InferImplIsElementUnknown(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)328 AbstractBasePtr InferImplIsElementUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
329                                           const AbstractBasePtrList &args_abs_list) {
330   constexpr size_t input_size = 1;
331   const std::string &op_name = primitive->name();
332   CheckArgsSize(op_name, args_abs_list, input_size);
333   auto abs = args_abs_list[0];
334   if (!abs->isa<AbstractSequence>()) {
335     MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be tuple or list but got " << abs->ToString();
336   }
337   auto abs_seq = abs->cast<AbstractSequencePtr>();
338   if (!abs_seq->dynamic_len()) {
339     MS_EXCEPTION(TypeError) << "The input of " << op_name << " should be variable length sequence.";
340   }
341   bool is_element_unknown = (abs_seq->dynamic_len_element_abs() == nullptr);
342   return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(is_element_unknown), kBool);
343 }
344 
InferImplLoad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)345 AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
346                               const AbstractBasePtrList &args_abs_list) {
347   // Inputs: Ref/Tensor, universal
348   constexpr auto load_input_size = 2;
349   CheckArgsSize(primitive->name(), args_abs_list, load_input_size);
350   auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(args_abs_list[0]);
351   if (ref_abs != nullptr) {
352     // Return tensor value if input is Ref.
353     return ref_abs->CloneAsTensor();
354   }
355   return args_abs_list[0]->Broaden();
356 }
357 
InferImplTransData(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)358 AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
359                                    const AbstractBasePtrList &args_abs_list) {
360   // An object of a subclass of AbstractBase
361   CheckArgsSize(primitive->name(), args_abs_list, 1);
362   auto output = args_abs_list[0];
363   MS_EXCEPTION_IF_NULL(output);
364   return output;
365 }
366 
InferImplTensorMove(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)367 AbstractBasePtr InferImplTensorMove(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
368                                     const AbstractBasePtrList &args_abs_list) {
369   // An object of a subclass of AbstractBase
370   CheckArgsSize(primitive->name(), args_abs_list, 1);
371   auto output = args_abs_list[0];
372   MS_EXCEPTION_IF_NULL(output);
373   return output;
374 }
375 
376 // Infer for MapTensor.default_value.
InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)377 AbstractBasePtr InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
378                                                   const AbstractBasePtrList &args_abs_list) {
379   CheckArgsSize(primitive->name(), args_abs_list, 1);
380   const auto &arg = args_abs_list[0];
381   MS_EXCEPTION_IF_NULL(arg);
382   auto abs_map_tensor = arg->cast_ptr<abstract::AbstractMapTensor>();
383   if (abs_map_tensor == nullptr) {
384     MS_EXCEPTION(TypeError) << "Expect MapTensor, but got " << arg->ToString();
385   }
386   return std::make_shared<AbstractScalar>(abs_map_tensor->default_value());
387 }
388 // Infer for MapTensor.permit_filter_value.
InferImplMapTensorGetPermitFilterValue(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)389 AbstractBasePtr InferImplMapTensorGetPermitFilterValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
390                                                        const AbstractBasePtrList &args_abs_list) {
391   CheckArgsSize(primitive->name(), args_abs_list, 1);
392   const auto &arg = args_abs_list[0];
393   MS_EXCEPTION_IF_NULL(arg);
394   auto abs_map_tensor = arg->cast_ptr<abstract::AbstractMapTensor>();
395   if (abs_map_tensor == nullptr) {
396     MS_EXCEPTION(TypeError) << "Expect MapTensor, but got " << arg->ToString();
397   }
398   return std::make_shared<AbstractScalar>(abs_map_tensor->permit_filter_value());
399 }
400 // Infer for MapTensor.evict_filter_value.
InferImplMapTensorGetEvictFilterValue(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)401 AbstractBasePtr InferImplMapTensorGetEvictFilterValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
402                                                       const AbstractBasePtrList &args_abs_list) {
403   CheckArgsSize(primitive->name(), args_abs_list, 1);
404   const auto &arg = args_abs_list[0];
405   MS_EXCEPTION_IF_NULL(arg);
406   auto abs_map_tensor = arg->cast_ptr<abstract::AbstractMapTensor>();
407   if (abs_map_tensor == nullptr) {
408     MS_EXCEPTION(TypeError) << "Expect MapTensor, but got " << arg->ToString();
409   }
410   return std::make_shared<AbstractScalar>(abs_map_tensor->evict_filter_value());
411 }
412 }  // namespace abstract
413 }  // namespace mindspore
414