1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/sparse_matrix_sparse_mat_mul.h"
18
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23
24 #include "mindapi/src/helper.h"
25 #include "mindspore/core/ops/math_ops.h"
26 #include "mindspore/core/ops/sparse_ops.h"
27 #include "utils/check_convert_utils.h"
28
29 namespace mindspore {
30 namespace ops {
31 namespace {
32 const int MAX_LENGTH = 200000000;
33
SparseMatrixSparseMatMulCheckInteger(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)34 void SparseMatrixSparseMatMulCheckInteger(const PrimitivePtr &primitive,
35 const std::vector<AbstractBasePtr> &input_args) {
36 MS_EXCEPTION_IF_NULL(primitive);
37 auto prim_name = primitive->name();
38 const int kOne = 1;
39
40 std::vector<int64_t> x1_dense_shape =
41 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
42 std::vector<int64_t> x1_batch_pointer =
43 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
44 std::vector<int64_t> x1_row_pointer =
45 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
46 std::vector<int64_t> x1_col_indices =
47 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
48 std::vector<int64_t> x1_values =
49 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->GetShape())[kShape];
50 std::vector<int64_t> x2_dense_shape =
51 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShape())[kShape];
52 std::vector<int64_t> x2_batch_pointer =
53 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->GetShape())[kShape];
54 std::vector<int64_t> x2_row_pointer =
55 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->GetShape())[kShape];
56 std::vector<int64_t> x2_col_indices =
57 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex8]->GetShape())[kShape];
58 std::vector<int64_t> x2_values =
59 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShape())[kShape];
60
61 (void)CheckAndConvertUtils::CheckInteger("rank of x1_dense_shape", SizeToLong(x1_dense_shape.size()), kEqual, kOne,
62 prim_name);
63 (void)CheckAndConvertUtils::CheckInteger("rank of x1_batch_pointer", SizeToLong(x1_batch_pointer.size()), kEqual,
64 kOne, prim_name);
65 (void)CheckAndConvertUtils::CheckInteger("rank of x1_row_pointer", SizeToLong(x1_row_pointer.size()), kEqual, kOne,
66 prim_name);
67 (void)CheckAndConvertUtils::CheckInteger("rank of x1_col_indices", SizeToLong(x1_col_indices.size()), kEqual, kOne,
68 prim_name);
69 (void)CheckAndConvertUtils::CheckInteger("rank of x1_values", SizeToLong(x1_values.size()), kEqual, kOne, prim_name);
70 (void)CheckAndConvertUtils::CheckInteger("rank of x2_dense_shape", SizeToLong(x2_dense_shape.size()), kEqual, kOne,
71 prim_name);
72 (void)CheckAndConvertUtils::CheckInteger("rank of x2_batch_pointer", SizeToLong(x2_batch_pointer.size()), kEqual,
73 kOne, prim_name);
74 (void)CheckAndConvertUtils::CheckInteger("rank of x2_row_pointer", SizeToLong(x2_row_pointer.size()), kEqual, kOne,
75 prim_name);
76 (void)CheckAndConvertUtils::CheckInteger("rank of x2_col_indices", SizeToLong(x2_col_indices.size()), kEqual, kOne,
77 prim_name);
78 (void)CheckAndConvertUtils::CheckInteger("rank of x2_values", SizeToLong(x2_values.size()), kEqual, kOne, prim_name);
79 }
80
FrontendSparseMatrixSparseMatMulInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)81 abstract::TupleShapePtr FrontendSparseMatrixSparseMatMulInferShape(const PrimitivePtr &primitive,
82 const std::vector<AbstractBasePtr> &input_args) {
83 std::vector<int64_t> x1_dense_shape =
84 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
85 const int64_t rank_x1 = x1_dense_shape[0];
86 std::vector<int64_t> x2_dense_shape =
87 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShape())[kShape];
88 const int64_t rank_x2 = x2_dense_shape[0];
89 if (rank_x1 != rank_x2) {
90 MS_EXCEPTION(ValueError)
91 << "For SparseMatrixSparseMatMul, x1_dense_shape.shape[0] and rank of x2_dense must be the "
92 "same, but got x1_dense_shape.shape[0] = "
93 << rank_x1 << ", and rank of x2_dense = " << rank_x2 << ".";
94 }
95
96 SparseMatrixSparseMatMulCheckInteger(primitive, input_args);
97
98 const int kInputNoBatch = 2;
99 const int kInputWithBatch = 3;
100 if (rank_x1 != kInputNoBatch && rank_x1 != kInputWithBatch) {
101 MS_EXCEPTION(ValueError) << "For SparseMatrixSparseMatMul, rank of x1_dense_shape must be (2,) or (3,), but got "
102 << rank_x1 << ".";
103 }
104
105 std::vector<int64_t> x1_batch_shape =
106 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
107 std::vector<int64_t> x2_batch_shape =
108 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->GetShape())[kShape];
109
110 if (x1_batch_shape[0] != x2_batch_shape[0]) {
111 MS_EXCEPTION(ValueError) << "For SparseMatrixSparseMatMul, x1_batch_shape[0] and x2_batch_shape[0] must be the "
112 "same, but got x1_batch_shape[0] = "
113 << x1_batch_shape[0] << ", and x2_batch_shape[0] = " << x2_batch_shape[0] << ".";
114 }
115
116 ShapeVector dense_shape = {x1_dense_shape[0]};
117 ShapeVector batch_shape = {x1_batch_shape[0]};
118 abstract::ShapePtr y_dense_shape = std::make_shared<abstract::Shape>(dense_shape);
119 abstract::ShapePtr y_batch_shape = std::make_shared<abstract::Shape>(batch_shape);
120 abstract::ShapePtr y_row_shape = nullptr;
121 abstract::ShapePtr y_col_shape = nullptr;
122 abstract::ShapePtr y_values_shape = nullptr;
123
124 ShapeVector col_shape = {abstract::Shape::kShapeDimAny};
125 ShapeVector values_shape = {abstract::Shape::kShapeDimAny};
126 y_col_shape = std::make_shared<abstract::Shape>(col_shape);
127 y_values_shape = std::make_shared<abstract::Shape>(values_shape);
128
129 if (CheckAndConvertUtils::IsTensor(input_args[0]) && !input_args[0]->GetValue()->isa<ValueAny>() &&
130 !input_args[0]->GetValue()->isa<None>()) {
131 auto dense_shape_type_ptr = input_args[0]->GetType();
132 MS_EXCEPTION_IF_NULL(dense_shape_type_ptr);
133 auto dense_shape_value_ptr = input_args[0]->GetValue();
134 MS_EXCEPTION_IF_NULL(dense_shape_value_ptr);
135 auto dense_shape_value_ptr_tensor = CheckAndConvertUtils::CheckTensorIntValue(
136 "dense_shape", dense_shape_value_ptr, primitive->name(), dense_shape_type_ptr);
137 auto row_value = static_cast<int64_t>(*(dense_shape_value_ptr_tensor.end() - 2));
138 auto col_value = static_cast<int64_t>(*(dense_shape_value_ptr_tensor.end() - 1));
139
140 auto transpose_a = GetValue<bool>(primitive->GetAttr(kTransposeA));
141 auto transpose_b = GetValue<bool>(primitive->GetAttr(kTransposeB));
142 auto adjoint_a = GetValue<bool>(primitive->GetAttr("adjoint_a"));
143 auto adjoint_b = GetValue<bool>(primitive->GetAttr("adjoint_b"));
144
145 if (adjoint_a && transpose_a) {
146 MS_EXCEPTION(ValueError)
147 << "For SparseMatrixSparseMatMul, only one of adjoint_a and transpose_a may be true, but got adjoint_a="
148 << adjoint_a << " and transpose_a=" << transpose_a << ".";
149 }
150 if (adjoint_b && transpose_b) {
151 MS_EXCEPTION(ValueError)
152 << "For SparseMatrixSparseMatMul, only one of adjoint_b and transpose_b may be true, but got adjoint_b="
153 << adjoint_b << " and transpose_b=" << transpose_b << ".";
154 }
155 if (adjoint_a || transpose_a) {
156 row_value = col_value;
157 }
158
159 ShapeVector row_shape = {(x1_batch_shape[0] - 1) * (row_value + 1)};
160 y_row_shape = std::make_shared<abstract::Shape>(row_shape);
161 return std::make_shared<abstract::TupleShape>(
162 std::vector<abstract::BaseShapePtr>{y_dense_shape, y_batch_shape, y_row_shape, y_col_shape, y_values_shape});
163 } else {
164 ShapeVector row_shape = {-1};
165 y_row_shape = std::make_shared<abstract::Shape>(row_shape);
166 return std::make_shared<abstract::TupleShape>(
167 std::vector<abstract::BaseShapePtr>{y_dense_shape, y_batch_shape, y_row_shape, y_col_shape, y_values_shape});
168 }
169 }
170
SparseMatrixSparseMatMulInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)171 abstract::TupleShapePtr SparseMatrixSparseMatMulInferShape(const PrimitivePtr &primitive,
172 const std::vector<AbstractBasePtr> &input_args) {
173 std::vector<int64_t> x1_dense_shape =
174 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
175 const int64_t rank_x1 = x1_dense_shape[0];
176 std::vector<int64_t> x2_dense_shape =
177 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShape())[kShape];
178 const int64_t rank_x2 = x2_dense_shape[0];
179 if (rank_x1 != rank_x2) {
180 MS_EXCEPTION(ValueError)
181 << "For SparseMatrixSparseMatMul, x1_dense_shape.shape[0] and rank of x2_dense must be the "
182 "same, but got x1_dense_shape.shape[0] = "
183 << rank_x1 << ", and rank of x2_dense = " << rank_x2 << ".";
184 }
185
186 SparseMatrixSparseMatMulCheckInteger(primitive, input_args);
187
188 const int kInputNoBatch = 2;
189 const int kInputWithBatch = 3;
190 if (rank_x1 != kInputNoBatch && rank_x1 != kInputWithBatch) {
191 MS_EXCEPTION(ValueError) << "For SparseMatrixSparseMatMul, rank of x1_dense_shape must be (2,) or (3,), but got "
192 << rank_x1 << ".";
193 }
194
195 std::vector<int64_t> x1_batch_shape =
196 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
197 std::vector<int64_t> x2_batch_shape =
198 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->GetShape())[kShape];
199
200 if (x1_batch_shape[0] != x2_batch_shape[0]) {
201 MS_EXCEPTION(ValueError) << "For SparseMatrixSparseMatMul, x1_batch_shape[0] and x2_batch_shape[0] must be the "
202 "same, but got x1_batch_shape[0] = "
203 << x1_batch_shape[0] << ", and x2_batch_shape[0] = " << x2_batch_shape[0] << ".";
204 }
205
206 ShapeVector dense_shape = {x1_dense_shape[0]};
207 ShapeVector batch_shape = {x1_batch_shape[0]};
208 abstract::ShapePtr y_dense_shape = std::make_shared<abstract::Shape>(dense_shape);
209 abstract::ShapePtr y_batch_shape = std::make_shared<abstract::Shape>(batch_shape);
210
211 int64_t max_length = MAX_LENGTH;
212 ShapeVector infer_shape_max = {max_length};
213 abstract::ShapePtr y_col_shape = std::make_shared<abstract::Shape>(infer_shape_max);
214 abstract::ShapePtr y_values_shape = std::make_shared<abstract::Shape>(infer_shape_max);
215
216 auto dense_shape_type_ptr = input_args[0]->GetType();
217 MS_EXCEPTION_IF_NULL(dense_shape_type_ptr);
218 auto dense_shape_value_ptr = input_args[0]->GetValue();
219 MS_EXCEPTION_IF_NULL(dense_shape_value_ptr);
220 auto dense_shape_value_ptr_tensor = CheckAndConvertUtils::CheckTensorIntValue(
221 "dense_shape", dense_shape_value_ptr, primitive->name(), dense_shape_type_ptr);
222 auto row_value = static_cast<int64_t>(*(dense_shape_value_ptr_tensor.end() - 2));
223 auto col_value = static_cast<int64_t>(*(dense_shape_value_ptr_tensor.end() - 1));
224
225 auto transpose_a = GetValue<bool>(primitive->GetAttr(kTransposeA));
226 auto transpose_b = GetValue<bool>(primitive->GetAttr(kTransposeB));
227 auto adjoint_a = GetValue<bool>(primitive->GetAttr("adjoint_a"));
228 auto adjoint_b = GetValue<bool>(primitive->GetAttr("adjoint_b"));
229
230 if (adjoint_a && transpose_a) {
231 MS_EXCEPTION(ValueError)
232 << "For SparseMatrixSparseMatMul, only one of adjoint_a and transpose_a may be true, but got adjoint_a="
233 << adjoint_a << " and transpose_a=" << transpose_a << ".";
234 }
235 if (adjoint_b && transpose_b) {
236 MS_EXCEPTION(ValueError)
237 << "For SparseMatrixSparseMatMul, only one of adjoint_b and transpose_b may be true, but got adjoint_b="
238 << adjoint_b << " and transpose_b=" << transpose_b << ".";
239 }
240 if (adjoint_a || transpose_a) {
241 row_value = col_value;
242 }
243
244 ShapeVector row_shape = {(x1_batch_shape[0] - 1) * (row_value + 1)};
245 abstract::ShapePtr y_row_shape = std::make_shared<abstract::Shape>(row_shape);
246 return std::make_shared<abstract::TupleShape>(
247 std::vector<abstract::BaseShapePtr>{y_dense_shape, y_batch_shape, y_row_shape, y_col_shape, y_values_shape});
248 }
249
SparseMatrixSparseMatMulInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)250 TuplePtr SparseMatrixSparseMatMulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
251 const std::set<TypePtr> index_valid_types = {kInt32, kInt64};
252 const std::set<TypePtr> values_valid_types = {kFloat32, kFloat64, kComplex64, kComplex128};
253 auto x1_dense_type = input_args[kInputIndex0]->GetType();
254 auto x1_batch_type = input_args[kInputIndex1]->GetType();
255 auto x1_row_type = input_args[kInputIndex2]->GetType();
256 auto x1_col_type = input_args[kInputIndex3]->GetType();
257 auto x1_values_type = input_args[kInputIndex4]->GetType();
258
259 auto x2_dense_type = input_args[kInputIndex5]->GetType();
260 auto x2_batch_type = input_args[kInputIndex6]->GetType();
261 auto x2_row_type = input_args[kInputIndex7]->GetType();
262 auto x2_col_type = input_args[kInputIndex8]->GetType();
263
264 std::map<std::string, TypePtr> types;
265 (void)types.emplace("x1_dense_shape", x1_dense_type);
266 (void)types.emplace("x1_batch_pointers", x1_batch_type);
267 (void)types.emplace("x1_row_pointers", x1_row_type);
268 (void)types.emplace("x1_col_indices", x1_col_type);
269 (void)types.emplace("x2_dense_shape", x2_dense_type);
270 (void)types.emplace("x2_batch_pointers", x2_batch_type);
271 (void)types.emplace("x2_row_pointers", x2_row_type);
272 (void)types.emplace("x2_col_indices", x2_col_type);
273 (void)CheckAndConvertUtils::CheckTensorTypeSame(types, index_valid_types, prim->name());
274 (void)CheckAndConvertUtils::CheckTensorTypeValid("x1_values", x1_values_type, values_valid_types, prim->name());
275
276 return std::make_shared<Tuple>(
277 std::vector<TypePtr>{x1_dense_type, x1_batch_type, x1_row_type, x1_col_type, x1_values_type});
278 }
279 } // namespace
280
281 MIND_API_OPERATOR_IMPL(SparseMatrixSparseMatMul, BaseOperator);
SparseMatrixSparseMatMulInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)282 AbstractBasePtr SparseMatrixSparseMatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
283 const std::vector<AbstractBasePtr> &input_args) {
284 MS_EXCEPTION_IF_NULL(primitive);
285 for (const auto &item : input_args) {
286 MS_EXCEPTION_IF_NULL(item);
287 }
288 auto prim_name = primitive->name();
289 const int64_t input_num = 10;
290 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
291 auto infer_type = SparseMatrixSparseMatMulInferType(primitive, input_args);
292 auto infer_shape = FrontendSparseMatrixSparseMatMulInferShape(primitive, input_args);
293 return abstract::MakeAbstract(infer_shape, infer_type);
294 }
295 // AG means auto generated
296 class MIND_API AGSparseMatrixSparseMatMulInfer : public abstract::OpInferBase {
297 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const298 BaseShapePtr InferShape(const PrimitivePtr &primitive,
299 const std::vector<AbstractBasePtr> &input_args) const override {
300 return SparseMatrixSparseMatMulInferShape(primitive, input_args);
301 }
302
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const303 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
304 return SparseMatrixSparseMatMulInferType(primitive, input_args);
305 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const306 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
307 const std::vector<AbstractBasePtr> &input_args) const override {
308 return SparseMatrixSparseMatMulInfer(engine, primitive, input_args);
309 }
310
GetValueDependArgIndices() const311 std::set<int64_t> GetValueDependArgIndices() const override { return {0}; }
312 };
313
314 REGISTER_PRIMITIVE_OP_INFER_IMPL(SparseMatrixSparseMatMul, prim::kPrimSparseMatrixSparseMatMul,
315 AGSparseMatrixSparseMatMulInfer, false);
316 } // namespace ops
317 } // namespace mindspore
318