• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ops/shape_calc.h"
17 #include <algorithm>
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <unordered_set>
22 #include <vector>
23 #include "abstract/abstract_value.h"
24 #include "abstract/ops/primitive_infer_map.h"
25 #include "ir/anf.h"
26 #include "ir/dtype/number.h"
27 #include "ir/kernel_tensor_value.h"
28 #include "ir/value.h"
29 #include "mindapi/src/helper.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "ops/op_utils.h"
32 #include "utils/check_convert_utils.h"
33 #include "utils/hash_set.h"
34 #include "utils/log_adapter.h"
35 #include "utils/anf_utils.h"
36 
37 namespace mindspore::ops {
38 namespace {
GetShapeFromScalarOrTensor(const abstract::BaseShapePtr & base_shape)39 ShapeVector GetShapeFromScalarOrTensor(const abstract::BaseShapePtr &base_shape) {
40   MS_EXCEPTION_IF_NULL(base_shape);
41   if (base_shape->isa<abstract::TensorShape>()) {
42     return base_shape->GetShapeVector();
43   } else if (!base_shape->isa<abstract::NoShape>()) {
44     MS_EXCEPTION(TypeError) << "For Primitive[ShapeCalc], only support tuple of scalar or tensor now, but got "
45                             << base_shape;
46   }
47 
48   return ShapeVector{};
49 }
50 
TryGetValueArg(const AbstractBasePtr & abs,ShapeArray * args,std::vector<std::vector<size_t>> * pos_idx)51 bool TryGetValueArg(const AbstractBasePtr &abs, ShapeArray *args, std::vector<std::vector<size_t>> *pos_idx) {
52   size_t offset_base = args->size();
53   pos_idx->push_back(std::vector<size_t>{offset_base});
54   auto value_ptr = abs->GetValue();
55   MS_EXCEPTION_IF_NULL(value_ptr);
56   if (!ops::IsValueKnown(value_ptr)) {
57     args->push_back(ShapeVector{});
58     return false;
59   }
60   if (value_ptr->isa<Int64Imm>()) {
61     auto scalar_optional = ops::GetScalarValue<int64_t>(value_ptr);
62     if (scalar_optional.has_value()) {
63       args->push_back(ShapeVector{scalar_optional.value()});
64       return true;
65     }
66   } else if (value_ptr->isa<tensor::Tensor>() || value_ptr->isa<KernelTensorValue>() ||
67              value_ptr->isa<ValueSequeue>()) {
68     auto shape_value_optional = ops::GetArrayValue<int64_t>(abs);
69     if (shape_value_optional.has_value()) {
70       auto shape_array_value = shape_value_optional.value();
71       if (!shape_array_value.HasUnknownValue()) {
72         args->push_back(shape_array_value.ToVector());
73         return true;
74       }
75     }
76   } else {
77     MS_EXCEPTION(TypeError) << "For ShapeCalc, the shape input type must be Tensor/Scalar/Tuple/List, but got "
78                             << value_ptr->ToString() << ".";
79   }
80 
81   return false;
82 }
83 
CreateAbstractInt64TupleByNum(int64_t num)84 AbstractBasePtr CreateAbstractInt64TupleByNum(int64_t num) {
85   AbstractBasePtrList abs_list;
86   if (num == -1) {
87     const auto &abstract = std::make_shared<abstract::AbstractTuple>(abs_list);
88     abstract->set_dynamic_len(true);
89     abstract->set_dynamic_len_element_abs(std::make_shared<abstract::AbstractScalar>(kInt64));
90     return abstract;
91   }
92   abs_list.reserve(LongToSize(num));
93   for (size_t i = 0; i < LongToSize(num); ++i) {
94     abs_list.push_back(std::make_shared<abstract::AbstractScalar>(kInt64));
95   }
96   return std::make_shared<abstract::AbstractTuple>(abs_list);
97 }
98 }  // namespace
99 
TryGetShapeArg(const AbstractBasePtr & abs,ShapeArray * args,std::vector<std::vector<size_t>> * pos_idx)100 bool TryGetShapeArg(const AbstractBasePtr &abs, ShapeArray *args, std::vector<std::vector<size_t>> *pos_idx) {
101   MS_EXCEPTION_IF_NULL(args);
102   MS_EXCEPTION_IF_NULL(pos_idx);
103 
104   size_t offset_base = args->size();
105   std::vector<size_t> pos;
106   auto base_shape = abs->GetShape();
107   MS_EXCEPTION_IF_NULL(base_shape);
108   if (base_shape->isa<abstract::NoShape>() || base_shape->isa<abstract::TensorShape>()) {
109     args->push_back(GetShapeFromScalarOrTensor(base_shape));
110     pos.push_back(offset_base);
111   } else if (base_shape->isa<abstract::SequenceShape>()) {
112     auto sequence_shape = base_shape->cast<abstract::SequenceShapePtr>();
113     MS_EXCEPTION_IF_NULL(sequence_shape);
114     for (size_t i = 0; i < sequence_shape->size(); ++i) {
115       args->push_back(GetShapeFromScalarOrTensor((*sequence_shape)[i]));
116       pos.push_back(offset_base + i);
117     }
118   } else {
119     if (base_shape->isa<abstract::DynamicSequenceShape>()) {
120       auto dynamic_sequence = base_shape->cast<abstract::DynamicSequenceShapePtr>();
121       MS_EXCEPTION_IF_NULL(dynamic_sequence);
122       auto element_base_shape = dynamic_sequence->element_shape();
123       args->push_back(GetShapeFromScalarOrTensor(element_base_shape));
124       pos.push_back(offset_base);
125     }
126     pos_idx->push_back(pos);
127     return false;
128   }
129 
130   pos_idx->push_back(pos);
131   return true;
132 }
133 
get_functor() const134 ShapeCalcBaseFunctorPtr ShapeCalc::get_functor() const {
135   auto attr = api::ToRef<mindspore::Primitive>(impl_).GetAttr(kAttrFunctor);
136   MS_EXCEPTION_IF_NULL(attr);
137   return attr->cast<ShapeCalcBaseFunctorPtr>();
138 }
139 
get_calc_result() const140 ShapeArray ShapeCalc::get_calc_result() const { return GetValue<ShapeArray>(GetAttr(kAttrCalcResult)); }
141 
142 class MIND_API ShapeCalcInfer : public abstract::OpInferBase {
143  public:
InferShapeAndType(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const144   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
145                                     const std::vector<AbstractBasePtr> &input_args) const override {
146     MS_EXCEPTION_IF_NULL(primitive);
147     auto only_depend_shape = GetValue<std::vector<bool>>(primitive->GetAttr(kAttrOnlyDependShape));
148     ShapeArray args;
149     HashSet<size_t> unknown_inputs;
150     bool is_any_dynamic_shape = false;
151     std::vector<std::vector<size_t>> pos_idx;
152     pos_idx.reserve(input_args.size());
153     (void)primitive->AddAttr(kInputRealTuple, MakeValue(true));
154     for (size_t i = 0; i < input_args.size(); ++i) {
155       const auto &abs = input_args[i];
156       MS_EXCEPTION_IF_NULL(abs);
157       if (only_depend_shape[i]) {
158         // If it is not value depend, use shape as arg.
159         size_t offset_base = args.size();
160         if (!TryGetShapeArg(abs, &args, &pos_idx)) {
161           (void)unknown_inputs.insert(i);
162         } else {
163           auto is_new_dynamic = std::any_of(args.begin() + offset_base, args.end(), IsDynamic);
164           is_any_dynamic_shape = is_any_dynamic_shape || is_new_dynamic;
165         }
166       } else {
167         // Value depended, try to get value from input abstract.
168         if (!TryGetValueArg(abs, &args, &pos_idx)) {
169           (void)unknown_inputs.insert(i);
170         }
171       }
172     }
173 
174     auto functor_attr = primitive->GetAttr(kAttrFunctor);
175     MS_EXCEPTION_IF_NULL(functor_attr);
176     auto functor = functor_attr->cast<ShapeCalcBaseFunctorPtr>();
177     MS_EXCEPTION_IF_NULL(functor);
178 
179     ShapeVector out;
180     bool is_dynamic_sequence = false;
181     if (!unknown_inputs.empty() || is_any_dynamic_shape) {
182       auto infer_res = functor->Infer(args, unknown_inputs, pos_idx);
183       out = infer_res.first;
184       is_dynamic_sequence = infer_res.second;
185     } else {
186       auto ans = functor->Calc(args, pos_idx);
187       primitive->set_attr(kAttrCalcResult, MakeValue(ans));
188       out.reserve(ans.size());
189       std::transform(ans.cbegin(), ans.cend(), std::back_inserter(out),
190                      [](const ShapeVector &shape) { return SizeToLong(shape.size()); });
191     }
192     if (!is_dynamic_sequence && out.size() == 1) {
193       // single output does not use AbstractTuple to avoid TupleGetItem
194       (void)primitive->AddAttr(kOutputRealTuple, MakeValue(true));
195       return CreateAbstractInt64TupleByNum(out[0]);
196     }
197 
198     // multiple outputs
199     if (!is_dynamic_sequence && primitive->HasAttr(kOutputRealTuple) && !out.empty()) {
200       auto first_len = out[0];
201       if (std::any_of(out.begin() + 1, out.end(), [first_len](int64_t len) { return first_len != len; })) {
202         MS_LOG(EXCEPTION) << "For 'ShapeCalc', each output should have same size in dynamic length case.";
203       }
204     }
205 
206     AbstractBasePtrList abs_list;
207     abs_list.reserve(out.size());
208     (void)std::transform(out.begin(), out.end(), std::back_inserter(abs_list),
209                          [](int64_t s) { return CreateAbstractInt64TupleByNum(s); });
210     auto output_abstract = std::make_shared<abstract::AbstractTuple>(abs_list);
211     if (is_dynamic_sequence) {
212       (void)primitive->AddAttr(kOutputRealTuple, MakeValue(true));
213       output_abstract->CheckAndConvertToDynamicLenSequence();
214     }
215     return output_abstract;
216   }
217 
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const218   BaseShapePtr InferShape(const PrimitivePtr &primitive,
219                           const std::vector<AbstractBasePtr> &input_args) const override {
220     return InferShapeAndType(nullptr, primitive, input_args)->GetShape();
221   }
222 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const223   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
224     return InferShapeAndType(nullptr, primitive, input_args)->GetType();
225   }
226 };
227 
228 MIND_API_OPERATOR_IMPL(ShapeCalc, BaseOperator);
229 REGISTER_PRIMITIVE_OP_INFER_IMPL(ShapeCalc, prim::kPrimShapeCalc, ShapeCalcInfer, false);
230 }  // namespace mindspore::ops
231