1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/unique_consecutive.h"
18
19 #include <functional>
20 #include <iostream>
21
22 #include "abstract/dshape.h"
23 #include "abstract/ops/primitive_infer_map.h"
24 #include "mindapi/src/helper.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/math_ops.h"
27 #include "ops/op_utils.h"
28 #include "ops/primitive_c.h"
29 #include "utils/check_convert_utils.h"
30
31 namespace mindspore {
32 namespace ops {
33 namespace {
34 constexpr int64_t kUniqueConsecutiveInputNum = 1;
35 // For aicpu, if axis is 1000, that represents None.
36 constexpr int64_t kAxisIsNone = 1000;
37
CheckNullInput(const std::vector<int64_t> & shape)38 bool CheckNullInput(const std::vector<int64_t> &shape) {
39 if (shape.size() != 0) {
40 if (std::any_of(shape.begin(), shape.end(), [](int64_t i) { return i == 0; })) {
41 return true;
42 }
43 }
44 return false;
45 }
46
UniqueConsecutiveInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)47 abstract::BaseShapePtr UniqueConsecutiveInferShape(const PrimitivePtr &primitive,
48 const std::vector<AbstractBasePtr> &input_args) {
49 auto op_name = primitive->name();
50 auto input_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape());
51 auto input_shape_vec = input_shape_map[kShape];
52 if (CheckNullInput(input_shape_vec)) {
53 MS_LOG(EXCEPTION) << "For " << op_name << ", the shape of input cannot contain zero.";
54 }
55
56 auto axis_ptr = primitive->GetAttr(kAxis);
57 MS_EXCEPTION_IF_NULL(axis_ptr);
58 abstract::ShapePtr output_shape;
59 abstract::ShapePtr idx_shape;
60 abstract::ShapePtr counts_shape;
61 ShapeVector output_max_vec;
62 ShapeVector idx_shape_vec;
63 ShapeVector counts_max_vec;
64 // dynamic shape, the infershape function will be called two times. In the second time, the attribute
65 // axis may be deleted so as to axis_ptr is nullptr.
66 if (axis_ptr->isa<None>() || GetValue<int64_t>(axis_ptr) == kAxisIsNone) {
67 MS_LOG(INFO) << "node:" << op_name << " has no axis attribute or axis id None! Deal as flatten";
68 (void)primitive->SetAttrs({{"axis", MakeValue(kAxisIsNone)}});
69 idx_shape_vec = input_shape_vec;
70 auto input_total = std::accumulate(input_shape_vec.begin(), input_shape_vec.end(), 1, std::multiplies<int64_t>());
71 output_max_vec = {input_total};
72 counts_max_vec = {input_total};
73 } else {
74 int64_t axis = GetValue<int64_t>(axis_ptr);
75 int64_t ndims = SizeToLong(input_shape_vec.size());
76 if (axis >= ndims || axis < -ndims) {
77 MS_EXCEPTION(ValueError) << "For " << op_name << ", the axis must be in the range [-" << ndims << "," << ndims
78 << "), but got " << axis << ".";
79 }
80 if (axis < 0) {
81 axis = axis + ndims;
82 }
83 size_t axis_size = LongToSize(axis);
84 output_max_vec = input_shape_vec;
85 idx_shape_vec = {input_shape_vec[axis_size]};
86 counts_max_vec = {input_shape_vec[axis_size]};
87 }
88
89 auto idx_ptr = primitive->GetAttr("return_idx");
90 MS_EXCEPTION_IF_NULL(idx_ptr);
91 auto cnt_ptr = primitive->GetAttr("return_counts");
92 MS_EXCEPTION_IF_NULL(cnt_ptr);
93 const auto &return_idx = GetValue<bool>(idx_ptr);
94 if (!return_idx) {
95 idx_shape_vec = {0};
96 }
97
98 output_shape = std::make_shared<abstract::Shape>(output_max_vec);
99 counts_shape = std::make_shared<abstract::Shape>(counts_max_vec);
100 idx_shape = std::make_shared<abstract::Shape>(idx_shape_vec);
101
102 auto ret_shape_vec = std::vector<abstract::BaseShapePtr>{output_shape};
103 (void)ret_shape_vec.emplace_back(idx_shape);
104 (void)ret_shape_vec.emplace_back(counts_shape);
105 return std::make_shared<abstract::TupleShape>(ret_shape_vec);
106 }
107
UniqueConsecutiveFrontendInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)108 abstract::BaseShapePtr UniqueConsecutiveFrontendInferShape(const PrimitivePtr &primitive,
109 const std::vector<AbstractBasePtr> &input_args) {
110 auto op_name = primitive->name();
111 auto input_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape());
112 auto input_shape_vec = input_shape_map[kShape];
113 if (CheckNullInput(input_shape_vec)) {
114 MS_LOG(EXCEPTION) << "For " << op_name << ", the shape of input cannot contain zero.";
115 }
116
117 auto axis_ptr = primitive->GetAttr(kAxis);
118 MS_EXCEPTION_IF_NULL(axis_ptr);
119 abstract::ShapePtr output_shape;
120 abstract::ShapePtr idx_shape;
121 abstract::ShapePtr counts_shape;
122 ShapeVector output_vec;
123 ShapeVector idx_shape_vec;
124 ShapeVector counts_shape_vec;
125 // dynamic shape, the infershape function will be called two times. In the second time, the attribute
126 // axis may be deleted so as to axis_ptr is nullptr.
127 if (axis_ptr->isa<None>() || GetValue<int64_t>(axis_ptr) == kAxisIsNone) {
128 MS_LOG(INFO) << "node:" << op_name << " has no axis attribute or axis id None! Deal as flatten";
129 (void)primitive->SetAttrs({{"axis", MakeValue(kAxisIsNone)}});
130 output_vec = {abstract::Shape::kShapeDimAny};
131 counts_shape_vec = {abstract::Shape::kShapeDimAny};
132 idx_shape_vec = input_shape_vec;
133 } else {
134 int64_t axis = GetValue<int64_t>(axis_ptr);
135 int64_t ndims = SizeToLong(input_shape_vec.size());
136 if (axis >= ndims || axis < -ndims) {
137 MS_EXCEPTION(ValueError) << "For " << op_name << ", the axis must be in the range [-" << ndims << "," << ndims
138 << "), but got " << axis << ".";
139 }
140 if (axis < 0) {
141 axis = axis + ndims;
142 }
143 if (IsDynamicRank(input_shape_vec) || IsDynamicShape(input_shape_vec)) {
144 output_vec = {abstract::Shape::kShapeRankAny};
145 counts_shape_vec = {abstract::Shape::kShapeRankAny};
146 idx_shape_vec = {abstract::Shape::kShapeRankAny};
147 } else {
148 size_t axis_size = LongToSize(axis);
149 output_vec = input_shape_vec;
150 output_vec[axis_size] = abstract::Shape::kShapeDimAny;
151 idx_shape_vec = {input_shape_vec[axis_size]};
152 counts_shape_vec = {abstract::Shape::kShapeDimAny};
153 }
154 }
155
156 auto idx_ptr = primitive->GetAttr("return_idx");
157 MS_EXCEPTION_IF_NULL(idx_ptr);
158 auto cnt_ptr = primitive->GetAttr("return_counts");
159 MS_EXCEPTION_IF_NULL(cnt_ptr);
160 const auto &return_idx = GetValue<bool>(idx_ptr);
161 const auto &return_counts = GetValue<bool>(cnt_ptr);
162 if (!return_idx) {
163 idx_shape_vec = {0};
164 }
165 if (!return_counts) {
166 counts_shape_vec = {0};
167 }
168
169 output_shape = std::make_shared<abstract::Shape>(output_vec);
170 counts_shape = std::make_shared<abstract::Shape>(counts_shape_vec);
171 idx_shape = std::make_shared<abstract::Shape>(idx_shape_vec);
172
173 auto ret_shape_vec = std::vector<abstract::BaseShapePtr>{output_shape};
174 (void)ret_shape_vec.emplace_back(idx_shape);
175 (void)ret_shape_vec.emplace_back(counts_shape);
176 return std::make_shared<abstract::TupleShape>(ret_shape_vec);
177 }
178
UniqueConsecutiveInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)179 TypePtr UniqueConsecutiveInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
180 MS_EXCEPTION_IF_NULL(primitive);
181 auto name = primitive->name();
182 const std::set<TypePtr> valid_types = {kComplex64, kComplex128, kFloat16, kFloat, kFloat64, kInt8, kInt16,
183 kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
184 auto input_type = CheckAndConvertUtils::CheckTypeValid("input", input_args[0]->GetType(), valid_types, name);
185 std::vector<TypePtr> ret_type_vec = {input_type, std::make_shared<TensorType>(kInt64),
186 std::make_shared<TensorType>(kInt64)};
187 return std::make_shared<Tuple>(ret_type_vec);
188 }
189 } // namespace
190
UniqueConsecutiveInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)191 AbstractBasePtr UniqueConsecutiveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
192 const std::vector<AbstractBasePtr> &input_args) {
193 MS_EXCEPTION_IF_NULL(primitive);
194 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kUniqueConsecutiveInputNum, primitive->name());
195 auto infer_type = UniqueConsecutiveInferType(primitive, input_args);
196 auto infer_shape = UniqueConsecutiveFrontendInferShape(primitive, input_args);
197 return abstract::MakeAbstract(infer_shape, infer_type);
198 }
199
200 MIND_API_OPERATOR_IMPL(UniqueConsecutive, BaseOperator);
201
202 // AG means auto generated
203 class MIND_API AGUniqueConsecutiveInfer : public abstract::OpInferBase {
204 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const205 BaseShapePtr InferShape(const PrimitivePtr &primitive,
206 const std::vector<AbstractBasePtr> &input_args) const override {
207 return UniqueConsecutiveInferShape(primitive, input_args);
208 }
209
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const210 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
211 return UniqueConsecutiveInferType(primitive, input_args);
212 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const213 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
214 const std::vector<AbstractBasePtr> &input_args) const override {
215 return UniqueConsecutiveInfer(engine, primitive, input_args);
216 }
217 };
218
219 REGISTER_PRIMITIVE_OP_INFER_IMPL(UniqueConsecutive, prim::kPrimUniqueConsecutive, AGUniqueConsecutiveInfer, false);
220 } // namespace ops
221 } // namespace mindspore
222