1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
20 #define MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
21 #include <string>
22 #include <memory>
23 #include "abstract/abstract_value.h"
24 #include "abstract/param_validator.h"
25 #include "base/core_ops.h"
26 namespace mindspore {
27 namespace abstract {
28 AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
29 const AbstractBasePtrList &args_spec_list);
30 AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
31 const AbstractBasePtrList &args_spec_list);
32 AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,
33 const AbstractBasePtrList &args_spec_list);
34 AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
35 const AbstractBasePtrList &args_spec_list);
36 AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,
37 const AbstractBasePtrList &args_spec_list);
38 AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
39 const AbstractBasePtrList &args_spec_list);
40 AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
41 const AbstractBasePtrList &args_spec_list);
42 AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
43 const AbstractBasePtrList &args_spec_list);
44 AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
45 const AbstractBasePtrList &args_spec_list);
46 AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
47 const AbstractBasePtrList &args_spec_list);
48 AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
49 const AbstractBasePtrList &args_spec_list);
50 AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
51 const AbstractBasePtrList &args_spec_list);
52 AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
53 const AbstractBasePtrList &args_spec_list);
54 AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
55 const AbstractBasePtrList &args_spec_list);
56 AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
57 const AbstractBasePtrList &args_spec_list);
58 AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
59 const AbstractBasePtrList &args_spec_list);
60 AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
61 const AbstractBasePtrList &args_spec_list);
62 AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
63 const AbstractBasePtrList &args_spec_list);
64 AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
65 const AbstractBasePtrList &args_spec_list);
66
67 AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
68 const AbstractBasePtrList &args_spec_list);
69 AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
70 const AbstractBasePtrList &args_spec_list);
71 AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
72 const AbstractBasePtrList &args_spec_list);
73 AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
74 const AbstractBasePtrList &args_spec_list);
75
76 AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
77 const AbstractBasePtrList &args_spec_list);
78 AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
79 const AbstractBasePtrList &args_spec_list);
80 AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
81 const AbstractBasePtrList &args_spec_list);
82 AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
83 const AbstractBasePtrList &args_spec_list);
84
85 AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
86 const AbstractBasePtrList &args_spec_list);
87 AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
88 const AbstractBasePtrList &args_spec_list);
89 AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
90 const AbstractBasePtrList &args_spec_list);
91 AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
92 const AbstractBasePtrList &args_spec_list);
93 AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
94 const AbstractBasePtrList &args_spec_list);
95 AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
96 const AbstractBasePtrList &args_spec_list);
97 AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
98 const AbstractBasePtrList &args_spec_list);
99 AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
100 const AbstractBasePtrList &args_spec_list);
101 AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
102 const AbstractBasePtrList &args_spec_list);
103 AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
104 const AbstractBasePtrList &args_spec_list);
105 AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
106 const AbstractBasePtrList &args_spec_list);
107 AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
108 const AbstractBasePtrList &args_spec_list);
109 AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
110 const AbstractBasePtrList &args_spec_list);
111 AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
112 const AbstractBasePtrList &args_spec_list);
113 AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
114 const AbstractBasePtrList &args_spec_list);
115 AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
116 const AbstractBasePtrList &args_spec_list);
117 AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
118 const AbstractBasePtrList &args_spec_list);
119 AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
120 const AbstractBasePtrList &args_spec_list);
121 AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
122 const AbstractBasePtrList &args_spec_list);
123 AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
124 const AbstractBasePtrList &args_spec_list);
125 AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
126 const AbstractBasePtrList &args_spec_list);
127 AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
128 const AbstractBasePtrList &args_spec_list);
129 AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
130 const AbstractBasePtrList &args_spec_list);
131
132 AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
133 const AbstractBasePtrList &args_spec_list);
134 AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
135 const AbstractBasePtrList &args_spec_list);
136 AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
137 const AbstractBasePtrList &args_spec_list);
138 AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
139 const AbstractBasePtrList &args_spec_list);
140 AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
141 const AbstractBasePtrList &args_spec_list);
142 AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
143 const AbstractBasePtrList &args_spec_list);
144 AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
145 const AbstractBasePtrList &args_spec_list);
146 AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
147 const AbstractBasePtrList &args_spec_list);
148 AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
149 const AbstractBasePtrList &args_spec_list);
150 AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
151 const AbstractBasePtrList &args_spec_list);
152 AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
153 const AbstractBasePtrList &args_spec_list);
154 AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
155 const AbstractBasePtrList &args_spec_list);
156 AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
157 const AbstractBasePtrList &args_spec_list);
158 AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
159 const AbstractBasePtrList &args_spec_list);
160 AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
161 const AbstractBasePtrList &args_spec_list);
162
163 AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
164 const AbstractBasePtrList &args_spec_list);
165 AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
166 const AbstractBasePtrList &args_spec_list);
167 AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
168 const AbstractBasePtrList &args_spec_list);
169 AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
170 const AbstractBasePtrList &args_spec_list);
171 AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
172 const AbstractBasePtrList &args_spec_list);
173
174 AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
175 const AbstractBasePtrList &args_spec_list);
176 AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
177 const AbstractBasePtrList &args_spec_list);
178 AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
179 const AbstractBasePtrList &args_spec_list);
180 AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
181 const AbstractBasePtrList &args_spec_list);
182 AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
183 const AbstractBasePtrList &args_spec_list);
184 AbstractBasePtr InferImplScatterSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
185 const AbstractBasePtrList &args_spec_list);
186 AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
187 const AbstractBasePtrList &args_spec_list);
188 AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
189 const AbstractBasePtrList &args_spec_list);
190 AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
191 const AbstractBasePtrList &args_spec_list);
192 AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
193 const AbstractBasePtrList &args_spec_list);
194 AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
195 const AbstractBasePtrList &args_spec_list);
196 AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
197 const AbstractBasePtrList &args_spec_list);
198 AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
199 const AbstractBasePtrList &args_spec_list);
200 AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
201 const AbstractBasePtrList &args_spec_list);
202 AbstractBasePtr InferImplPadAndShift(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
203 const AbstractBasePtrList &args_spec_list);
204 AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
205 const AbstractBasePtrList &args_spec_list);
206 AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
207 const AbstractBasePtrList &args_spec_list);
208 AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
209 const AbstractBasePtrList &args_spec_list);
210 AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
211 const AbstractBasePtrList &args_spec_list);
212 AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
213 const AbstractBasePtrList &args_spec_list);
214 AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
215 const AbstractBasePtrList &args_spec_list);
216 AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
217 const AbstractBasePtrList &args_spec_list);
218 AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
219 const AbstractBasePtrList &args_spec_list);
220 AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
221 const AbstractBasePtrList &args_spec_list);
222 AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
223 const AbstractBasePtrList &args_spec_list);
224 AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
225 const AbstractBasePtrList &args_spec_list);
226 AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
227 const AbstractBasePtrList &args_spec_list);
228 AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
229 const AbstractBasePtrList &args_spec_list);
230 AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
231 const AbstractBasePtrList &args_spec_list);
232 AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
233 const AbstractBasePtrList &args_spec_list);
234 AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
235 const AbstractBasePtrList &args_spec_list);
236 AbstractBasePtr InferImplReduceFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
237 const AbstractBasePtrList &args_spec_list);
238 AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
239 const AbstractBasePtrList &args_spec_list);
240 AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
241 const AbstractBasePtrList &args_spec_list);
242 AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
243 const AbstractBasePtrList &args_spec_list);
244 AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
245 const AbstractBasePtrList &args_spec_list);
246 AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
247 const AbstractBasePtrList &args_spec_list);
248 AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
249 const AbstractBasePtrList &args_spec_list);
250 AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
251 const AbstractBasePtrList &args_spec_list);
252 AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
253 const AbstractBasePtrList &args_spec_list);
254 AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
255 const AbstractBasePtrList &args_spec_list);
256 AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
257 const AbstractBasePtrList &args_spec_list);
258 AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
259 const AbstractBasePtrList &args_spec_list);
260 AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
261 const AbstractBasePtrList &args_spec_list);
262 AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
263 const AbstractBasePtrList &args_spec_list);
264 AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
265 const AbstractBasePtrList &args_spec_list);
266 AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
267 const AbstractBasePtrList &args_spec_list);
268 AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
269 const AbstractBasePtrList &args_spec_list);
270 AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
271 const AbstractBasePtrList &args_spec_list);
272 AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
273 const AbstractBasePtrList &args_spec_list);
274 AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
275 const AbstractBasePtrList &args_spec_list);
276 AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
277 const AbstractBasePtrList &args_spec_list);
278 AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
279 const AbstractBasePtrList &args_spec_list);
280 template <typename T>
InferTupleOrListOrDictLen(const std::string & op_name,const AbstractBasePtrList & args_spec_list)281 AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
282 // Inputs: a tuple or list or dict.
283 CheckArgsSize(op_name, args_spec_list, 1);
284 auto arg = CheckArg<T>(op_name, args_spec_list, 0);
285 return std::make_shared<AbstractScalar>(SizeToLong(arg->size()));
286 }
287 } // namespace abstract
288 } // namespace mindspore
289 #endif // MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
290