1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/sparse_segment_sum.h"
18
19 #include <map>
20 #include <memory>
21 #include <set>
22
23 #include "abstract/ops/primitive_infer_map.h"
24 #include "mindapi/src/helper.h"
25 #include "mindspore/core/ops/sparse_ops.h"
26 #include "ops/op_name.h"
27
28 namespace mindspore {
29 namespace ops {
30 namespace {
SparseSegmentSumInferShape(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)31 abstract::ShapePtr SparseSegmentSumInferShape(const PrimitivePtr &prim,
32 const std::vector<AbstractBasePtr> &input_args) {
33 MS_EXCEPTION_IF_NULL(prim);
34 auto prim_name = prim->name();
35 auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
36 auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
37 auto segment_ids_shape =
38 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
39 // support dynamic rank
40 if (IsDynamicRank(x_shape) || IsDynamicRank(indices_shape) || IsDynamicRank(segment_ids_shape)) {
41 return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
42 }
43 (void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual,
44 SizeToLong(kInputIndex1), prim->name());
45 (void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
46 SizeToLong(kInputIndex1), prim->name());
47 if (x_shape.size() < kInputIndex1) {
48 MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
49 << "x's rank must be greater than 1, but got [" << x_shape.size() << "].";
50 }
51 if (!(IsDynamic(indices_shape) || IsDynamic(segment_ids_shape)) &&
52 indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
53 MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
54 << "but got indices [" << indices_shape[kInputIndex0] << "] "
55 << "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
56 }
57 if ((indices_shape[kInputIndex0] == kInputIndex0) || (segment_ids_shape[kInputIndex0] == kInputIndex0)) {
58 MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must greater than 0, "
59 << "but got indices [" << indices_shape[kInputIndex0] << "] "
60 << "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
61 }
62 if (!input_args[kInputIndex2]->GetValue()->isa<ValueAny>() && !input_args[kInputIndex2]->GetValue()->isa<None>()) {
63 auto segment_ids_value_ptr = input_args[kInputIndex2]->GetValue();
64 MS_EXCEPTION_IF_NULL(segment_ids_value_ptr);
65 auto segment_ids_type_ptr = input_args[kInputIndex2]->GetType();
66 MS_EXCEPTION_IF_NULL(segment_ids_type_ptr);
67 auto segment_ids_value_ptr_tensor = CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr,
68 prim->name(), segment_ids_type_ptr);
69 size_t dim_zero = static_cast<size_t>(segment_ids_value_ptr_tensor.back()) + kInputIndex1;
70 if (dim_zero < kInputIndex1) {
71 MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be greater or equal to 0, "
72 << "but got [" << dim_zero << "].";
73 } else {
74 ShapeVector y_shape = x_shape;
75 y_shape[kInputIndex0] = static_cast<int64_t>(dim_zero);
76 return std::make_shared<abstract::Shape>(y_shape);
77 }
78 } else {
79 ShapeVector output_shape = x_shape;
80 output_shape[kInputIndex0] = -1;
81 return std::make_shared<abstract::Shape>(output_shape);
82 }
83 }
84
SparseSegmentSumInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)85 TypePtr SparseSegmentSumInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
86 MS_EXCEPTION_IF_NULL(prim);
87 auto prim_name = prim->name();
88 auto x_type = input_args[kInputIndex0]->GetType();
89 auto indices_type = input_args[kInputIndex1]->GetType();
90 auto segment_ids_type = input_args[kInputIndex2]->GetType();
91 const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kFloat16, kFloat32, kFloat64};
92 const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
93 (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
94 std::map<std::string, TypePtr> types;
95 (void)types.emplace("indices", indices_type);
96 (void)types.emplace("segment_ids", segment_ids_type);
97 (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
98 return input_args[kInputIndex0]->GetType();
99 }
100 } // namespace
101
102 MIND_API_OPERATOR_IMPL(SparseSegmentSum, BaseOperator);
SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)103 AbstractBasePtr SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
104 const std::vector<AbstractBasePtr> &input_args) {
105 MS_EXCEPTION_IF_NULL(prim);
106 auto prim_name = prim->name();
107 const int64_t input_num = static_cast<int64_t>(kInputIndex3);
108 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
109 auto types = SparseSegmentSumInferType(prim, input_args);
110 auto shapes = SparseSegmentSumInferShape(prim, input_args);
111 return abstract::MakeAbstract(shapes, types);
112 }
113
114 // AG means auto generated
115 class MIND_API AGSparseSegmentSumInfer : public abstract::OpInferBase {
116 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const117 BaseShapePtr InferShape(const PrimitivePtr &primitive,
118 const std::vector<AbstractBasePtr> &input_args) const override {
119 return SparseSegmentSumInferShape(primitive, input_args);
120 }
121
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const122 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
123 return SparseSegmentSumInferType(primitive, input_args);
124 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const125 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
126 const std::vector<AbstractBasePtr> &input_args) const override {
127 return SparseSegmentSumInfer(engine, primitive, input_args);
128 }
129
GetValueDependArgIndices() const130 std::set<int64_t> GetValueDependArgIndices() const override { return {2}; }
131 };
132
133 REGISTER_PRIMITIVE_OP_INFER_IMPL(SparseSegmentSum, prim::kPrimSparseSegmentSum, AGSparseSegmentSumInfer, false);
134 } // namespace ops
135 } // namespace mindspore
136