1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "ops/sparse_fill_empty_rows.h"
17
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <vector>
23
24 #include "abstract/abstract_value.h"
25 #include "abstract/dshape.h"
26 #include "abstract/ops/op_infer.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "abstract/utils.h"
29 #include "base/base.h"
30 #include "ir/anf.h"
31 #include "ir/dtype/container.h"
32 #include "ir/dtype/number.h"
33 #include "ir/dtype/tensor_type.h"
34 #include "ir/named.h"
35 #include "ir/primitive.h"
36 #include "ir/value.h"
37 #include "mindapi/base/shape_vector.h"
38 #include "mindapi/src/helper.h"
39 #include "mindspore/core/ops/math_ops.h"
40 #include "mindspore/core/ops/sparse_ops.h"
41 #include "ops/op_name.h"
42 #include "ops/primitive_c.h"
43 #include "utils/check_convert_utils.h"
44 #include "utils/convert_utils_base.h"
45 #include "utils/log_adapter.h"
46 #include "utils/shape_utils.h"
47
48 namespace mindspore {
49 namespace ops {
50 namespace {
CheckSparseFillEmptyRowsInputs(const std::vector<AbstractBasePtr> & input_args,const std::string & op_name)51 bool CheckSparseFillEmptyRowsInputs(const std::vector<AbstractBasePtr> &input_args, const std::string &op_name) {
52 auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
53 auto values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
54 auto dense_shape_shape =
55 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
56 auto default_value_shape =
57 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
58 if (IsDynamic(indices_shape) || IsDynamic(values_shape) || IsDynamic(dense_shape_shape) ||
59 IsDynamic(default_value_shape)) {
60 return false;
61 }
62
63 const int64_t indice_rank = 2;
64 const int64_t values_rank = 1;
65 const int64_t dense_shape_rank = 1;
66 const int64_t default_value_rank = 0;
67 const int64_t dense_rank = 2;
68
69 (void)CheckAndConvertUtils::CheckInteger("indices rank", SizeToLong(indices_shape.size()), kEqual, indice_rank,
70 op_name);
71 if (indices_shape[1] != dense_rank) {
72 MS_EXCEPTION(ValueError) << "For SparseFillEmptyRows, "
73 << "the last dim of the indices must be 2, but got " << indices_shape[1] << ".";
74 }
75 (void)CheckAndConvertUtils::CheckInteger("values rank", SizeToLong(values_shape.size()), kEqual, values_rank,
76 op_name);
77 (void)CheckAndConvertUtils::CheckInteger("dense_shape rank", SizeToLong(dense_shape_shape.size()), kEqual,
78 dense_shape_rank, op_name);
79 (void)CheckAndConvertUtils::CheckInteger("dense_shape size", dense_shape_shape[0], kEqual, dense_rank, op_name);
80 (void)CheckAndConvertUtils::CheckInteger("default_value rank", SizeToLong(default_value_shape.size()), kEqual,
81 default_value_rank, op_name);
82 if (indices_shape[0] != values_shape[0]) {
83 MS_EXCEPTION(ValueError) << "For SparseFillEmptyRows, "
84 << "the size of indices must be equal to values, but got " << indices_shape[0] << " and "
85 << values_shape[0] << ".";
86 }
87 return true;
88 }
89
SparseFillEmptyRowsInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)90 abstract::TupleShapePtr SparseFillEmptyRowsInferShape(const PrimitivePtr &primitive,
91 const std::vector<AbstractBasePtr> &input_args) {
92 auto op_name = primitive->name();
93 abstract::ShapePtr output_indices_shape;
94 abstract::ShapePtr output_values_shape;
95 abstract::ShapePtr output_empty_row_indicator_shape;
96 abstract::ShapePtr output_reverse_index_map_shape;
97
98 const int64_t rank = 2;
99 auto input_shape_value = input_args[kInputIndex2]->GetValue();
100 MS_EXCEPTION_IF_NULL(input_shape_value);
101
102 if (CheckSparseFillEmptyRowsInputs(input_args, op_name) && !input_shape_value->isa<ValueAny>() &&
103 !input_shape_value->isa<None>()) {
104 auto indice_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape())[kShape];
105 const int64_t input_nnz = indice_shape[0];
106
107 auto dense_row = CheckAndConvertUtils::CheckTensorIntValue("x_dense_shape", input_shape_value, op_name,
108 input_args[kInputIndex2]->GetType())[0];
109 output_indices_shape = std::make_shared<abstract::Shape>(ShapeVector({input_nnz + dense_row, rank}));
110 output_values_shape = std::make_shared<abstract::Shape>(ShapeVector({input_nnz + dense_row}));
111 output_empty_row_indicator_shape = std::make_shared<abstract::Shape>(ShapeVector({dense_row}));
112 output_reverse_index_map_shape = std::make_shared<abstract::Shape>(ShapeVector({input_nnz}));
113 } else {
114 output_indices_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny, rank}));
115 output_values_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
116 output_empty_row_indicator_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
117 output_reverse_index_map_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
118 }
119
120 return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
121 output_indices_shape, output_values_shape, output_empty_row_indicator_shape, output_reverse_index_map_shape});
122 }
123
SparseFillEmptyRowsFrontendInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)124 abstract::TupleShapePtr SparseFillEmptyRowsFrontendInferShape(const PrimitivePtr &primitive,
125 const std::vector<AbstractBasePtr> &input_args) {
126 auto op_name = primitive->name();
127 abstract::ShapePtr output_indices_shape;
128 abstract::ShapePtr output_values_shape;
129 abstract::ShapePtr output_empty_row_indicator_shape;
130 abstract::ShapePtr output_reverse_index_map_shape;
131
132 const int64_t rank = 2;
133 auto input_shape_value = input_args[kInputIndex2]->GetValue();
134 MS_EXCEPTION_IF_NULL(input_shape_value);
135
136 if (CheckSparseFillEmptyRowsInputs(input_args, op_name) && !input_shape_value->isa<ValueAny>() &&
137 !input_shape_value->isa<None>()) {
138 auto indice_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape())[kShape];
139 const int64_t input_nnz = indice_shape[0];
140
141 auto dense_row = CheckAndConvertUtils::CheckTensorIntValue("x_dense_shape", input_shape_value, op_name,
142 input_args[kInputIndex2]->GetType())[0];
143 output_indices_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny, rank}));
144 output_values_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
145 output_empty_row_indicator_shape = std::make_shared<abstract::Shape>(ShapeVector({dense_row}));
146 output_reverse_index_map_shape = std::make_shared<abstract::Shape>(ShapeVector({input_nnz}));
147 } else {
148 output_indices_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny, rank}));
149 output_values_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
150 output_empty_row_indicator_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
151 output_reverse_index_map_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
152 }
153
154 return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
155 output_indices_shape, output_values_shape, output_empty_row_indicator_shape, output_reverse_index_map_shape});
156 }
157
SparseFillEmptyRowsInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)158 TypePtr SparseFillEmptyRowsInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
159 auto op_name = primitive->name();
160 const std::set<TypePtr> common_valid_types_with_bool_and_complex = {
161 kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
162 kUInt64, kFloat16, kFloat32, kFloat64, kBool, kComplex64, kComplex128};
163 auto indices_type = input_args[kInputIndex0]->GetType();
164 auto values_type = input_args[kInputIndex1]->GetType();
165 auto dense_shape_type = input_args[kInputIndex2]->GetType();
166 auto default_value_type = input_args[kInputIndex3]->GetType();
167 std::map<std::string, TypePtr> types;
168 (void)types.emplace("values", values_type);
169 (void)types.emplace("default_value", default_value_type);
170 (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_bool_and_complex, op_name);
171 (void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, {kInt64}, op_name);
172 (void)CheckAndConvertUtils::CheckTensorTypeValid("dense_shape", dense_shape_type, {kInt64}, op_name);
173 return std::make_shared<Tuple>(std::vector<TypePtr>{std::make_shared<TensorType>(kInt64), values_type,
174 std::make_shared<TensorType>(kBool),
175 std::make_shared<TensorType>(kInt64)});
176 }
177 } // namespace
178
179 MIND_API_OPERATOR_IMPL(SparseFillEmptyRows, BaseOperator);
SparseFillEmptyRowsInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)180 AbstractBasePtr SparseFillEmptyRowsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
181 const std::vector<AbstractBasePtr> &input_args) {
182 MS_EXCEPTION_IF_NULL(primitive);
183 const int64_t kInputNum = 4;
184 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
185 auto infer_type = SparseFillEmptyRowsInferType(primitive, input_args);
186 auto infer_shape = SparseFillEmptyRowsFrontendInferShape(primitive, input_args);
187 return abstract::MakeAbstract(infer_shape, infer_type);
188 }
189
190 // AG means auto generated
191 class MIND_API AGSparseFillEmptyRowsInfer : public abstract::OpInferBase {
192 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const193 BaseShapePtr InferShape(const PrimitivePtr &primitive,
194 const std::vector<AbstractBasePtr> &input_args) const override {
195 return SparseFillEmptyRowsInferShape(primitive, input_args);
196 }
197
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const198 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
199 return SparseFillEmptyRowsInferType(primitive, input_args);
200 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const201 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
202 const std::vector<AbstractBasePtr> &input_args) const override {
203 return SparseFillEmptyRowsInfer(engine, primitive, input_args);
204 }
205
GetValueDependArgIndices() const206 std::set<int64_t> GetValueDependArgIndices() const override { return {2}; }
207 };
208
209 REGISTER_PRIMITIVE_OP_INFER_IMPL(SparseFillEmptyRows, prim::kPrimSparseFillEmptyRows, AGSparseFillEmptyRowsInfer,
210 false);
211 } // namespace ops
212 } // namespace mindspore
213