• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/set_size.h"
18 
19 #include <set>
20 
21 #include "abstract/abstract_value.h"
22 #include "abstract/ops/primitive_infer_map.h"
23 #include "mindapi/src/helper.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ops/op_utils.h"
26 #include "utils/check_convert_utils.h"
27 
28 namespace mindspore {
29 namespace ops {
30 namespace {
SetSizeInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)31 abstract::ShapePtr SetSizeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
32   MS_EXCEPTION_IF_NULL(primitive);
33   auto op_name = primitive->name();
34   auto set_indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape())[kShape];
35   auto set_values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShape())[kShape];
36   auto set_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShape())[kShape];
37   // support dynamic rank
38   if (IsDynamicRank(set_indices_shape) || IsDynamicRank(set_values_shape) || IsDynamicRank(set_shape_shape)) {
39     return std::make_shared<abstract::TensorShape>(ShapeVector({abstract::TensorShape::kShapeRankAny}));
40   }
41 
42   auto set_indices_shape_num = 2;
43   (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_indices",
44                                            SizeToLong(set_indices_shape.size()), kEqual, set_indices_shape_num,
45                                            op_name);
46   (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_values", SizeToLong(set_values_shape.size()),
47                                            kEqual, 1, op_name);
48   (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_shape", SizeToLong(set_shape_shape.size()),
49                                            kEqual, 1, op_name);
50 
51   if (IsDynamic(set_shape_shape)) {
52     return std::make_shared<abstract::TensorShape>(ShapeVector({abstract::TensorShape::kShapeRankAny}));
53   }
54 
55   if (!IsDynamic(set_indices_shape) && !IsDynamic(set_values_shape)) {
56     (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_indices or set_shape",
57                                              set_indices_shape[1], kEqual, set_shape_shape[0], op_name);
58     (void)CheckAndConvertUtils::CheckInteger("dimension of SetSize input set_indices or set_values",
59                                              set_indices_shape[0], kEqual, set_values_shape[0], op_name);
60   }
61 
62   MS_EXCEPTION_IF_NULL(primitive->GetAttr("validate_indices"));
63   auto shape_size_dim = set_shape_shape[0];
64   bool gen_value_succ = false;
65   std::vector<int64_t> set_shape_value_vec(shape_size_dim);
66   auto set_shape_tensor = input_args[2];
67   MS_EXCEPTION_IF_NULL(set_shape_tensor);
68   if (CheckAndConvertUtils::IsTensor(set_shape_tensor)) {
69     const std::set<TypePtr> output_size_valid_types = {kInt64};
70     (void)CheckAndConvertUtils::CheckTensorTypeValid("set_shape", set_shape_tensor->GetType(), output_size_valid_types,
71                                                      op_name);
72     if (IsValueKnown(set_shape_tensor)) {
73       auto value = GetArrayValue<int64_t>(set_shape_tensor).value().ToVector();
74       for (size_t i = 0; i < LongToSize(shape_size_dim); ++i) {
75         set_shape_value_vec[i] = value[i];
76       }
77       gen_value_succ = true;
78     }
79   }
80   if (!gen_value_succ) {
81     auto dense_size = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShape())[kShape];
82     ShapeVector max_shape(dense_size[0] - 1);
83     auto max_length_ptr = primitive->GetAttr("max_length");
84     MS_EXCEPTION_IF_NULL(max_length_ptr);
85     int64_t max_length = GetValue<int64_t>(max_length_ptr);
86     for (int64_t i = 1; i <= dense_size[0] - 1; ++i) {
87       max_shape.end()[-i] = max_length;
88     }
89     return std::make_shared<abstract::TensorShape>(max_shape);
90   } else {
91     ShapeVector output_shape;
92     auto set_values_index = 2;
93     auto dense_size = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShape())[kShape];
94     if (dense_size.size() == 1 && dense_size[0] < set_values_index) {
95       output_shape.push_back(1);
96     } else {
97       for (unsigned int i = 0; i < dense_size[0] - 1; ++i) {
98         output_shape.push_back(set_shape_value_vec[i]);
99       }
100     }
101     return std::make_shared<abstract::TensorShape>(output_shape);
102   }
103 }
104 
SetSizeInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)105 TypePtr SetSizeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
106   auto prim_name = prim->name();
107   const std::set<TypePtr> valid_types = {kInt64};
108   const std::set<TypePtr> set_values_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16};
109   (void)CheckAndConvertUtils::CheckTensorTypeValid("set_indices", input_args[kInputIndex0]->GetType(), valid_types,
110                                                    prim_name);
111   (void)CheckAndConvertUtils::CheckTensorTypeValid("set_values", input_args[kInputIndex1]->GetType(),
112                                                    set_values_valid_types, prim_name);
113   (void)CheckAndConvertUtils::CheckTensorTypeValid("set_shape", input_args[kInputIndex2]->GetType(), valid_types,
114                                                    prim_name);
115   return std::make_shared<TensorType>(kInt32);
116 }
117 }  // namespace
118 
119 MIND_API_OPERATOR_IMPL(SetSize, BaseOperator);
120 
Init(const bool validate_indices)121 void SetSize::Init(const bool validate_indices) { set_validate_indices(validate_indices); }
122 
set_validate_indices(const bool & validate_indices)123 void SetSize::set_validate_indices(const bool &validate_indices) {
124   (void)AddAttr(kValidateIndices, api::MakeValue(validate_indices));
125 }
126 
get_validate_indices() const127 bool SetSize::get_validate_indices() const {
128   auto value_ptr = GetAttr(kValidateIndices);
129   return GetValue<bool>(value_ptr);
130 }
131 
SetSizeInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)132 AbstractBasePtr SetSizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
133                              const std::vector<AbstractBasePtr> &input_args) {
134   MS_EXCEPTION_IF_NULL(primitive);
135   auto prim_name = primitive->name();
136   const int64_t input_num = 3;
137   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
138   for (const auto &item : input_args) {
139     MS_EXCEPTION_IF_NULL(item);
140   }
141   auto infer_type = SetSizeInferType(primitive, input_args);
142   auto infer_shape = SetSizeInferShape(primitive, input_args);
143   return abstract::MakeAbstract(infer_shape, infer_type);
144 }
145 
146 // AG means auto generated
147 class MIND_API AGSetSizeInfer : public abstract::OpInferBase {
148  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const149   BaseShapePtr InferShape(const PrimitivePtr &primitive,
150                           const std::vector<AbstractBasePtr> &input_args) const override {
151     return SetSizeInferShape(primitive, input_args);
152   }
153 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const154   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
155     return SetSizeInferType(primitive, input_args);
156   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const157   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
158                                     const std::vector<AbstractBasePtr> &input_args) const override {
159     return SetSizeInfer(engine, primitive, input_args);
160   }
161 
GetValueDependArgIndices() const162   std::set<int64_t> GetValueDependArgIndices() const override { return {2}; }
163 };
164 
165 REGISTER_PRIMITIVE_OP_INFER_IMPL(SetSize, prim::kPrimSetSize, AGSetSizeInfer, false);
166 }  // namespace ops
167 }  // namespace mindspore
168