• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
20 #define MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
21 #include <string>
22 #include <memory>
23 #include "abstract/abstract_value.h"
24 #include "abstract/param_validator.h"
25 #include "abstract/ops/primitive_infer_map.h"
26 namespace mindspore {
27 namespace abstract {
28 MIND_API AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
29                                          const AbstractBasePtrList &args_abs_list);
30 MIND_API AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
31                                          const AbstractBasePtrList &args_abs_list);
32 MIND_API AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
33                                               const AbstractBasePtrList &args_abs_list);
34 MIND_API AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
35                                       const AbstractBasePtrList &args_abs_list);
36 MIND_API AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
37                                         const AbstractBasePtrList &args_abs_list);
38 MIND_API AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
39                                          const AbstractBasePtrList &args_abs_list);
40 MIND_API AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
41                                             const AbstractBasePtrList &args_abs_list);
42 MIND_API AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
43                                              const AbstractBasePtrList &args_abs_list);
44 MIND_API AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
45                                             const AbstractBasePtrList &args_abs_list);
46 MIND_API AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &,
47                                            const AbstractBasePtrList &args_abs_list);
48 MIND_API AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
49                                                 const AbstractBasePtrList &args_abs_list);
50 MIND_API AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
51                                                 const AbstractBasePtrList &args_abs_list);
52 MIND_API AbstractBasePtr InferImplBroadcastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
53                                                  const AbstractBasePtrList &args_abs_list);
54 MIND_API AbstractBasePtr InferImplidentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
55                                            const AbstractBasePtrList &args_abs_list);
56 
57 MIND_API AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
58                                            const AbstractBasePtrList &args_abs_list);
59 MIND_API AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
60                                                      const AbstractBasePtrList &args_abs_list);
61 MIND_API AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
62                                                      const AbstractBasePtrList &args_abs_list);
63 MIND_API AbstractBasePtr InferImplMakeKeywordArg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
64                                                  const AbstractBasePtrList &args_abs_list);
65 MIND_API AbstractBasePtr InferImplExtractKeywordArg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
66                                                     const AbstractBasePtrList &args_abs_list);
67 MIND_API AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
68                                               const AbstractBasePtrList &args_abs_list);
69 MIND_API AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
70                                               const AbstractBasePtrList &args_abs_list);
71 MIND_API AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
72                                               const AbstractBasePtrList &args_abs_list);
73 MIND_API AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
74                                                 const AbstractBasePtrList &args_abs_list);
75 MIND_API AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
76                                             const AbstractBasePtrList &args_abs_list);
77 MIND_API AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
78                                            const AbstractBasePtrList &args_abs_list);
79 MIND_API AbstractBasePtr InferImplMutable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
80                                           const AbstractBasePtrList &args_abs_list);
81 MIND_API AbstractBasePtr InferImplGetGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
82                                           const AbstractBasePtrList &args_abs_list);
83 MIND_API AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
84                                              const AbstractBasePtrList &args_abs_list);
85 MIND_API AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
86                                                const AbstractBasePtrList &args_abs_list);
87 MIND_API AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
88                                          const AbstractBasePtrList &args_abs_list);
89 MIND_API AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
90                                               const AbstractBasePtrList &args_abs_list);
91 MIND_API AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
92                                         const AbstractBasePtrList &args_abs_list);
93 MIND_API AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
94                                                 const AbstractBasePtrList &args_abs_list);
95 MIND_API AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
96                                                      const AbstractBasePtrList &args_abs_list);
97 MIND_API AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
98                                                       const AbstractBasePtrList &args_abs_list);
99 MIND_API AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
100                                                          const AbstractBasePtrList &args_abs_list);
101 MIND_API AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
102                                                const AbstractBasePtrList &args_abs_list);
103 MIND_API AbstractBasePtr InferImplScatterSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
104                                              const AbstractBasePtrList &args_abs_list);
105 MIND_API AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
106                                       const AbstractBasePtrList &args_abs_list);
107 MIND_API AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
108                                                const AbstractBasePtrList &args_abs_list);
109 MIND_API AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
110                                               const AbstractBasePtrList &args_abs_list);
111 MIND_API AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
112                                                  const AbstractBasePtrList &args_abs_list);
113 MIND_API AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
114                                                              const AbstractBasePtrList &args_abs_list);
115 MIND_API AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
116                                           const AbstractBasePtrList &args_abs_list);
117 MIND_API AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
118                                             const AbstractBasePtrList &args_abs_list);
119 MIND_API AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
120                                                 const AbstractBasePtrList &args_abs_list);
121 MIND_API AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
122                                       const AbstractBasePtrList &args_abs_list);
123 MIND_API AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
124                                             const AbstractBasePtrList &args_abs_list);
125 MIND_API AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
126                                           const AbstractBasePtrList &args_abs_list);
127 MIND_API AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
128                                            const AbstractBasePtrList &args_abs_list);
129 MIND_API AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
130                                            const AbstractBasePtrList &args_abs_list);
131 MIND_API AbstractBasePtr InferImplIsDimUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
132                                                const AbstractBasePtrList &args_abs_list);
133 MIND_API AbstractBasePtr InferImplIsShapeUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
134                                                  const AbstractBasePtrList &args_abs_list);
135 MIND_API AbstractBasePtr InferImplIsElementUnknown(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
136                                                    const AbstractBasePtrList &args_abs_list);
137 MIND_API AbstractBasePtr InferImplIsTensorBoolCond(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
138                                                    const AbstractBasePtrList &args_abs_list);
139 MIND_API AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
140                                       const AbstractBasePtrList &args_abs_list);
141 MIND_API AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
142                                              const AbstractBasePtrList &args_abs_list);
143 MIND_API AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
144                                                const AbstractBasePtrList &args_abs_list);
145 MIND_API AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
146                                                 const AbstractBasePtrList &args_abs_list);
147 MIND_API AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
148                                        const AbstractBasePtrList &args_abs_list);
149 MIND_API AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
150                                             const AbstractBasePtrList &args_abs_list);
151 MIND_API AbstractBasePtr InferImplTensorMove(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
152                                              const AbstractBasePtrList &args_abs_list);
153 MIND_API AbstractBasePtr InferImplRealInner(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
154                                             const AbstractBasePtrList &args_abs_list);
155 MIND_API AbstractBasePtr InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
156                                                            const AbstractBasePtrList &args_abs_list);
157 MIND_API AbstractBasePtr InferImplMapTensorGetPermitFilterValue(const AnalysisEnginePtr &,
158                                                                 const PrimitivePtr &primitive,
159                                                                 const AbstractBasePtrList &args_abs_list);
160 MIND_API AbstractBasePtr InferImplMapTensorGetEvictFilterValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
161                                                                const AbstractBasePtrList &args_abs_list);
162 
163 template <typename T>
InferTupleOrListOrDictLen(const std::string & op_name,const AbstractBasePtrList & args_abs_list)164 AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_abs_list) {
165   // Inputs: a tuple or list or dict.
166   constexpr size_t len_input_size = 1;
167   CheckArgsSize(op_name, args_abs_list, len_input_size);
168   auto arg = CheckArg<T>(op_name, args_abs_list, 0);
169   auto abs = dyn_cast<AbstractSequence>(args_abs_list[0]);
170   if (abs != nullptr && abs->dynamic_len()) {
171     // If the sequence is dynamic length, return any value scalar.
172     return std::make_shared<AbstractScalar>(kValueAny, kInt64);
173   }
174   return std::make_shared<AbstractScalar>(SizeToLong(arg->size()));
175 }
176 #define REG_PRIM_INFER_FUNC(name, in_white_list)                                  \
177   static auto helper_eval_##name = abstract::RegisterStandardPrimitiveEvalHelper( \
178     abstract::GetDeprecatedPrimitiveInferMapPtr(), prim::kPrim##name, InferImpl##name, nullptr, in_white_list);
179 }  // namespace abstract
180 }  // namespace mindspore
181 #endif  // MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
182