• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/op_utils.h"
18 
19 #include <algorithm>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "abstract/dshape.h"
28 #include "abstract/ops/primitive_infer_map.h"
29 #include "abstract/param_validator.h"
30 #include "ir/dtype/tensor_type.h"
31 #include "ir/dtype/type.h"
32 #include "ir/named.h"
33 #include "ir/primitive.h"
34 #include "ir/scalar.h"
35 #include "ir/tensor.h"
36 #include "ir/value.h"
37 #include "ir/kernel_tensor_value.h"
38 #include "mindapi/base/type_id.h"
39 #include "mindapi/src/helper.h"
40 #include "ops/op_name.h"
41 #include "ops/op_def.h"
42 #include "utils/check_convert_utils.h"
43 #include "utils/convert_utils_base.h"
44 #include "utils/log_adapter.h"
45 #include "utils/shape_utils.h"
46 #include "ir/func_graph.h"
47 #include "ops/ops_func_impl/simple_infer.h"
48 
49 namespace mindspore {
50 namespace ops {
CalBroadCastShape(const std::vector<int64_t> & x_shape,const std::vector<int64_t> & y_shape,const std::string & op_name,const std::string & op_x_name,const std::string & op_y_name)51 std::vector<int64_t> CalBroadCastShape(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
52                                        const std::string &op_name, const std::string &op_x_name,
53                                        const std::string &op_y_name) {
54   if (x_shape == y_shape) {
55     return x_shape;
56   }
57 
58   if (IsDynamicRank(x_shape) || IsDynamicRank(y_shape)) {
59     return {abstract::Shape::kShapeRankAny};
60   }
61 
62   std::vector<int64_t> broadcast_shape;
63   auto x_length = x_shape.size();
64   auto y_length = y_shape.size();
65   auto res = x_length > y_length;
66   size_t max_len = res ? x_length : y_length;
67   size_t min_len = res ? y_length : x_length;
68   const std::vector<int64_t> &max_shape = res ? x_shape : y_shape;
69   const std::vector<int64_t> &min_shape = res ? y_shape : x_shape;
70 
71   broadcast_shape = max_shape;
72   auto miss = max_len - min_len;
73   for (size_t i = 0; i < min_len; i++) {
74     auto dst_i = miss + i;
75     if (max_shape[dst_i] == 1) {
76       broadcast_shape[dst_i] = min_shape[i];
77     } else if (MS_UNLIKELY(max_shape[dst_i] == -1)) {
78       if (min_shape[i] != 1) {
79         broadcast_shape[dst_i] = min_shape[i];
80       }
81     } else if (MS_UNLIKELY(max_shape[dst_i] != min_shape[i] && min_shape[i] != -1 && min_shape[i] != 1)) {
82       auto x_shape_name = op_x_name + ".shape";
83       auto y_shape_name = op_y_name + ".shape";
84       MS_EXCEPTION(ValueError) << "For '" << op_name << "', " << x_shape_name << " and " << y_shape_name
85                                << " need to broadcast. The value of " << x_shape_name << "["
86                                << std::to_string(x_length + i) << "] or " << y_shape_name << "["
87                                << std::to_string(y_length + i)
88                                << "] must be 1 or -1 when they are not the same, but got " << x_shape_name << " = "
89                                << tensor::ShapeToString(x_shape) << " and " << y_shape_name << " = "
90                                << tensor::ShapeToString(y_shape);
91     }
92   }
93   return broadcast_shape;
94 }
95 
BroadCastInferShape(const std::string & op_name,const std::vector<AbstractBasePtr> & input_args)96 abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) {
97   MS_EXCEPTION_IF_NULL(input_args[kIndex0]);
98   MS_EXCEPTION_IF_NULL(input_args[kIndex1]);
99   ShapeVector x_shape;
100   if (!input_args[0]->GetShape()->isa<abstract::NoShape>()) {
101     x_shape = GetShapeFromTensor(input_args[0]);
102   }
103 
104   ShapeVector y_shape;
105   if (!input_args[1]->GetShape()->isa<abstract::NoShape>()) {
106     y_shape = GetShapeFromTensor(input_args[1]);
107   }
108 
109   auto broadcast_shape = CalBroadCastShape(x_shape, y_shape, op_name);
110   return std::make_shared<abstract::Shape>(broadcast_shape);
111 }
112 
BroadCastInferShape(const std::string & op_name,const ValuePtrList & input_values)113 ShapeVector BroadCastInferShape(const std::string &op_name, const ValuePtrList &input_values) {
114   const auto &x_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
115   const auto &y_tensor = input_values[kIndex1]->cast<tensor::BaseTensorPtr>();
116   MS_EXCEPTION_IF_NULL(x_tensor);
117   MS_EXCEPTION_IF_NULL(y_tensor);
118 
119   auto x_shape = x_tensor->shape();
120   auto y_shape = y_tensor->shape();
121 
122   auto broadcast_shape = CalBroadCastShape(x_shape, y_shape, op_name);
123   return broadcast_shape;
124 }
125 
IsBroadcastable(const std::vector<int64_t> & x_shape,const std::vector<int64_t> & y_shape)126 bool IsBroadcastable(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape) {
127   if (x_shape == y_shape) {
128     return true;
129   }
130 
131   if (IsDynamicRank(x_shape) || IsDynamicRank(y_shape)) {
132     return true;
133   }
134 
135   if (x_shape.size() < y_shape.size()) {
136     return false;
137   }
138 
139   auto miss = x_shape.size() - y_shape.size();
140   for (size_t i = 0; i < y_shape.size(); i++) {
141     if (x_shape[miss + i] == y_shape[i]) {
142       continue;
143     }
144     if (x_shape[miss + i] == -1) {
145       continue;
146     }
147     if (y_shape[i] == -1 || y_shape[i] == 1) {
148       continue;
149     }
150     return false;
151   }
152   return true;
153 }
154 
EltwiseGradInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)155 BaseShapePtr EltwiseGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
156   MS_EXCEPTION_IF_NULL(input_args[0]);
157   MS_EXCEPTION_IF_NULL(input_args[1]);
158   auto prim_name = primitive->name();
159   auto x = CheckAndConvertUtils::CheckArgsType(prim_name, input_args, 0, kObjectTypeTensorType);
160   auto dout = CheckAndConvertUtils::CheckArgsType(prim_name, input_args, 1, kObjectTypeTensorType);
161   auto x_shape_ptr = x->GetShape();
162   auto dout_shape_ptr = dout->GetShape();
163   MS_EXCEPTION_IF_NULL(x_shape_ptr);
164   MS_EXCEPTION_IF_NULL(dout_shape_ptr);
165   auto x_shape = x_shape_ptr->GetShapeVector();
166   auto dout_shape = dout_shape_ptr->GetShapeVector();
167   if (IsDynamicRank(x_shape) || IsDynamicRank(dout_shape)) {
168     return input_args[1]->GetShape()->Clone();
169   } else if (x_shape.size() != dout_shape.size()) {
170     MS_EXCEPTION(ValueError) << "Rank of x(" << x_shape.size() << ") and dout(" << dout_shape.size()
171                              << ") not equal, primitive name: " << prim_name << ".";
172   }
173 
174   for (size_t i = 0; i < x_shape.size(); i++) {
175     if (x_shape[i] != abstract::Shape::kShapeDimAny && dout_shape[i] != abstract::Shape::kShapeDimAny &&
176         x_shape[i] != dout_shape[i]) {
177       MS_EXCEPTION(ValueError) << "The " << i << "th dim of x(" << x_shape[i] << ") and dout(" << dout_shape[i]
178                                << ") not equal, primitive name: " << prim_name << ".";
179     }
180   }
181   return input_args[0]->GetShape()->Clone();
182 }
183 
EltwiseGradInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)184 TypePtr EltwiseGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
185   MS_EXCEPTION_IF_NULL(primitive);
186   MS_EXCEPTION_IF_NULL(input_args[0]);
187   MS_EXCEPTION_IF_NULL(input_args[1]);
188   auto grad_type = input_args[0]->GetType();
189   MS_EXCEPTION_IF_NULL(grad_type);
190   auto x_type = input_args[1]->GetType();
191   MS_EXCEPTION_IF_NULL(x_type);
192   if (grad_type->type_id() != x_type->type_id()) {
193     MS_LOG_EXCEPTION << "For " << primitive->name()
194                      << ", the grad type must be same as input type, but got grad_type: " << grad_type->ToString()
195                      << " and x_type: " << x_type->ToString();
196   }
197   return grad_type->Clone();
198 }
199 
EltwiseGradSimpleInferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values)200 ShapeArray EltwiseGradSimpleInferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) {
201   MS_EXCEPTION_IF_NULL(primitive);
202   const auto &prim_name = primitive->name();
203   const auto &dout_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
204   MS_EXCEPTION_IF_NULL(dout_tensor);
205   const auto &y_tensor = input_values[kIndex1]->cast<tensor::BaseTensorPtr>();
206   MS_EXCEPTION_IF_NULL(y_tensor);
207 
208   const auto &dout_shape = dout_tensor->shape();
209   const auto &y_shape = y_tensor->shape();
210 
211   if (dout_shape.size() != y_shape.size()) {
212     MS_EXCEPTION(ValueError) << "Rank of x(" << y_shape.size() << ") and dout(" << dout_shape.size()
213                              << ") not equal, primitive name: " << prim_name << ".";
214   }
215 
216   for (size_t i = 0; i < y_shape.size(); i++) {
217     if (y_shape[i] != dout_shape[i]) {
218       MS_EXCEPTION(ValueError) << "The " << i << "th dim of x(" << y_shape[i] << ") and dout(" << dout_shape[i]
219                                << ") not equal, primitive name: " << prim_name << ".";
220     }
221   }
222   return {dout_shape};
223 }
224 
EltwiseGradSimpleInferType(const PrimitivePtr & primitive,const ValuePtrList & input_values)225 TypePtrList EltwiseGradSimpleInferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) {
226   MS_EXCEPTION_IF_NULL(primitive);
227   const auto &dout_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
228   MS_EXCEPTION_IF_NULL(dout_tensor);
229   const auto &y_tensor = input_values[kIndex1]->cast<tensor::BaseTensorPtr>();
230   MS_EXCEPTION_IF_NULL(y_tensor);
231 
232   const auto &dout_type = dout_tensor->Dtype();
233   const auto &y_type = y_tensor->Dtype();
234 
235   if (dout_type->type_id() != y_type->type_id()) {
236     MS_LOG_EXCEPTION << "For " << primitive->name()
237                      << ", the grad type must be same as input type, but got grad_type: " << dout_type->ToString()
238                      << " and x_type: " << y_type->ToString();
239   }
240   return {dout_type};
241 }
242 
ReduceFuncCheckAxisInferImpl(const PrimitivePtr & prim,std::vector<int64_t> * axis,const size_t dim)243 void ReduceFuncCheckAxisInferImpl(const PrimitivePtr &prim, std::vector<int64_t> *axis, const size_t dim) {
244   MS_EXCEPTION_IF_NULL(axis);
245   int64_t dim_ = static_cast<int64_t>(dim);
246   for (size_t i = 0; i < axis->size(); i++) {
247     if (dim == 0) {
248       if ((axis->at(i) != -1 && axis->at(i) != 0)) {
249         MS_EXCEPTION(ValueError) << "For '" << prim->name()
250                                  << "', 'axis' must be in [-1, 0]. But got 'axis' = " << axis->at(i) << ".";
251       }
252       axis->at(i) = 0;
253       continue;
254     }
255     if (axis->at(i) < -dim_ || axis->at(i) >= dim_) {
256       MS_EXCEPTION(ValueError) << "For '" << prim->name() << "', 'axis' must be in [" << -dim_ << ", " << dim_
257                                << "). But got 'axis' = " << axis->at(i) << ".";
258     }
259     if (axis->at(i) >= -dim_ && axis->at(i) < 0) {
260       axis->at(i) += dim_;
261     }
262   }
263 }
264 
ReduceFuncCalShapeInferImpl(const PrimitivePtr &,const ShapeVector & x_shape,const std::vector<int64_t> & axis,bool keep_dims_value)265 ShapeVector ReduceFuncCalShapeInferImpl(const PrimitivePtr &, const ShapeVector &x_shape,
266                                         const std::vector<int64_t> &axis, bool keep_dims_value) {
267   ShapeVector out_shape;
268   ShapeVector axis_value;
269   (void)axis_value.insert(axis_value.end(), axis.begin(), axis.end());
270   (void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());
271   std::sort(axis_value.begin(), axis_value.end());
272   auto last = std::unique(axis_value.begin(), axis_value.end());
273   axis_value.erase(last, axis_value.end());
274   if (keep_dims_value) {
275     if (x_shape.size() == 0) {
276       return {};
277     }
278     for (auto i : axis_value) {
279       out_shape.at(LongToSize(i)) = 1;
280     }
281     if (axis_value.empty()) {
282       for (size_t i = 0; i < out_shape.size(); i++) {
283         out_shape.at(i) = 1;
284       }
285     }
286     return out_shape;
287   }
288   if (axis.size() == 0 || x_shape.size() == 0) {
289     return {};
290   }
291   std::vector<int64_t>::reverse_iterator it_re;
292   for (it_re = axis_value.rbegin(); it_re != axis_value.rend(); ++it_re) {
293     (void)out_shape.erase(out_shape.begin() + *it_re);
294   }
295   return out_shape;
296 }
297 
ReduceFuncCalShapeAxisDyn(const ShapeVector & x_shape,bool keep_dims)298 ShapeVector ReduceFuncCalShapeAxisDyn(const ShapeVector &x_shape, bool keep_dims) {
299   ShapeVector out_shape;
300   constexpr int dynamic_rank_value = -2;
301   if (!keep_dims) {
302     out_shape.push_back(dynamic_rank_value);
303   } else {
304     (void)out_shape.insert(out_shape.end(), x_shape.size(), -1LL);
305   }
306   return out_shape;
307 }
308 
CheckAndGetAxisValueFromAttr(const PrimitivePtr & primitive,std::vector<int64_t> * axis_value,int64_t *)309 void CheckAndGetAxisValueFromAttr(const PrimitivePtr &primitive, std::vector<int64_t> *axis_value, int64_t *) {
310   auto op_name = primitive->name();
311   auto axis_ptr = primitive->GetAttr("axis");
312   MS_EXCEPTION_IF_NULL(axis_ptr);
313   if (axis_ptr->isa<tensor::BaseTensor>()) {
314     *axis_value = CheckAndConvertUtils::CheckTensorIntValue("axis", axis_ptr, op_name);
315   } else {
316     *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", axis_ptr, op_name);
317   }
318 }
319 
CheckAndGetAxisValueFromScalar(const ValuePtr & input_value,const std::string & op_name,std::vector<int64_t> * axis_value,int64_t * axis_shape_v)320 bool CheckAndGetAxisValueFromScalar(const ValuePtr &input_value, const std::string &op_name,
321                                     std::vector<int64_t> *axis_value, int64_t *axis_shape_v) {
322   *axis_shape_v = 1;
323   bool is_dynamic = false;
324   if (IsValueKnown(input_value)) {
325     *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", input_value, op_name);
326   } else {
327     is_dynamic = true;
328   }
329   return is_dynamic;
330 }
331 
CheckAndGetAxisValueFromSequence(const abstract::AbstractBasePtr & abs,const ValuePtr & input_value,const std::string & op_name,std::vector<int64_t> * axis_value,int64_t * axis_shape_v)332 bool CheckAndGetAxisValueFromSequence(const abstract::AbstractBasePtr &abs, const ValuePtr &input_value,
333                                       const std::string &op_name, std::vector<int64_t> *axis_value,
334                                       int64_t *axis_shape_v) {
335   bool is_dynamic = false;
336   if (IsValueKnown(input_value)) {
337     *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", input_value, op_name);
338     if (axis_value->empty()) {
339       *axis_shape_v = 0;
340     }
341   } else {
342     is_dynamic = true;
343     auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
344     MS_EXCEPTION_IF_NULL(seq_abs);
345     *axis_shape_v = seq_abs->dynamic_len() ? -1 : SizeToLong(seq_abs->size());
346   }
347 
348   return is_dynamic;
349 }
350 
CheckAndGetAxisValueFromTensor(const std::vector<abstract::AbstractBasePtr> & input_args,const ValuePtr & input_value,const std::string & op_name,std::vector<int64_t> * axis_value,int64_t * axis_shape_v)351 bool CheckAndGetAxisValueFromTensor(const std::vector<abstract::AbstractBasePtr> &input_args,
352                                     const ValuePtr &input_value, const std::string &op_name,
353                                     std::vector<int64_t> *axis_value, int64_t *axis_shape_v) {
354   bool is_dynamic = false;
355   (void)CheckAndConvertUtils::CheckTensorTypeValid("axis", input_args[kInputIndex1]->GetType(), {kInt32, kInt64},
356                                                    op_name);
357   if (input_value->isa<tensor::BaseTensor>()) {
358     *axis_value = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, op_name);
359     if (axis_value->empty()) {
360       *axis_shape_v = 0;
361     }
362   } else {
363     is_dynamic = true;
364     auto axis_shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 1);
365     if (axis_shape->shape().size() > 1) {
366       MS_EXCEPTION(ValueError) << "For '" << op_name << "', the axis's shape length should be 1, but got '"
367                                << axis_shape->shape().size() << "'.";
368     } else if (axis_shape->shape().size() == 0) {
369       *axis_shape_v = 1;
370     } else {
371       *axis_shape_v = axis_shape->shape()[0];
372     }
373   }
374   return is_dynamic;
375 }
376 
CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> & input_args,std::vector<int64_t> * axis_value,int64_t * axis_shape_v,const PrimitivePtr & primitive)377 bool CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> &input_args, std::vector<int64_t> *axis_value,
378                           int64_t *axis_shape_v, const PrimitivePtr &primitive) {
379   MS_EXCEPTION_IF_NULL(axis_value);
380   MS_EXCEPTION_IF_NULL(axis_shape_v);
381   bool is_dynamic = false;
382   const std::string &op_name = primitive->name();
383   if (input_args.size() == 1) {
384     CheckAndGetAxisValueFromAttr(primitive, axis_value, axis_shape_v);
385     return false;
386   }
387   MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
388   auto input_value = input_args[kInputIndex1]->GetValue();
389   if (input_value->isa<KernelTensorValue>()) {
390     auto value_opt = GetArrayValue<int64_t>(input_args[kInputIndex1]);
391     auto value_array = value_opt.value();
392     *axis_value = value_array.ToVector();
393     return !value_opt.has_value();
394   }
395   if (input_args[kInputIndex1]->isa<abstract::AbstractScalar>()) {
396     is_dynamic = CheckAndGetAxisValueFromScalar(input_value, op_name, axis_value, axis_shape_v);
397   } else if (input_args[kInputIndex1]->isa<abstract::AbstractSequence>()) {
398     is_dynamic =
399       CheckAndGetAxisValueFromSequence(input_args[kInputIndex1], input_value, op_name, axis_value, axis_shape_v);
400   } else if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
401     is_dynamic = CheckAndGetAxisValueFromTensor(input_args, input_value, op_name, axis_value, axis_shape_v);
402   } else {
403     MS_EXCEPTION(ValueError) << "For '" << op_name
404                              << "', the second input type should be tensor or scalar, but got invalid abstract type:"
405                              << input_args[kInputIndex1]->type_name() << ".";
406   }
407   return is_dynamic;
408 }
409 
IsDynamicShapeSkipExecute(const bool skip_mode,const ShapeVector & axes_shape)410 bool IsDynamicShapeSkipExecute(const bool skip_mode, const ShapeVector &axes_shape) {
411   // Skip run ReduceSum when axis is a Empty Tensor
412   if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
413     return true;
414   }
415   return false;
416 }
MakeWrapDim(int64_t dim,int64_t dim_post_expr)417 int64_t MakeWrapDim(int64_t dim, int64_t dim_post_expr) {
418   // this will make range [-1, 0]
419   if (dim_post_expr <= 0) {
420     dim_post_expr = 1;
421   }
422 
423   if (dim < 0) {
424     dim += dim_post_expr;
425   }
426 
427   return dim;
428 }
429 
MakeDimMask(std::vector<int64_t> dims,int64_t ndim)430 std::bitset<kBitSize> MakeDimMask(std::vector<int64_t> dims, int64_t ndim) {
431   std::bitset<kBitSize> mask = std::bitset<kBitSize>();
432   if (dims.empty()) {
433     mask.flip();
434   } else {
435     for (int64_t dim : dims) {
436       mask.set(MakeWrapDim(dim, ndim));
437     }
438   }
439 
440   return mask;
441 }
442 
ReduceExtInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)443 abstract::ShapePtr ReduceExtInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
444   auto input_shape_ptr = input_args[0]->GetShape();
445   const auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_shape_ptr)[kShape];
446   int64_t ndim = static_cast<int64_t>(input_shape.size());
447   auto dim = GetValue<std::vector<int64_t>>(input_args[1]->GetValue());
448   auto keepdim = GetValue<bool>(input_args[2]->GetValue());
449   std::bitset<kBitSize> mask = MakeDimMask(dim, ndim);
450   auto shape = input_shape;
451 
452   for (int dim_temp = static_cast<int64_t>(shape.size()) - 1; dim_temp >= 0; dim_temp--) {
453     if (mask[dim_temp]) {
454       if (keepdim) {
455         shape[dim_temp] = 1;
456       } else {
457         shape.erase(shape.begin() + dim_temp);
458       }
459     }
460   }
461   return std::make_shared<abstract::Shape>(shape);
462 }
463 
ReduceExtInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)464 TypePtr ReduceExtInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
465   auto dtype_ptr = input_args[3]->GetValue();
466   (void)CheckAndConvertUtils::CheckTypeValid("input", input_args[0]->BuildType(),
467                                              common_valid_types_with_complex_and_bool, prim->name());
468   auto dtype_type_ptr = dtype_ptr->cast<TypePtr>();
469   if (dtype_type_ptr->type_id() == kMetaTypeNone) {
470     return input_args[0]->BuildType();
471   } else {
472     return dtype_ptr->cast<TypePtr>();
473   }
474 }
475 
ReduceBaseInferShape(const PrimitivePtr & primitive,const std::vector<abstract::AbstractBasePtr> & input_args,const std::string & prim_name)476 abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
477                                         const std::vector<abstract::AbstractBasePtr> &input_args,
478                                         const std::string &prim_name) {
479   MS_EXCEPTION_IF_NULL(primitive);
480   auto x_shape = GetShapeFromTensor(input_args[0]);
481   bool skip_mode = false;
482   if (primitive->HasAttr(kSkipMode)) {
483     auto skip_mode_value_ptr = primitive->GetAttr(kSkipMode);
484     MS_EXCEPTION_IF_NULL(skip_mode_value_ptr);
485     skip_mode = GetValue<bool>(skip_mode_value_ptr);
486   }
487   auto keep_dimis_value_ptr = primitive->GetAttr(kKeepDims);
488   MS_EXCEPTION_IF_NULL(keep_dimis_value_ptr);
489   if (!keep_dimis_value_ptr->isa<BoolImm>()) {
490     MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', 'keep_dims' must be Bool.";
491   }
492   bool keep_dims = GetValue<bool>(keep_dimis_value_ptr);
493   std::vector<int64_t> axis_value;
494   int64_t axis_shape = 1;
495   bool axis_is_dynamic = CheckAndGetAxisValue(input_args, &axis_value, &axis_shape, primitive);
496   if (IsDynamicShapeSkipExecute(skip_mode, {axis_shape})) {
497     return std::make_shared<abstract::Shape>(x_shape);
498   }
499   ShapeVector out_shape = {};
500   constexpr int dynamic_rank_value = -2;
501   if (IsDynamicRank(x_shape)) {
502     if (axis_shape == 0 && !keep_dims) {
503       return std::make_shared<abstract::Shape>(out_shape);
504     }
505     out_shape.push_back(dynamic_rank_value);
506     return std::make_shared<abstract::Shape>(out_shape);
507   }
508   if (axis_shape == -1 && !keep_dims) {
509     out_shape.push_back(dynamic_rank_value);
510     return std::make_shared<abstract::Shape>(out_shape);
511   }
512   ReduceFuncCheckAxisInferImpl(primitive, &axis_value, x_shape.size());
513 
514   if (axis_is_dynamic) {
515     out_shape = ReduceFuncCalShapeAxisDyn(x_shape, keep_dims);
516     return std::make_shared<abstract::Shape>(out_shape);
517   }
518   out_shape = ReduceFuncCalShapeInferImpl(primitive, x_shape, axis_value, keep_dims);
519   return std::make_shared<abstract::Shape>(out_shape);
520 }
521 
ReduceBaseInferType(const PrimitivePtr & prim,const std::vector<abstract::AbstractBasePtr> & input_args,const std::set<TypePtr> & check_list)522 TypePtr ReduceBaseInferType(const PrimitivePtr &prim, const std::vector<abstract::AbstractBasePtr> &input_args,
523                             const std::set<TypePtr> &check_list) {
524   MS_EXCEPTION_IF_NULL(prim);
525   MS_EXCEPTION_IF_NULL(input_args[0]);
526   auto x_type = input_args[0]->GetType();
527   (void)CheckAndConvertUtils::CheckTensorTypeValid("x dtype", x_type, check_list, prim->name());
528   return x_type;
529 }
530 
SetPadShape(const ShapeVector & x_shape,const ArrayValue<int64_t> & paddings)531 BaseShapePtr SetPadShape(const ShapeVector &x_shape, const ArrayValue<int64_t> &paddings) {
532   const size_t kNum2 = 2;
533   auto out_shape = x_shape;
534   auto x_rank = x_shape.size();
535   for (size_t i = 0; i < paddings.size() / kNum2; i++) {
536     auto pad_idx = i * kNum2;
537     if (out_shape[x_rank - i - 1] != abstract::Shape::kShapeDimAny && !paddings.IsValueUnknown(pad_idx) &&
538         !paddings.IsValueUnknown(pad_idx + kIndex1)) {
539       auto paddings_l = paddings[pad_idx];
540       auto paddings_r = paddings[pad_idx + kIndex1];
541       out_shape[x_rank - i - kIndex1] = out_shape[x_rank - i - kIndex1] + paddings_l + paddings_r;
542     } else {
543       out_shape[x_rank - i - kIndex1] = abstract::Shape::kShapeDimAny;
544     }
545   }
546   return std::make_shared<abstract::Shape>(out_shape);
547 }
548 
PadInferShapeBase(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args,const size_t pad_dim)549 BaseShapePtr PadInferShapeBase(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
550                                const size_t pad_dim) {
551   MS_EXCEPTION_IF_NULL(primitive);
552   auto x_base_shape = input_args[kInputIndex0]->GetShape();
553   auto x_shape = x_base_shape->GetShapeVector();
554   // input x dynamic rank
555   MS_EXCEPTION_IF_NULL(x_base_shape);
556   if (x_base_shape->IsDimUnknown()) {
557     return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
558   }
559   // input x dynamic shape
560   auto x_rank = x_shape.size();
561   constexpr size_t minValidDim = 1;
562   constexpr size_t maxValidDim = 2;
563   if (x_rank != pad_dim + minValidDim && x_rank != pad_dim + maxValidDim) {
564     MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', input should be " << pad_dim + minValidDim
565                              << "D or " << pad_dim + maxValidDim << "D, but got " << x_rank;
566   }
567   // padding
568   auto paddings_opt = GetArrayValue<int64_t>(input_args[kInputIndex1]);
569   if (!paddings_opt.has_value()) {
570     ShapeVector out_shape = x_shape;
571     for (size_t dim = 1; dim <= pad_dim; ++dim) {
572       out_shape[x_rank - dim] = abstract::Shape::kShapeDimAny;
573     }
574     return std::make_shared<abstract::Shape>(std::move(out_shape));
575   }
576   constexpr size_t kScaleNum = 2;
577   auto paddings = paddings_opt.value();
578   if (paddings.size() != pad_dim * kScaleNum) {
579     MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the padding length should be "
580                              << pad_dim * kScaleNum << ", but got " << paddings.size();
581   }
582 
583   auto out_shape = SetPadShape(x_shape, paddings);
584   return out_shape;
585 }
586 
ObscureShapeEqual(const ShapeVector & lhs,const ShapeVector & rhs)587 bool ObscureShapeEqual(const ShapeVector &lhs, const ShapeVector &rhs) {
588   if (lhs == rhs) {
589     return true;
590   }
591   if (lhs.size() != rhs.size()) {
592     return false;
593   }
594   for (size_t i = 0; i < lhs.size(); ++i) {
595     if (lhs[i] != rhs[i] && lhs[i] != -1 && rhs[i] != -1) {
596       return false;
597     }
598   }
599   return true;
600 }
601 
GetSequenceValue(const std::string & arg_name,const AbstractBasePtr & abs,const std::string & prim_name)602 std::vector<int64_t> GetSequenceValue(const std::string &arg_name, const AbstractBasePtr &abs,
603                                       const std::string &prim_name) {
604   MS_EXCEPTION_IF_NULL(abs);
605   auto abs_seq = dyn_cast<abstract::AbstractSequence>(abs);
606   MS_EXCEPTION_IF_NULL(abs_seq);
607   if (abs_seq->dynamic_len()) {
608     return std::vector<int64_t>{abstract::Shape::kShapeRankAny};
609   }
610   std::vector<int64_t> out_shape;
611   for (auto element : abs_seq->elements()) {
612     auto element_val = element->GetValue();
613     if (element_val->ContainsValueAny()) {
614       out_shape.push_back(abstract::Shape::kShapeDimAny);
615     } else if (element_val->isa<Int64Imm>()) {
616       (void)out_shape.emplace_back(GetValue<ShapeValueDType>(element_val));
617     } else if (element_val->isa<Int32Imm>()) {
618       (void)out_shape.emplace_back(GetValue<int32_t>(element_val));
619     } else {
620       MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
621                               << " must be one of ['tuple', 'list'] with all Int elements, but got " << abs->ToString();
622     }
623   }
624   return out_shape;
625 }
626 
GetShapeValue(const PrimitivePtr & primitive,const AbstractBasePtr & arg)627 ShapeVector GetShapeValue(const PrimitivePtr &primitive, const AbstractBasePtr &arg) {
628   MS_EXCEPTION_IF_NULL(primitive);
629   auto prim_name = primitive->name();
630   auto abs_value = arg->GetValue();
631   MS_EXCEPTION_IF_NULL(abs_value);
632   auto arg_type = arg->GetType();
633   MS_EXCEPTION_IF_NULL(arg_type);
634 
635   if (IsValueKnown(abs_value)) {
636     if (CheckAndConvertUtils::IsTensor(arg)) {
637       return CheckAndConvertUtils::CheckTensorIntValue("shape", abs_value, "", arg_type);
638     } else if (CheckAndConvertUtils::IsSequence(arg)) {
639       return CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", arg, prim_name);
640     }
641   } else if (CheckAndConvertUtils::IsTensor(arg)) {
642     auto arg_shape = arg->GetShape()->GetShapeVector();
643     if (arg_shape.size() != 1) {
644       MS_EXCEPTION(ValueError) << "For Primitive[" << primitive->name()
645                                << "], Shape of shape value only could be one-dimensional";
646     }
647     if (IsDynamic(arg_shape)) {
648       return {abstract::Shape::kShapeRankAny};
649     }
650     auto shape_size = arg_shape[0];
651     return ShapeVector(shape_size, abstract::Shape::kShapeDimAny);
652   } else if (arg->isa<abstract::AbstractSequence>()) {
653     return GetSequenceValue("input[shape]", arg, prim_name);
654   }
655 
656   MS_EXCEPTION(TypeError) << "For " << prim_name << ", the input type must be Tensor/Tuple/List , but got"
657                           << arg_type->ToString() << ".";
658 }
659 
CheckSparseShape(ShapeVector sparse_shp,ShapeVector dense_shp)660 void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
661   constexpr auto csr_mul_batch_pos = 2;
662   int dlen = SizeToInt(sparse_shp.size()) - SizeToInt(dense_shp.size());
663   if (dlen < 0) {
664     MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, "
665                              << "but sparse tensor has " << sparse_shp.size() << " dimensions, "
666                              << "and dense tensor has " << dense_shp.size() << " dimensions. ";
667   }
668   for (int i = 0; i < dlen; i++) {
669     (void)dense_shp.insert(dense_shp.begin(), 1);
670   }
671   if (sparse_shp.size() != dense_shp.size()) {
672     MS_LOG(EXCEPTION) << "Failure: sparse_shp.size() != dense_shp.size().";
673   }
674   if (sparse_shp.size() < 1) {
675     MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
676   }
677   for (size_t i = 0; i < sparse_shp.size(); i++) {
678     auto s = sparse_shp[i];
679     auto d = dense_shp[i];
680     if (i < csr_mul_batch_pos) {
681       if (d != s && d != 1) {
682         MS_EXCEPTION(ValueError) << "Dense shape cannot broadcast to sparse shape.";
683       }
684     } else {
685       if (d != s) {
686         MS_EXCEPTION(ValueError) << "Currently, sparse shape and dense shape must equal in feature dimensions.";
687       }
688     }
689   }
690 }
691 
CheckSparseShape(const size_t shape_size,const size_t expected_dim,const std::string & arg_name)692 void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name) {
693   if (shape_size != expected_dim) {
694     MS_EXCEPTION(ValueError) << arg_name << " must be a " << expected_dim << "-dimensional tensor, but got a "
695                              << shape_size << "-dimensional tensor.";
696   }
697 }
698 
CheckSparseIndicesDtype(const TypePtr data_type,const std::string & arg_name)699 void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name) {
700   if (!(data_type->equal(kInt16) || data_type->equal(kInt32) || data_type->equal(kInt64))) {
701     MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
702                             << data_type->ToString() << ".";
703   }
704 }
705 
CheckSparseIndicesDtypeInt32(const TypePtr data_type,const std::string & arg_name)706 void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name) {
707   if (!data_type->equal(kInt32)) {
708     MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got "
709                             << data_type->ToString() << ".";
710   }
711 }
712 
ConvertToShapeVector(const abstract::AbstractTuplePtr & shape)713 ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape) {
714   auto shape_value = shape->GetValue()->cast<ValueTuplePtr>();
715   MS_EXCEPTION_IF_NULL(shape_value);
716   ShapeVector shape_vec;
717   (void)std::transform(std::begin(shape_value->value()), std::end(shape_value->value()), std::back_inserter(shape_vec),
718                        [](const ValuePtr &e) -> int64_t {
719                          auto elem = GetValue<int64_t>(e);
720                          return elem;
721                        });
722   return shape_vec;
723 }
724 
725 template <typename T>
InferSparseAttr(const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)726 std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list) {
727   MS_EXCEPTION_IF_NULL(primitive);
728   constexpr size_t kSizeExpect = 1;
729   if (args_abs_list.size() != kSizeExpect) {
730     MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the number of input should be " << kSizeExpect
731                       << ", but got " << args_abs_list.size() << ".";
732   }
733   constexpr size_t kIndex = 0;
734   auto abs = args_abs_list[kIndex];
735   MS_EXCEPTION_IF_NULL(abs);
736   // To avoid AbstractSparseTensors being generalized to AbstractTuple.
737   if (dyn_cast<T>(abs) == nullptr) {
738     auto abs_tuple = dyn_cast<abstract::AbstractTuple>(abs);
739     if (abs_tuple != nullptr) {
740       return std::make_shared<T>(abs_tuple->elements());
741     }
742   } else if (dyn_cast<T>(abs) != nullptr) {
743     return dyn_cast<T>(abs);
744   }
745   MS_EXCEPTION(TypeError) << "For \'" << primitive->name() << "\', input[" << kIndex
746                           << "] should be AbstractSparseTensor or AbstractTuple, but got " << abs->GetType()->ToString()
747                           << ".";
748 }
749 template std::shared_ptr<abstract::AbstractCSRTensor> InferSparseAttr(const PrimitivePtr &primitive,
750                                                                       const AbstractBasePtrList &args_abs_list);
751 template std::shared_ptr<abstract::AbstractCOOTensor> InferSparseAttr(const PrimitivePtr &primitive,
752                                                                       const AbstractBasePtrList &args_abs_list);
753 
754 template <typename T>
TensorToSequenceInfer(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)755 AbstractBasePtr TensorToSequenceInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
756   MS_EXCEPTION_IF_NULL(primitive);
757   auto prim_name = primitive->name();
758   constexpr size_t input_0_index = 0;
759 
760   auto x_shape = input_args[kInputIndex0]->GetShape()->GetShapeVector();
761   if (x_shape.size() > 1) {
762     MS_EXCEPTION(ValueError) << "For Primitive[" << prim_name << "], the input must be a 1-D Tensor, but got Tensor "
763                              << "with shape: " << x_shape << ".";
764   }
765 
766   auto x_type = input_args[input_0_index]->GetType();
767   MS_EXCEPTION_IF_NULL(x_type);
768   if (!x_type->isa<TensorType>()) {
769     MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input must be a Tensor but got "
770                             << x_type->ToString() << ".";
771   }
772   auto tensor_type = x_type->cast<TensorTypePtr>();
773   const auto &element_type = tensor_type->element();
774   MS_EXCEPTION_IF_NULL(element_type);
775   AbstractBasePtrList abs_list;
776   if (IsDynamic(x_shape)) {
777     abs_list.push_back(std::make_shared<abstract::AbstractScalar>(kValueAny, element_type));
778     auto abs = std::make_shared<T>(abs_list);
779     abs->CheckAndConvertToDynamicLenSequence();
780     return abs;
781   }
782   if (x_shape.empty()) {
783     abs_list.push_back(std::make_shared<abstract::AbstractScalar>(kValueAny, element_type));
784   } else {
785     for (int64_t i = 0; i < x_shape[0]; i++) {
786       abs_list.push_back(std::make_shared<abstract::AbstractScalar>(kValueAny, element_type));
787     }
788   }
789   auto abs = std::make_shared<T>(abs_list);
790   return abs;
791 }
792 
CheckDynamicLengthSequenceSetItem(const std::string & op_name,const abstract::AbstractSequencePtr & queue,const AbstractBasePtr & target)793 void CheckDynamicLengthSequenceSetItem(const std::string &op_name, const abstract::AbstractSequencePtr &queue,
794                                        const AbstractBasePtr &target) {
795   auto element_abs = queue->dynamic_len_element_abs();
796   if (element_abs == nullptr) {
797     MS_LOG(EXCEPTION) << "Empty variable len sequence can not setitem.";
798   }
799   const auto precondition_log = "For " + op_name + ", when the queue is dynamic length";
800   const auto standard_abs_description = "element within dynamic length sequence";
801   const auto differ_abs_description = "target element";
802   CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{element_abs, target},
803                                                       precondition_log, standard_abs_description,
804                                                       differ_abs_description);
805 }
806 
807 template <typename T>
InferSequenceSetItem(const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)808 AbstractBasePtr InferSequenceSetItem(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list) {
809   // Inputs: a tuple or list, a scalar whose value is an int64 number and an object of a subclass of AbstractBase.
810   MS_EXCEPTION_IF_NULL(primitive);
811   auto op_name = primitive->name();
812   constexpr int args_spec_size = 3;
813   constexpr size_t kIndex2 = 2;
814   abstract::CheckArgsSize(op_name, args_abs_list, args_spec_size);
815   auto queue = abstract::CheckArg<T>(op_name, args_abs_list, 0);
816   auto index = abstract::CheckArg<abstract::AbstractScalar>(op_name, args_abs_list, 1);
817 
818   auto index_type = index->GetType();
819   MS_EXCEPTION_IF_NULL(index_type);
820   if (index_type->type_id() != kInt64->type_id()) {
821     MS_EXCEPTION(TypeError) << op_name << " evaluator index should be an int64 number, but got a "
822                             << index_type->ToString() << " number.";
823   }
824   ValuePtr index_value = index->GetValue();
825   MS_EXCEPTION_IF_NULL(index_value);
826   auto target = args_abs_list[kIndex2];
827   MS_EXCEPTION_IF_NULL(target);
828   if (queue->dynamic_len()) {
829     CheckDynamicLengthSequenceSetItem(op_name, queue, target);
830     return queue->Clone();
831   }
832   if (index_value->ContainsValueAny()) {
833     // If the index is variable and the sequence is constant length, then all of the element within the sequence
834     // should have the same type and shape with the target input. The element within the return sequence should
835     // be all broadened.
836     const auto &elements = queue->elements();
837     if (elements.size() == 0) {
838       MS_LOG(EXCEPTION) << "Empty sequence can not setitem.";
839     }
840     const auto precondition_log = "For " + op_name + ", when the index is variable and the queue is constant length";
841     CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(elements, precondition_log);
842     auto first_element = elements[0];
843     const auto standard_abs_description = "element within constant length sequence";
844     const auto differ_abs_description = "target element";
845     CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{first_element, target},
846                                                         precondition_log, standard_abs_description,
847                                                         differ_abs_description);
848     return CheckAndConvertUtils::BroadenAllSequenceElements(queue);
849   }
850   auto index_int64_value = GetValue<int64_t>(index_value);
851   AbstractBasePtrList elements = queue->elements();
852   std::size_t nelems = elements.size();
853   if (nelems == 0) {
854     MS_EXCEPTION(ValueError) << "Can not setitem for an empty sequence.";
855   }
856   int64_t index_positive_value = index_int64_value >= 0 ? index_int64_value : index_int64_value + SizeToLong(nelems);
857   if (index_positive_value < 0 || index_positive_value >= SizeToLong(nelems)) {
858     MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << index_int64_value << " to set out of range: [-"
859                              << nelems << "," << (nelems - 1) << "].";
860   }
861   size_t index_unsigned_value = LongToSize(index_positive_value);
862   elements[index_unsigned_value] = args_abs_list[kIndex2];
863   MS_LOG(DEBUG) << "SetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString();
864   return std::make_shared<T>(elements, queue->sequence_nodes());
865 }
866 
867 template AbstractBasePtr InferSequenceSetItem<abstract::AbstractList>(const PrimitivePtr &primitive,
868                                                                       const AbstractBasePtrList &args_abs_list);
869 template AbstractBasePtr InferSequenceSetItem<abstract::AbstractTuple>(const PrimitivePtr &primitive,
870                                                                        const AbstractBasePtrList &args_abs_list);
871 
872 template AbstractBasePtr TensorToSequenceInfer<abstract::AbstractList>(const PrimitivePtr &primitive,
873                                                                        const std::vector<AbstractBasePtr> &input_args);
874 
875 template AbstractBasePtr TensorToSequenceInfer<abstract::AbstractTuple>(const PrimitivePtr &primitive,
876                                                                         const std::vector<AbstractBasePtr> &input_args);
877 
878 template <typename T>
GetScalarCastValue(const std::string & op_name,const ValuePtr & elem)879 T GetScalarCastValue(const std::string &op_name, const ValuePtr &elem) {
880   T res;
881   MS_EXCEPTION_IF_NULL(elem);
882   if (elem->isa<Int64Imm>()) {
883     auto elem_value = GetValue<int64_t>(elem);
884     res = static_cast<T>(elem_value);
885   } else if (elem->isa<Int32Imm>()) {
886     auto elem_value = GetValue<int32_t>(elem);
887     res = static_cast<T>(elem_value);
888   } else if (elem->isa<Int16Imm>()) {
889     auto elem_value = GetValue<int16_t>(elem);
890     res = static_cast<T>(elem_value);
891   } else if (elem->isa<Int8Imm>()) {
892     auto elem_value = GetValue<int8_t>(elem);
893     res = static_cast<T>(elem_value);
894   } else if (elem->isa<UInt64Imm>()) {
895     auto elem_value = GetValue<uint64_t>(elem);
896     res = static_cast<T>(elem_value);
897   } else if (elem->isa<UInt32Imm>()) {
898     auto elem_value = GetValue<uint32_t>(elem);
899     res = static_cast<T>(elem_value);
900   } else if (elem->isa<UInt16Imm>()) {
901     auto elem_value = GetValue<uint16_t>(elem);
902     res = static_cast<T>(elem_value);
903   } else if (elem->isa<UInt8Imm>()) {
904     auto elem_value = GetValue<uint8_t>(elem);
905     res = static_cast<T>(elem_value);
906   } else if (elem->isa<FP64Imm>()) {
907     auto elem_value = GetValue<double>(elem);
908     res = static_cast<T>(elem_value);
909   } else if (elem->isa<FP32Imm>()) {
910     auto elem_value = GetValue<float>(elem);
911     res = static_cast<T>(elem_value);
912   } else if (elem->isa<BoolImm>()) {
913     auto elem_value = GetValue<bool>(elem);
914     res = static_cast<T>(elem_value);
915   } else {
916     MS_EXCEPTION(TypeError) << "For op '" << op_name
917                             << "' input must be [int32, int64, float32, float64, bool], but got " << elem->ToString();
918   }
919   return res;
920 }
921 
922 template int64_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
923 template int32_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
924 template int16_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
925 template int8_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
926 template uint64_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
927 template uint32_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
928 template uint16_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
929 template uint8_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
930 template double GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
931 template float GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
932 template bool GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
933 
HighPriorityType(const TypePtr & x_type,const TypePtr & y_type,const std::string & op_name)934 TypePtr HighPriorityType(const TypePtr &x_type, const TypePtr &y_type, const std::string &op_name) {
935   static std::map<TypeId, size_t> prio_map = {{kNumberTypeFloat64, 1},
936                                               {kNumberTypeFloat32, 2},
937                                               {kNumberTypeInt64, 3},
938                                               {kNumberTypeInt32, 4},
939                                               {kNumberTypeBool, 5}};
940   auto x_iter = prio_map.find(x_type->type_id());
941   auto y_iter = prio_map.find(y_type->type_id());
942   if (x_iter == prio_map.end() || y_iter == prio_map.end()) {
943     MS_EXCEPTION(ValueError) << "For '" << op_name
944                              << "', the x and y type should be int or float, but got x type: " << x_type
945                              << " y type: " << y_type;
946   }
947   if (x_iter->second < y_iter->second) {
948     return x_type;
949   }
950   if (x_iter->second == y_iter->second && x_iter->first == kNumberTypeBool) {
951     return kInt32;
952   }
953   return y_type;
954 }
955 
GetInputDependValueList(const PrimitivePtr & op_prim)956 std::set<int64_t> GetInputDependValueList(const PrimitivePtr &op_prim) {
957   MS_EXCEPTION_IF_NULL(op_prim);
958   std::set<int64_t> depend_list;
959   mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_prim->name());
960   if (op_def == nullptr) {
961     // Use old Primitive infer.
962     auto op_infer_opt = abstract::GetPrimitiveInferImpl(op_prim);
963     if (!op_infer_opt.has_value()) {
964       if (op_prim->HasAttr(kAttrMeOpName)) {
965         auto ori_prim_name = GetValue<std::string>(op_prim->GetAttr(kAttrMeOpName));
966         op_infer_opt = abstract::GetPrimitiveInferImpl(std::make_shared<Primitive>(ori_prim_name));
967       }
968     }
969     if (op_infer_opt.has_value()) {
970       auto op_infer = op_infer_opt.value().Get();
971       if (op_infer != nullptr && depend_list.empty()) {
972         depend_list = op_infer->GetValueDependArgIndices();
973       }
974     }
975     return depend_list;
976   }
977 
978   depend_list = op_def->func_impl_.GetValueDependArgIndices();
979   if (!depend_list.empty()) {
980     return depend_list;
981   }
982   // if not defined the GetValueDependArgIndices() func in infer, consider all the no-Tensor
983   // input as value depend.
984   auto args = op_def->args_;
985   for (size_t i = 0; i < args.size(); i++) {
986     if (args[i].arg_dtype_ != mindspore::ops::OP_DTYPE::DT_TENSOR &&
987         args[i].arg_dtype_ != mindspore::ops::OP_DTYPE::DT_TUPLE_TENSOR &&
988         args[i].arg_dtype_ != mindspore::ops::OP_DTYPE::DT_LIST_TENSOR) {
989       (void)depend_list.insert(i);
990     }
991   }
992   return depend_list;
993 }
994 
GetInputIndexByName(const std::string & op_name,const std::string & input_name)995 size_t GetInputIndexByName(const std::string &op_name, const std::string &input_name) {
996   mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
997   if (op_def == nullptr) {
998     MS_LOG(INFO) << op_name << " is not defined in opdef.";
999     return SIZE_MAX;
1000   }
1001   auto ks_iter = op_def->indexes_.find(input_name);
1002   if (ks_iter != op_def->indexes_.end()) {
1003     size_t index = ks_iter->second;
1004     MS_LOG(INFO) << "Find " << input_name << "in " << index << "th input of OP " << op_name;
1005     return index;
1006   }
1007   MS_LOG(INFO) << "Not Find " << input_name << "in OP " << op_name;
1008   return SIZE_MAX;
1009 }
1010 
GetInputNameByIndex(const std::string & op_name,size_t index)1011 std::string GetInputNameByIndex(const std::string &op_name, size_t index) {
1012   mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
1013   if (op_def == nullptr) {
1014     return "";
1015   }
1016   if (index >= op_def->args_.size()) {
1017     MS_LOG(INTERNAL_EXCEPTION) << "Get input name by index out of range, index: " << index
1018                                << ", size: " << op_def->args_.size() << ", op name: " << op_name;
1019   }
1020   auto input = op_def->args_[index];
1021   return input.arg_name_;
1022 }
1023 
HasOpDef(const std::string & op_name)1024 bool HasOpDef(const std::string &op_name) {
1025   mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
1026   return op_def != nullptr;
1027 }
1028 
GetOpInputsNum(const std::string & op_name)1029 size_t GetOpInputsNum(const std::string &op_name) {
1030   mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
1031   if (op_def == nullptr) {
1032     MS_LOG(INFO) << op_name << " is not defined in opdef.";
1033     return SIZE_MAX;
1034   }
1035   return op_def->indexes_.size();
1036 }
1037 
1038 // This is used to convert arg with 'prim_init' of cnode convert to attr of primitive.
1039 // CNode in new mindir can be converted to old mindir by this function.
1040 // For example, {PrimAvgPool, x, kernel_size, strides, pad_mode, data_format} =>
1041 //              {PrimAvgPool, x}
ConvertArgsToAttr(const CNodePtr & cnode)1042 CNodePtr ConvertArgsToAttr(const CNodePtr &cnode) {
1043   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1044   MS_EXCEPTION_IF_NULL(prim);
1045   auto prim_name = prim->name();
1046   auto op_def = mindspore::ops::GetOpDef(prim_name);
1047   if (op_def == nullptr) {
1048     MS_LOG(DEBUG) << "Prim:" << prim->ToString()
1049                   << "is not a primitive defined in yaml, cannot convert args to attr, cnode:" << cnode->DebugString();
1050     return nullptr;
1051   }
1052   std::vector<AnfNodePtr> new_node_inputs = {cnode->input(0)};
1053   for (size_t arg_index = 0; arg_index < op_def->args_.size(); ++arg_index) {
1054     auto arg = op_def->args_[arg_index];
1055     if (!arg.as_init_arg_) {
1056       // origin is input , put the node input into new node inputs vector
1057       (void)new_node_inputs.emplace_back(cnode->input(arg_index + 1));
1058       continue;
1059     }
1060 
1061     auto arg_input_node = cnode->input(arg_index + 1);
1062     if (!arg_input_node->isa<ValueNode>()) {
1063       // arg is not ValueNode, Network has dynamic args, not support
1064       MS_LOG(INTERNAL_EXCEPTION) << "Node " << cnode->DebugString() << " with arg " << arg_input_node->DebugString()
1065                                  << " is dynamic, not supported now.";
1066       continue;
1067     }
1068     auto arg_value_node = arg_input_node->cast<ValueNodePtr>();
1069     auto arg_value = arg_value_node->value();
1070     prim->AddAttr(arg.arg_name_, arg_value);
1071   }
1072 
1073   auto func_graph = cnode->func_graph();
1074   MS_EXCEPTION_IF_NULL(func_graph);
1075   auto new_node = func_graph->NewCNode(new_node_inputs);
1076   new_node->set_abstract(cnode->abstract());
1077   new_node->set_fullname_with_scope(cnode->fullname_with_scope());
1078   return new_node;
1079 }
1080 
1081 template <typename T>
GetScalarValue(const ValuePtr & value)1082 std::optional<T> GetScalarValue(const ValuePtr &value) {
1083   MS_EXCEPTION_IF_NULL(value);
1084   if (value->isa<ValueAny>()) {
1085     return std::nullopt;
1086   }
1087 
1088   if (value->isa<KernelTensorValue>()) {
1089     auto kernel_tensor_value = value->cast<KernelTensorValuePtr>();
1090     MS_EXCEPTION_IF_NULL(kernel_tensor_value);
1091 
1092     MS_EXCEPTION_IF_CHECK_FAIL((kernel_tensor_value->GetDataSize() == sizeof(T)),
1093                                "The data size in kernel tensor value which contains a scalar [" +
1094                                  std::to_string(kernel_tensor_value->GetDataSize()) +
1095                                  "] is not equal to the data type size [" + std::to_string(sizeof(T)) + "]");
1096 
1097     const T *data_ptr = reinterpret_cast<const T *>(kernel_tensor_value->GetDataPtr());
1098     MS_EXCEPTION_IF_NULL(data_ptr);
1099     return *data_ptr;
1100   }
1101 
1102   return GetValue<T>(value);
1103 }
1104 
1105 // Specialization for std::string type.
1106 template <>
GetScalarValue(const ValuePtr & value)1107 MS_CORE_API std::optional<std::string> GetScalarValue(const ValuePtr &value) {
1108   MS_EXCEPTION_IF_NULL(value);
1109   if (value->isa<ValueAny>()) {
1110     return std::nullopt;
1111   }
1112 
1113   if (value->isa<KernelTensorValue>()) {
1114     auto kernel_tensor_value = value->cast<KernelTensorValuePtr>();
1115     MS_EXCEPTION_IF_NULL(kernel_tensor_value);
1116     const char *data_ptr = reinterpret_cast<const char *>(kernel_tensor_value->GetDataPtr());
1117     MS_EXCEPTION_IF_NULL(data_ptr);
1118     size_t str_len = kernel_tensor_value->GetDataSize();
1119 
1120     return std::string(data_ptr, data_ptr + str_len);
1121   }
1122 
1123   return GetValue<std::string>(value);
1124 }
1125 
1126 template MS_CORE_API std::optional<int64_t> GetScalarValue(const ValuePtr &value);
1127 template MS_CORE_API std::optional<int32_t> GetScalarValue(const ValuePtr &value);
1128 template MS_CORE_API std::optional<int16_t> GetScalarValue(const ValuePtr &value);
1129 template MS_CORE_API std::optional<int8_t> GetScalarValue(const ValuePtr &value);
1130 template MS_CORE_API std::optional<uint64_t> GetScalarValue(const ValuePtr &value);
1131 template MS_CORE_API std::optional<uint32_t> GetScalarValue(const ValuePtr &value);
1132 template MS_CORE_API std::optional<uint16_t> GetScalarValue(const ValuePtr &value);
1133 template MS_CORE_API std::optional<uint8_t> GetScalarValue(const ValuePtr &value);
1134 template MS_CORE_API std::optional<double> GetScalarValue(const ValuePtr &value);
1135 template MS_CORE_API std::optional<float> GetScalarValue(const ValuePtr &value);
1136 template MS_CORE_API std::optional<bool> GetScalarValue(const ValuePtr &value);
1137 
1138 // This interface is only used to convert values of type Sequence or Tensor to std::vector.
1139 template <typename T>
GetArrayValue(const ValuePtr & value)1140 std::optional<ArrayValue<T>> GetArrayValue(const ValuePtr &value) {
1141   MS_EXCEPTION_IF_NULL(value);
1142   if (value->isa<ValueAny>()) {
1143     return std::nullopt;
1144   }
1145 
1146   std::vector<T> array_data;
1147   if (value->isa<KernelTensorValue>()) {
1148     auto kernel_tensor_value = value->cast<KernelTensorValuePtr>();
1149     MS_EXCEPTION_IF_NULL(kernel_tensor_value);
1150 
1151     if (kernel_tensor_value->GetDataSize() % sizeof(T) != 0) {
1152       MS_LOG(EXCEPTION) << "The size is incompatible, kernel tensor value size: " << kernel_tensor_value->GetDataSize()
1153                         << ", expected element size: " << sizeof(T);
1154     }
1155 
1156     size_t element_size = kernel_tensor_value->GetDataSize() / sizeof(T);
1157     if (element_size != 0) {
1158       const T *data_ptr = reinterpret_cast<const T *>(kernel_tensor_value->GetDataPtr());
1159       MS_EXCEPTION_IF_NULL(data_ptr);
1160       array_data.assign(data_ptr, data_ptr + element_size);
1161     }
1162   } else if (value->isa<ValueSequence>()) {
1163     // Sequence structure: Data is stored discretely.
1164     auto value_seq = value->cast<ValueSequencePtr>();
1165     MS_EXCEPTION_IF_NULL(value_seq);
1166 
1167     const auto &element_values = value_seq->value();
1168     size_t element_size = element_values.size();
1169     array_data.reserve(element_size);
1170     for (size_t i = 0; i < element_size; i++) {
1171       const auto &element = element_values[i];
1172       MS_EXCEPTION_IF_NULL(element);
1173       if (element->isa<ValueAny>() || element->isa<None>()) {
1174         return std::nullopt;
1175       }
1176       if constexpr (std::is_same_v<T, float16>) {
1177         MS_LOG(EXCEPTION) << "For ValueSequence, float16 type is not support!";
1178       } else {
1179         array_data.push_back(GetValue<T>(element));
1180       }
1181     }
1182   } else if (value->isa<tensor::BaseTensor>()) {
1183     // Tensor structure: Data is stored continuously.
1184     auto tensor = value->cast<tensor::BaseTensorPtr>();
1185     MS_EXCEPTION_IF_NULL(tensor);
1186     size_t element_size = tensor->DataSize();
1187     T *data = reinterpret_cast<T *>(tensor->data_c());
1188     array_data.assign(data, data + element_size);
1189   } else {
1190     MS_LOG(EXCEPTION) << "Failed to get array value, expect sequence or tensor type, but got: " << value->type_name();
1191   }
1192   return std::optional<ArrayValue<T>>(std::in_place, std::move(array_data), std::set<size_t>());
1193 }
1194 
1195 template <typename T>
GetArrayValue(const AbstractBasePtr & abs_base)1196 std::optional<ArrayValue<T>> GetArrayValue(const AbstractBasePtr &abs_base) {
1197   MS_EXCEPTION_IF_NULL(abs_base);
1198   auto value = abs_base->GetValue();
1199   // If value is constant or is value sequence with some constant elements.
1200   if (!value->isa<ValueAny>()) {
1201     return GetArrayValue<T>(value);
1202   }
1203 
1204   // If value is ValueAny, need check whether abstract is AbstractSequence, it is in frontend.
1205   std::vector<T> array_data;
1206   std::set<size_t> unknown_value_indexes;
1207   if (abs_base->isa<abstract::AbstractSequence>()) {
1208     auto abs_sequence = abs_base->cast<abstract::AbstractSequencePtr>();
1209     if (abs_sequence->dynamic_len()) {
1210       return std::nullopt;
1211     }
1212     for (size_t i = 0; i < abs_sequence->size(); ++i) {
1213       auto elem_value = abs_sequence->elements()[i]->GetValue();
1214       if (elem_value->isa<ValueAny>() || elem_value->isa<None>()) {
1215         array_data.push_back(static_cast<T>(0));
1216         (void)unknown_value_indexes.insert(i);
1217         continue;
1218       }
1219       if constexpr (std::is_same_v<T, float16>) {
1220         MS_LOG(EXCEPTION) << "For ValueSequence, float16 type is not support!";
1221       } else {
1222         array_data.push_back(GetValue<T>(elem_value));
1223       }
1224     }
1225     return std::optional<ArrayValue<T>>(std::in_place, std::move(array_data), std::move(unknown_value_indexes));
1226   }
1227   // Only abstract sequence with ValueAny need to handle, other situation just return nullopt.
1228   return std::nullopt;
1229 }
1230 
1231 template MS_CORE_API std::optional<ArrayValue<int64_t>> GetArrayValue(const ValuePtr &value);
1232 template MS_CORE_API std::optional<ArrayValue<int32_t>> GetArrayValue(const ValuePtr &value);
1233 template MS_CORE_API std::optional<ArrayValue<int16_t>> GetArrayValue(const ValuePtr &value);
1234 template MS_CORE_API std::optional<ArrayValue<int8_t>> GetArrayValue(const ValuePtr &value);
1235 template MS_CORE_API std::optional<ArrayValue<uint64_t>> GetArrayValue(const ValuePtr &value);
1236 template MS_CORE_API std::optional<ArrayValue<uint32_t>> GetArrayValue(const ValuePtr &value);
1237 template MS_CORE_API std::optional<ArrayValue<uint16_t>> GetArrayValue(const ValuePtr &value);
1238 template MS_CORE_API std::optional<ArrayValue<uint8_t>> GetArrayValue(const ValuePtr &value);
1239 template MS_CORE_API std::optional<ArrayValue<double>> GetArrayValue(const ValuePtr &value);
1240 template MS_CORE_API std::optional<ArrayValue<float>> GetArrayValue(const ValuePtr &value);
1241 template MS_CORE_API std::optional<ArrayValue<bool>> GetArrayValue(const ValuePtr &value);
1242 template MS_CORE_API std::optional<ArrayValue<std::string>> GetArrayValue(const ValuePtr &value);
1243 template MS_CORE_API std::optional<ArrayValue<float16>> GetArrayValue(const ValuePtr &value);
1244 template MS_CORE_API std::optional<ArrayValue<bfloat16>> GetArrayValue(const ValuePtr &value);
1245 
1246 template MS_CORE_API std::optional<ArrayValue<int64_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1247 template MS_CORE_API std::optional<ArrayValue<int32_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1248 template MS_CORE_API std::optional<ArrayValue<int16_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1249 template MS_CORE_API std::optional<ArrayValue<int8_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1250 template MS_CORE_API std::optional<ArrayValue<uint64_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1251 template MS_CORE_API std::optional<ArrayValue<uint32_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1252 template MS_CORE_API std::optional<ArrayValue<uint16_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1253 template MS_CORE_API std::optional<ArrayValue<uint8_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1254 template MS_CORE_API std::optional<ArrayValue<double>> GetArrayValue(const AbstractBasePtr &abs_base);
1255 template MS_CORE_API std::optional<ArrayValue<float>> GetArrayValue(const AbstractBasePtr &abs_base);
1256 template MS_CORE_API std::optional<ArrayValue<bool>> GetArrayValue(const AbstractBasePtr &abs_base);
1257 template MS_CORE_API std::optional<ArrayValue<std::string>> GetArrayValue(const AbstractBasePtr &abs_base);
1258 template MS_CORE_API std::optional<ArrayValue<float16>> GetArrayValue(const AbstractBasePtr &abs_base);
1259 template MS_CORE_API std::optional<ArrayValue<bfloat16>> GetArrayValue(const AbstractBasePtr &abs_base);
1260 
CheckTensorScalarRank(const PrimitivePtr & primitive,const AbstractBasePtr input_arg,const std::string & arg_name)1261 void CheckTensorScalarRank(const PrimitivePtr &primitive, const AbstractBasePtr input_arg,
1262                            const std::string &arg_name) {
1263   MS_EXCEPTION_IF_NULL(input_arg);
1264   auto shape_ptr = input_arg->GetShape();
1265   MS_EXCEPTION_IF_NULL(shape_ptr);
1266   const auto &input_shape = shape_ptr->GetShapeVector();
1267   const int64_t kDimZero = 0;
1268   if (MS_LIKELY(!IsDynamic(input_shape))) {
1269     MS_CHECK_VALUE(input_shape.size() == LongToSize(kDimZero),
1270                    CheckAndConvertUtils::FormatCheckIntegerMsg("rank of " + arg_name, SizeToLong(input_shape.size()),
1271                                                                kEqual, kDimZero, primitive));
1272   }
1273 }
1274 }  // namespace ops
1275 }  // namespace mindspore
1276