1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
20 #define MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
21 #include <string>
22 #include <memory>
23 #include "abstract/abstract_value.h"
24 #include "abstract/param_validator.h"
25 #include "abstract/ops/primitive_infer_map.h"
26 namespace mindspore {
27 namespace abstract {
28 MIND_API AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
29 const AbstractBasePtrList &args_abs_list);
30 MIND_API AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
31 const AbstractBasePtrList &args_abs_list);
32 MIND_API AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
33 const AbstractBasePtrList &args_abs_list);
34 MIND_API AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
35 const AbstractBasePtrList &args_abs_list);
36 MIND_API AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
37 const AbstractBasePtrList &args_abs_list);
38 MIND_API AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
39 const AbstractBasePtrList &args_abs_list);
40 MIND_API AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
41 const AbstractBasePtrList &args_abs_list);
42 MIND_API AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
43 const AbstractBasePtrList &args_abs_list);
44 MIND_API AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
45 const AbstractBasePtrList &args_abs_list);
46 MIND_API AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &,
47 const AbstractBasePtrList &args_abs_list);
48 MIND_API AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
49 const AbstractBasePtrList &args_abs_list);
50 MIND_API AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
51 const AbstractBasePtrList &args_abs_list);
52 MIND_API AbstractBasePtr InferImplBroadcastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
53 const AbstractBasePtrList &args_abs_list);
54 MIND_API AbstractBasePtr InferImplidentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
55 const AbstractBasePtrList &args_abs_list);
56
57 MIND_API AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
58 const AbstractBasePtrList &args_abs_list);
59 MIND_API AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
60 const AbstractBasePtrList &args_abs_list);
61 MIND_API AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
62 const AbstractBasePtrList &args_abs_list);
63 MIND_API AbstractBasePtr InferImplMakeKeywordArg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
64 const AbstractBasePtrList &args_abs_list);
65 MIND_API AbstractBasePtr InferImplExtractKeywordArg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
66 const AbstractBasePtrList &args_abs_list);
67 MIND_API AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
68 const AbstractBasePtrList &args_abs_list);
69 MIND_API AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
70 const AbstractBasePtrList &args_abs_list);
71 MIND_API AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
72 const AbstractBasePtrList &args_abs_list);
73 MIND_API AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
74 const AbstractBasePtrList &args_abs_list);
75 MIND_API AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
76 const AbstractBasePtrList &args_abs_list);
77 MIND_API AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
78 const AbstractBasePtrList &args_abs_list);
79 MIND_API AbstractBasePtr InferImplMutable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
80 const AbstractBasePtrList &args_abs_list);
81 MIND_API AbstractBasePtr InferImplGetGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
82 const AbstractBasePtrList &args_abs_list);
83 MIND_API AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
84 const AbstractBasePtrList &args_abs_list);
85 MIND_API AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
86 const AbstractBasePtrList &args_abs_list);
87 MIND_API AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
88 const AbstractBasePtrList &args_abs_list);
89 MIND_API AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
90 const AbstractBasePtrList &args_abs_list);
91 MIND_API AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
92 const AbstractBasePtrList &args_abs_list);
93 MIND_API AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
94 const AbstractBasePtrList &args_abs_list);
95 MIND_API AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
96 const AbstractBasePtrList &args_abs_list);
97 MIND_API AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
98 const AbstractBasePtrList &args_abs_list);
99 MIND_API AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
100 const AbstractBasePtrList &args_abs_list);
101 MIND_API AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
102 const AbstractBasePtrList &args_abs_list);
103 MIND_API AbstractBasePtr InferImplScatterSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
104 const AbstractBasePtrList &args_abs_list);
105 MIND_API AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
106 const AbstractBasePtrList &args_abs_list);
107 MIND_API AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
108 const AbstractBasePtrList &args_abs_list);
109 MIND_API AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
110 const AbstractBasePtrList &args_abs_list);
111 MIND_API AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
112 const AbstractBasePtrList &args_abs_list);
113 MIND_API AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
114 const AbstractBasePtrList &args_abs_list);
115 MIND_API AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
116 const AbstractBasePtrList &args_abs_list);
117 MIND_API AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
118 const AbstractBasePtrList &args_abs_list);
119 MIND_API AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
120 const AbstractBasePtrList &args_abs_list);
121 MIND_API AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
122 const AbstractBasePtrList &args_abs_list);
123 MIND_API AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
124 const AbstractBasePtrList &args_abs_list);
125 MIND_API AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
126 const AbstractBasePtrList &args_abs_list);
127 MIND_API AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
128 const AbstractBasePtrList &args_abs_list);
129 MIND_API AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
130 const AbstractBasePtrList &args_abs_list);
131 MIND_API AbstractBasePtr InferImplIsDimUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
132 const AbstractBasePtrList &args_abs_list);
133 MIND_API AbstractBasePtr InferImplIsShapeUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
134 const AbstractBasePtrList &args_abs_list);
135 MIND_API AbstractBasePtr InferImplIsElementUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
136 const AbstractBasePtrList &args_abs_list);
137 MIND_API AbstractBasePtr InferImplIsTensorBoolCond(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
138 const AbstractBasePtrList &args_abs_list);
139 MIND_API AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
140 const AbstractBasePtrList &args_abs_list);
141 MIND_API AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
142 const AbstractBasePtrList &args_abs_list);
143 MIND_API AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
144 const AbstractBasePtrList &args_abs_list);
145 MIND_API AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
146 const AbstractBasePtrList &args_abs_list);
147 MIND_API AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
148 const AbstractBasePtrList &args_abs_list);
149 MIND_API AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
150 const AbstractBasePtrList &args_abs_list);
151 MIND_API AbstractBasePtr InferImplTensorMove(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
152 const AbstractBasePtrList &args_abs_list);
153 MIND_API AbstractBasePtr InferImplRealInner(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
154 const AbstractBasePtrList &args_abs_list);
155 MIND_API AbstractBasePtr InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
156 const AbstractBasePtrList &args_abs_list);
157 MIND_API AbstractBasePtr InferImplMapTensorGetPermitFilterValue(const AnalysisEnginePtr &,
158 const PrimitivePtr &primitive,
159 const AbstractBasePtrList &args_abs_list);
160 MIND_API AbstractBasePtr InferImplMapTensorGetEvictFilterValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
161 const AbstractBasePtrList &args_abs_list);
162
163 template <typename T>
InferTupleOrListOrDictLen(const std::string & op_name,const AbstractBasePtrList & args_abs_list)164 AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_abs_list) {
165 // Inputs: a tuple or list or dict.
166 constexpr size_t len_input_size = 1;
167 CheckArgsSize(op_name, args_abs_list, len_input_size);
168 auto arg = CheckArg<T>(op_name, args_abs_list, 0);
169 auto abs = dyn_cast<AbstractSequence>(args_abs_list[0]);
170 if (abs != nullptr && abs->dynamic_len()) {
171 // If the sequence is dynamic length, return any value scalar.
172 return std::make_shared<AbstractScalar>(kValueAny, kInt64);
173 }
174 return std::make_shared<AbstractScalar>(SizeToLong(arg->size()));
175 }
176 #define REG_PRIM_INFER_FUNC(name, in_white_list) \
177 static auto helper_eval_##name = abstract::RegisterStandardPrimitiveEvalHelper( \
178 abstract::GetDeprecatedPrimitiveInferMapPtr(), prim::kPrim##name, InferImpl##name, nullptr, in_white_list);
179 } // namespace abstract
180 } // namespace mindspore
181 #endif // MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
182