1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "abstract/primitive_infer_map.h"
20 #include <map>
21 #include <string>
22 #include <vector>
23 #include "ops/exp.h"
24 #include "ops/log.h"
25 #include "ops/reciprocal.h"
26 #include "ops/real_div.h"
27 #include "ops/add.h"
28 #include "ops/equal.h"
29 #include "ops/not_equal.h"
30 #include "ops/neg.h"
31 #include "ops/mul.h"
32 #include "ops/sub.h"
33 #include "ops/strided_slice.h"
34 #include "ops/reduce_sum.h"
35 #include "abstract/abstract_function.h"
36 #include "abstract/infer_functions.h"
37 #include "utils/ms_context.h"
38 #include "ops/tile.h"
39
40 namespace mindspore {
41 namespace abstract {
GetDependsFormMap(const CNodePtr & cnode)42 std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
43 const auto kOneHot = prim::kPrimOneHot->name();
44 const auto kDropoutGenMask = prim::kPrimDropoutGenMask->name();
45 const auto kTranspose = prim::kPrimTranspose->name();
46 const auto kReduceSum = prim::kPrimReduceSum->name();
47 const auto kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
48 const auto kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
49 const auto kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
50 const auto kGather = prim::kPrimGather->name();
51 const auto kGatherV2 = prim::kPrimGatherV2->name();
52 const auto kDynamicShape = prim::kPrimDynamicShape->name();
53 const auto kRange = prim::kPrimRange->name();
54 const auto kConv2DBackpropFilter = prim::kPrimConv2DBackpropFilter->name();
55 const auto kConv2DBackpropInput = prim::kPrimConv2DBackpropInput->name();
56 // common dynamic shape depends
57 static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = {{kUnsortedSegmentSum, {2}},
58 {kUnsortedSegmentMin, {2}},
59 {kUnsortedSegmentMax, {2}},
60 {kGather, {2}},
61 {kGatherV2, {2}},
62 {kDynamicShape, {0}},
63 {kRange, {0, 1, 2}},
64 {kConv2DBackpropFilter, {2}},
65 {kConv2DBackpropInput, {2}},
66 {kOneHot, {1, 3}},
67 {kDropoutGenMask, {0}}};
68
69 auto ms_context = MsContext::GetInstance();
70 MS_EXCEPTION_IF_NULL(ms_context);
71 auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
72 if (device == kAscendDevice) {
73 (void)dynamic_shape_depends.emplace(kReduceSum, std::vector<int64_t>{1});
74 (void)dynamic_shape_depends.emplace(kTranspose, std::vector<int64_t>{1});
75 }
76
77 MS_EXCEPTION_IF_NULL(cnode);
78 if (cnode->inputs().empty()) {
79 MS_LOG(EXCEPTION) << "Invalid inputs";
80 }
81 auto primitive = GetValueNode<PrimitivePtr>(cnode->inputs()[0]);
82 MS_EXCEPTION_IF_NULL(primitive);
83 auto iter = dynamic_shape_depends.find(primitive->ToString());
84 if (iter != dynamic_shape_depends.end()) {
85 return iter->second;
86 }
87 return {};
88 }
89
GetPrimitiveToEvalImplMap()90 PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
91 static PrimitiveEvalImplMap prim_eval_implement_map = {
92 // Statements
93 {prim::kPrimReturn, {InferImplReturn, nullptr, true}},
94 {prim::kPrimSwitch, {InferImplSwitch, nullptr, true}},
95 {prim::kPrimSwitchLayer, {InferImplSwitchLayer, nullptr, true}},
96 {prim::kPrimIs_, {InferImplIs_, nullptr, true}},
97 {prim::kPrimIsNot, {InferImplIsNot, nullptr, true}},
98 {prim::kPrimInDict, {InferImplInDict, nullptr, true}},
99 {prim::kPrimNotInDict, {InferImplNotInDict, nullptr, true}},
100 {prim::kPrimIsConsant, {InferImplIsConstant, nullptr, true}},
101 // Maths
102 {prim::kPrimMatMul, {InferImplMatMul, nullptr, true}},
103 {prim::kPrimBatchMatMul, {InferImplBatchMatMul, nullptr, true}},
104 {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, nullptr, true}},
105 {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, nullptr, true}},
106 {prim::kPrimSqrt, {InferImplSqrt, nullptr, true}},
107 // Array
108 {prim::kPrimRange, {InferImplRange, nullptr, true}},
109 {prim::kPrimScalarToArray, {InferImplScalarToArray, nullptr, true}},
110 {prim::kPrimArrayToScalar, {InferImplArrayToScalar, nullptr, true}},
111 {prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}},
112 {prim::kPrimUnique, {InferImplUnique, nullptr, true}},
113 {prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}},
114 {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}},
115 {prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}},
116 {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}},
117 {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, nullptr, true}},
118 {prim::kPrimScatterAdd, {InferImplScatterAdd, nullptr, true}},
119 {prim::kPrimScatterSub, {InferImplScatterSub, nullptr, true}},
120 {prim::kPrimSubAndFilter, {InferImplSubAndFilter, nullptr, true}},
121 {prim::kPrimScatterUpdate, {InferImplScatterUpdate, nullptr, true}},
122 {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, nullptr, true}},
123 {prim::kPrimDynamicAssign, {InferImplDynamicAssign, nullptr, true}},
124 {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}},
125 {prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}},
126 {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}},
127 {prim::kPrimDynamicStitch, {InferImplDynamicStitch, nullptr, true}},
128 {prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}},
129 {prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}},
130 {prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}},
131 {prim::kPrimSplit, {InferImplSplit, nullptr, true}},
132 {prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}},
133 {prim::kPrimSort, {InferImplSort, nullptr, true}},
134 {prim::kPrimMaskedSelect, {InferImplMaskedSelect, nullptr, true}},
135 {prim::kPrimTensorCopySlices, {InferImplTensorCopySlices, nullptr, true}},
136 // Structure
137 {prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}},
138 {prim::kPrimMakeList, {InferImplMakeList, nullptr, true}},
139 {prim::kPrimMakeDict, {InferImplMakeDict, nullptr, true}},
140 {prim::kPrimMakeSlice, {InferImplMakeSlice, nullptr, true}},
141 {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, nullptr, true}},
142 {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, nullptr, true}},
143 {prim::kPrimTupleGetItem, {InferImplTupleGetItem, nullptr, true}},
144 {prim::kPrimListGetItem, {InferImplListGetItem, nullptr, true}},
145 {prim::kPrimTupleSetItem, {InferImplTupleSetItem, nullptr, true}},
146 {prim::kPrimListSetItem, {InferImplListSetItem, nullptr, true}},
147 {prim::kPrimDictGetItem, {InferImplDictGetItem, nullptr, true}},
148 {prim::kPrimDictSetItem, {InferImplDictSetItem, nullptr, true}},
149 {prim::kPrimDictGetKeys, {InferImplDictGetKeys, nullptr, true}},
150 {prim::kPrimDictGetValues, {InferImplDictGetValues, nullptr, true}},
151 {prim::kPrimListAppend, {InferImplListAppend, nullptr, true}},
152 {prim::kPrimTupleLen, {InferImplTupleLen, nullptr, true}},
153 {prim::kPrimListLen, {InferImplListLen, nullptr, true}},
154 {prim::kPrimArrayLen, {InferImplArrayLen, nullptr, true}},
155 // NN
156 {prim::kPrimPooling, {InferImplPooling, nullptr, true}},
157 {prim::kPrimPoolingGrad, {InferImplPoolingGrad, nullptr, true}},
158 {prim::kPrimBatchNorm, {InferImplBatchNorm, nullptr, true}},
159 {prim::kPrimConv2D, {InferImplConv2D, nullptr, true}},
160 {prim::kPrimBiasAdd, {InferImplBiasAdd, nullptr, true}},
161 {prim::kPrimBpropCut, {InferImplBpropCut, nullptr, true}},
162 {prim::kPrimDropout, {InferImplDropout, nullptr, true}},
163 {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, nullptr, true}},
164 {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, nullptr, true}},
165 {prim::kPrimSGD, {InferImplSGD, nullptr, true}},
166 {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, nullptr, true}},
167 {prim::kPrimHSigmoid, {InferImplHSigmoid, nullptr, true}},
168 {prim::kPrimHSigmoidGrad, {InferImplHSigmoidGrad, nullptr, true}},
169 // Others
170 {prim::kPrimIdentity, {InferImplIdentity, nullptr, true}},
171 {prim::kPrimLoad, {InferImplLoad, nullptr, true}},
172 // Set impl to null as it will use PartialEvaluator;
173 {prim::kPrimPartial, {nullptr, nullptr, true}},
174 {prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}},
175 {prim::kPrimEnvSetItem, {InferImplEnvSetItem, nullptr, true}},
176 {prim::kPrimEnvAdd, {InferImplEnvAdd, nullptr, true}},
177 {prim::kPrimMakeRefKey, {InferImplMakeRefKey, nullptr, true}},
178 {prim::kPrimMakeRef, {InferImplMakeRef, nullptr, true}},
179 {prim::kPrimGetRefKey, {InferImplGetRefKey, nullptr, true}},
180 {prim::kPrimGetRefValue, {InferImplGetRefValue, nullptr, true}},
181 {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}},
182 {prim::kPrimDepend, {InferImplDepend, nullptr, true}},
183 {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}},
184 // Debug
185 {prim::kPrimDebug, {InferImplDebug, nullptr, true}},
186 // Dynamic shape testing
187 {prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, nullptr, true}},
188 // SparseTensor
189 {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, nullptr, true}},
190 {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, nullptr, true}},
191 {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, nullptr, true}},
192 {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, nullptr, true}},
193 // RowTensor
194 {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, nullptr, true}},
195
196 {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, nullptr, true}},
197 {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, nullptr, true}},
198 {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, nullptr, true}},
199 {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, nullptr, false}},
200 // Comm Ops
201 {prim::kPrimAllSwap, {InferImplAllSwap, nullptr, true}},
202 {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, nullptr, true}},
203 {prim::kPrimFusedPushWeight, {nullptr, nullptr, true}},
204 {prim::kPrimFusedPullWeight, {nullptr, nullptr, true}},
205 };
206 return prim_eval_implement_map;
207 }
208
GetPrimitiveToBackendEvalImplMap()209 PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
210 static PrimitiveEvalImplMap prim_backend_eval_implement_map = {
211 {prim::kPrimMul, {ops::MulInfer, nullptr, true}},
212 {prim::kPrimAdd, {ops::AddInfer, nullptr, false}},
213 {prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
214 {prim::kPrimSub, {ops::SubInfer, nullptr, false}},
215 {prim::kPrimNeg, {ops::NegInfer, nullptr, false}},
216 {prim::kPrimTile, {ops::TileInfer, nullptr, true}},
217 {prim::kPrimEqual, {ops::EqualInfer, nullptr, true}},
218 {prim::kPrimNotEqual, {ops::NotEqualInfer, nullptr, true}},
219 {prim::kPrimLog, {ops::LogInfer, nullptr, true}},
220 {prim::kPrimReciprocal, {ops::ReciprocalInfer, nullptr, true}},
221 {prim::kPrimReduceSum, {ops::ReduceSumInfer, nullptr, true}},
222 {prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
223 {prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
224 {prim::kPrimReduceAny, {InferImplReduceFunc, nullptr, true}},
225 {prim::kPrimReduceMax, {InferImplReduceFunc, nullptr, true}},
226 {prim::kPrimReduceMin, {InferImplReduceFunc, nullptr, true}},
227 {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, nullptr, true}},
228 {prim::kPrimReduceScatter, {InferImplReduceScatter, nullptr, true}},
229 {prim::kPrimCast, {InferImplCast, nullptr, true}},
230 {prim::kPrimExp, {ops::ExpInfer, nullptr, true}},
231 {prim::kPrimExpandDims, {InferImplExpandDims, nullptr, true}},
232 {prim::kPrimAllReduce, {InferImplAllReduce, nullptr, true}},
233 {prim::kPrimBroadcast, {InferImplBroadcast, nullptr, true}},
234 {prim::kPrimAllGather, {InferImplAllGather, nullptr, true}},
235 {prim::kPrimMinimum, {InferImplMinimum, nullptr, true}},
236 {prim::kPrimDivNoNan, {InferImplDivNoNan, nullptr, true}},
237 {prim::kPrimLinSpace, {InferImplLinSpace, nullptr, true}},
238
239 {prim::kPrimLess, {InferImplLess, nullptr, true}},
240 {prim::kPrimStack, {InferImplStack, nullptr, true}},
241 {prim::kPrimPad, {InferImplPad, nullptr, true}},
242 {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}},
243 {prim::kPrimDiv, {InferImplDiv, nullptr, true}},
244 {prim::kPrimRealDiv, {ops::RealDivInfer, nullptr, false}},
245 {prim::kPrimTranspose, {InferImplTranspose, nullptr, true}},
246 {prim::kPrimStridedSlice, {ops::StridedSliceInfer, nullptr, true}},
247 {prim::kPrimReshape, {InferImplReshape, nullptr, true}},
248 {prim::kPrimConcat, {InferImplConcat, nullptr, true}},
249 {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}},
250 {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, nullptr, true}},
251 {prim::kPrimTransData, {InferImplTransData, nullptr, true}},
252 };
253 return prim_backend_eval_implement_map;
254 }
255
GetPrimitiveInferImpl(const PrimitivePtr & primitive)256 StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
257 MS_EXCEPTION_IF_NULL(primitive);
258 auto iter = GetPrimitiveToEvalImplMap().find(primitive);
259 if (iter == GetPrimitiveToEvalImplMap().end()) {
260 return {nullptr, nullptr, false};
261 }
262 return iter->second;
263 }
264
RegisterStandardPrimitiveImpl(const PrimitivePtr & primitive,const StandardPrimitiveImplReg & impl_reg)265 void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {
266 auto &prim_eval_map = GetPrimitiveToEvalImplMap();
267 prim_eval_map[primitive] = impl_reg;
268 }
269 } // namespace abstract
270 } // namespace mindspore
271