1 /**
2 * Copyright 2020-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/op_utils.h"
18
19 #include <algorithm>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "abstract/dshape.h"
28 #include "abstract/ops/primitive_infer_map.h"
29 #include "abstract/param_validator.h"
30 #include "ir/dtype/tensor_type.h"
31 #include "ir/dtype/type.h"
32 #include "ir/named.h"
33 #include "ir/primitive.h"
34 #include "ir/scalar.h"
35 #include "ir/tensor.h"
36 #include "ir/value.h"
37 #include "ir/kernel_tensor_value.h"
38 #include "mindapi/base/type_id.h"
39 #include "mindapi/src/helper.h"
40 #include "ops/op_name.h"
41 #include "ops/op_def.h"
42 #include "utils/check_convert_utils.h"
43 #include "utils/convert_utils_base.h"
44 #include "utils/log_adapter.h"
45 #include "utils/shape_utils.h"
46 #include "ir/func_graph.h"
47 #include "ops/ops_func_impl/simple_infer.h"
48
49 namespace mindspore {
50 namespace ops {
CalBroadCastShape(const std::vector<int64_t> & x_shape,const std::vector<int64_t> & y_shape,const std::string & op_name,const std::string & op_x_name,const std::string & op_y_name)51 std::vector<int64_t> CalBroadCastShape(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
52 const std::string &op_name, const std::string &op_x_name,
53 const std::string &op_y_name) {
54 if (x_shape == y_shape) {
55 return x_shape;
56 }
57
58 if (IsDynamicRank(x_shape) || IsDynamicRank(y_shape)) {
59 return {abstract::Shape::kShapeRankAny};
60 }
61
62 std::vector<int64_t> broadcast_shape;
63 auto x_length = x_shape.size();
64 auto y_length = y_shape.size();
65 auto res = x_length > y_length;
66 size_t max_len = res ? x_length : y_length;
67 size_t min_len = res ? y_length : x_length;
68 const std::vector<int64_t> &max_shape = res ? x_shape : y_shape;
69 const std::vector<int64_t> &min_shape = res ? y_shape : x_shape;
70
71 broadcast_shape = max_shape;
72 auto miss = max_len - min_len;
73 for (size_t i = 0; i < min_len; i++) {
74 auto dst_i = miss + i;
75 if (max_shape[dst_i] == 1) {
76 broadcast_shape[dst_i] = min_shape[i];
77 } else if (MS_UNLIKELY(max_shape[dst_i] == -1)) {
78 if (min_shape[i] != 1) {
79 broadcast_shape[dst_i] = min_shape[i];
80 }
81 } else if (MS_UNLIKELY(max_shape[dst_i] != min_shape[i] && min_shape[i] != -1 && min_shape[i] != 1)) {
82 auto x_shape_name = op_x_name + ".shape";
83 auto y_shape_name = op_y_name + ".shape";
84 MS_EXCEPTION(ValueError) << "For '" << op_name << "', " << x_shape_name << " and " << y_shape_name
85 << " need to broadcast. The value of " << x_shape_name << "["
86 << std::to_string(x_length + i) << "] or " << y_shape_name << "["
87 << std::to_string(y_length + i)
88 << "] must be 1 or -1 when they are not the same, but got " << x_shape_name << " = "
89 << tensor::ShapeToString(x_shape) << " and " << y_shape_name << " = "
90 << tensor::ShapeToString(y_shape);
91 }
92 }
93 return broadcast_shape;
94 }
95
BroadCastInferShape(const std::string & op_name,const std::vector<AbstractBasePtr> & input_args)96 abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) {
97 MS_EXCEPTION_IF_NULL(input_args[kIndex0]);
98 MS_EXCEPTION_IF_NULL(input_args[kIndex1]);
99 ShapeVector x_shape;
100 if (!input_args[0]->GetShape()->isa<abstract::NoShape>()) {
101 x_shape = GetShapeFromTensor(input_args[0]);
102 }
103
104 ShapeVector y_shape;
105 if (!input_args[1]->GetShape()->isa<abstract::NoShape>()) {
106 y_shape = GetShapeFromTensor(input_args[1]);
107 }
108
109 auto broadcast_shape = CalBroadCastShape(x_shape, y_shape, op_name);
110 return std::make_shared<abstract::Shape>(broadcast_shape);
111 }
112
BroadCastInferShape(const std::string & op_name,const ValuePtrList & input_values)113 ShapeVector BroadCastInferShape(const std::string &op_name, const ValuePtrList &input_values) {
114 const auto &x_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
115 const auto &y_tensor = input_values[kIndex1]->cast<tensor::BaseTensorPtr>();
116 MS_EXCEPTION_IF_NULL(x_tensor);
117 MS_EXCEPTION_IF_NULL(y_tensor);
118
119 auto x_shape = x_tensor->shape();
120 auto y_shape = y_tensor->shape();
121
122 auto broadcast_shape = CalBroadCastShape(x_shape, y_shape, op_name);
123 return broadcast_shape;
124 }
125
IsBroadcastable(const std::vector<int64_t> & x_shape,const std::vector<int64_t> & y_shape)126 bool IsBroadcastable(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape) {
127 if (x_shape == y_shape) {
128 return true;
129 }
130
131 if (IsDynamicRank(x_shape) || IsDynamicRank(y_shape)) {
132 return true;
133 }
134
135 if (x_shape.size() < y_shape.size()) {
136 return false;
137 }
138
139 auto miss = x_shape.size() - y_shape.size();
140 for (size_t i = 0; i < y_shape.size(); i++) {
141 if (x_shape[miss + i] == y_shape[i]) {
142 continue;
143 }
144 if (x_shape[miss + i] == -1) {
145 continue;
146 }
147 if (y_shape[i] == -1 || y_shape[i] == 1) {
148 continue;
149 }
150 return false;
151 }
152 return true;
153 }
154
EltwiseGradInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)155 BaseShapePtr EltwiseGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
156 MS_EXCEPTION_IF_NULL(input_args[0]);
157 MS_EXCEPTION_IF_NULL(input_args[1]);
158 auto prim_name = primitive->name();
159 auto x = CheckAndConvertUtils::CheckArgsType(prim_name, input_args, 0, kObjectTypeTensorType);
160 auto dout = CheckAndConvertUtils::CheckArgsType(prim_name, input_args, 1, kObjectTypeTensorType);
161 auto x_shape_ptr = x->GetShape();
162 auto dout_shape_ptr = dout->GetShape();
163 MS_EXCEPTION_IF_NULL(x_shape_ptr);
164 MS_EXCEPTION_IF_NULL(dout_shape_ptr);
165 auto x_shape = x_shape_ptr->GetShapeVector();
166 auto dout_shape = dout_shape_ptr->GetShapeVector();
167 if (IsDynamicRank(x_shape) || IsDynamicRank(dout_shape)) {
168 return input_args[1]->GetShape()->Clone();
169 } else if (x_shape.size() != dout_shape.size()) {
170 MS_EXCEPTION(ValueError) << "Rank of x(" << x_shape.size() << ") and dout(" << dout_shape.size()
171 << ") not equal, primitive name: " << prim_name << ".";
172 }
173
174 for (size_t i = 0; i < x_shape.size(); i++) {
175 if (x_shape[i] != abstract::Shape::kShapeDimAny && dout_shape[i] != abstract::Shape::kShapeDimAny &&
176 x_shape[i] != dout_shape[i]) {
177 MS_EXCEPTION(ValueError) << "The " << i << "th dim of x(" << x_shape[i] << ") and dout(" << dout_shape[i]
178 << ") not equal, primitive name: " << prim_name << ".";
179 }
180 }
181 return input_args[0]->GetShape()->Clone();
182 }
183
EltwiseGradInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)184 TypePtr EltwiseGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
185 MS_EXCEPTION_IF_NULL(primitive);
186 MS_EXCEPTION_IF_NULL(input_args[0]);
187 MS_EXCEPTION_IF_NULL(input_args[1]);
188 auto grad_type = input_args[0]->GetType();
189 MS_EXCEPTION_IF_NULL(grad_type);
190 auto x_type = input_args[1]->GetType();
191 MS_EXCEPTION_IF_NULL(x_type);
192 if (grad_type->type_id() != x_type->type_id()) {
193 MS_LOG_EXCEPTION << "For " << primitive->name()
194 << ", the grad type must be same as input type, but got grad_type: " << grad_type->ToString()
195 << " and x_type: " << x_type->ToString();
196 }
197 return grad_type->Clone();
198 }
199
EltwiseGradSimpleInferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values)200 ShapeArray EltwiseGradSimpleInferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) {
201 MS_EXCEPTION_IF_NULL(primitive);
202 const auto &prim_name = primitive->name();
203 const auto &dout_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
204 MS_EXCEPTION_IF_NULL(dout_tensor);
205 const auto &y_tensor = input_values[kIndex1]->cast<tensor::BaseTensorPtr>();
206 MS_EXCEPTION_IF_NULL(y_tensor);
207
208 const auto &dout_shape = dout_tensor->shape();
209 const auto &y_shape = y_tensor->shape();
210
211 if (dout_shape.size() != y_shape.size()) {
212 MS_EXCEPTION(ValueError) << "Rank of x(" << y_shape.size() << ") and dout(" << dout_shape.size()
213 << ") not equal, primitive name: " << prim_name << ".";
214 }
215
216 for (size_t i = 0; i < y_shape.size(); i++) {
217 if (y_shape[i] != dout_shape[i]) {
218 MS_EXCEPTION(ValueError) << "The " << i << "th dim of x(" << y_shape[i] << ") and dout(" << dout_shape[i]
219 << ") not equal, primitive name: " << prim_name << ".";
220 }
221 }
222 return {dout_shape};
223 }
224
EltwiseGradSimpleInferType(const PrimitivePtr & primitive,const ValuePtrList & input_values)225 TypePtrList EltwiseGradSimpleInferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) {
226 MS_EXCEPTION_IF_NULL(primitive);
227 const auto &dout_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
228 MS_EXCEPTION_IF_NULL(dout_tensor);
229 const auto &y_tensor = input_values[kIndex1]->cast<tensor::BaseTensorPtr>();
230 MS_EXCEPTION_IF_NULL(y_tensor);
231
232 const auto &dout_type = dout_tensor->Dtype();
233 const auto &y_type = y_tensor->Dtype();
234
235 if (dout_type->type_id() != y_type->type_id()) {
236 MS_LOG_EXCEPTION << "For " << primitive->name()
237 << ", the grad type must be same as input type, but got grad_type: " << dout_type->ToString()
238 << " and x_type: " << y_type->ToString();
239 }
240 return {dout_type};
241 }
242
ReduceFuncCheckAxisInferImpl(const PrimitivePtr & prim,std::vector<int64_t> * axis,const size_t dim)243 void ReduceFuncCheckAxisInferImpl(const PrimitivePtr &prim, std::vector<int64_t> *axis, const size_t dim) {
244 MS_EXCEPTION_IF_NULL(axis);
245 int64_t dim_ = static_cast<int64_t>(dim);
246 for (size_t i = 0; i < axis->size(); i++) {
247 if (dim == 0) {
248 if ((axis->at(i) != -1 && axis->at(i) != 0)) {
249 MS_EXCEPTION(ValueError) << "For '" << prim->name()
250 << "', 'axis' must be in [-1, 0]. But got 'axis' = " << axis->at(i) << ".";
251 }
252 axis->at(i) = 0;
253 continue;
254 }
255 if (axis->at(i) < -dim_ || axis->at(i) >= dim_) {
256 MS_EXCEPTION(ValueError) << "For '" << prim->name() << "', 'axis' must be in [" << -dim_ << ", " << dim_
257 << "). But got 'axis' = " << axis->at(i) << ".";
258 }
259 if (axis->at(i) >= -dim_ && axis->at(i) < 0) {
260 axis->at(i) += dim_;
261 }
262 }
263 }
264
ReduceFuncCalShapeInferImpl(const PrimitivePtr &,const ShapeVector & x_shape,const std::vector<int64_t> & axis,bool keep_dims_value)265 ShapeVector ReduceFuncCalShapeInferImpl(const PrimitivePtr &, const ShapeVector &x_shape,
266 const std::vector<int64_t> &axis, bool keep_dims_value) {
267 ShapeVector out_shape;
268 ShapeVector axis_value;
269 (void)axis_value.insert(axis_value.end(), axis.begin(), axis.end());
270 (void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());
271 std::sort(axis_value.begin(), axis_value.end());
272 auto last = std::unique(axis_value.begin(), axis_value.end());
273 axis_value.erase(last, axis_value.end());
274 if (keep_dims_value) {
275 if (x_shape.size() == 0) {
276 return {};
277 }
278 for (auto i : axis_value) {
279 out_shape.at(LongToSize(i)) = 1;
280 }
281 if (axis_value.empty()) {
282 for (size_t i = 0; i < out_shape.size(); i++) {
283 out_shape.at(i) = 1;
284 }
285 }
286 return out_shape;
287 }
288 if (axis.size() == 0 || x_shape.size() == 0) {
289 return {};
290 }
291 std::vector<int64_t>::reverse_iterator it_re;
292 for (it_re = axis_value.rbegin(); it_re != axis_value.rend(); ++it_re) {
293 (void)out_shape.erase(out_shape.begin() + *it_re);
294 }
295 return out_shape;
296 }
297
ReduceFuncCalShapeAxisDyn(const ShapeVector & x_shape,bool keep_dims)298 ShapeVector ReduceFuncCalShapeAxisDyn(const ShapeVector &x_shape, bool keep_dims) {
299 ShapeVector out_shape;
300 constexpr int dynamic_rank_value = -2;
301 if (!keep_dims) {
302 out_shape.push_back(dynamic_rank_value);
303 } else {
304 (void)out_shape.insert(out_shape.end(), x_shape.size(), -1LL);
305 }
306 return out_shape;
307 }
308
CheckAndGetAxisValueFromAttr(const PrimitivePtr & primitive,std::vector<int64_t> * axis_value,int64_t *)309 void CheckAndGetAxisValueFromAttr(const PrimitivePtr &primitive, std::vector<int64_t> *axis_value, int64_t *) {
310 auto op_name = primitive->name();
311 auto axis_ptr = primitive->GetAttr("axis");
312 MS_EXCEPTION_IF_NULL(axis_ptr);
313 if (axis_ptr->isa<tensor::BaseTensor>()) {
314 *axis_value = CheckAndConvertUtils::CheckTensorIntValue("axis", axis_ptr, op_name);
315 } else {
316 *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", axis_ptr, op_name);
317 }
318 }
319
CheckAndGetAxisValueFromScalar(const ValuePtr & input_value,const std::string & op_name,std::vector<int64_t> * axis_value,int64_t * axis_shape_v)320 bool CheckAndGetAxisValueFromScalar(const ValuePtr &input_value, const std::string &op_name,
321 std::vector<int64_t> *axis_value, int64_t *axis_shape_v) {
322 *axis_shape_v = 1;
323 bool is_dynamic = false;
324 if (IsValueKnown(input_value)) {
325 *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", input_value, op_name);
326 } else {
327 is_dynamic = true;
328 }
329 return is_dynamic;
330 }
331
CheckAndGetAxisValueFromSequence(const abstract::AbstractBasePtr & abs,const ValuePtr & input_value,const std::string & op_name,std::vector<int64_t> * axis_value,int64_t * axis_shape_v)332 bool CheckAndGetAxisValueFromSequence(const abstract::AbstractBasePtr &abs, const ValuePtr &input_value,
333 const std::string &op_name, std::vector<int64_t> *axis_value,
334 int64_t *axis_shape_v) {
335 bool is_dynamic = false;
336 if (IsValueKnown(input_value)) {
337 *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", input_value, op_name);
338 if (axis_value->empty()) {
339 *axis_shape_v = 0;
340 }
341 } else {
342 is_dynamic = true;
343 auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
344 MS_EXCEPTION_IF_NULL(seq_abs);
345 *axis_shape_v = seq_abs->dynamic_len() ? -1 : SizeToLong(seq_abs->size());
346 }
347
348 return is_dynamic;
349 }
350
CheckAndGetAxisValueFromTensor(const std::vector<abstract::AbstractBasePtr> & input_args,const ValuePtr & input_value,const std::string & op_name,std::vector<int64_t> * axis_value,int64_t * axis_shape_v)351 bool CheckAndGetAxisValueFromTensor(const std::vector<abstract::AbstractBasePtr> &input_args,
352 const ValuePtr &input_value, const std::string &op_name,
353 std::vector<int64_t> *axis_value, int64_t *axis_shape_v) {
354 bool is_dynamic = false;
355 (void)CheckAndConvertUtils::CheckTensorTypeValid("axis", input_args[kInputIndex1]->GetType(), {kInt32, kInt64},
356 op_name);
357 if (input_value->isa<tensor::BaseTensor>()) {
358 *axis_value = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, op_name);
359 if (axis_value->empty()) {
360 *axis_shape_v = 0;
361 }
362 } else {
363 is_dynamic = true;
364 auto axis_shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 1);
365 if (axis_shape->shape().size() > 1) {
366 MS_EXCEPTION(ValueError) << "For '" << op_name << "', the axis's shape length should be 1, but got '"
367 << axis_shape->shape().size() << "'.";
368 } else if (axis_shape->shape().size() == 0) {
369 *axis_shape_v = 1;
370 } else {
371 *axis_shape_v = axis_shape->shape()[0];
372 }
373 }
374 return is_dynamic;
375 }
376
CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> & input_args,std::vector<int64_t> * axis_value,int64_t * axis_shape_v,const PrimitivePtr & primitive)377 bool CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> &input_args, std::vector<int64_t> *axis_value,
378 int64_t *axis_shape_v, const PrimitivePtr &primitive) {
379 MS_EXCEPTION_IF_NULL(axis_value);
380 MS_EXCEPTION_IF_NULL(axis_shape_v);
381 bool is_dynamic = false;
382 const std::string &op_name = primitive->name();
383 if (input_args.size() == 1) {
384 CheckAndGetAxisValueFromAttr(primitive, axis_value, axis_shape_v);
385 return false;
386 }
387 MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
388 auto input_value = input_args[kInputIndex1]->GetValue();
389 if (input_value->isa<KernelTensorValue>()) {
390 auto value_opt = GetArrayValue<int64_t>(input_args[kInputIndex1]);
391 auto value_array = value_opt.value();
392 *axis_value = value_array.ToVector();
393 return !value_opt.has_value();
394 }
395 if (input_args[kInputIndex1]->isa<abstract::AbstractScalar>()) {
396 is_dynamic = CheckAndGetAxisValueFromScalar(input_value, op_name, axis_value, axis_shape_v);
397 } else if (input_args[kInputIndex1]->isa<abstract::AbstractSequence>()) {
398 is_dynamic =
399 CheckAndGetAxisValueFromSequence(input_args[kInputIndex1], input_value, op_name, axis_value, axis_shape_v);
400 } else if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
401 is_dynamic = CheckAndGetAxisValueFromTensor(input_args, input_value, op_name, axis_value, axis_shape_v);
402 } else {
403 MS_EXCEPTION(ValueError) << "For '" << op_name
404 << "', the second input type should be tensor or scalar, but got invalid abstract type:"
405 << input_args[kInputIndex1]->type_name() << ".";
406 }
407 return is_dynamic;
408 }
409
IsDynamicShapeSkipExecute(const bool skip_mode,const ShapeVector & axes_shape)410 bool IsDynamicShapeSkipExecute(const bool skip_mode, const ShapeVector &axes_shape) {
411 // Skip run ReduceSum when axis is a Empty Tensor
412 if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
413 return true;
414 }
415 return false;
416 }
MakeWrapDim(int64_t dim,int64_t dim_post_expr)417 int64_t MakeWrapDim(int64_t dim, int64_t dim_post_expr) {
418 // this will make range [-1, 0]
419 if (dim_post_expr <= 0) {
420 dim_post_expr = 1;
421 }
422
423 if (dim < 0) {
424 dim += dim_post_expr;
425 }
426
427 return dim;
428 }
429
MakeDimMask(std::vector<int64_t> dims,int64_t ndim)430 std::bitset<kBitSize> MakeDimMask(std::vector<int64_t> dims, int64_t ndim) {
431 std::bitset<kBitSize> mask = std::bitset<kBitSize>();
432 if (dims.empty()) {
433 mask.flip();
434 } else {
435 for (int64_t dim : dims) {
436 mask.set(MakeWrapDim(dim, ndim));
437 }
438 }
439
440 return mask;
441 }
442
ReduceExtInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)443 abstract::ShapePtr ReduceExtInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
444 auto input_shape_ptr = input_args[0]->GetShape();
445 const auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_shape_ptr)[kShape];
446 int64_t ndim = static_cast<int64_t>(input_shape.size());
447 auto dim = GetValue<std::vector<int64_t>>(input_args[1]->GetValue());
448 auto keepdim = GetValue<bool>(input_args[2]->GetValue());
449 std::bitset<kBitSize> mask = MakeDimMask(dim, ndim);
450 auto shape = input_shape;
451
452 for (int dim_temp = static_cast<int64_t>(shape.size()) - 1; dim_temp >= 0; dim_temp--) {
453 if (mask[dim_temp]) {
454 if (keepdim) {
455 shape[dim_temp] = 1;
456 } else {
457 shape.erase(shape.begin() + dim_temp);
458 }
459 }
460 }
461 return std::make_shared<abstract::Shape>(shape);
462 }
463
ReduceExtInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)464 TypePtr ReduceExtInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
465 auto dtype_ptr = input_args[3]->GetValue();
466 (void)CheckAndConvertUtils::CheckTypeValid("input", input_args[0]->BuildType(),
467 common_valid_types_with_complex_and_bool, prim->name());
468 auto dtype_type_ptr = dtype_ptr->cast<TypePtr>();
469 if (dtype_type_ptr->type_id() == kMetaTypeNone) {
470 return input_args[0]->BuildType();
471 } else {
472 return dtype_ptr->cast<TypePtr>();
473 }
474 }
475
ReduceBaseInferShape(const PrimitivePtr & primitive,const std::vector<abstract::AbstractBasePtr> & input_args,const std::string & prim_name)476 abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
477 const std::vector<abstract::AbstractBasePtr> &input_args,
478 const std::string &prim_name) {
479 MS_EXCEPTION_IF_NULL(primitive);
480 auto x_shape = GetShapeFromTensor(input_args[0]);
481 bool skip_mode = false;
482 if (primitive->HasAttr(kSkipMode)) {
483 auto skip_mode_value_ptr = primitive->GetAttr(kSkipMode);
484 MS_EXCEPTION_IF_NULL(skip_mode_value_ptr);
485 skip_mode = GetValue<bool>(skip_mode_value_ptr);
486 }
487 auto keep_dimis_value_ptr = primitive->GetAttr(kKeepDims);
488 MS_EXCEPTION_IF_NULL(keep_dimis_value_ptr);
489 if (!keep_dimis_value_ptr->isa<BoolImm>()) {
490 MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', 'keep_dims' must be Bool.";
491 }
492 bool keep_dims = GetValue<bool>(keep_dimis_value_ptr);
493 std::vector<int64_t> axis_value;
494 int64_t axis_shape = 1;
495 bool axis_is_dynamic = CheckAndGetAxisValue(input_args, &axis_value, &axis_shape, primitive);
496 if (IsDynamicShapeSkipExecute(skip_mode, {axis_shape})) {
497 return std::make_shared<abstract::Shape>(x_shape);
498 }
499 ShapeVector out_shape = {};
500 constexpr int dynamic_rank_value = -2;
501 if (IsDynamicRank(x_shape)) {
502 if (axis_shape == 0 && !keep_dims) {
503 return std::make_shared<abstract::Shape>(out_shape);
504 }
505 out_shape.push_back(dynamic_rank_value);
506 return std::make_shared<abstract::Shape>(out_shape);
507 }
508 if (axis_shape == -1 && !keep_dims) {
509 out_shape.push_back(dynamic_rank_value);
510 return std::make_shared<abstract::Shape>(out_shape);
511 }
512 ReduceFuncCheckAxisInferImpl(primitive, &axis_value, x_shape.size());
513
514 if (axis_is_dynamic) {
515 out_shape = ReduceFuncCalShapeAxisDyn(x_shape, keep_dims);
516 return std::make_shared<abstract::Shape>(out_shape);
517 }
518 out_shape = ReduceFuncCalShapeInferImpl(primitive, x_shape, axis_value, keep_dims);
519 return std::make_shared<abstract::Shape>(out_shape);
520 }
521
ReduceBaseInferType(const PrimitivePtr & prim,const std::vector<abstract::AbstractBasePtr> & input_args,const std::set<TypePtr> & check_list)522 TypePtr ReduceBaseInferType(const PrimitivePtr &prim, const std::vector<abstract::AbstractBasePtr> &input_args,
523 const std::set<TypePtr> &check_list) {
524 MS_EXCEPTION_IF_NULL(prim);
525 MS_EXCEPTION_IF_NULL(input_args[0]);
526 auto x_type = input_args[0]->GetType();
527 (void)CheckAndConvertUtils::CheckTensorTypeValid("x dtype", x_type, check_list, prim->name());
528 return x_type;
529 }
530
SetPadShape(const ShapeVector & x_shape,const ArrayValue<int64_t> & paddings)531 BaseShapePtr SetPadShape(const ShapeVector &x_shape, const ArrayValue<int64_t> &paddings) {
532 const size_t kNum2 = 2;
533 auto out_shape = x_shape;
534 auto x_rank = x_shape.size();
535 for (size_t i = 0; i < paddings.size() / kNum2; i++) {
536 auto pad_idx = i * kNum2;
537 if (out_shape[x_rank - i - 1] != abstract::Shape::kShapeDimAny && !paddings.IsValueUnknown(pad_idx) &&
538 !paddings.IsValueUnknown(pad_idx + kIndex1)) {
539 auto paddings_l = paddings[pad_idx];
540 auto paddings_r = paddings[pad_idx + kIndex1];
541 out_shape[x_rank - i - kIndex1] = out_shape[x_rank - i - kIndex1] + paddings_l + paddings_r;
542 } else {
543 out_shape[x_rank - i - kIndex1] = abstract::Shape::kShapeDimAny;
544 }
545 }
546 return std::make_shared<abstract::Shape>(out_shape);
547 }
548
PadInferShapeBase(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args,const size_t pad_dim)549 BaseShapePtr PadInferShapeBase(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
550 const size_t pad_dim) {
551 MS_EXCEPTION_IF_NULL(primitive);
552 auto x_base_shape = input_args[kInputIndex0]->GetShape();
553 auto x_shape = x_base_shape->GetShapeVector();
554 // input x dynamic rank
555 MS_EXCEPTION_IF_NULL(x_base_shape);
556 if (x_base_shape->IsDimUnknown()) {
557 return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
558 }
559 // input x dynamic shape
560 auto x_rank = x_shape.size();
561 constexpr size_t minValidDim = 1;
562 constexpr size_t maxValidDim = 2;
563 if (x_rank != pad_dim + minValidDim && x_rank != pad_dim + maxValidDim) {
564 MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', input should be " << pad_dim + minValidDim
565 << "D or " << pad_dim + maxValidDim << "D, but got " << x_rank;
566 }
567 // padding
568 auto paddings_opt = GetArrayValue<int64_t>(input_args[kInputIndex1]);
569 if (!paddings_opt.has_value()) {
570 ShapeVector out_shape = x_shape;
571 for (size_t dim = 1; dim <= pad_dim; ++dim) {
572 out_shape[x_rank - dim] = abstract::Shape::kShapeDimAny;
573 }
574 return std::make_shared<abstract::Shape>(std::move(out_shape));
575 }
576 constexpr size_t kScaleNum = 2;
577 auto paddings = paddings_opt.value();
578 if (paddings.size() != pad_dim * kScaleNum) {
579 MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the padding length should be "
580 << pad_dim * kScaleNum << ", but got " << paddings.size();
581 }
582
583 auto out_shape = SetPadShape(x_shape, paddings);
584 return out_shape;
585 }
586
ObscureShapeEqual(const ShapeVector & lhs,const ShapeVector & rhs)587 bool ObscureShapeEqual(const ShapeVector &lhs, const ShapeVector &rhs) {
588 if (lhs == rhs) {
589 return true;
590 }
591 if (lhs.size() != rhs.size()) {
592 return false;
593 }
594 for (size_t i = 0; i < lhs.size(); ++i) {
595 if (lhs[i] != rhs[i] && lhs[i] != -1 && rhs[i] != -1) {
596 return false;
597 }
598 }
599 return true;
600 }
601
GetSequenceValue(const std::string & arg_name,const AbstractBasePtr & abs,const std::string & prim_name)602 std::vector<int64_t> GetSequenceValue(const std::string &arg_name, const AbstractBasePtr &abs,
603 const std::string &prim_name) {
604 MS_EXCEPTION_IF_NULL(abs);
605 auto abs_seq = dyn_cast<abstract::AbstractSequence>(abs);
606 MS_EXCEPTION_IF_NULL(abs_seq);
607 if (abs_seq->dynamic_len()) {
608 return std::vector<int64_t>{abstract::Shape::kShapeRankAny};
609 }
610 std::vector<int64_t> out_shape;
611 for (auto element : abs_seq->elements()) {
612 auto element_val = element->GetValue();
613 if (element_val->ContainsValueAny()) {
614 out_shape.push_back(abstract::Shape::kShapeDimAny);
615 } else if (element_val->isa<Int64Imm>()) {
616 (void)out_shape.emplace_back(GetValue<ShapeValueDType>(element_val));
617 } else if (element_val->isa<Int32Imm>()) {
618 (void)out_shape.emplace_back(GetValue<int32_t>(element_val));
619 } else {
620 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
621 << " must be one of ['tuple', 'list'] with all Int elements, but got " << abs->ToString();
622 }
623 }
624 return out_shape;
625 }
626
GetShapeValue(const PrimitivePtr & primitive,const AbstractBasePtr & arg)627 ShapeVector GetShapeValue(const PrimitivePtr &primitive, const AbstractBasePtr &arg) {
628 MS_EXCEPTION_IF_NULL(primitive);
629 auto prim_name = primitive->name();
630 auto abs_value = arg->GetValue();
631 MS_EXCEPTION_IF_NULL(abs_value);
632 auto arg_type = arg->GetType();
633 MS_EXCEPTION_IF_NULL(arg_type);
634
635 if (IsValueKnown(abs_value)) {
636 if (CheckAndConvertUtils::IsTensor(arg)) {
637 return CheckAndConvertUtils::CheckTensorIntValue("shape", abs_value, "", arg_type);
638 } else if (CheckAndConvertUtils::IsSequence(arg)) {
639 return CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", arg, prim_name);
640 }
641 } else if (CheckAndConvertUtils::IsTensor(arg)) {
642 auto arg_shape = arg->GetShape()->GetShapeVector();
643 if (arg_shape.size() != 1) {
644 MS_EXCEPTION(ValueError) << "For Primitive[" << primitive->name()
645 << "], Shape of shape value only could be one-dimensional";
646 }
647 if (IsDynamic(arg_shape)) {
648 return {abstract::Shape::kShapeRankAny};
649 }
650 auto shape_size = arg_shape[0];
651 return ShapeVector(shape_size, abstract::Shape::kShapeDimAny);
652 } else if (arg->isa<abstract::AbstractSequence>()) {
653 return GetSequenceValue("input[shape]", arg, prim_name);
654 }
655
656 MS_EXCEPTION(TypeError) << "For " << prim_name << ", the input type must be Tensor/Tuple/List , but got"
657 << arg_type->ToString() << ".";
658 }
659
CheckSparseShape(ShapeVector sparse_shp,ShapeVector dense_shp)660 void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
661 constexpr auto csr_mul_batch_pos = 2;
662 int dlen = SizeToInt(sparse_shp.size()) - SizeToInt(dense_shp.size());
663 if (dlen < 0) {
664 MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, "
665 << "but sparse tensor has " << sparse_shp.size() << " dimensions, "
666 << "and dense tensor has " << dense_shp.size() << " dimensions. ";
667 }
668 for (int i = 0; i < dlen; i++) {
669 (void)dense_shp.insert(dense_shp.begin(), 1);
670 }
671 if (sparse_shp.size() != dense_shp.size()) {
672 MS_LOG(EXCEPTION) << "Failure: sparse_shp.size() != dense_shp.size().";
673 }
674 if (sparse_shp.size() < 1) {
675 MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
676 }
677 for (size_t i = 0; i < sparse_shp.size(); i++) {
678 auto s = sparse_shp[i];
679 auto d = dense_shp[i];
680 if (i < csr_mul_batch_pos) {
681 if (d != s && d != 1) {
682 MS_EXCEPTION(ValueError) << "Dense shape cannot broadcast to sparse shape.";
683 }
684 } else {
685 if (d != s) {
686 MS_EXCEPTION(ValueError) << "Currently, sparse shape and dense shape must equal in feature dimensions.";
687 }
688 }
689 }
690 }
691
CheckSparseShape(const size_t shape_size,const size_t expected_dim,const std::string & arg_name)692 void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name) {
693 if (shape_size != expected_dim) {
694 MS_EXCEPTION(ValueError) << arg_name << " must be a " << expected_dim << "-dimensional tensor, but got a "
695 << shape_size << "-dimensional tensor.";
696 }
697 }
698
CheckSparseIndicesDtype(const TypePtr data_type,const std::string & arg_name)699 void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name) {
700 if (!(data_type->equal(kInt16) || data_type->equal(kInt32) || data_type->equal(kInt64))) {
701 MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
702 << data_type->ToString() << ".";
703 }
704 }
705
CheckSparseIndicesDtypeInt32(const TypePtr data_type,const std::string & arg_name)706 void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name) {
707 if (!data_type->equal(kInt32)) {
708 MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got "
709 << data_type->ToString() << ".";
710 }
711 }
712
ConvertToShapeVector(const abstract::AbstractTuplePtr & shape)713 ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape) {
714 auto shape_value = shape->GetValue()->cast<ValueTuplePtr>();
715 MS_EXCEPTION_IF_NULL(shape_value);
716 ShapeVector shape_vec;
717 (void)std::transform(std::begin(shape_value->value()), std::end(shape_value->value()), std::back_inserter(shape_vec),
718 [](const ValuePtr &e) -> int64_t {
719 auto elem = GetValue<int64_t>(e);
720 return elem;
721 });
722 return shape_vec;
723 }
724
725 template <typename T>
InferSparseAttr(const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)726 std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list) {
727 MS_EXCEPTION_IF_NULL(primitive);
728 constexpr size_t kSizeExpect = 1;
729 if (args_abs_list.size() != kSizeExpect) {
730 MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the number of input should be " << kSizeExpect
731 << ", but got " << args_abs_list.size() << ".";
732 }
733 constexpr size_t kIndex = 0;
734 auto abs = args_abs_list[kIndex];
735 MS_EXCEPTION_IF_NULL(abs);
736 // To avoid AbstractSparseTensors being generalized to AbstractTuple.
737 if (dyn_cast<T>(abs) == nullptr) {
738 auto abs_tuple = dyn_cast<abstract::AbstractTuple>(abs);
739 if (abs_tuple != nullptr) {
740 return std::make_shared<T>(abs_tuple->elements());
741 }
742 } else if (dyn_cast<T>(abs) != nullptr) {
743 return dyn_cast<T>(abs);
744 }
745 MS_EXCEPTION(TypeError) << "For \'" << primitive->name() << "\', input[" << kIndex
746 << "] should be AbstractSparseTensor or AbstractTuple, but got " << abs->GetType()->ToString()
747 << ".";
748 }
749 template std::shared_ptr<abstract::AbstractCSRTensor> InferSparseAttr(const PrimitivePtr &primitive,
750 const AbstractBasePtrList &args_abs_list);
751 template std::shared_ptr<abstract::AbstractCOOTensor> InferSparseAttr(const PrimitivePtr &primitive,
752 const AbstractBasePtrList &args_abs_list);
753
754 template <typename T>
TensorToSequenceInfer(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)755 AbstractBasePtr TensorToSequenceInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
756 MS_EXCEPTION_IF_NULL(primitive);
757 auto prim_name = primitive->name();
758 constexpr size_t input_0_index = 0;
759
760 auto x_shape = input_args[kInputIndex0]->GetShape()->GetShapeVector();
761 if (x_shape.size() > 1) {
762 MS_EXCEPTION(ValueError) << "For Primitive[" << prim_name << "], the input must be a 1-D Tensor, but got Tensor "
763 << "with shape: " << x_shape << ".";
764 }
765
766 auto x_type = input_args[input_0_index]->GetType();
767 MS_EXCEPTION_IF_NULL(x_type);
768 if (!x_type->isa<TensorType>()) {
769 MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input must be a Tensor but got "
770 << x_type->ToString() << ".";
771 }
772 auto tensor_type = x_type->cast<TensorTypePtr>();
773 const auto &element_type = tensor_type->element();
774 MS_EXCEPTION_IF_NULL(element_type);
775 AbstractBasePtrList abs_list;
776 if (IsDynamic(x_shape)) {
777 abs_list.push_back(std::make_shared<abstract::AbstractScalar>(kValueAny, element_type));
778 auto abs = std::make_shared<T>(abs_list);
779 abs->CheckAndConvertToDynamicLenSequence();
780 return abs;
781 }
782 if (x_shape.empty()) {
783 abs_list.push_back(std::make_shared<abstract::AbstractScalar>(kValueAny, element_type));
784 } else {
785 for (int64_t i = 0; i < x_shape[0]; i++) {
786 abs_list.push_back(std::make_shared<abstract::AbstractScalar>(kValueAny, element_type));
787 }
788 }
789 auto abs = std::make_shared<T>(abs_list);
790 return abs;
791 }
792
CheckDynamicLengthSequenceSetItem(const std::string & op_name,const abstract::AbstractSequencePtr & queue,const AbstractBasePtr & target)793 void CheckDynamicLengthSequenceSetItem(const std::string &op_name, const abstract::AbstractSequencePtr &queue,
794 const AbstractBasePtr &target) {
795 auto element_abs = queue->dynamic_len_element_abs();
796 if (element_abs == nullptr) {
797 MS_LOG(EXCEPTION) << "Empty variable len sequence can not setitem.";
798 }
799 const auto precondition_log = "For " + op_name + ", when the queue is dynamic length";
800 const auto standard_abs_description = "element within dynamic length sequence";
801 const auto differ_abs_description = "target element";
802 CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{element_abs, target},
803 precondition_log, standard_abs_description,
804 differ_abs_description);
805 }
806
807 template <typename T>
InferSequenceSetItem(const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)808 AbstractBasePtr InferSequenceSetItem(const PrimitivePtr &primitive, const AbstractBasePtrList &args_abs_list) {
809 // Inputs: a tuple or list, a scalar whose value is an int64 number and an object of a subclass of AbstractBase.
810 MS_EXCEPTION_IF_NULL(primitive);
811 auto op_name = primitive->name();
812 constexpr int args_spec_size = 3;
813 constexpr size_t kIndex2 = 2;
814 abstract::CheckArgsSize(op_name, args_abs_list, args_spec_size);
815 auto queue = abstract::CheckArg<T>(op_name, args_abs_list, 0);
816 auto index = abstract::CheckArg<abstract::AbstractScalar>(op_name, args_abs_list, 1);
817
818 auto index_type = index->GetType();
819 MS_EXCEPTION_IF_NULL(index_type);
820 if (index_type->type_id() != kInt64->type_id()) {
821 MS_EXCEPTION(TypeError) << op_name << " evaluator index should be an int64 number, but got a "
822 << index_type->ToString() << " number.";
823 }
824 ValuePtr index_value = index->GetValue();
825 MS_EXCEPTION_IF_NULL(index_value);
826 auto target = args_abs_list[kIndex2];
827 MS_EXCEPTION_IF_NULL(target);
828 if (queue->dynamic_len()) {
829 CheckDynamicLengthSequenceSetItem(op_name, queue, target);
830 return queue->Clone();
831 }
832 if (index_value->ContainsValueAny()) {
833 // If the index is variable and the sequence is constant length, then all of the element within the sequence
834 // should have the same type and shape with the target input. The element within the return sequence should
835 // be all broadened.
836 const auto &elements = queue->elements();
837 if (elements.size() == 0) {
838 MS_LOG(EXCEPTION) << "Empty sequence can not setitem.";
839 }
840 const auto precondition_log = "For " + op_name + ", when the index is variable and the queue is constant length";
841 CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(elements, precondition_log);
842 auto first_element = elements[0];
843 const auto standard_abs_description = "element within constant length sequence";
844 const auto differ_abs_description = "target element";
845 CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{first_element, target},
846 precondition_log, standard_abs_description,
847 differ_abs_description);
848 return CheckAndConvertUtils::BroadenAllSequenceElements(queue);
849 }
850 auto index_int64_value = GetValue<int64_t>(index_value);
851 AbstractBasePtrList elements = queue->elements();
852 std::size_t nelems = elements.size();
853 if (nelems == 0) {
854 MS_EXCEPTION(ValueError) << "Can not setitem for an empty sequence.";
855 }
856 int64_t index_positive_value = index_int64_value >= 0 ? index_int64_value : index_int64_value + SizeToLong(nelems);
857 if (index_positive_value < 0 || index_positive_value >= SizeToLong(nelems)) {
858 MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << index_int64_value << " to set out of range: [-"
859 << nelems << "," << (nelems - 1) << "].";
860 }
861 size_t index_unsigned_value = LongToSize(index_positive_value);
862 elements[index_unsigned_value] = args_abs_list[kIndex2];
863 MS_LOG(DEBUG) << "SetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString();
864 return std::make_shared<T>(elements, queue->sequence_nodes());
865 }
866
867 template AbstractBasePtr InferSequenceSetItem<abstract::AbstractList>(const PrimitivePtr &primitive,
868 const AbstractBasePtrList &args_abs_list);
869 template AbstractBasePtr InferSequenceSetItem<abstract::AbstractTuple>(const PrimitivePtr &primitive,
870 const AbstractBasePtrList &args_abs_list);
871
872 template AbstractBasePtr TensorToSequenceInfer<abstract::AbstractList>(const PrimitivePtr &primitive,
873 const std::vector<AbstractBasePtr> &input_args);
874
875 template AbstractBasePtr TensorToSequenceInfer<abstract::AbstractTuple>(const PrimitivePtr &primitive,
876 const std::vector<AbstractBasePtr> &input_args);
877
878 template <typename T>
GetScalarCastValue(const std::string & op_name,const ValuePtr & elem)879 T GetScalarCastValue(const std::string &op_name, const ValuePtr &elem) {
880 T res;
881 MS_EXCEPTION_IF_NULL(elem);
882 if (elem->isa<Int64Imm>()) {
883 auto elem_value = GetValue<int64_t>(elem);
884 res = static_cast<T>(elem_value);
885 } else if (elem->isa<Int32Imm>()) {
886 auto elem_value = GetValue<int32_t>(elem);
887 res = static_cast<T>(elem_value);
888 } else if (elem->isa<Int16Imm>()) {
889 auto elem_value = GetValue<int16_t>(elem);
890 res = static_cast<T>(elem_value);
891 } else if (elem->isa<Int8Imm>()) {
892 auto elem_value = GetValue<int8_t>(elem);
893 res = static_cast<T>(elem_value);
894 } else if (elem->isa<UInt64Imm>()) {
895 auto elem_value = GetValue<uint64_t>(elem);
896 res = static_cast<T>(elem_value);
897 } else if (elem->isa<UInt32Imm>()) {
898 auto elem_value = GetValue<uint32_t>(elem);
899 res = static_cast<T>(elem_value);
900 } else if (elem->isa<UInt16Imm>()) {
901 auto elem_value = GetValue<uint16_t>(elem);
902 res = static_cast<T>(elem_value);
903 } else if (elem->isa<UInt8Imm>()) {
904 auto elem_value = GetValue<uint8_t>(elem);
905 res = static_cast<T>(elem_value);
906 } else if (elem->isa<FP64Imm>()) {
907 auto elem_value = GetValue<double>(elem);
908 res = static_cast<T>(elem_value);
909 } else if (elem->isa<FP32Imm>()) {
910 auto elem_value = GetValue<float>(elem);
911 res = static_cast<T>(elem_value);
912 } else if (elem->isa<BoolImm>()) {
913 auto elem_value = GetValue<bool>(elem);
914 res = static_cast<T>(elem_value);
915 } else {
916 MS_EXCEPTION(TypeError) << "For op '" << op_name
917 << "' input must be [int32, int64, float32, float64, bool], but got " << elem->ToString();
918 }
919 return res;
920 }
921
922 template int64_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
923 template int32_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
924 template int16_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
925 template int8_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
926 template uint64_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
927 template uint32_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
928 template uint16_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
929 template uint8_t GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
930 template double GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
931 template float GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
932 template bool GetScalarCastValue(const std::string &op_name, const ValuePtr &elem);
933
HighPriorityType(const TypePtr & x_type,const TypePtr & y_type,const std::string & op_name)934 TypePtr HighPriorityType(const TypePtr &x_type, const TypePtr &y_type, const std::string &op_name) {
935 static std::map<TypeId, size_t> prio_map = {{kNumberTypeFloat64, 1},
936 {kNumberTypeFloat32, 2},
937 {kNumberTypeInt64, 3},
938 {kNumberTypeInt32, 4},
939 {kNumberTypeBool, 5}};
940 auto x_iter = prio_map.find(x_type->type_id());
941 auto y_iter = prio_map.find(y_type->type_id());
942 if (x_iter == prio_map.end() || y_iter == prio_map.end()) {
943 MS_EXCEPTION(ValueError) << "For '" << op_name
944 << "', the x and y type should be int or float, but got x type: " << x_type
945 << " y type: " << y_type;
946 }
947 if (x_iter->second < y_iter->second) {
948 return x_type;
949 }
950 if (x_iter->second == y_iter->second && x_iter->first == kNumberTypeBool) {
951 return kInt32;
952 }
953 return y_type;
954 }
955
GetInputDependValueList(const PrimitivePtr & op_prim)956 std::set<int64_t> GetInputDependValueList(const PrimitivePtr &op_prim) {
957 MS_EXCEPTION_IF_NULL(op_prim);
958 std::set<int64_t> depend_list;
959 mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_prim->name());
960 if (op_def == nullptr) {
961 // Use old Primitive infer.
962 auto op_infer_opt = abstract::GetPrimitiveInferImpl(op_prim);
963 if (!op_infer_opt.has_value()) {
964 if (op_prim->HasAttr(kAttrMeOpName)) {
965 auto ori_prim_name = GetValue<std::string>(op_prim->GetAttr(kAttrMeOpName));
966 op_infer_opt = abstract::GetPrimitiveInferImpl(std::make_shared<Primitive>(ori_prim_name));
967 }
968 }
969 if (op_infer_opt.has_value()) {
970 auto op_infer = op_infer_opt.value().Get();
971 if (op_infer != nullptr && depend_list.empty()) {
972 depend_list = op_infer->GetValueDependArgIndices();
973 }
974 }
975 return depend_list;
976 }
977
978 depend_list = op_def->func_impl_.GetValueDependArgIndices();
979 if (!depend_list.empty()) {
980 return depend_list;
981 }
982 // if not defined the GetValueDependArgIndices() func in infer, consider all the no-Tensor
983 // input as value depend.
984 auto args = op_def->args_;
985 for (size_t i = 0; i < args.size(); i++) {
986 if (args[i].arg_dtype_ != mindspore::ops::OP_DTYPE::DT_TENSOR &&
987 args[i].arg_dtype_ != mindspore::ops::OP_DTYPE::DT_TUPLE_TENSOR &&
988 args[i].arg_dtype_ != mindspore::ops::OP_DTYPE::DT_LIST_TENSOR) {
989 (void)depend_list.insert(i);
990 }
991 }
992 return depend_list;
993 }
994
GetInputIndexByName(const std::string & op_name,const std::string & input_name)995 size_t GetInputIndexByName(const std::string &op_name, const std::string &input_name) {
996 mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
997 if (op_def == nullptr) {
998 MS_LOG(INFO) << op_name << " is not defined in opdef.";
999 return SIZE_MAX;
1000 }
1001 auto ks_iter = op_def->indexes_.find(input_name);
1002 if (ks_iter != op_def->indexes_.end()) {
1003 size_t index = ks_iter->second;
1004 MS_LOG(INFO) << "Find " << input_name << "in " << index << "th input of OP " << op_name;
1005 return index;
1006 }
1007 MS_LOG(INFO) << "Not Find " << input_name << "in OP " << op_name;
1008 return SIZE_MAX;
1009 }
1010
GetInputNameByIndex(const std::string & op_name,size_t index)1011 std::string GetInputNameByIndex(const std::string &op_name, size_t index) {
1012 mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
1013 if (op_def == nullptr) {
1014 return "";
1015 }
1016 if (index >= op_def->args_.size()) {
1017 MS_LOG(INTERNAL_EXCEPTION) << "Get input name by index out of range, index: " << index
1018 << ", size: " << op_def->args_.size() << ", op name: " << op_name;
1019 }
1020 auto input = op_def->args_[index];
1021 return input.arg_name_;
1022 }
1023
HasOpDef(const std::string & op_name)1024 bool HasOpDef(const std::string &op_name) {
1025 mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
1026 return op_def != nullptr;
1027 }
1028
GetOpInputsNum(const std::string & op_name)1029 size_t GetOpInputsNum(const std::string &op_name) {
1030 mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
1031 if (op_def == nullptr) {
1032 MS_LOG(INFO) << op_name << " is not defined in opdef.";
1033 return SIZE_MAX;
1034 }
1035 return op_def->indexes_.size();
1036 }
1037
1038 // This is used to convert arg with 'prim_init' of cnode convert to attr of primitive.
1039 // CNode in new mindir can be converted to old mindir by this function.
1040 // For example, {PrimAvgPool, x, kernel_size, strides, pad_mode, data_format} =>
1041 // {PrimAvgPool, x}
ConvertArgsToAttr(const CNodePtr & cnode)1042 CNodePtr ConvertArgsToAttr(const CNodePtr &cnode) {
1043 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1044 MS_EXCEPTION_IF_NULL(prim);
1045 auto prim_name = prim->name();
1046 auto op_def = mindspore::ops::GetOpDef(prim_name);
1047 if (op_def == nullptr) {
1048 MS_LOG(DEBUG) << "Prim:" << prim->ToString()
1049 << "is not a primitive defined in yaml, cannot convert args to attr, cnode:" << cnode->DebugString();
1050 return nullptr;
1051 }
1052 std::vector<AnfNodePtr> new_node_inputs = {cnode->input(0)};
1053 for (size_t arg_index = 0; arg_index < op_def->args_.size(); ++arg_index) {
1054 auto arg = op_def->args_[arg_index];
1055 if (!arg.as_init_arg_) {
1056 // origin is input , put the node input into new node inputs vector
1057 (void)new_node_inputs.emplace_back(cnode->input(arg_index + 1));
1058 continue;
1059 }
1060
1061 auto arg_input_node = cnode->input(arg_index + 1);
1062 if (!arg_input_node->isa<ValueNode>()) {
1063 // arg is not ValueNode, Network has dynamic args, not support
1064 MS_LOG(INTERNAL_EXCEPTION) << "Node " << cnode->DebugString() << " with arg " << arg_input_node->DebugString()
1065 << " is dynamic, not supported now.";
1066 continue;
1067 }
1068 auto arg_value_node = arg_input_node->cast<ValueNodePtr>();
1069 auto arg_value = arg_value_node->value();
1070 prim->AddAttr(arg.arg_name_, arg_value);
1071 }
1072
1073 auto func_graph = cnode->func_graph();
1074 MS_EXCEPTION_IF_NULL(func_graph);
1075 auto new_node = func_graph->NewCNode(new_node_inputs);
1076 new_node->set_abstract(cnode->abstract());
1077 new_node->set_fullname_with_scope(cnode->fullname_with_scope());
1078 return new_node;
1079 }
1080
1081 template <typename T>
GetScalarValue(const ValuePtr & value)1082 std::optional<T> GetScalarValue(const ValuePtr &value) {
1083 MS_EXCEPTION_IF_NULL(value);
1084 if (value->isa<ValueAny>()) {
1085 return std::nullopt;
1086 }
1087
1088 if (value->isa<KernelTensorValue>()) {
1089 auto kernel_tensor_value = value->cast<KernelTensorValuePtr>();
1090 MS_EXCEPTION_IF_NULL(kernel_tensor_value);
1091
1092 MS_EXCEPTION_IF_CHECK_FAIL((kernel_tensor_value->GetDataSize() == sizeof(T)),
1093 "The data size in kernel tensor value which contains a scalar [" +
1094 std::to_string(kernel_tensor_value->GetDataSize()) +
1095 "] is not equal to the data type size [" + std::to_string(sizeof(T)) + "]");
1096
1097 const T *data_ptr = reinterpret_cast<const T *>(kernel_tensor_value->GetDataPtr());
1098 MS_EXCEPTION_IF_NULL(data_ptr);
1099 return *data_ptr;
1100 }
1101
1102 return GetValue<T>(value);
1103 }
1104
1105 // Specialization for std::string type.
1106 template <>
GetScalarValue(const ValuePtr & value)1107 MS_CORE_API std::optional<std::string> GetScalarValue(const ValuePtr &value) {
1108 MS_EXCEPTION_IF_NULL(value);
1109 if (value->isa<ValueAny>()) {
1110 return std::nullopt;
1111 }
1112
1113 if (value->isa<KernelTensorValue>()) {
1114 auto kernel_tensor_value = value->cast<KernelTensorValuePtr>();
1115 MS_EXCEPTION_IF_NULL(kernel_tensor_value);
1116 const char *data_ptr = reinterpret_cast<const char *>(kernel_tensor_value->GetDataPtr());
1117 MS_EXCEPTION_IF_NULL(data_ptr);
1118 size_t str_len = kernel_tensor_value->GetDataSize();
1119
1120 return std::string(data_ptr, data_ptr + str_len);
1121 }
1122
1123 return GetValue<std::string>(value);
1124 }
1125
1126 template MS_CORE_API std::optional<int64_t> GetScalarValue(const ValuePtr &value);
1127 template MS_CORE_API std::optional<int32_t> GetScalarValue(const ValuePtr &value);
1128 template MS_CORE_API std::optional<int16_t> GetScalarValue(const ValuePtr &value);
1129 template MS_CORE_API std::optional<int8_t> GetScalarValue(const ValuePtr &value);
1130 template MS_CORE_API std::optional<uint64_t> GetScalarValue(const ValuePtr &value);
1131 template MS_CORE_API std::optional<uint32_t> GetScalarValue(const ValuePtr &value);
1132 template MS_CORE_API std::optional<uint16_t> GetScalarValue(const ValuePtr &value);
1133 template MS_CORE_API std::optional<uint8_t> GetScalarValue(const ValuePtr &value);
1134 template MS_CORE_API std::optional<double> GetScalarValue(const ValuePtr &value);
1135 template MS_CORE_API std::optional<float> GetScalarValue(const ValuePtr &value);
1136 template MS_CORE_API std::optional<bool> GetScalarValue(const ValuePtr &value);
1137
1138 // This interface is only used to convert values of type Sequence or Tensor to std::vector.
1139 template <typename T>
GetArrayValue(const ValuePtr & value)1140 std::optional<ArrayValue<T>> GetArrayValue(const ValuePtr &value) {
1141 MS_EXCEPTION_IF_NULL(value);
1142 if (value->isa<ValueAny>()) {
1143 return std::nullopt;
1144 }
1145
1146 std::vector<T> array_data;
1147 if (value->isa<KernelTensorValue>()) {
1148 auto kernel_tensor_value = value->cast<KernelTensorValuePtr>();
1149 MS_EXCEPTION_IF_NULL(kernel_tensor_value);
1150
1151 if (kernel_tensor_value->GetDataSize() % sizeof(T) != 0) {
1152 MS_LOG(EXCEPTION) << "The size is incompatible, kernel tensor value size: " << kernel_tensor_value->GetDataSize()
1153 << ", expected element size: " << sizeof(T);
1154 }
1155
1156 size_t element_size = kernel_tensor_value->GetDataSize() / sizeof(T);
1157 if (element_size != 0) {
1158 const T *data_ptr = reinterpret_cast<const T *>(kernel_tensor_value->GetDataPtr());
1159 MS_EXCEPTION_IF_NULL(data_ptr);
1160 array_data.assign(data_ptr, data_ptr + element_size);
1161 }
1162 } else if (value->isa<ValueSequence>()) {
1163 // Sequence structure: Data is stored discretely.
1164 auto value_seq = value->cast<ValueSequencePtr>();
1165 MS_EXCEPTION_IF_NULL(value_seq);
1166
1167 const auto &element_values = value_seq->value();
1168 size_t element_size = element_values.size();
1169 array_data.reserve(element_size);
1170 for (size_t i = 0; i < element_size; i++) {
1171 const auto &element = element_values[i];
1172 MS_EXCEPTION_IF_NULL(element);
1173 if (element->isa<ValueAny>() || element->isa<None>()) {
1174 return std::nullopt;
1175 }
1176 if constexpr (std::is_same_v<T, float16>) {
1177 MS_LOG(EXCEPTION) << "For ValueSequence, float16 type is not support!";
1178 } else {
1179 array_data.push_back(GetValue<T>(element));
1180 }
1181 }
1182 } else if (value->isa<tensor::BaseTensor>()) {
1183 // Tensor structure: Data is stored continuously.
1184 auto tensor = value->cast<tensor::BaseTensorPtr>();
1185 MS_EXCEPTION_IF_NULL(tensor);
1186 size_t element_size = tensor->DataSize();
1187 T *data = reinterpret_cast<T *>(tensor->data_c());
1188 array_data.assign(data, data + element_size);
1189 } else {
1190 MS_LOG(EXCEPTION) << "Failed to get array value, expect sequence or tensor type, but got: " << value->type_name();
1191 }
1192 return std::optional<ArrayValue<T>>(std::in_place, std::move(array_data), std::set<size_t>());
1193 }
1194
1195 template <typename T>
GetArrayValue(const AbstractBasePtr & abs_base)1196 std::optional<ArrayValue<T>> GetArrayValue(const AbstractBasePtr &abs_base) {
1197 MS_EXCEPTION_IF_NULL(abs_base);
1198 auto value = abs_base->GetValue();
1199 // If value is constant or is value sequence with some constant elements.
1200 if (!value->isa<ValueAny>()) {
1201 return GetArrayValue<T>(value);
1202 }
1203
1204 // If value is ValueAny, need check whether abstract is AbstractSequence, it is in frontend.
1205 std::vector<T> array_data;
1206 std::set<size_t> unknown_value_indexes;
1207 if (abs_base->isa<abstract::AbstractSequence>()) {
1208 auto abs_sequence = abs_base->cast<abstract::AbstractSequencePtr>();
1209 if (abs_sequence->dynamic_len()) {
1210 return std::nullopt;
1211 }
1212 for (size_t i = 0; i < abs_sequence->size(); ++i) {
1213 auto elem_value = abs_sequence->elements()[i]->GetValue();
1214 if (elem_value->isa<ValueAny>() || elem_value->isa<None>()) {
1215 array_data.push_back(static_cast<T>(0));
1216 (void)unknown_value_indexes.insert(i);
1217 continue;
1218 }
1219 if constexpr (std::is_same_v<T, float16>) {
1220 MS_LOG(EXCEPTION) << "For ValueSequence, float16 type is not support!";
1221 } else {
1222 array_data.push_back(GetValue<T>(elem_value));
1223 }
1224 }
1225 return std::optional<ArrayValue<T>>(std::in_place, std::move(array_data), std::move(unknown_value_indexes));
1226 }
1227 // Only abstract sequence with ValueAny need to handle, other situation just return nullopt.
1228 return std::nullopt;
1229 }
1230
1231 template MS_CORE_API std::optional<ArrayValue<int64_t>> GetArrayValue(const ValuePtr &value);
1232 template MS_CORE_API std::optional<ArrayValue<int32_t>> GetArrayValue(const ValuePtr &value);
1233 template MS_CORE_API std::optional<ArrayValue<int16_t>> GetArrayValue(const ValuePtr &value);
1234 template MS_CORE_API std::optional<ArrayValue<int8_t>> GetArrayValue(const ValuePtr &value);
1235 template MS_CORE_API std::optional<ArrayValue<uint64_t>> GetArrayValue(const ValuePtr &value);
1236 template MS_CORE_API std::optional<ArrayValue<uint32_t>> GetArrayValue(const ValuePtr &value);
1237 template MS_CORE_API std::optional<ArrayValue<uint16_t>> GetArrayValue(const ValuePtr &value);
1238 template MS_CORE_API std::optional<ArrayValue<uint8_t>> GetArrayValue(const ValuePtr &value);
1239 template MS_CORE_API std::optional<ArrayValue<double>> GetArrayValue(const ValuePtr &value);
1240 template MS_CORE_API std::optional<ArrayValue<float>> GetArrayValue(const ValuePtr &value);
1241 template MS_CORE_API std::optional<ArrayValue<bool>> GetArrayValue(const ValuePtr &value);
1242 template MS_CORE_API std::optional<ArrayValue<std::string>> GetArrayValue(const ValuePtr &value);
1243 template MS_CORE_API std::optional<ArrayValue<float16>> GetArrayValue(const ValuePtr &value);
1244 template MS_CORE_API std::optional<ArrayValue<bfloat16>> GetArrayValue(const ValuePtr &value);
1245
1246 template MS_CORE_API std::optional<ArrayValue<int64_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1247 template MS_CORE_API std::optional<ArrayValue<int32_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1248 template MS_CORE_API std::optional<ArrayValue<int16_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1249 template MS_CORE_API std::optional<ArrayValue<int8_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1250 template MS_CORE_API std::optional<ArrayValue<uint64_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1251 template MS_CORE_API std::optional<ArrayValue<uint32_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1252 template MS_CORE_API std::optional<ArrayValue<uint16_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1253 template MS_CORE_API std::optional<ArrayValue<uint8_t>> GetArrayValue(const AbstractBasePtr &abs_base);
1254 template MS_CORE_API std::optional<ArrayValue<double>> GetArrayValue(const AbstractBasePtr &abs_base);
1255 template MS_CORE_API std::optional<ArrayValue<float>> GetArrayValue(const AbstractBasePtr &abs_base);
1256 template MS_CORE_API std::optional<ArrayValue<bool>> GetArrayValue(const AbstractBasePtr &abs_base);
1257 template MS_CORE_API std::optional<ArrayValue<std::string>> GetArrayValue(const AbstractBasePtr &abs_base);
1258 template MS_CORE_API std::optional<ArrayValue<float16>> GetArrayValue(const AbstractBasePtr &abs_base);
1259 template MS_CORE_API std::optional<ArrayValue<bfloat16>> GetArrayValue(const AbstractBasePtr &abs_base);
1260
CheckTensorScalarRank(const PrimitivePtr & primitive,const AbstractBasePtr input_arg,const std::string & arg_name)1261 void CheckTensorScalarRank(const PrimitivePtr &primitive, const AbstractBasePtr input_arg,
1262 const std::string &arg_name) {
1263 MS_EXCEPTION_IF_NULL(input_arg);
1264 auto shape_ptr = input_arg->GetShape();
1265 MS_EXCEPTION_IF_NULL(shape_ptr);
1266 const auto &input_shape = shape_ptr->GetShapeVector();
1267 const int64_t kDimZero = 0;
1268 if (MS_LIKELY(!IsDynamic(input_shape))) {
1269 MS_CHECK_VALUE(input_shape.size() == LongToSize(kDimZero),
1270 CheckAndConvertUtils::FormatCheckIntegerMsg("rank of " + arg_name, SizeToLong(input_shape.size()),
1271 kEqual, kDimZero, primitive));
1272 }
1273 }
1274 } // namespace ops
1275 } // namespace mindspore
1276