1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/set_size.h"
18
19 #include <set>
20
21 #include "abstract/abstract_value.h"
22 #include "abstract/ops/primitive_infer_map.h"
23 #include "mindapi/src/helper.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ops/op_utils.h"
26 #include "utils/check_convert_utils.h"
27
28 namespace mindspore {
29 namespace ops {
30 namespace {
SetSizeInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)31 abstract::ShapePtr SetSizeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
32 MS_EXCEPTION_IF_NULL(primitive);
33 auto op_name = primitive->name();
34 auto set_indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape())[kShape];
35 auto set_values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShape())[kShape];
36 auto set_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShape())[kShape];
37 // support dynamic rank
38 if (IsDynamicRank(set_indices_shape) || IsDynamicRank(set_values_shape) || IsDynamicRank(set_shape_shape)) {
39 return std::make_shared<abstract::TensorShape>(ShapeVector({abstract::TensorShape::kShapeRankAny}));
40 }
41
42 auto set_indices_shape_num = 2;
43 (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_indices",
44 SizeToLong(set_indices_shape.size()), kEqual, set_indices_shape_num,
45 op_name);
46 (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_values", SizeToLong(set_values_shape.size()),
47 kEqual, 1, op_name);
48 (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_shape", SizeToLong(set_shape_shape.size()),
49 kEqual, 1, op_name);
50
51 if (IsDynamic(set_shape_shape)) {
52 return std::make_shared<abstract::TensorShape>(ShapeVector({abstract::TensorShape::kShapeRankAny}));
53 }
54
55 if (!IsDynamic(set_indices_shape) && !IsDynamic(set_values_shape)) {
56 (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_indices or set_shape",
57 set_indices_shape[1], kEqual, set_shape_shape[0], op_name);
58 (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_indices or set_values",
59 set_indices_shape[0], kEqual, set_values_shape[0], op_name);
60 }
61
62 MS_EXCEPTION_IF_NULL(primitive->GetAttr("validate_indices"));
63 auto shape_size_dim = set_shape_shape[0];
64 bool gen_value_succ = false;
65 std::vector<int64_t> set_shape_value_vec(shape_size_dim);
66 auto set_shape_tensor = input_args[2];
67 MS_EXCEPTION_IF_NULL(set_shape_tensor);
68 if (CheckAndConvertUtils::IsTensor(set_shape_tensor)) {
69 const std::set<TypePtr> output_size_valid_types = {kInt64};
70 (void)CheckAndConvertUtils::CheckTensorTypeValid("set_shape", set_shape_tensor->GetType(), output_size_valid_types,
71 op_name);
72 if (IsValueKnown(set_shape_tensor)) {
73 auto value = GetArrayValue<int64_t>(set_shape_tensor).value().ToVector();
74 for (size_t i = 0; i < LongToSize(shape_size_dim); ++i) {
75 set_shape_value_vec[i] = value[i];
76 }
77 gen_value_succ = true;
78 }
79 }
80 if (!gen_value_succ) {
81 auto dense_size = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShape())[kShape];
82 ShapeVector max_shape(dense_size[0] - 1);
83 auto max_length_ptr = primitive->GetAttr("max_length");
84 MS_EXCEPTION_IF_NULL(max_length_ptr);
85 int64_t max_length = GetValue<int64_t>(max_length_ptr);
86 for (int64_t i = 1; i <= dense_size[0] - 1; ++i) {
87 max_shape.end()[-i] = max_length;
88 }
89 return std::make_shared<abstract::TensorShape>(max_shape);
90 } else {
91 ShapeVector output_shape;
92 auto set_values_index = 2;
93 auto dense_size = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShape())[kShape];
94 if (dense_size.size() == 1 && dense_size[0] < set_values_index) {
95 output_shape.push_back(1);
96 } else {
97 for (unsigned int i = 0; i < dense_size[0] - 1; ++i) {
98 output_shape.push_back(set_shape_value_vec[i]);
99 }
100 }
101 return std::make_shared<abstract::TensorShape>(output_shape);
102 }
103 }
104
SetSizeInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)105 TypePtr SetSizeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
106 auto prim_name = prim->name();
107 const std::set<TypePtr> valid_types = {kInt64};
108 const std::set<TypePtr> set_values_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16};
109 (void)CheckAndConvertUtils::CheckTensorTypeValid("set_indices", input_args[kInputIndex0]->GetType(), valid_types,
110 prim_name);
111 (void)CheckAndConvertUtils::CheckTensorTypeValid("set_values", input_args[kInputIndex1]->GetType(),
112 set_values_valid_types, prim_name);
113 (void)CheckAndConvertUtils::CheckTensorTypeValid("set_shape", input_args[kInputIndex2]->GetType(), valid_types,
114 prim_name);
115 return std::make_shared<TensorType>(kInt32);
116 }
117 } // namespace
118
119 MIND_API_OPERATOR_IMPL(SetSize, BaseOperator);
120
Init(const bool validate_indices)121 void SetSize::Init(const bool validate_indices) { set_validate_indices(validate_indices); }
122
set_validate_indices(const bool & validate_indices)123 void SetSize::set_validate_indices(const bool &validate_indices) {
124 (void)AddAttr(kValidateIndices, api::MakeValue(validate_indices));
125 }
126
get_validate_indices() const127 bool SetSize::get_validate_indices() const {
128 auto value_ptr = GetAttr(kValidateIndices);
129 return GetValue<bool>(value_ptr);
130 }
131
SetSizeInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)132 AbstractBasePtr SetSizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
133 const std::vector<AbstractBasePtr> &input_args) {
134 MS_EXCEPTION_IF_NULL(primitive);
135 auto prim_name = primitive->name();
136 const int64_t input_num = 3;
137 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
138 for (const auto &item : input_args) {
139 MS_EXCEPTION_IF_NULL(item);
140 }
141 auto infer_type = SetSizeInferType(primitive, input_args);
142 auto infer_shape = SetSizeInferShape(primitive, input_args);
143 return abstract::MakeAbstract(infer_shape, infer_type);
144 }
145
146 // AG means auto generated
147 class MIND_API AGSetSizeInfer : public abstract::OpInferBase {
148 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const149 BaseShapePtr InferShape(const PrimitivePtr &primitive,
150 const std::vector<AbstractBasePtr> &input_args) const override {
151 return SetSizeInferShape(primitive, input_args);
152 }
153
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const154 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
155 return SetSizeInferType(primitive, input_args);
156 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const157 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
158 const std::vector<AbstractBasePtr> &input_args) const override {
159 return SetSizeInfer(engine, primitive, input_args);
160 }
161
GetValueDependArgIndices() const162 std::set<int64_t> GetValueDependArgIndices() const override { return {2}; }
163 };
164
165 REGISTER_PRIMITIVE_OP_INFER_IMPL(SetSize, prim::kPrimSetSize, AGSetSizeInfer, false);
166 } // namespace ops
167 } // namespace mindspore
168