• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/sparse_matrix_sparse_mat_mul.h"
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 
24 #include "mindapi/src/helper.h"
25 #include "mindspore/core/ops/math_ops.h"
26 #include "mindspore/core/ops/sparse_ops.h"
27 #include "utils/check_convert_utils.h"
28 
29 namespace mindspore {
30 namespace ops {
31 namespace {
32 const int MAX_LENGTH = 200000000;
33 
SparseMatrixSparseMatMulCheckInteger(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)34 void SparseMatrixSparseMatMulCheckInteger(const PrimitivePtr &primitive,
35                                           const std::vector<AbstractBasePtr> &input_args) {
36   MS_EXCEPTION_IF_NULL(primitive);
37   auto prim_name = primitive->name();
38   const int kOne = 1;
39 
40   std::vector<int64_t> x1_dense_shape =
41     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
42   std::vector<int64_t> x1_batch_pointer =
43     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
44   std::vector<int64_t> x1_row_pointer =
45     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
46   std::vector<int64_t> x1_col_indices =
47     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
48   std::vector<int64_t> x1_values =
49     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->GetShape())[kShape];
50   std::vector<int64_t> x2_dense_shape =
51     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShape())[kShape];
52   std::vector<int64_t> x2_batch_pointer =
53     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->GetShape())[kShape];
54   std::vector<int64_t> x2_row_pointer =
55     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->GetShape())[kShape];
56   std::vector<int64_t> x2_col_indices =
57     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex8]->GetShape())[kShape];
58   std::vector<int64_t> x2_values =
59     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShape())[kShape];
60 
61   (void)CheckAndConvertUtils::CheckInteger("rank of x1_dense_shape", SizeToLong(x1_dense_shape.size()), kEqual, kOne,
62                                            prim_name);
63   (void)CheckAndConvertUtils::CheckInteger("rank of x1_batch_pointer", SizeToLong(x1_batch_pointer.size()), kEqual,
64                                            kOne, prim_name);
65   (void)CheckAndConvertUtils::CheckInteger("rank of x1_row_pointer", SizeToLong(x1_row_pointer.size()), kEqual, kOne,
66                                            prim_name);
67   (void)CheckAndConvertUtils::CheckInteger("rank of x1_col_indices", SizeToLong(x1_col_indices.size()), kEqual, kOne,
68                                            prim_name);
69   (void)CheckAndConvertUtils::CheckInteger("rank of x1_values", SizeToLong(x1_values.size()), kEqual, kOne, prim_name);
70   (void)CheckAndConvertUtils::CheckInteger("rank of x2_dense_shape", SizeToLong(x2_dense_shape.size()), kEqual, kOne,
71                                            prim_name);
72   (void)CheckAndConvertUtils::CheckInteger("rank of x2_batch_pointer", SizeToLong(x2_batch_pointer.size()), kEqual,
73                                            kOne, prim_name);
74   (void)CheckAndConvertUtils::CheckInteger("rank of x2_row_pointer", SizeToLong(x2_row_pointer.size()), kEqual, kOne,
75                                            prim_name);
76   (void)CheckAndConvertUtils::CheckInteger("rank of x2_col_indices", SizeToLong(x2_col_indices.size()), kEqual, kOne,
77                                            prim_name);
78   (void)CheckAndConvertUtils::CheckInteger("rank of x2_values", SizeToLong(x2_values.size()), kEqual, kOne, prim_name);
79 }
80 
FrontendSparseMatrixSparseMatMulInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)81 abstract::TupleShapePtr FrontendSparseMatrixSparseMatMulInferShape(const PrimitivePtr &primitive,
82                                                                    const std::vector<AbstractBasePtr> &input_args) {
83   std::vector<int64_t> x1_dense_shape =
84     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
85   const int64_t rank_x1 = x1_dense_shape[0];
86   std::vector<int64_t> x2_dense_shape =
87     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShape())[kShape];
88   const int64_t rank_x2 = x2_dense_shape[0];
89   if (rank_x1 != rank_x2) {
90     MS_EXCEPTION(ValueError)
91       << "For SparseMatrixSparseMatMul, x1_dense_shape.shape[0] and rank of x2_dense must be the "
92          "same, but got x1_dense_shape.shape[0] = "
93       << rank_x1 << ", and rank of x2_dense = " << rank_x2 << ".";
94   }
95 
96   SparseMatrixSparseMatMulCheckInteger(primitive, input_args);
97 
98   const int kInputNoBatch = 2;
99   const int kInputWithBatch = 3;
100   if (rank_x1 != kInputNoBatch && rank_x1 != kInputWithBatch) {
101     MS_EXCEPTION(ValueError) << "For SparseMatrixSparseMatMul, rank of x1_dense_shape must be (2,) or (3,), but got "
102                              << rank_x1 << ".";
103   }
104 
105   std::vector<int64_t> x1_batch_shape =
106     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
107   std::vector<int64_t> x2_batch_shape =
108     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->GetShape())[kShape];
109 
110   if (x1_batch_shape[0] != x2_batch_shape[0]) {
111     MS_EXCEPTION(ValueError) << "For SparseMatrixSparseMatMul, x1_batch_shape[0] and x2_batch_shape[0] must be the "
112                                 "same, but got x1_batch_shape[0] = "
113                              << x1_batch_shape[0] << ", and x2_batch_shape[0] = " << x2_batch_shape[0] << ".";
114   }
115 
116   ShapeVector dense_shape = {x1_dense_shape[0]};
117   ShapeVector batch_shape = {x1_batch_shape[0]};
118   abstract::ShapePtr y_dense_shape = std::make_shared<abstract::Shape>(dense_shape);
119   abstract::ShapePtr y_batch_shape = std::make_shared<abstract::Shape>(batch_shape);
120   abstract::ShapePtr y_row_shape = nullptr;
121   abstract::ShapePtr y_col_shape = nullptr;
122   abstract::ShapePtr y_values_shape = nullptr;
123 
124   ShapeVector col_shape = {abstract::Shape::kShapeDimAny};
125   ShapeVector values_shape = {abstract::Shape::kShapeDimAny};
126   y_col_shape = std::make_shared<abstract::Shape>(col_shape);
127   y_values_shape = std::make_shared<abstract::Shape>(values_shape);
128 
129   if (CheckAndConvertUtils::IsTensor(input_args[0]) && !input_args[0]->GetValue()->isa<ValueAny>() &&
130       !input_args[0]->GetValue()->isa<None>()) {
131     auto dense_shape_type_ptr = input_args[0]->GetType();
132     MS_EXCEPTION_IF_NULL(dense_shape_type_ptr);
133     auto dense_shape_value_ptr = input_args[0]->GetValue();
134     MS_EXCEPTION_IF_NULL(dense_shape_value_ptr);
135     auto dense_shape_value_ptr_tensor = CheckAndConvertUtils::CheckTensorIntValue(
136       "dense_shape", dense_shape_value_ptr, primitive->name(), dense_shape_type_ptr);
137     auto row_value = static_cast<int64_t>(*(dense_shape_value_ptr_tensor.end() - 2));
138     auto col_value = static_cast<int64_t>(*(dense_shape_value_ptr_tensor.end() - 1));
139 
140     auto transpose_a = GetValue<bool>(primitive->GetAttr(kTransposeA));
141     auto transpose_b = GetValue<bool>(primitive->GetAttr(kTransposeB));
142     auto adjoint_a = GetValue<bool>(primitive->GetAttr("adjoint_a"));
143     auto adjoint_b = GetValue<bool>(primitive->GetAttr("adjoint_b"));
144 
145     if (adjoint_a && transpose_a) {
146       MS_EXCEPTION(ValueError)
147         << "For SparseMatrixSparseMatMul, only one of adjoint_a and transpose_a may be true, but got adjoint_a="
148         << adjoint_a << " and transpose_a=" << transpose_a << ".";
149     }
150     if (adjoint_b && transpose_b) {
151       MS_EXCEPTION(ValueError)
152         << "For SparseMatrixSparseMatMul, only one of adjoint_b and transpose_b  may be true, but got adjoint_b="
153         << adjoint_b << " and transpose_b=" << transpose_b << ".";
154     }
155     if (adjoint_a || transpose_a) {
156       row_value = col_value;
157     }
158 
159     ShapeVector row_shape = {(x1_batch_shape[0] - 1) * (row_value + 1)};
160     y_row_shape = std::make_shared<abstract::Shape>(row_shape);
161     return std::make_shared<abstract::TupleShape>(
162       std::vector<abstract::BaseShapePtr>{y_dense_shape, y_batch_shape, y_row_shape, y_col_shape, y_values_shape});
163   } else {
164     ShapeVector row_shape = {-1};
165     y_row_shape = std::make_shared<abstract::Shape>(row_shape);
166     return std::make_shared<abstract::TupleShape>(
167       std::vector<abstract::BaseShapePtr>{y_dense_shape, y_batch_shape, y_row_shape, y_col_shape, y_values_shape});
168   }
169 }
170 
SparseMatrixSparseMatMulInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)171 abstract::TupleShapePtr SparseMatrixSparseMatMulInferShape(const PrimitivePtr &primitive,
172                                                            const std::vector<AbstractBasePtr> &input_args) {
173   std::vector<int64_t> x1_dense_shape =
174     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
175   const int64_t rank_x1 = x1_dense_shape[0];
176   std::vector<int64_t> x2_dense_shape =
177     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShape())[kShape];
178   const int64_t rank_x2 = x2_dense_shape[0];
179   if (rank_x1 != rank_x2) {
180     MS_EXCEPTION(ValueError)
181       << "For SparseMatrixSparseMatMul, x1_dense_shape.shape[0] and rank of x2_dense must be the "
182          "same, but got x1_dense_shape.shape[0] = "
183       << rank_x1 << ", and rank of x2_dense = " << rank_x2 << ".";
184   }
185 
186   SparseMatrixSparseMatMulCheckInteger(primitive, input_args);
187 
188   const int kInputNoBatch = 2;
189   const int kInputWithBatch = 3;
190   if (rank_x1 != kInputNoBatch && rank_x1 != kInputWithBatch) {
191     MS_EXCEPTION(ValueError) << "For SparseMatrixSparseMatMul, rank of x1_dense_shape must be (2,) or (3,), but got "
192                              << rank_x1 << ".";
193   }
194 
195   std::vector<int64_t> x1_batch_shape =
196     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
197   std::vector<int64_t> x2_batch_shape =
198     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->GetShape())[kShape];
199 
200   if (x1_batch_shape[0] != x2_batch_shape[0]) {
201     MS_EXCEPTION(ValueError) << "For SparseMatrixSparseMatMul, x1_batch_shape[0] and x2_batch_shape[0] must be the "
202                                 "same, but got x1_batch_shape[0] = "
203                              << x1_batch_shape[0] << ", and x2_batch_shape[0] = " << x2_batch_shape[0] << ".";
204   }
205 
206   ShapeVector dense_shape = {x1_dense_shape[0]};
207   ShapeVector batch_shape = {x1_batch_shape[0]};
208   abstract::ShapePtr y_dense_shape = std::make_shared<abstract::Shape>(dense_shape);
209   abstract::ShapePtr y_batch_shape = std::make_shared<abstract::Shape>(batch_shape);
210 
211   int64_t max_length = MAX_LENGTH;
212   ShapeVector infer_shape_max = {max_length};
213   abstract::ShapePtr y_col_shape = std::make_shared<abstract::Shape>(infer_shape_max);
214   abstract::ShapePtr y_values_shape = std::make_shared<abstract::Shape>(infer_shape_max);
215 
216   auto dense_shape_type_ptr = input_args[0]->GetType();
217   MS_EXCEPTION_IF_NULL(dense_shape_type_ptr);
218   auto dense_shape_value_ptr = input_args[0]->GetValue();
219   MS_EXCEPTION_IF_NULL(dense_shape_value_ptr);
220   auto dense_shape_value_ptr_tensor = CheckAndConvertUtils::CheckTensorIntValue(
221     "dense_shape", dense_shape_value_ptr, primitive->name(), dense_shape_type_ptr);
222   auto row_value = static_cast<int64_t>(*(dense_shape_value_ptr_tensor.end() - 2));
223   auto col_value = static_cast<int64_t>(*(dense_shape_value_ptr_tensor.end() - 1));
224 
225   auto transpose_a = GetValue<bool>(primitive->GetAttr(kTransposeA));
226   auto transpose_b = GetValue<bool>(primitive->GetAttr(kTransposeB));
227   auto adjoint_a = GetValue<bool>(primitive->GetAttr("adjoint_a"));
228   auto adjoint_b = GetValue<bool>(primitive->GetAttr("adjoint_b"));
229 
230   if (adjoint_a && transpose_a) {
231     MS_EXCEPTION(ValueError)
232       << "For SparseMatrixSparseMatMul, only one of adjoint_a and transpose_a may be true, but got adjoint_a="
233       << adjoint_a << " and transpose_a=" << transpose_a << ".";
234   }
235   if (adjoint_b && transpose_b) {
236     MS_EXCEPTION(ValueError)
237       << "For SparseMatrixSparseMatMul, only one of adjoint_b and transpose_b  may be true, but got adjoint_b="
238       << adjoint_b << " and transpose_b=" << transpose_b << ".";
239   }
240   if (adjoint_a || transpose_a) {
241     row_value = col_value;
242   }
243 
244   ShapeVector row_shape = {(x1_batch_shape[0] - 1) * (row_value + 1)};
245   abstract::ShapePtr y_row_shape = std::make_shared<abstract::Shape>(row_shape);
246   return std::make_shared<abstract::TupleShape>(
247     std::vector<abstract::BaseShapePtr>{y_dense_shape, y_batch_shape, y_row_shape, y_col_shape, y_values_shape});
248 }
249 
SparseMatrixSparseMatMulInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)250 TuplePtr SparseMatrixSparseMatMulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
251   const std::set<TypePtr> index_valid_types = {kInt32, kInt64};
252   const std::set<TypePtr> values_valid_types = {kFloat32, kFloat64, kComplex64, kComplex128};
253   auto x1_dense_type = input_args[kInputIndex0]->GetType();
254   auto x1_batch_type = input_args[kInputIndex1]->GetType();
255   auto x1_row_type = input_args[kInputIndex2]->GetType();
256   auto x1_col_type = input_args[kInputIndex3]->GetType();
257   auto x1_values_type = input_args[kInputIndex4]->GetType();
258 
259   auto x2_dense_type = input_args[kInputIndex5]->GetType();
260   auto x2_batch_type = input_args[kInputIndex6]->GetType();
261   auto x2_row_type = input_args[kInputIndex7]->GetType();
262   auto x2_col_type = input_args[kInputIndex8]->GetType();
263 
264   std::map<std::string, TypePtr> types;
265   (void)types.emplace("x1_dense_shape", x1_dense_type);
266   (void)types.emplace("x1_batch_pointers", x1_batch_type);
267   (void)types.emplace("x1_row_pointers", x1_row_type);
268   (void)types.emplace("x1_col_indices", x1_col_type);
269   (void)types.emplace("x2_dense_shape", x2_dense_type);
270   (void)types.emplace("x2_batch_pointers", x2_batch_type);
271   (void)types.emplace("x2_row_pointers", x2_row_type);
272   (void)types.emplace("x2_col_indices", x2_col_type);
273   (void)CheckAndConvertUtils::CheckTensorTypeSame(types, index_valid_types, prim->name());
274   (void)CheckAndConvertUtils::CheckTensorTypeValid("x1_values", x1_values_type, values_valid_types, prim->name());
275 
276   return std::make_shared<Tuple>(
277     std::vector<TypePtr>{x1_dense_type, x1_batch_type, x1_row_type, x1_col_type, x1_values_type});
278 }
279 }  // namespace
280 
281 MIND_API_OPERATOR_IMPL(SparseMatrixSparseMatMul, BaseOperator);
SparseMatrixSparseMatMulInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)282 AbstractBasePtr SparseMatrixSparseMatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
283                                               const std::vector<AbstractBasePtr> &input_args) {
284   MS_EXCEPTION_IF_NULL(primitive);
285   for (const auto &item : input_args) {
286     MS_EXCEPTION_IF_NULL(item);
287   }
288   auto prim_name = primitive->name();
289   const int64_t input_num = 10;
290   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
291   auto infer_type = SparseMatrixSparseMatMulInferType(primitive, input_args);
292   auto infer_shape = FrontendSparseMatrixSparseMatMulInferShape(primitive, input_args);
293   return abstract::MakeAbstract(infer_shape, infer_type);
294 }
295 // AG means auto generated
296 class MIND_API AGSparseMatrixSparseMatMulInfer : public abstract::OpInferBase {
297  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const298   BaseShapePtr InferShape(const PrimitivePtr &primitive,
299                           const std::vector<AbstractBasePtr> &input_args) const override {
300     return SparseMatrixSparseMatMulInferShape(primitive, input_args);
301   }
302 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const303   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
304     return SparseMatrixSparseMatMulInferType(primitive, input_args);
305   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const306   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
307                                     const std::vector<AbstractBasePtr> &input_args) const override {
308     return SparseMatrixSparseMatMulInfer(engine, primitive, input_args);
309   }
310 
GetValueDependArgIndices() const311   std::set<int64_t> GetValueDependArgIndices() const override { return {0}; }
312 };
313 
314 REGISTER_PRIMITIVE_OP_INFER_IMPL(SparseMatrixSparseMatMul, prim::kPrimSparseMatrixSparseMatMul,
315                                  AGSparseMatrixSparseMatMulInfer, false);
316 }  // namespace ops
317 }  // namespace mindspore
318