• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ops/range_v2.h"
17 
18 #include <memory>
19 #include <set>
20 #include <type_traits>
21 #include <vector>
22 
23 #include "abstract/abstract_value.h"
24 #include "abstract/dshape.h"
25 #include "abstract/ops/op_infer.h"
26 #include "abstract/ops/primitive_infer_map.h"
27 #include "abstract/utils.h"
28 #include "base/base.h"
29 #include "ir/anf.h"
30 #include "ir/dtype/number.h"
31 #include "ir/dtype/type.h"
32 #include "ir/named.h"
33 #include "ir/primitive.h"
34 #include "ir/tensor.h"
35 #include "ir/value.h"
36 #include "mindapi/base/shape_vector.h"
37 #include "mindapi/src/helper.h"
38 #include "mindspore/core/ops/array_ops.h"
39 #include "ops/op_name.h"
40 #include "ops/op_utils.h"
41 #include "ops/primitive_c.h"
42 #include "utils/check_convert_utils.h"
43 #include "utils/log_adapter.h"
44 #include "utils/shape_utils.h"
45 
46 namespace mindspore {
47 namespace ops {
48 namespace {
49 #define IsSameType(source_type, cmp_type) (cmp_type->equal(source_type))
50 #define IsNoneOrAnyValue(value_ptr) ((value_ptr->isa<None>()) || (value_ptr->ContainsValueAny()))
51 constexpr auto op_name = "RangeV2";
52 
53 template <typename T>
RangeV2CalculateShape(const AbstractBasePtr & start_ptr,const AbstractBasePtr limit_ptr,const AbstractBasePtr delta_ptr)54 int64_t RangeV2CalculateShape(const AbstractBasePtr &start_ptr, const AbstractBasePtr limit_ptr,
55                               const AbstractBasePtr delta_ptr) {
56   auto start_array = GetArrayValue<T>(start_ptr);
57   auto limit_array = GetArrayValue<T>(limit_ptr);
58   auto delta_array = GetArrayValue<T>(delta_ptr);
59   if (!start_array.has_value() || start_array.value().size() != 1) {
60     MS_EXCEPTION(TypeError) << "For RangeV2, start must a scalar but element number more than 1.";
61   }
62   if (!limit_array.has_value() || limit_array.value().size() != 1) {
63     MS_EXCEPTION(TypeError) << "For RangeV2, limit must a scalar but element number more than 1.";
64   }
65   if (!delta_array.has_value() || delta_array.value().size() != 1) {
66     MS_EXCEPTION(TypeError) << "For RangeV2, delta must a scalar but element number more than 1.";
67   }
68   T start = start_array.value()[0];
69   T limit = limit_array.value()[0];
70   T delta = delta_array.value()[0];
71   bool valid_value = (delta == T(0) || (delta > 0 && start > limit) || (delta < 0 && start < limit));
72   if (valid_value) {
73     if (delta == T(0)) {
74       MS_EXCEPTION(ValueError) << "For RangeV2, delta cannot be equal to zero.";
75     }
76     if (delta > 0 && start > limit) {
77       MS_EXCEPTION(ValueError) << "For RangeV2, delta cannot be positive when limit < start.";
78     }
79     if (delta < 0 && start < limit) {
80       MS_EXCEPTION(ValueError) << "For RangeV2, delta cannot be negative when limit > start.";
81     }
82   }
83   int64_t shape_size = 0;
84   if (std::is_integral<T>::value) {
85     shape_size = static_cast<int64_t>((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta));
86   } else {
87     shape_size = static_cast<int64_t>(std::ceil(std::abs((limit - start) / delta)));
88   }
89   return shape_size;
90 }
91 
RangeV2CheckAndInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)92 abstract::ShapePtr RangeV2CheckAndInferShape(const PrimitivePtr &primitive,
93                                              const std::vector<AbstractBasePtr> &input_args) {
94   MS_EXCEPTION_IF_NULL(primitive->GetAttr(kMaxLen));
95   auto start = input_args[kInputIndex0];
96   auto limit = input_args[kInputIndex1];
97   auto delta = input_args[kInputIndex2];
98   // support dynamic rank
99   auto start_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(start->GetShape())[kShape];
100   auto limit_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(limit->GetShape())[kShape];
101   auto delta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(delta->GetShape())[kShape];
102   if (IsDynamicRank(start_shape) || IsDynamicRank(limit_shape) || IsDynamicRank(delta_shape)) {
103     return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
104   }
105   int64_t shape_size = abstract::Shape::kShapeDimAny;
106 
107   bool is_compile = !IsValueKnown(start) || !IsValueKnown(limit) || !IsValueKnown(delta);
108   // not in compile, need inferShape
109   if (!is_compile) {
110     auto dtype = CheckAndConvertUtils::GetTensorInputType(op_name, input_args, kInputIndex0);
111     if (IsSameType(dtype, kInt) || IsSameType(dtype, kInt32)) {
112       shape_size = RangeV2CalculateShape<int32_t>(start, limit, delta);
113     } else if (IsSameType(dtype, kInt64)) {
114       shape_size = RangeV2CalculateShape<int64_t>(start, limit, delta);
115     } else if (IsSameType(dtype, kFloat) || IsSameType(dtype, kFloat32)) {
116       shape_size = RangeV2CalculateShape<float>(start, limit, delta);
117     } else if (IsSameType(dtype, kFloat64)) {
118       shape_size = RangeV2CalculateShape<double>(start, limit, delta);
119     } else {
120       MS_EXCEPTION(TypeError) << "For RangeV2, the dtype of input must be int32, int64, float32, float64, but got "
121                               << dtype->meta_type() << ".";
122     }
123     if (shape_size < 0) {
124       MS_EXCEPTION(ValueError) << "For RangeV2, infer shape error, shape_size [" << shape_size << "] is negative.";
125     }
126   }
127 
128   ShapeVector out_shape = {};
129   if (is_compile) {
130     (void)out_shape.emplace_back(abstract::Shape::kShapeDimAny);
131     return std::make_shared<abstract::Shape>(out_shape);
132   }
133 
134   (void)out_shape.emplace_back(shape_size);
135   return std::make_shared<abstract::Shape>(out_shape);
136 }
137 
RangeV2CheckAndInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)138 TypePtr RangeV2CheckAndInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
139   std::set<TypePtr> support_types = {kInt32, kInt64, kFloat32, kFloat64};
140   auto start_type = CheckAndConvertUtils::CheckTensorTypeValid("start", input_args[kInputIndex0]->GetType(),
141                                                                support_types, prim->name());
142   auto limit_type = CheckAndConvertUtils::CheckTensorTypeValid("limit", input_args[kInputIndex1]->GetType(),
143                                                                support_types, prim->name());
144   auto delta_type = CheckAndConvertUtils::CheckTensorTypeValid("delta", input_args[kInputIndex1]->GetType(),
145                                                                support_types, prim->name());
146   MS_EXCEPTION_IF_NULL(start_type);
147   MS_EXCEPTION_IF_NULL(limit_type);
148   MS_EXCEPTION_IF_NULL(delta_type);
149   bool same_type = IsSameType(start_type, limit_type) && IsSameType(limit_type, delta_type);
150   if (!same_type) {
151     MS_EXCEPTION(TypeError) << "For RangeV2, start, limit delta should have same type, but get start["
152                             << start_type->meta_type() << "], limit[" << limit_type->meta_type() << "], delta["
153                             << delta_type->meta_type() << "].";
154   }
155   return start_type;
156 }
157 }  // namespace
158 
RangeV2Infer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)159 AbstractBasePtr RangeV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
160                              const std::vector<AbstractBasePtr> &input_args) {
161   MS_EXCEPTION_IF_NULL(primitive);
162   const int64_t input_num = 3;
163   CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
164   (void)CheckAndConvertUtils::CheckArgsType(op_name, input_args, kInputIndex0, kObjectTypeTensorType);
165   (void)CheckAndConvertUtils::CheckArgsType(op_name, input_args, kInputIndex1, kObjectTypeTensorType);
166   (void)CheckAndConvertUtils::CheckArgsType(op_name, input_args, kInputIndex2, kObjectTypeTensorType);
167   // infer type must in before
168   auto infer_type = RangeV2CheckAndInferType(primitive, input_args);
169   auto infer_shape = RangeV2CheckAndInferShape(primitive, input_args);
170   return abstract::MakeAbstract(infer_shape, infer_type);
171 }
172 
173 MIND_API_OPERATOR_IMPL(RangeV2, BaseOperator);
174 
175 // AG means auto generated
176 class MIND_API AGRangeV2Infer : public abstract::OpInferBase {
177  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const178   BaseShapePtr InferShape(const PrimitivePtr &primitive,
179                           const std::vector<AbstractBasePtr> &input_args) const override {
180     return RangeV2CheckAndInferShape(primitive, input_args);
181   }
182 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const183   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
184     return RangeV2CheckAndInferType(primitive, input_args);
185   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const186   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
187                                     const std::vector<AbstractBasePtr> &input_args) const override {
188     return RangeV2Infer(engine, primitive, input_args);
189   }
190 
GetValueDependArgIndices() const191   std::set<int64_t> GetValueDependArgIndices() const override { return {0, 1, 2}; }
192 };
193 
194 REGISTER_PRIMITIVE_OP_INFER_IMPL(RangeV2, prim::kPrimRangeV2, AGRangeV2Infer, false);
195 }  // namespace ops
196 }  // namespace mindspore
197