• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ops/sparse_fill_empty_rows.h"
17 
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <vector>
23 
24 #include "abstract/abstract_value.h"
25 #include "abstract/dshape.h"
26 #include "abstract/ops/op_infer.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "abstract/utils.h"
29 #include "base/base.h"
30 #include "ir/anf.h"
31 #include "ir/dtype/container.h"
32 #include "ir/dtype/number.h"
33 #include "ir/dtype/tensor_type.h"
34 #include "ir/named.h"
35 #include "ir/primitive.h"
36 #include "ir/value.h"
37 #include "mindapi/base/shape_vector.h"
38 #include "mindapi/src/helper.h"
39 #include "mindspore/core/ops/math_ops.h"
40 #include "mindspore/core/ops/sparse_ops.h"
41 #include "ops/op_name.h"
42 #include "ops/primitive_c.h"
43 #include "utils/check_convert_utils.h"
44 #include "utils/convert_utils_base.h"
45 #include "utils/log_adapter.h"
46 #include "utils/shape_utils.h"
47 
48 namespace mindspore {
49 namespace ops {
50 namespace {
CheckSparseFillEmptyRowsInputs(const std::vector<AbstractBasePtr> & input_args,const std::string & op_name)51 bool CheckSparseFillEmptyRowsInputs(const std::vector<AbstractBasePtr> &input_args, const std::string &op_name) {
52   auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
53   auto values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
54   auto dense_shape_shape =
55     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
56   auto default_value_shape =
57     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
58   if (IsDynamic(indices_shape) || IsDynamic(values_shape) || IsDynamic(dense_shape_shape) ||
59       IsDynamic(default_value_shape)) {
60     return false;
61   }
62 
63   const int64_t indice_rank = 2;
64   const int64_t values_rank = 1;
65   const int64_t dense_shape_rank = 1;
66   const int64_t default_value_rank = 0;
67   const int64_t dense_rank = 2;
68 
69   (void)CheckAndConvertUtils::CheckInteger("indices rank", SizeToLong(indices_shape.size()), kEqual, indice_rank,
70                                            op_name);
71   if (indices_shape[1] != dense_rank) {
72     MS_EXCEPTION(ValueError) << "For SparseFillEmptyRows, "
73                              << "the last dim of the indices must be 2, but got " << indices_shape[1] << ".";
74   }
75   (void)CheckAndConvertUtils::CheckInteger("values rank", SizeToLong(values_shape.size()), kEqual, values_rank,
76                                            op_name);
77   (void)CheckAndConvertUtils::CheckInteger("dense_shape rank", SizeToLong(dense_shape_shape.size()), kEqual,
78                                            dense_shape_rank, op_name);
79   (void)CheckAndConvertUtils::CheckInteger("dense_shape size", dense_shape_shape[0], kEqual, dense_rank, op_name);
80   (void)CheckAndConvertUtils::CheckInteger("default_value rank", SizeToLong(default_value_shape.size()), kEqual,
81                                            default_value_rank, op_name);
82   if (indices_shape[0] != values_shape[0]) {
83     MS_EXCEPTION(ValueError) << "For SparseFillEmptyRows, "
84                              << "the size of indices must be equal to values, but got " << indices_shape[0] << " and "
85                              << values_shape[0] << ".";
86   }
87   return true;
88 }
89 
SparseFillEmptyRowsInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)90 abstract::TupleShapePtr SparseFillEmptyRowsInferShape(const PrimitivePtr &primitive,
91                                                       const std::vector<AbstractBasePtr> &input_args) {
92   auto op_name = primitive->name();
93   abstract::ShapePtr output_indices_shape;
94   abstract::ShapePtr output_values_shape;
95   abstract::ShapePtr output_empty_row_indicator_shape;
96   abstract::ShapePtr output_reverse_index_map_shape;
97 
98   const int64_t rank = 2;
99   auto input_shape_value = input_args[kInputIndex2]->GetValue();
100   MS_EXCEPTION_IF_NULL(input_shape_value);
101 
102   if (CheckSparseFillEmptyRowsInputs(input_args, op_name) && !input_shape_value->isa<ValueAny>() &&
103       !input_shape_value->isa<None>()) {
104     auto indice_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape())[kShape];
105     const int64_t input_nnz = indice_shape[0];
106 
107     auto dense_row = CheckAndConvertUtils::CheckTensorIntValue("x_dense_shape", input_shape_value, op_name,
108                                                                input_args[kInputIndex2]->GetType())[0];
109     output_indices_shape = std::make_shared<abstract::Shape>(ShapeVector({input_nnz + dense_row, rank}));
110     output_values_shape = std::make_shared<abstract::Shape>(ShapeVector({input_nnz + dense_row}));
111     output_empty_row_indicator_shape = std::make_shared<abstract::Shape>(ShapeVector({dense_row}));
112     output_reverse_index_map_shape = std::make_shared<abstract::Shape>(ShapeVector({input_nnz}));
113   } else {
114     output_indices_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny, rank}));
115     output_values_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
116     output_empty_row_indicator_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
117     output_reverse_index_map_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
118   }
119 
120   return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
121     output_indices_shape, output_values_shape, output_empty_row_indicator_shape, output_reverse_index_map_shape});
122 }
123 
SparseFillEmptyRowsFrontendInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)124 abstract::TupleShapePtr SparseFillEmptyRowsFrontendInferShape(const PrimitivePtr &primitive,
125                                                               const std::vector<AbstractBasePtr> &input_args) {
126   auto op_name = primitive->name();
127   abstract::ShapePtr output_indices_shape;
128   abstract::ShapePtr output_values_shape;
129   abstract::ShapePtr output_empty_row_indicator_shape;
130   abstract::ShapePtr output_reverse_index_map_shape;
131 
132   const int64_t rank = 2;
133   auto input_shape_value = input_args[kInputIndex2]->GetValue();
134   MS_EXCEPTION_IF_NULL(input_shape_value);
135 
136   if (CheckSparseFillEmptyRowsInputs(input_args, op_name) && !input_shape_value->isa<ValueAny>() &&
137       !input_shape_value->isa<None>()) {
138     auto indice_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape())[kShape];
139     const int64_t input_nnz = indice_shape[0];
140 
141     auto dense_row = CheckAndConvertUtils::CheckTensorIntValue("x_dense_shape", input_shape_value, op_name,
142                                                                input_args[kInputIndex2]->GetType())[0];
143     output_indices_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny, rank}));
144     output_values_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
145     output_empty_row_indicator_shape = std::make_shared<abstract::Shape>(ShapeVector({dense_row}));
146     output_reverse_index_map_shape = std::make_shared<abstract::Shape>(ShapeVector({input_nnz}));
147   } else {
148     output_indices_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny, rank}));
149     output_values_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
150     output_empty_row_indicator_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
151     output_reverse_index_map_shape = std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeDimAny}));
152   }
153 
154   return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
155     output_indices_shape, output_values_shape, output_empty_row_indicator_shape, output_reverse_index_map_shape});
156 }
157 
SparseFillEmptyRowsInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)158 TypePtr SparseFillEmptyRowsInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
159   auto op_name = primitive->name();
160   const std::set<TypePtr> common_valid_types_with_bool_and_complex = {
161     kInt8,   kInt16,   kInt32,   kInt64,   kUInt8, kUInt16,    kUInt32,
162     kUInt64, kFloat16, kFloat32, kFloat64, kBool,  kComplex64, kComplex128};
163   auto indices_type = input_args[kInputIndex0]->GetType();
164   auto values_type = input_args[kInputIndex1]->GetType();
165   auto dense_shape_type = input_args[kInputIndex2]->GetType();
166   auto default_value_type = input_args[kInputIndex3]->GetType();
167   std::map<std::string, TypePtr> types;
168   (void)types.emplace("values", values_type);
169   (void)types.emplace("default_value", default_value_type);
170   (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_bool_and_complex, op_name);
171   (void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, {kInt64}, op_name);
172   (void)CheckAndConvertUtils::CheckTensorTypeValid("dense_shape", dense_shape_type, {kInt64}, op_name);
173   return std::make_shared<Tuple>(std::vector<TypePtr>{std::make_shared<TensorType>(kInt64), values_type,
174                                                       std::make_shared<TensorType>(kBool),
175                                                       std::make_shared<TensorType>(kInt64)});
176 }
177 }  // namespace
178 
179 MIND_API_OPERATOR_IMPL(SparseFillEmptyRows, BaseOperator);
SparseFillEmptyRowsInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)180 AbstractBasePtr SparseFillEmptyRowsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
181                                          const std::vector<AbstractBasePtr> &input_args) {
182   MS_EXCEPTION_IF_NULL(primitive);
183   const int64_t kInputNum = 4;
184   CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
185   auto infer_type = SparseFillEmptyRowsInferType(primitive, input_args);
186   auto infer_shape = SparseFillEmptyRowsFrontendInferShape(primitive, input_args);
187   return abstract::MakeAbstract(infer_shape, infer_type);
188 }
189 
190 // AG means auto generated
191 class MIND_API AGSparseFillEmptyRowsInfer : public abstract::OpInferBase {
192  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const193   BaseShapePtr InferShape(const PrimitivePtr &primitive,
194                           const std::vector<AbstractBasePtr> &input_args) const override {
195     return SparseFillEmptyRowsInferShape(primitive, input_args);
196   }
197 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const198   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
199     return SparseFillEmptyRowsInferType(primitive, input_args);
200   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const201   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
202                                     const std::vector<AbstractBasePtr> &input_args) const override {
203     return SparseFillEmptyRowsInfer(engine, primitive, input_args);
204   }
205 
GetValueDependArgIndices() const206   std::set<int64_t> GetValueDependArgIndices() const override { return {2}; }
207 };
208 
209 REGISTER_PRIMITIVE_OP_INFER_IMPL(SparseFillEmptyRows, prim::kPrimSparseFillEmptyRows, AGSparseFillEmptyRowsInfer,
210                                  false);
211 }  // namespace ops
212 }  // namespace mindspore
213