• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "abstract/primitive_infer_map.h"
20 #include <map>
21 #include <string>
22 #include <vector>
23 #include "ops/exp.h"
24 #include "ops/log.h"
25 #include "ops/reciprocal.h"
26 #include "ops/real_div.h"
27 #include "ops/add.h"
28 #include "ops/equal.h"
29 #include "ops/not_equal.h"
30 #include "ops/neg.h"
31 #include "ops/mul.h"
32 #include "ops/sub.h"
33 #include "ops/strided_slice.h"
34 #include "ops/reduce_sum.h"
35 #include "abstract/abstract_function.h"
36 #include "abstract/infer_functions.h"
37 #include "utils/ms_context.h"
38 #include "ops/tile.h"
39 
40 namespace mindspore {
41 namespace abstract {
GetDependsFormMap(const CNodePtr & cnode)42 std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
43   const auto kOneHot = prim::kPrimOneHot->name();
44   const auto kDropoutGenMask = prim::kPrimDropoutGenMask->name();
45   const auto kTranspose = prim::kPrimTranspose->name();
46   const auto kReduceSum = prim::kPrimReduceSum->name();
47   const auto kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
48   const auto kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
49   const auto kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
50   const auto kGather = prim::kPrimGather->name();
51   const auto kGatherV2 = prim::kPrimGatherV2->name();
52   const auto kDynamicShape = prim::kPrimDynamicShape->name();
53   const auto kRange = prim::kPrimRange->name();
54   const auto kConv2DBackpropFilter = prim::kPrimConv2DBackpropFilter->name();
55   const auto kConv2DBackpropInput = prim::kPrimConv2DBackpropInput->name();
56   // common dynamic shape depends
57   static std::map<std::string, std::vector<int64_t>> dynamic_shape_depends = {{kUnsortedSegmentSum, {2}},
58                                                                               {kUnsortedSegmentMin, {2}},
59                                                                               {kUnsortedSegmentMax, {2}},
60                                                                               {kGather, {2}},
61                                                                               {kGatherV2, {2}},
62                                                                               {kDynamicShape, {0}},
63                                                                               {kRange, {0, 1, 2}},
64                                                                               {kConv2DBackpropFilter, {2}},
65                                                                               {kConv2DBackpropInput, {2}},
66                                                                               {kOneHot, {1, 3}},
67                                                                               {kDropoutGenMask, {0}}};
68 
69   auto ms_context = MsContext::GetInstance();
70   MS_EXCEPTION_IF_NULL(ms_context);
71   auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
72   if (device == kAscendDevice) {
73     (void)dynamic_shape_depends.emplace(kReduceSum, std::vector<int64_t>{1});
74     (void)dynamic_shape_depends.emplace(kTranspose, std::vector<int64_t>{1});
75   }
76 
77   MS_EXCEPTION_IF_NULL(cnode);
78   if (cnode->inputs().empty()) {
79     MS_LOG(EXCEPTION) << "Invalid inputs";
80   }
81   auto primitive = GetValueNode<PrimitivePtr>(cnode->inputs()[0]);
82   MS_EXCEPTION_IF_NULL(primitive);
83   auto iter = dynamic_shape_depends.find(primitive->ToString());
84   if (iter != dynamic_shape_depends.end()) {
85     return iter->second;
86   }
87   return {};
88 }
89 
GetPrimitiveToEvalImplMap()90 PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
91   static PrimitiveEvalImplMap prim_eval_implement_map = {
92     // Statements
93     {prim::kPrimReturn, {InferImplReturn, nullptr, true}},
94     {prim::kPrimSwitch, {InferImplSwitch, nullptr, true}},
95     {prim::kPrimSwitchLayer, {InferImplSwitchLayer, nullptr, true}},
96     {prim::kPrimIs_, {InferImplIs_, nullptr, true}},
97     {prim::kPrimIsNot, {InferImplIsNot, nullptr, true}},
98     {prim::kPrimInDict, {InferImplInDict, nullptr, true}},
99     {prim::kPrimNotInDict, {InferImplNotInDict, nullptr, true}},
100     {prim::kPrimIsConsant, {InferImplIsConstant, nullptr, true}},
101     // Maths
102     {prim::kPrimMatMul, {InferImplMatMul, nullptr, true}},
103     {prim::kPrimBatchMatMul, {InferImplBatchMatMul, nullptr, true}},
104     {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, nullptr, true}},
105     {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, nullptr, true}},
106     {prim::kPrimSqrt, {InferImplSqrt, nullptr, true}},
107     // Array
108     {prim::kPrimRange, {InferImplRange, nullptr, true}},
109     {prim::kPrimScalarToArray, {InferImplScalarToArray, nullptr, true}},
110     {prim::kPrimArrayToScalar, {InferImplArrayToScalar, nullptr, true}},
111     {prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}},
112     {prim::kPrimUnique, {InferImplUnique, nullptr, true}},
113     {prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}},
114     {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}},
115     {prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}},
116     {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}},
117     {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, nullptr, true}},
118     {prim::kPrimScatterAdd, {InferImplScatterAdd, nullptr, true}},
119     {prim::kPrimScatterSub, {InferImplScatterSub, nullptr, true}},
120     {prim::kPrimSubAndFilter, {InferImplSubAndFilter, nullptr, true}},
121     {prim::kPrimScatterUpdate, {InferImplScatterUpdate, nullptr, true}},
122     {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, nullptr, true}},
123     {prim::kPrimDynamicAssign, {InferImplDynamicAssign, nullptr, true}},
124     {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}},
125     {prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}},
126     {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}},
127     {prim::kPrimDynamicStitch, {InferImplDynamicStitch, nullptr, true}},
128     {prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}},
129     {prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}},
130     {prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}},
131     {prim::kPrimSplit, {InferImplSplit, nullptr, true}},
132     {prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}},
133     {prim::kPrimSort, {InferImplSort, nullptr, true}},
134     {prim::kPrimMaskedSelect, {InferImplMaskedSelect, nullptr, true}},
135     {prim::kPrimTensorCopySlices, {InferImplTensorCopySlices, nullptr, true}},
136     // Structure
137     {prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}},
138     {prim::kPrimMakeList, {InferImplMakeList, nullptr, true}},
139     {prim::kPrimMakeDict, {InferImplMakeDict, nullptr, true}},
140     {prim::kPrimMakeSlice, {InferImplMakeSlice, nullptr, true}},
141     {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, nullptr, true}},
142     {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, nullptr, true}},
143     {prim::kPrimTupleGetItem, {InferImplTupleGetItem, nullptr, true}},
144     {prim::kPrimListGetItem, {InferImplListGetItem, nullptr, true}},
145     {prim::kPrimTupleSetItem, {InferImplTupleSetItem, nullptr, true}},
146     {prim::kPrimListSetItem, {InferImplListSetItem, nullptr, true}},
147     {prim::kPrimDictGetItem, {InferImplDictGetItem, nullptr, true}},
148     {prim::kPrimDictSetItem, {InferImplDictSetItem, nullptr, true}},
149     {prim::kPrimDictGetKeys, {InferImplDictGetKeys, nullptr, true}},
150     {prim::kPrimDictGetValues, {InferImplDictGetValues, nullptr, true}},
151     {prim::kPrimListAppend, {InferImplListAppend, nullptr, true}},
152     {prim::kPrimTupleLen, {InferImplTupleLen, nullptr, true}},
153     {prim::kPrimListLen, {InferImplListLen, nullptr, true}},
154     {prim::kPrimArrayLen, {InferImplArrayLen, nullptr, true}},
155     // NN
156     {prim::kPrimPooling, {InferImplPooling, nullptr, true}},
157     {prim::kPrimPoolingGrad, {InferImplPoolingGrad, nullptr, true}},
158     {prim::kPrimBatchNorm, {InferImplBatchNorm, nullptr, true}},
159     {prim::kPrimConv2D, {InferImplConv2D, nullptr, true}},
160     {prim::kPrimBiasAdd, {InferImplBiasAdd, nullptr, true}},
161     {prim::kPrimBpropCut, {InferImplBpropCut, nullptr, true}},
162     {prim::kPrimDropout, {InferImplDropout, nullptr, true}},
163     {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, nullptr, true}},
164     {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, nullptr, true}},
165     {prim::kPrimSGD, {InferImplSGD, nullptr, true}},
166     {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, nullptr, true}},
167     {prim::kPrimHSigmoid, {InferImplHSigmoid, nullptr, true}},
168     {prim::kPrimHSigmoidGrad, {InferImplHSigmoidGrad, nullptr, true}},
169     // Others
170     {prim::kPrimIdentity, {InferImplIdentity, nullptr, true}},
171     {prim::kPrimLoad, {InferImplLoad, nullptr, true}},
172     // Set impl to null as it will use PartialEvaluator;
173     {prim::kPrimPartial, {nullptr, nullptr, true}},
174     {prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}},
175     {prim::kPrimEnvSetItem, {InferImplEnvSetItem, nullptr, true}},
176     {prim::kPrimEnvAdd, {InferImplEnvAdd, nullptr, true}},
177     {prim::kPrimMakeRefKey, {InferImplMakeRefKey, nullptr, true}},
178     {prim::kPrimMakeRef, {InferImplMakeRef, nullptr, true}},
179     {prim::kPrimGetRefKey, {InferImplGetRefKey, nullptr, true}},
180     {prim::kPrimGetRefValue, {InferImplGetRefValue, nullptr, true}},
181     {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}},
182     {prim::kPrimDepend, {InferImplDepend, nullptr, true}},
183     {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}},
184     // Debug
185     {prim::kPrimDebug, {InferImplDebug, nullptr, true}},
186     // Dynamic shape testing
187     {prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, nullptr, true}},
188     // SparseTensor
189     {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, nullptr, true}},
190     {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, nullptr, true}},
191     {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, nullptr, true}},
192     {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, nullptr, true}},
193     // RowTensor
194     {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, nullptr, true}},
195 
196     {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, nullptr, true}},
197     {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, nullptr, true}},
198     {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, nullptr, true}},
199     {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, nullptr, false}},
200     // Comm Ops
201     {prim::kPrimAllSwap, {InferImplAllSwap, nullptr, true}},
202     {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, nullptr, true}},
203     {prim::kPrimFusedPushWeight, {nullptr, nullptr, true}},
204     {prim::kPrimFusedPullWeight, {nullptr, nullptr, true}},
205   };
206   return prim_eval_implement_map;
207 }
208 
GetPrimitiveToBackendEvalImplMap()209 PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
210   static PrimitiveEvalImplMap prim_backend_eval_implement_map = {
211     {prim::kPrimMul, {ops::MulInfer, nullptr, true}},
212     {prim::kPrimAdd, {ops::AddInfer, nullptr, false}},
213     {prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
214     {prim::kPrimSub, {ops::SubInfer, nullptr, false}},
215     {prim::kPrimNeg, {ops::NegInfer, nullptr, false}},
216     {prim::kPrimTile, {ops::TileInfer, nullptr, true}},
217     {prim::kPrimEqual, {ops::EqualInfer, nullptr, true}},
218     {prim::kPrimNotEqual, {ops::NotEqualInfer, nullptr, true}},
219     {prim::kPrimLog, {ops::LogInfer, nullptr, true}},
220     {prim::kPrimReciprocal, {ops::ReciprocalInfer, nullptr, true}},
221     {prim::kPrimReduceSum, {ops::ReduceSumInfer, nullptr, true}},
222     {prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
223     {prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
224     {prim::kPrimReduceAny, {InferImplReduceFunc, nullptr, true}},
225     {prim::kPrimReduceMax, {InferImplReduceFunc, nullptr, true}},
226     {prim::kPrimReduceMin, {InferImplReduceFunc, nullptr, true}},
227     {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, nullptr, true}},
228     {prim::kPrimReduceScatter, {InferImplReduceScatter, nullptr, true}},
229     {prim::kPrimCast, {InferImplCast, nullptr, true}},
230     {prim::kPrimExp, {ops::ExpInfer, nullptr, true}},
231     {prim::kPrimExpandDims, {InferImplExpandDims, nullptr, true}},
232     {prim::kPrimAllReduce, {InferImplAllReduce, nullptr, true}},
233     {prim::kPrimBroadcast, {InferImplBroadcast, nullptr, true}},
234     {prim::kPrimAllGather, {InferImplAllGather, nullptr, true}},
235     {prim::kPrimMinimum, {InferImplMinimum, nullptr, true}},
236     {prim::kPrimDivNoNan, {InferImplDivNoNan, nullptr, true}},
237     {prim::kPrimLinSpace, {InferImplLinSpace, nullptr, true}},
238 
239     {prim::kPrimLess, {InferImplLess, nullptr, true}},
240     {prim::kPrimStack, {InferImplStack, nullptr, true}},
241     {prim::kPrimPad, {InferImplPad, nullptr, true}},
242     {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}},
243     {prim::kPrimDiv, {InferImplDiv, nullptr, true}},
244     {prim::kPrimRealDiv, {ops::RealDivInfer, nullptr, false}},
245     {prim::kPrimTranspose, {InferImplTranspose, nullptr, true}},
246     {prim::kPrimStridedSlice, {ops::StridedSliceInfer, nullptr, true}},
247     {prim::kPrimReshape, {InferImplReshape, nullptr, true}},
248     {prim::kPrimConcat, {InferImplConcat, nullptr, true}},
249     {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}},
250     {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, nullptr, true}},
251     {prim::kPrimTransData, {InferImplTransData, nullptr, true}},
252   };
253   return prim_backend_eval_implement_map;
254 }
255 
GetPrimitiveInferImpl(const PrimitivePtr & primitive)256 StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
257   MS_EXCEPTION_IF_NULL(primitive);
258   auto iter = GetPrimitiveToEvalImplMap().find(primitive);
259   if (iter == GetPrimitiveToEvalImplMap().end()) {
260     return {nullptr, nullptr, false};
261   }
262   return iter->second;
263 }
264 
RegisterStandardPrimitiveImpl(const PrimitivePtr & primitive,const StandardPrimitiveImplReg & impl_reg)265 void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {
266   auto &prim_eval_map = GetPrimitiveToEvalImplMap();
267   prim_eval_map[primitive] = impl_reg;
268 }
269 }  // namespace abstract
270 }  // namespace mindspore
271