• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
20 #define MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
21 #include <string>
22 #include <memory>
23 #include "abstract/abstract_value.h"
24 #include "abstract/param_validator.h"
25 #include "base/core_ops.h"
26 namespace mindspore {
27 namespace abstract {
28 AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
29                                 const AbstractBasePtrList &args_spec_list);
30 AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
31                                 const AbstractBasePtrList &args_spec_list);
32 AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,
33                                      const AbstractBasePtrList &args_spec_list);
34 AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
35                              const AbstractBasePtrList &args_spec_list);
36 AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,
37                                const AbstractBasePtrList &args_spec_list);
38 AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
39                                 const AbstractBasePtrList &args_spec_list);
40 AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
41                                    const AbstractBasePtrList &args_spec_list);
42 AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
43                                     const AbstractBasePtrList &args_spec_list);
44 AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
45                                  const AbstractBasePtrList &args_spec_list);
46 AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
47                                      const AbstractBasePtrList &args_spec_list);
48 AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
49                                          const AbstractBasePtrList &args_spec_list);
50 AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
51                                    const AbstractBasePtrList &args_spec_list);
52 AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
53                                 const AbstractBasePtrList &args_spec_list);
54 AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
55                                  const AbstractBasePtrList &args_spec_list);
56 AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
57                                      const AbstractBasePtrList &args_spec_list);
58 AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
59                                   const AbstractBasePtrList &args_spec_list);
60 AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
61                                       const AbstractBasePtrList &args_spec_list);
62 AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
63                                   const AbstractBasePtrList &args_spec_list);
64 AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
65                                  const AbstractBasePtrList &args_spec_list);
66 
67 AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
68                                       const AbstractBasePtrList &args_spec_list);
69 AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
70                              const AbstractBasePtrList &args_spec_list);
71 AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
72                               const AbstractBasePtrList &args_spec_list);
73 AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
74                                   const AbstractBasePtrList &args_spec_list);
75 
76 AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
77                                        const AbstractBasePtrList &args_spec_list);
78 AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
79                                        const AbstractBasePtrList &args_spec_list);
80 AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
81                                         const AbstractBasePtrList &args_spec_list);
82 AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
83                                const AbstractBasePtrList &args_spec_list);
84 
85 AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
86                                    const AbstractBasePtrList &args_spec_list);
87 AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
88                                   const AbstractBasePtrList &args_spec_list);
89 AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
90                                   const AbstractBasePtrList &args_spec_list);
91 AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
92                                             const AbstractBasePtrList &args_spec_list);
93 AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
94                                             const AbstractBasePtrList &args_spec_list);
95 AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
96                                             const AbstractBasePtrList &args_spec_list);
97 AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
98                                    const AbstractBasePtrList &args_spec_list);
99 AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
100                                    const AbstractBasePtrList &args_spec_list);
101 AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
102                                       const AbstractBasePtrList &args_spec_list);
103 AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
104                                     const AbstractBasePtrList &args_spec_list);
105 AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
106                                       const AbstractBasePtrList &args_spec_list);
107 AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
108                                      const AbstractBasePtrList &args_spec_list);
109 AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
110                                       const AbstractBasePtrList &args_spec_list);
111 AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
112                                      const AbstractBasePtrList &args_spec_list);
113 AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
114                                      const AbstractBasePtrList &args_spec_list);
115 AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
116                                      const AbstractBasePtrList &args_spec_list);
117 AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
118                                      const AbstractBasePtrList &args_spec_list);
119 AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
120                                        const AbstractBasePtrList &args_spec_list);
121 AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
122                                     const AbstractBasePtrList &args_spec_list);
123 AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
124                                   const AbstractBasePtrList &args_spec_list);
125 AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
126                                  const AbstractBasePtrList &args_spec_list);
127 AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
128                                   const AbstractBasePtrList &args_spec_list);
129 AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
130                                   const AbstractBasePtrList &args_spec_list);
131 
132 AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
133                                     const AbstractBasePtrList &args_spec_list);
134 AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
135                                     const AbstractBasePtrList &args_spec_list);
136 AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
137                                 const AbstractBasePtrList &args_spec_list);
138 AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
139                                     const AbstractBasePtrList &args_spec_list);
140 AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
141                                  const AbstractBasePtrList &args_spec_list);
142 AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
143                                    const AbstractBasePtrList &args_spec_list);
144 AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
145                                      const AbstractBasePtrList &args_spec_list);
146 AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
147                                       const AbstractBasePtrList &args_spec_list);
148 AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
149                                 const AbstractBasePtrList &args_spec_list);
150 AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
151                                      const AbstractBasePtrList &args_spec_list);
152 AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
153                                const AbstractBasePtrList &args_spec_list);
154 AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
155                                           const AbstractBasePtrList &args_spec_list);
156 AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
157                                                const AbstractBasePtrList &args_spec_list);
158 AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
159                                                 const AbstractBasePtrList &args_spec_list);
160 AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
161                                                    const AbstractBasePtrList &args_spec_list);
162 
163 AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
164                                        const AbstractBasePtrList &args_spec_list);
165 AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
166                                             const AbstractBasePtrList &args_spec_list);
167 AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
168                                              const AbstractBasePtrList &args_spec_list);
169 AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
170                                                 const AbstractBasePtrList &args_spec_list);
171 AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
172                                       const AbstractBasePtrList &args_spec_list);
173 
174 AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
175                                     const AbstractBasePtrList &args_spec_list);
176 AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
177                                 const AbstractBasePtrList &args_spec_list);
178 AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
179                                           const AbstractBasePtrList &args_spec_list);
180 AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
181                                        const AbstractBasePtrList &args_spec_list);
182 AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
183                                     const AbstractBasePtrList &args_spec_list);
184 AbstractBasePtr InferImplScatterSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
185                                     const AbstractBasePtrList &args_spec_list);
186 AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
187                                        const AbstractBasePtrList &args_spec_list);
188 AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
189                              const AbstractBasePtrList &args_spec_list);
190 AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
191                                  const AbstractBasePtrList &args_spec_list);
192 AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
193                                       const AbstractBasePtrList &args_spec_list);
194 AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
195                                      const AbstractBasePtrList &args_spec_list);
196 AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
197                                         const AbstractBasePtrList &args_spec_list);
198 AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
199                                      const AbstractBasePtrList &args_spec_list);
200 AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
201                                                const AbstractBasePtrList &args_spec_list);
202 AbstractBasePtr InferImplPadAndShift(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
203                                      const AbstractBasePtrList &args_spec_list);
204 AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
205                                        const AbstractBasePtrList &args_spec_list);
206 AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
207                                   const AbstractBasePtrList &args_spec_list);
208 AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
209                                       const AbstractBasePtrList &args_spec_list);
210 AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
211                                          const AbstractBasePtrList &args_spec_list);
212 AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
213                                                     const AbstractBasePtrList &args_spec_list);
214 AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
215                                  const AbstractBasePtrList &args_spec_list);
216 AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
217                                    const AbstractBasePtrList &args_spec_list);
218 AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
219                                    const AbstractBasePtrList &args_spec_list);
220 AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
221                                    const AbstractBasePtrList &args_spec_list);
222 AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
223                                        const AbstractBasePtrList &args_spec_list);
224 AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
225                              const AbstractBasePtrList &args_spec_list);
226 AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
227                                    const AbstractBasePtrList &args_spec_list);
228 AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
229                                  const AbstractBasePtrList &args_spec_list);
230 AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
231                                      const AbstractBasePtrList &args_spec_list);
232 AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
233                                          const AbstractBasePtrList &args_spec_list);
234 AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
235                              const AbstractBasePtrList &args_spec_list);
236 AbstractBasePtr InferImplReduceFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
237                                     const AbstractBasePtrList &args_spec_list);
238 AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
239                               const AbstractBasePtrList &args_spec_list);
240 AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
241                                  const AbstractBasePtrList &args_spec_list);
242 AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
243                                   const AbstractBasePtrList &args_spec_list);
244 AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
245                                   const AbstractBasePtrList &args_spec_list);
246 AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
247                                     const AbstractBasePtrList &args_spec_list);
248 AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
249                                                   const AbstractBasePtrList &args_spec_list);
250 AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
251                              const AbstractBasePtrList &args_spec_list);
252 AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
253                                     const AbstractBasePtrList &args_spec_list);
254 AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
255                                const AbstractBasePtrList &args_spec_list);
256 AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
257                                       const AbstractBasePtrList &args_spec_list);
258 AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
259                                 const AbstractBasePtrList &args_spec_list);
260 AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
261                                const AbstractBasePtrList &args_spec_list);
262 AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
263                                 const AbstractBasePtrList &args_spec_list);
264 AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
265                                      const AbstractBasePtrList &args_spec_list);
266 AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
267                               const AbstractBasePtrList &args_spec_list);
268 AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
269                                          const AbstractBasePtrList &args_spec_list);
270 AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
271                               const AbstractBasePtrList &args_spec_list);
272 AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
273                               const AbstractBasePtrList &args_spec_list);
274 AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
275                                       const AbstractBasePtrList &args_spec_list);
276 AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
277                                    const AbstractBasePtrList &args_spec_list);
278 AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
279                                           const AbstractBasePtrList &args_spec_list);
280 template <typename T>
InferTupleOrListOrDictLen(const std::string & op_name,const AbstractBasePtrList & args_spec_list)281 AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
282   // Inputs: a tuple or list or dict.
283   CheckArgsSize(op_name, args_spec_list, 1);
284   auto arg = CheckArg<T>(op_name, args_spec_list, 0);
285   return std::make_shared<AbstractScalar>(SizeToLong(arg->size()));
286 }
287 }  // namespace abstract
288 }  // namespace mindspore
289 #endif  // MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
290