1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "ops/range_v2.h"
17
18 #include <memory>
19 #include <set>
20 #include <type_traits>
21 #include <vector>
22
23 #include "abstract/abstract_value.h"
24 #include "abstract/dshape.h"
25 #include "abstract/ops/op_infer.h"
26 #include "abstract/ops/primitive_infer_map.h"
27 #include "abstract/utils.h"
28 #include "base/base.h"
29 #include "ir/anf.h"
30 #include "ir/dtype/number.h"
31 #include "ir/dtype/type.h"
32 #include "ir/named.h"
33 #include "ir/primitive.h"
34 #include "ir/tensor.h"
35 #include "ir/value.h"
36 #include "mindapi/base/shape_vector.h"
37 #include "mindapi/src/helper.h"
38 #include "mindspore/core/ops/array_ops.h"
39 #include "ops/op_name.h"
40 #include "ops/op_utils.h"
41 #include "ops/primitive_c.h"
42 #include "utils/check_convert_utils.h"
43 #include "utils/log_adapter.h"
44 #include "utils/shape_utils.h"
45
46 namespace mindspore {
47 namespace ops {
48 namespace {
49 #define IsSameType(source_type, cmp_type) (cmp_type->equal(source_type))
50 #define IsNoneOrAnyValue(value_ptr) ((value_ptr->isa<None>()) || (value_ptr->ContainsValueAny()))
51 constexpr auto op_name = "RangeV2";
52
53 template <typename T>
RangeV2CalculateShape(const AbstractBasePtr & start_ptr,const AbstractBasePtr limit_ptr,const AbstractBasePtr delta_ptr)54 int64_t RangeV2CalculateShape(const AbstractBasePtr &start_ptr, const AbstractBasePtr limit_ptr,
55 const AbstractBasePtr delta_ptr) {
56 auto start_array = GetArrayValue<T>(start_ptr);
57 auto limit_array = GetArrayValue<T>(limit_ptr);
58 auto delta_array = GetArrayValue<T>(delta_ptr);
59 if (!start_array.has_value() || start_array.value().size() != 1) {
60 MS_EXCEPTION(TypeError) << "For RangeV2, start must a scalar but element number more than 1.";
61 }
62 if (!limit_array.has_value() || limit_array.value().size() != 1) {
63 MS_EXCEPTION(TypeError) << "For RangeV2, limit must a scalar but element number more than 1.";
64 }
65 if (!delta_array.has_value() || delta_array.value().size() != 1) {
66 MS_EXCEPTION(TypeError) << "For RangeV2, delta must a scalar but element number more than 1.";
67 }
68 T start = start_array.value()[0];
69 T limit = limit_array.value()[0];
70 T delta = delta_array.value()[0];
71 bool valid_value = (delta == T(0) || (delta > 0 && start > limit) || (delta < 0 && start < limit));
72 if (valid_value) {
73 if (delta == T(0)) {
74 MS_EXCEPTION(ValueError) << "For RangeV2, delta cannot be equal to zero.";
75 }
76 if (delta > 0 && start > limit) {
77 MS_EXCEPTION(ValueError) << "For RangeV2, delta cannot be positive when limit < start.";
78 }
79 if (delta < 0 && start < limit) {
80 MS_EXCEPTION(ValueError) << "For RangeV2, delta cannot be negative when limit > start.";
81 }
82 }
83 int64_t shape_size = 0;
84 if (std::is_integral<T>::value) {
85 shape_size = static_cast<int64_t>((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta));
86 } else {
87 shape_size = static_cast<int64_t>(std::ceil(std::abs((limit - start) / delta)));
88 }
89 return shape_size;
90 }
91
RangeV2CheckAndInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)92 abstract::ShapePtr RangeV2CheckAndInferShape(const PrimitivePtr &primitive,
93 const std::vector<AbstractBasePtr> &input_args) {
94 MS_EXCEPTION_IF_NULL(primitive->GetAttr(kMaxLen));
95 auto start = input_args[kInputIndex0];
96 auto limit = input_args[kInputIndex1];
97 auto delta = input_args[kInputIndex2];
98 // support dynamic rank
99 auto start_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(start->GetShape())[kShape];
100 auto limit_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(limit->GetShape())[kShape];
101 auto delta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(delta->GetShape())[kShape];
102 if (IsDynamicRank(start_shape) || IsDynamicRank(limit_shape) || IsDynamicRank(delta_shape)) {
103 return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
104 }
105 int64_t shape_size = abstract::Shape::kShapeDimAny;
106
107 bool is_compile = !IsValueKnown(start) || !IsValueKnown(limit) || !IsValueKnown(delta);
108 // not in compile, need inferShape
109 if (!is_compile) {
110 auto dtype = CheckAndConvertUtils::GetTensorInputType(op_name, input_args, kInputIndex0);
111 if (IsSameType(dtype, kInt) || IsSameType(dtype, kInt32)) {
112 shape_size = RangeV2CalculateShape<int32_t>(start, limit, delta);
113 } else if (IsSameType(dtype, kInt64)) {
114 shape_size = RangeV2CalculateShape<int64_t>(start, limit, delta);
115 } else if (IsSameType(dtype, kFloat) || IsSameType(dtype, kFloat32)) {
116 shape_size = RangeV2CalculateShape<float>(start, limit, delta);
117 } else if (IsSameType(dtype, kFloat64)) {
118 shape_size = RangeV2CalculateShape<double>(start, limit, delta);
119 } else {
120 MS_EXCEPTION(TypeError) << "For RangeV2, the dtype of input must be int32, int64, float32, float64, but got "
121 << dtype->meta_type() << ".";
122 }
123 if (shape_size < 0) {
124 MS_EXCEPTION(ValueError) << "For RangeV2, infer shape error, shape_size [" << shape_size << "] is negative.";
125 }
126 }
127
128 ShapeVector out_shape = {};
129 if (is_compile) {
130 (void)out_shape.emplace_back(abstract::Shape::kShapeDimAny);
131 return std::make_shared<abstract::Shape>(out_shape);
132 }
133
134 (void)out_shape.emplace_back(shape_size);
135 return std::make_shared<abstract::Shape>(out_shape);
136 }
137
RangeV2CheckAndInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)138 TypePtr RangeV2CheckAndInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
139 std::set<TypePtr> support_types = {kInt32, kInt64, kFloat32, kFloat64};
140 auto start_type = CheckAndConvertUtils::CheckTensorTypeValid("start", input_args[kInputIndex0]->GetType(),
141 support_types, prim->name());
142 auto limit_type = CheckAndConvertUtils::CheckTensorTypeValid("limit", input_args[kInputIndex1]->GetType(),
143 support_types, prim->name());
144 auto delta_type = CheckAndConvertUtils::CheckTensorTypeValid("delta", input_args[kInputIndex1]->GetType(),
145 support_types, prim->name());
146 MS_EXCEPTION_IF_NULL(start_type);
147 MS_EXCEPTION_IF_NULL(limit_type);
148 MS_EXCEPTION_IF_NULL(delta_type);
149 bool same_type = IsSameType(start_type, limit_type) && IsSameType(limit_type, delta_type);
150 if (!same_type) {
151 MS_EXCEPTION(TypeError) << "For RangeV2, start, limit delta should have same type, but get start["
152 << start_type->meta_type() << "], limit[" << limit_type->meta_type() << "], delta["
153 << delta_type->meta_type() << "].";
154 }
155 return start_type;
156 }
157 } // namespace
158
RangeV2Infer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)159 AbstractBasePtr RangeV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
160 const std::vector<AbstractBasePtr> &input_args) {
161 MS_EXCEPTION_IF_NULL(primitive);
162 const int64_t input_num = 3;
163 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
164 (void)CheckAndConvertUtils::CheckArgsType(op_name, input_args, kInputIndex0, kObjectTypeTensorType);
165 (void)CheckAndConvertUtils::CheckArgsType(op_name, input_args, kInputIndex1, kObjectTypeTensorType);
166 (void)CheckAndConvertUtils::CheckArgsType(op_name, input_args, kInputIndex2, kObjectTypeTensorType);
167 // infer type must in before
168 auto infer_type = RangeV2CheckAndInferType(primitive, input_args);
169 auto infer_shape = RangeV2CheckAndInferShape(primitive, input_args);
170 return abstract::MakeAbstract(infer_shape, infer_type);
171 }
172
173 MIND_API_OPERATOR_IMPL(RangeV2, BaseOperator);
174
175 // AG means auto generated
176 class MIND_API AGRangeV2Infer : public abstract::OpInferBase {
177 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const178 BaseShapePtr InferShape(const PrimitivePtr &primitive,
179 const std::vector<AbstractBasePtr> &input_args) const override {
180 return RangeV2CheckAndInferShape(primitive, input_args);
181 }
182
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const183 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
184 return RangeV2CheckAndInferType(primitive, input_args);
185 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const186 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
187 const std::vector<AbstractBasePtr> &input_args) const override {
188 return RangeV2Infer(engine, primitive, input_args);
189 }
190
GetValueDependArgIndices() const191 std::set<int64_t> GetValueDependArgIndices() const override { return {0, 1, 2}; }
192 };
193
194 REGISTER_PRIMITIVE_OP_INFER_IMPL(RangeV2, prim::kPrimRangeV2, AGRangeV2Infer, false);
195 } // namespace ops
196 } // namespace mindspore
197