1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "abstract/ops/primitive_infer_map.h"
20 #include <string>
21 #include <vector>
22 #include <set>
23 #include <algorithm>
24 #include <cstdint>
25 #include <iterator>
26
27 #include "abstract/utils.h"
28 #include "ops/sparse_ops.h"
29 #include "ops/random_ops.h"
30 #include "ops/conv_pool_ops.h"
31 #include "ops/other_ops.h"
32 #include "ops/nn_ops.h"
33 #include "ops/math_ops.h"
34 #include "ops/image_ops.h"
35 #include "ops/array_ops.h"
36 #include "ops/framework_ops.h"
37 #include "ops/ops_frontend_func_impl.h"
38 #include "ops/op_def.h"
39 #include "ops/shape_calc.h"
40 #include "ops/op_utils.h"
41 #include "include/common/utils/utils.h"
42 #include "utils/ms_context.h"
43
44 namespace mindspore {
45 namespace abstract {
GetDependValueSize(const ValuePtr & value)46 int64_t GetDependValueSize(const ValuePtr &value) {
47 if (value->isa<Int64Imm>()) {
48 return GetValue<int64_t>(value);
49 }
50 if (!value->isa<ValueTuple>()) {
51 MS_LOG(EXCEPTION) << "the element of attr[dyn_input_size] should be all int64 of ValueTuple but got"
52 << value->ToString() << ", type :" << value->type_name();
53 }
54 int64_t size = 0;
55 auto value_tuple = value->cast_ptr<ValueTuple>();
56 MS_EXCEPTION_IF_NULL(value_tuple);
57 for (size_t i = 0; i < value_tuple->size(); ++i) {
58 size += GetDependValueSize((*value_tuple)[i]);
59 }
60 return size;
61 }
62
CheckScalarValid(const AbstractBasePtr & input_abstract)63 bool CheckScalarValid(const AbstractBasePtr &input_abstract) {
64 // Now, only scalar with int/float/uint will be used as the output of operator, so only add them to list.
65 if (input_abstract->isa<abstract::AbstractScalar>()) {
66 auto scalar_id = NormalizeTypeId(input_abstract->BuildType()->type_id());
67 return (scalar_id == kNumberTypeBool || scalar_id == kNumberTypeInt || scalar_id == kNumberTypeFloat ||
68 scalar_id == kNumberTypeUInt);
69 }
70 return false;
71 }
72
CheckNeedAddToDependList(const AbstractBasePtr & input_abstract)73 bool CheckNeedAddToDependList(const AbstractBasePtr &input_abstract) {
74 auto is_tensor = input_abstract->isa<abstract::AbstractTensor>();
75 bool is_integer = false;
76 bool is_tuple_scalar_or_tensor = false;
77 is_integer = CheckScalarValid(input_abstract);
78 if (input_abstract->isa<abstract::AbstractTuple>()) {
79 auto tuple_abs = input_abstract->cast_ptr<abstract::AbstractTuple>();
80 auto elements = tuple_abs->elements();
81 is_tuple_scalar_or_tensor = std::all_of(elements.begin(), elements.end(), [](const AbstractBasePtr &element) {
82 return (CheckScalarValid(element)) || element->isa<abstract::AbstractTensor>();
83 });
84 }
85 return is_tensor || is_integer || is_tuple_scalar_or_tensor;
86 }
87
RectifyDependListFromDynamicInputAttr(const CNodePtr & cnode,const PrimitivePtr & primitive,const std::set<int64_t> & ori_depend_list)88 std::set<int64_t> RectifyDependListFromDynamicInputAttr(const CNodePtr &cnode, const PrimitivePtr &primitive,
89 const std::set<int64_t> &ori_depend_list) {
90 std::set<int64_t> rec_depend_list = {};
91 constexpr auto all_tensor_inputs = -1;
92 if (ori_depend_list.size() == 1 && *(ori_depend_list.cbegin()) == all_tensor_inputs) {
93 for (size_t i = 1; i < cnode->size(); ++i) {
94 const auto &input = cnode->inputs()[i];
95 const auto &input_abstract = input->abstract();
96 if (input_abstract != nullptr) {
97 auto need_add_to_depend_list = CheckNeedAddToDependList(input_abstract);
98 if (need_add_to_depend_list) {
99 (void)rec_depend_list.emplace(SizeToLong(i - 1));
100 }
101 }
102 }
103 return rec_depend_list;
104 }
105
106 auto attr = primitive->GetAttr(kAttrDynInputSizes);
107 if (attr == nullptr) {
108 return ori_depend_list;
109 }
110
111 // mapping from input prototype index to corresponding start index of real input
112 std::vector<int64_t> dyn_input_sizes = GetValue<std::vector<int64_t>>(attr);
113 std::vector<int64_t> proto2real;
114 int64_t count = 0;
115 std::for_each(dyn_input_sizes.begin(), dyn_input_sizes.end(), [&count, &proto2real](int64_t dyn_size) {
116 proto2real.push_back(count);
117 count += dyn_size < 0 ? 1 : dyn_size;
118 });
119
120 std::for_each(ori_depend_list.begin(), ori_depend_list.end(),
121 [&proto2real, &dyn_input_sizes, &primitive, &rec_depend_list](int64_t proto_idx) {
122 if (proto_idx >= static_cast<int64_t>(dyn_input_sizes.size())) {
123 MS_LOG(EXCEPTION) << "The value depend index " << proto_idx << " of primitive " << primitive->name()
124 << " is out of range [0, " << dyn_input_sizes.size() << ").";
125 }
126 // value depend input is a normal input
127 if (dyn_input_sizes[proto_idx] < 0) {
128 rec_depend_list.insert(proto2real[proto_idx]);
129 }
130 // value depend input is is a dynamic input
131 for (int64_t i = 0; i < dyn_input_sizes[proto_idx]; ++i) {
132 rec_depend_list.insert(proto2real[proto_idx] + i);
133 }
134 });
135
136 return rec_depend_list;
137 }
138
GetValueDependArgIndices(const CNodePtr & cnode,bool is_proto)139 std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode, bool is_proto) {
140 MS_EXCEPTION_IF_NULL(cnode);
141 if (cnode->inputs().empty()) {
142 MS_LOG(EXCEPTION) << "Invalid inputs";
143 }
144 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
145 if (primitive == nullptr) {
146 return {};
147 }
148 auto prim_name = primitive->name();
149 std::set<int64_t> ori = {};
150 auto op_infer_opt = GetPrimitiveInferImpl(primitive);
151 if (!op_infer_opt.has_value()) {
152 // some operator will be mapped to new operator on Ascend like GatherV2, however they use same Infer information
153 if (primitive->HasAttr(kAttrMeOpName)) {
154 auto ori_prim_name = GetValue<std::string>(primitive->GetAttr(kAttrMeOpName));
155 op_infer_opt = GetPrimitiveInferImpl(std::make_shared<Primitive>(ori_prim_name));
156 }
157 }
158
159 if (op_infer_opt.has_value()) {
160 auto op_infer = op_infer_opt.value().Get();
161 if (op_infer != nullptr && ori.empty()) {
162 ori = op_infer->GetValueDependArgIndices();
163 }
164 if (prim_name == ops::kNameShapeCalc) {
165 auto only_depend_shape = GetValue<std::vector<bool>>(primitive->GetAttr(kAttrOnlyDependShape));
166 for (size_t i = 0; i < only_depend_shape.size(); i++) {
167 if (!only_depend_shape[i]) {
168 ori.insert(i);
169 }
170 }
171 }
172 } else if (ori.empty()) {
173 MS_LOG(DEBUG) << "Not find infer function GetValueDependArgIndices, prim name: " << prim_name;
174 // if not found in infer, consider all the non-tensor inputs as value depend args.
175 ori = ops::GetInputDependValueList(primitive);
176 if (prim_name == ops::kNameAvgPoolGrad && primitive->HasAttr(kAttrValueDepend)) {
177 auto value_depend_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrValueDepend));
178 ori.clear();
179 ori.insert(value_depend_vector.begin(), value_depend_vector.end());
180 }
181 }
182 if (ori.empty()) {
183 return ori;
184 }
185 size_t input_num = cnode->size() - 1;
186 std::set<int64_t> res = {};
187
188 (void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()),
189 [&](int64_t idx) { return idx < SizeToLong(input_num); });
190 if (is_proto) {
191 return res;
192 }
193 return RectifyDependListFromDynamicInputAttr(cnode, primitive, res);
194 }
195
GetPrimitiveInferMapPtr()196 PrimitiveEvalImplMap *GetPrimitiveInferMapPtr() {
197 static PrimitiveEvalImplMap prim_eval_implement_map{
198 // core/ops infer
199 // Do not add anything in this initializer anymore since it will be removed soon, core/ops prim should register its
200 // infer in its cc file.
201 };
202 return &prim_eval_implement_map;
203 }
GetPrimitiveInferMap()204 const PrimitiveEvalImplMap &GetPrimitiveInferMap() { return *GetPrimitiveInferMapPtr(); }
205
GetPrimitiveInferImpl(const PrimitivePtr & primitive)206 std::optional<StandardPrimitiveImplReg> GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
207 auto iter = GetPrimitiveInferMap().find(primitive);
208 if (iter != GetPrimitiveInferMap().end()) {
209 return iter->second;
210 }
211
212 iter = GetDeprecatedPrimitiveInferMap().find(primitive);
213 if (iter != GetDeprecatedPrimitiveInferMap().end()) {
214 return iter->second;
215 }
216 return std::optional<StandardPrimitiveImplReg>();
217 }
218
219 class OpInferCommon : public OpInferBase {
220 public:
221 OpInferCommon() = delete;
OpInferCommon(const InferAbstractImpl & infer_impl,const InferValueImpl & infer_value_impl)222 OpInferCommon(const InferAbstractImpl &infer_impl, const InferValueImpl &infer_value_impl)
223 : infer_impl_(infer_impl), infer_value_impl_(infer_value_impl) {}
224 ~OpInferCommon() = default;
225
226 BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
227 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
228 ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
229 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
230 const std::vector<AbstractBasePtr> &input_args) const override;
231
232 private:
233 InferAbstractImpl infer_impl_{nullptr};
234 InferValueImpl infer_value_impl_{nullptr};
235 };
236
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const237 BaseShapePtr OpInferCommon::InferShape(const PrimitivePtr &primitive,
238 const std::vector<AbstractBasePtr> &input_args) const {
239 if (!infer_impl_) {
240 return nullptr;
241 }
242
243 auto inferred_res = infer_impl_(nullptr, primitive, input_args);
244 if (inferred_res == nullptr) {
245 return nullptr;
246 }
247
248 return inferred_res->GetShape();
249 }
250
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const251 TypePtr OpInferCommon::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
252 if (!infer_impl_) {
253 return nullptr;
254 }
255
256 auto inferred_res = infer_impl_(nullptr, primitive, input_args);
257 if (inferred_res == nullptr) {
258 return nullptr;
259 }
260
261 return inferred_res->BuildType();
262 }
263
InferValue(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const264 ValuePtr OpInferCommon::InferValue(const PrimitivePtr &primitive,
265 const std::vector<AbstractBasePtr> &input_args) const {
266 if (!infer_value_impl_) {
267 return nullptr;
268 }
269 return infer_value_impl_(primitive, input_args);
270 }
271
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const272 AbstractBasePtr OpInferCommon::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
273 const PrimitivePtr &primitive,
274 const std::vector<AbstractBasePtr> &input_args) const {
275 if (!infer_impl_) {
276 return nullptr;
277 }
278
279 return infer_impl_(engine, primitive, input_args);
280 }
281
StandardPrimitiveImplReg(const InferAbstractImpl & infer_abstract,const InferValueImpl & infer_value,bool in_white_list)282 StandardPrimitiveImplReg::StandardPrimitiveImplReg(const InferAbstractImpl &infer_abstract,
283 const InferValueImpl &infer_value, bool in_white_list) {
284 op_infer_ = std::make_shared<OpInferCommon>(infer_abstract, infer_value);
285 is_impl_infer_shape_and_type_ = infer_abstract != nullptr;
286 is_impl_infer_value_ = infer_value != nullptr;
287 in_white_list_ = in_white_list;
288 }
289
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const290 AbstractBasePtr StandardPrimitiveImplReg::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
291 const PrimitivePtr &primitive,
292 const std::vector<AbstractBasePtr> &input_args) const {
293 if (op_infer_ == nullptr) {
294 return nullptr;
295 }
296
297 return op_infer_->InferShapeAndType(engine, primitive, input_args);
298 }
299
InferShape(const PrimitivePtr & prim,const AbstractBasePtrList & args) const300 BaseShapePtr StandardPrimitiveImplReg::InferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
301 if (op_infer_ == nullptr) {
302 return nullptr;
303 }
304
305 return op_infer_->InferShape(prim, args);
306 }
307
InferType(const PrimitivePtr & prim,const AbstractBasePtrList & args) const308 TypePtr StandardPrimitiveImplReg::InferType(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
309 if (op_infer_ == nullptr) {
310 return nullptr;
311 }
312
313 return op_infer_->InferType(prim, args);
314 }
315
InferValue(const PrimitivePtr & prim,const AbstractBasePtrList & args) const316 ValuePtr StandardPrimitiveImplReg::InferValue(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
317 if (op_infer_ == nullptr) {
318 return nullptr;
319 }
320
321 return op_infer_->InferValue(prim, args);
322 }
323
InferShapeByFuncImpl(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args,bool compile_phase)324 std::optional<BaseShapePtr> InferShapeByFuncImpl(const PrimitivePtr &primitive, const AbstractBasePtrList &input_args,
325 bool compile_phase) {
326 MS_EXCEPTION_IF_NULL(primitive);
327 auto op_name = primitive->name();
328 if (compile_phase) {
329 auto frontend_func_impl = ops::GetOpFrontendFuncImplPtr(op_name);
330 if (frontend_func_impl != nullptr) {
331 auto infer_result = frontend_func_impl->InferAbstract(primitive, input_args);
332 if (infer_result != nullptr) {
333 return infer_result->GetShape();
334 }
335 }
336 }
337
338 auto op_def = ops::GetOpDef(op_name);
339 if (op_def == nullptr) {
340 return std::nullopt;
341 }
342 (void)op_def->func_impl_.CheckValidation(primitive, input_args);
343 return op_def->func_impl_.InferShape(primitive, input_args);
344 }
345
InferTypeByFuncImpl(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args,bool compile_phase)346 std::optional<TypePtr> InferTypeByFuncImpl(const PrimitivePtr &primitive, const AbstractBasePtrList &input_args,
347 bool compile_phase) {
348 MS_EXCEPTION_IF_NULL(primitive);
349 auto op_name = primitive->name();
350 if (compile_phase) {
351 auto frontend_func_impl = ops::GetOpFrontendFuncImplPtr(op_name);
352 if (frontend_func_impl != nullptr) {
353 auto infer_result = frontend_func_impl->InferAbstract(primitive, input_args);
354 if (infer_result != nullptr) {
355 return infer_result->GetType();
356 }
357 }
358 }
359
360 auto op_def = ops::GetOpDef(op_name);
361 if (op_def == nullptr) {
362 return std::nullopt;
363 }
364 (void)op_def->func_impl_.CheckValidation(primitive, input_args);
365 return op_def->func_impl_.InferType(primitive, input_args);
366 }
367
InferAbstractByFuncImpl(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args)368 std::optional<AbstractBasePtr> InferAbstractByFuncImpl(const PrimitivePtr &primitive,
369 const AbstractBasePtrList &input_args) {
370 MS_EXCEPTION_IF_NULL(primitive);
371 auto op_name = primitive->name();
372 auto frontend_func_impl = ops::GetOpFrontendFuncImplPtr(op_name);
373 if (frontend_func_impl != nullptr) {
374 auto infer_result = frontend_func_impl->InferAbstract(primitive, input_args);
375 if (infer_result != nullptr) {
376 return infer_result;
377 }
378 }
379
380 auto op_def = ops::GetOpDef(op_name);
381 if (op_def == nullptr) {
382 return std::nullopt;
383 }
384 (void)op_def->func_impl_.CheckValidation(primitive, input_args);
385 auto shape = op_def->func_impl_.InferShape(primitive, input_args);
386 auto type = op_def->func_impl_.InferType(primitive, input_args);
387 return MakeAbstract(shape, type);
388 }
389
InferValueByFuncImpl(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args)390 std::optional<ValuePtr> InferValueByFuncImpl(const PrimitivePtr &primitive, const AbstractBasePtrList &input_args) {
391 MS_EXCEPTION_IF_NULL(primitive);
392 auto op_name = primitive->name();
393 auto frontend_func_impl = ops::GetOpFrontendFuncImplPtr(op_name);
394 if (frontend_func_impl == nullptr) {
395 return std::nullopt;
396 }
397 return frontend_func_impl->InferValue(primitive, input_args);
398 }
399
TryInferAbstract(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args)400 std::optional<AbstractBasePtr> TryInferAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_args) {
401 MS_EXCEPTION_IF_NULL(primitive);
402 auto abstract_optional = abstract::InferAbstractByFuncImpl(primitive, input_args);
403 if (abstract_optional.has_value()) {
404 return abstract_optional.value();
405 }
406
407 auto found = abstract::GetPrimitiveInferImpl(primitive);
408 if (!found.has_value() || !found.value().IsImplInferShapeAndType()) {
409 MS_LOG(DEBUG) << "The infer function of [" << primitive->name() << "] is not defined.";
410 return std::nullopt;
411 }
412 return found.value().InferShapeAndType(nullptr, primitive, input_args);
413 }
414 } // namespace abstract
415 } // namespace mindspore
416