1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "ops/shape_calc.h"
17 #include <algorithm>
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <unordered_set>
22 #include <vector>
23 #include "abstract/abstract_value.h"
24 #include "abstract/ops/primitive_infer_map.h"
25 #include "ir/anf.h"
26 #include "ir/dtype/number.h"
27 #include "ir/kernel_tensor_value.h"
28 #include "ir/value.h"
29 #include "mindapi/src/helper.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "ops/op_utils.h"
32 #include "utils/check_convert_utils.h"
33 #include "utils/hash_set.h"
34 #include "utils/log_adapter.h"
35 #include "utils/anf_utils.h"
36
37 namespace mindspore::ops {
38 namespace {
GetShapeFromScalarOrTensor(const abstract::BaseShapePtr & base_shape)39 ShapeVector GetShapeFromScalarOrTensor(const abstract::BaseShapePtr &base_shape) {
40 MS_EXCEPTION_IF_NULL(base_shape);
41 if (base_shape->isa<abstract::TensorShape>()) {
42 return base_shape->GetShapeVector();
43 } else if (!base_shape->isa<abstract::NoShape>()) {
44 MS_EXCEPTION(TypeError) << "For Primitive[ShapeCalc], only support tuple of scalar or tensor now, but got "
45 << base_shape;
46 }
47
48 return ShapeVector{};
49 }
50
TryGetValueArg(const AbstractBasePtr & abs,ShapeArray * args,std::vector<std::vector<size_t>> * pos_idx)51 bool TryGetValueArg(const AbstractBasePtr &abs, ShapeArray *args, std::vector<std::vector<size_t>> *pos_idx) {
52 size_t offset_base = args->size();
53 pos_idx->push_back(std::vector<size_t>{offset_base});
54 auto value_ptr = abs->GetValue();
55 MS_EXCEPTION_IF_NULL(value_ptr);
56 if (!ops::IsValueKnown(value_ptr)) {
57 args->push_back(ShapeVector{});
58 return false;
59 }
60 if (value_ptr->isa<Int64Imm>()) {
61 auto scalar_optional = ops::GetScalarValue<int64_t>(value_ptr);
62 if (scalar_optional.has_value()) {
63 args->push_back(ShapeVector{scalar_optional.value()});
64 return true;
65 }
66 } else if (value_ptr->isa<tensor::Tensor>() || value_ptr->isa<KernelTensorValue>() ||
67 value_ptr->isa<ValueSequeue>()) {
68 auto shape_value_optional = ops::GetArrayValue<int64_t>(abs);
69 if (shape_value_optional.has_value()) {
70 auto shape_array_value = shape_value_optional.value();
71 if (!shape_array_value.HasUnknownValue()) {
72 args->push_back(shape_array_value.ToVector());
73 return true;
74 }
75 }
76 } else {
77 MS_EXCEPTION(TypeError) << "For ShapeCalc, the shape input type must be Tensor/Scalar/Tuple/List, but got "
78 << value_ptr->ToString() << ".";
79 }
80
81 return false;
82 }
83
CreateAbstractInt64TupleByNum(int64_t num)84 AbstractBasePtr CreateAbstractInt64TupleByNum(int64_t num) {
85 AbstractBasePtrList abs_list;
86 if (num == -1) {
87 const auto &abstract = std::make_shared<abstract::AbstractTuple>(abs_list);
88 abstract->set_dynamic_len(true);
89 abstract->set_dynamic_len_element_abs(std::make_shared<abstract::AbstractScalar>(kInt64));
90 return abstract;
91 }
92 abs_list.reserve(LongToSize(num));
93 for (size_t i = 0; i < LongToSize(num); ++i) {
94 abs_list.push_back(std::make_shared<abstract::AbstractScalar>(kInt64));
95 }
96 return std::make_shared<abstract::AbstractTuple>(abs_list);
97 }
98 } // namespace
99
TryGetShapeArg(const AbstractBasePtr & abs,ShapeArray * args,std::vector<std::vector<size_t>> * pos_idx)100 bool TryGetShapeArg(const AbstractBasePtr &abs, ShapeArray *args, std::vector<std::vector<size_t>> *pos_idx) {
101 MS_EXCEPTION_IF_NULL(args);
102 MS_EXCEPTION_IF_NULL(pos_idx);
103
104 size_t offset_base = args->size();
105 std::vector<size_t> pos;
106 auto base_shape = abs->GetShape();
107 MS_EXCEPTION_IF_NULL(base_shape);
108 if (base_shape->isa<abstract::NoShape>() || base_shape->isa<abstract::TensorShape>()) {
109 args->push_back(GetShapeFromScalarOrTensor(base_shape));
110 pos.push_back(offset_base);
111 } else if (base_shape->isa<abstract::SequenceShape>()) {
112 auto sequence_shape = base_shape->cast<abstract::SequenceShapePtr>();
113 MS_EXCEPTION_IF_NULL(sequence_shape);
114 for (size_t i = 0; i < sequence_shape->size(); ++i) {
115 args->push_back(GetShapeFromScalarOrTensor((*sequence_shape)[i]));
116 pos.push_back(offset_base + i);
117 }
118 } else {
119 if (base_shape->isa<abstract::DynamicSequenceShape>()) {
120 auto dynamic_sequence = base_shape->cast<abstract::DynamicSequenceShapePtr>();
121 MS_EXCEPTION_IF_NULL(dynamic_sequence);
122 auto element_base_shape = dynamic_sequence->element_shape();
123 args->push_back(GetShapeFromScalarOrTensor(element_base_shape));
124 pos.push_back(offset_base);
125 }
126 pos_idx->push_back(pos);
127 return false;
128 }
129
130 pos_idx->push_back(pos);
131 return true;
132 }
133
get_functor() const134 ShapeCalcBaseFunctorPtr ShapeCalc::get_functor() const {
135 auto attr = api::ToRef<mindspore::Primitive>(impl_).GetAttr(kAttrFunctor);
136 MS_EXCEPTION_IF_NULL(attr);
137 return attr->cast<ShapeCalcBaseFunctorPtr>();
138 }
139
get_calc_result() const140 ShapeArray ShapeCalc::get_calc_result() const { return GetValue<ShapeArray>(GetAttr(kAttrCalcResult)); }
141
142 class MIND_API ShapeCalcInfer : public abstract::OpInferBase {
143 public:
InferShapeAndType(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const144 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
145 const std::vector<AbstractBasePtr> &input_args) const override {
146 MS_EXCEPTION_IF_NULL(primitive);
147 auto only_depend_shape = GetValue<std::vector<bool>>(primitive->GetAttr(kAttrOnlyDependShape));
148 ShapeArray args;
149 HashSet<size_t> unknown_inputs;
150 bool is_any_dynamic_shape = false;
151 std::vector<std::vector<size_t>> pos_idx;
152 pos_idx.reserve(input_args.size());
153 (void)primitive->AddAttr(kInputRealTuple, MakeValue(true));
154 for (size_t i = 0; i < input_args.size(); ++i) {
155 const auto &abs = input_args[i];
156 MS_EXCEPTION_IF_NULL(abs);
157 if (only_depend_shape[i]) {
158 // If it is not value depend, use shape as arg.
159 size_t offset_base = args.size();
160 if (!TryGetShapeArg(abs, &args, &pos_idx)) {
161 (void)unknown_inputs.insert(i);
162 } else {
163 auto is_new_dynamic = std::any_of(args.begin() + offset_base, args.end(), IsDynamic);
164 is_any_dynamic_shape = is_any_dynamic_shape || is_new_dynamic;
165 }
166 } else {
167 // Value depended, try to get value from input abstract.
168 if (!TryGetValueArg(abs, &args, &pos_idx)) {
169 (void)unknown_inputs.insert(i);
170 }
171 }
172 }
173
174 auto functor_attr = primitive->GetAttr(kAttrFunctor);
175 MS_EXCEPTION_IF_NULL(functor_attr);
176 auto functor = functor_attr->cast<ShapeCalcBaseFunctorPtr>();
177 MS_EXCEPTION_IF_NULL(functor);
178
179 ShapeVector out;
180 bool is_dynamic_sequence = false;
181 if (!unknown_inputs.empty() || is_any_dynamic_shape) {
182 auto infer_res = functor->Infer(args, unknown_inputs, pos_idx);
183 out = infer_res.first;
184 is_dynamic_sequence = infer_res.second;
185 } else {
186 auto ans = functor->Calc(args, pos_idx);
187 primitive->set_attr(kAttrCalcResult, MakeValue(ans));
188 out.reserve(ans.size());
189 std::transform(ans.cbegin(), ans.cend(), std::back_inserter(out),
190 [](const ShapeVector &shape) { return SizeToLong(shape.size()); });
191 }
192 if (!is_dynamic_sequence && out.size() == 1) {
193 // single output does not use AbstractTuple to avoid TupleGetItem
194 (void)primitive->AddAttr(kOutputRealTuple, MakeValue(true));
195 return CreateAbstractInt64TupleByNum(out[0]);
196 }
197
198 // multiple outputs
199 if (!is_dynamic_sequence && primitive->HasAttr(kOutputRealTuple) && !out.empty()) {
200 auto first_len = out[0];
201 if (std::any_of(out.begin() + 1, out.end(), [first_len](int64_t len) { return first_len != len; })) {
202 MS_LOG(EXCEPTION) << "For 'ShapeCalc', each output should have same size in dynamic length case.";
203 }
204 }
205
206 AbstractBasePtrList abs_list;
207 abs_list.reserve(out.size());
208 (void)std::transform(out.begin(), out.end(), std::back_inserter(abs_list),
209 [](int64_t s) { return CreateAbstractInt64TupleByNum(s); });
210 auto output_abstract = std::make_shared<abstract::AbstractTuple>(abs_list);
211 if (is_dynamic_sequence) {
212 (void)primitive->AddAttr(kOutputRealTuple, MakeValue(true));
213 output_abstract->CheckAndConvertToDynamicLenSequence();
214 }
215 return output_abstract;
216 }
217
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const218 BaseShapePtr InferShape(const PrimitivePtr &primitive,
219 const std::vector<AbstractBasePtr> &input_args) const override {
220 return InferShapeAndType(nullptr, primitive, input_args)->GetShape();
221 }
222
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const223 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
224 return InferShapeAndType(nullptr, primitive, input_args)->GetType();
225 }
226 };
227
228 MIND_API_OPERATOR_IMPL(ShapeCalc, BaseOperator);
229 REGISTER_PRIMITIVE_OP_INFER_IMPL(ShapeCalc, prim::kPrimShapeCalc, ShapeCalcInfer, false);
230 } // namespace mindspore::ops
231