• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/infer/infer_register.h"
17 
18 #ifdef _MSC_VER
19 #include "nnacl/infer/activation_grad_infer.h"
20 #include "nnacl/infer/adam_infer.h"
21 #include "nnacl/infer/adam_weight_decay_infer.h"
22 #include "nnacl/infer/add_sub_grad_infer.h"
23 #include "nnacl/infer/addn_infer.h"
24 #include "nnacl/infer/affine_infer.h"
25 #include "nnacl/infer/all_gather_infer.h"
26 #include "nnacl/infer/apply_momentum_infer.h"
27 #include "nnacl/infer/argmin_max_infer.h"
28 #include "nnacl/infer/arithmetic_compare_infer.h"
29 #include "nnacl/infer/arithmetic_grad_infer.h"
30 #include "nnacl/infer/arithmetic_infer.h"
31 #include "nnacl/infer/assert_op_infer.h"
32 #include "nnacl/infer/assign_add_infer.h"
33 #include "nnacl/infer/assign_infer.h"
34 #include "nnacl/infer/attention_infer.h"
35 #include "nnacl/infer/encoder_layer_infer.h"
36 #include "nnacl/infer/audio_spectrogram_infer.h"
37 #include "nnacl/infer/batch_to_space_infer.h"
38 #include "nnacl/infer/bias_grad_infer.h"
39 #include "nnacl/infer/binary_cross_entropy_infer.h"
40 #include "nnacl/infer/bn_grad_infer.h"
41 #include "nnacl/infer/broadcast_to_infer.h"
42 #include "nnacl/infer/cast_infer.h"
43 #include "nnacl/infer/common_infer.h"
44 #include "nnacl/infer/concat_infer.h"
45 #include "nnacl/infer/constant_of_shape_infer.h"
46 #include "nnacl/infer/decoder_layer_infer.h"
47 
48 #ifdef MSLITE_ENABLE_CONTROLFLOW
49 #include "nnacl/infer/control/tensor_array_infer.h"
50 #include "nnacl/infer/control/tensor_array_read_infer.h"
51 #include "nnacl/infer/control/tensor_array_write_infer.h"
52 #include "nnacl/infer/control/tensorlist_fromtensor_infer.h"
53 #include "nnacl/infer/control/tensorlist_getitem_infer.h"
54 #include "nnacl/infer/control/tensorlist_reserve_infer.h"
55 #include "nnacl/infer/control/tensorlist_setitem_infer.h"
56 #include "nnacl/infer/control/tensorlist_stack_infer.h"
57 #endif
58 #include "nnacl/infer/conv2d_grad_filter_infer.h"
59 #include "nnacl/infer/conv2d_grad_input_infer.h"
60 #include "nnacl/infer/conv2d_infer.h"
61 #include "nnacl/infer/crop_and_resize_infer.h"
62 #include "nnacl/infer/crop_infer.h"
63 #include "nnacl/infer/cumsum_infer.h"
64 #include "nnacl/infer/deconv2d_infer.h"
65 #include "nnacl/infer/depth_to_space_infer.h"
66 #include "nnacl/infer/depthwise_conv2d_infer.h"
67 #include "nnacl/infer/detection_post_process_infer.h"
68 #include "nnacl/infer/dropout_grad_infer.h"
69 #include "nnacl/infer/dropout_infer.h"
70 #include "nnacl/infer/dynamic_quant_infer.h"
71 #include "nnacl/infer/embedding_lookup_infer.h"
72 #include "nnacl/infer/expand_dims_infer.h"
73 #include "nnacl/infer/fft_imag_infer.h"
74 #include "nnacl/infer/fft_real_infer.h"
75 #include "nnacl/infer/fill_infer.h"
76 #include "nnacl/infer/fillv2_infer.h"
77 #include "nnacl/infer/flatten_grad_infer.h"
78 #include "nnacl/infer/flatten_infer.h"
79 #include "nnacl/infer/full_connection_infer.h"
80 #include "nnacl/infer/fused_batchnorm_infer.h"
81 #include "nnacl/infer/gather_infer.h"
82 #include "nnacl/infer/gather_nd_infer.h"
83 #include "nnacl/infer/glu_infer.h"
84 #include "nnacl/infer/group_conv2d_grad_input_infer.h"
85 #include "nnacl/infer/gru_infer.h"
86 #include "nnacl/infer/instance_norm_infer.h"
87 #include "nnacl/infer/invert_permutation_infer.h"
88 #include "nnacl/infer/layer_norm_grad_infer.h"
89 #include "nnacl/infer/layer_norm_infer.h"
90 #include "nnacl/infer/lin_space_infer.h"
91 #include "nnacl/infer/log_softmax_infer.h"
92 #include "nnacl/infer/lstm_grad_data_infer.h"
93 #include "nnacl/infer/lstm_grad_infer.h"
94 #include "nnacl/infer/lstm_grad_weight_infer.h"
95 #include "nnacl/infer/lstm_infer.h"
96 #include "nnacl/infer/matmul_infer.h"
97 #include "nnacl/infer/max_min_grad_infer.h"
98 #include "nnacl/infer/mfcc_infer.h"
99 #include "nnacl/infer/nllloss_grad_infer.h"
100 #include "nnacl/infer/nllloss_infer.h"
101 #include "nnacl/infer/non_max_suppression_infer.h"
102 #include "nnacl/infer/one_hot_infer.h"
103 #include "nnacl/infer/pad_infer.h"
104 #include "nnacl/infer/pooling_grad_infer.h"
105 #include "nnacl/infer/pooling_infer.h"
106 #include "nnacl/infer/power_infer.h"
107 #include "nnacl/infer/prior_box_infer.h"
108 #include "nnacl/infer/quant_dtype_cast_infer.h"
109 #include "nnacl/infer/ragged_range_infer.h"
110 #include "nnacl/infer/random_normal_infer.h"
111 #include "nnacl/infer/random_standard_normal_infer.h"
112 #include "nnacl/infer/range_infer.h"
113 #include "nnacl/infer/rank_infer.h"
114 #include "nnacl/infer/reduce_infer.h"
115 #include "nnacl/infer/reduce_scatter_infer.h"
116 #include "nnacl/infer/reshape_infer.h"
117 #include "nnacl/infer/resize_grad_infer.h"
118 #include "nnacl/infer/resize_infer.h"
119 #include "nnacl/infer/rfft_infer.h"
120 #include "nnacl/infer/roi_pooling_infer.h"
121 #include "nnacl/infer/scatter_nd_infer.h"
122 #include "nnacl/infer/scatter_nd_update_infer.h"
123 #include "nnacl/infer/select_infer.h"
124 #include "nnacl/infer/sgd_infer.h"
125 #include "nnacl/infer/invalid_infer.h"
126 #ifndef RUNTIME_PASS_CLIP
127 #include "nnacl/infer/shape_fusion_infer.h"
128 #endif
129 #include "nnacl/infer/shape_infer.h"
130 #include "nnacl/infer/size_infer.h"
131 #include "nnacl/infer/slice_infer.h"
132 #include "nnacl/infer/softmax_cross_entropy_infer.h"
133 #include "nnacl/infer/softmax_infer.h"
134 #include "nnacl/infer/space_to_batch_infer.h"
135 #include "nnacl/infer/space_to_batch_nd_infer.h"
136 #include "nnacl/infer/space_to_depth_infer.h"
137 #include "nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h"
138 #include "nnacl/infer/sparse_to_dense_infer.h"
139 #include "nnacl/infer/splice_infer.h"
140 #include "nnacl/infer/split_infer.h"
141 #include "nnacl/infer/split_with_over_lap_infer.h"
142 #include "nnacl/infer/squeeze_infer.h"
143 #include "nnacl/infer/stack_infer.h"
144 #include "nnacl/infer/strided_slice_grad_infer.h"
145 #include "nnacl/infer/strided_slice_infer.h"
146 #ifdef MSLITE_ENABLE_STRING_KERNEL
147 #include "nnacl/infer/string/custom_extract_features_infer.h"
148 #include "nnacl/infer/string/custom_normalize_infer.h"
149 #include "nnacl/infer/string/custom_predict_infer.h"
150 #include "nnacl/infer/string/hashtable_lookup_infer.h"
151 #include "nnacl/infer/string/lsh_projection_infer.h"
152 #include "nnacl/infer/string/skip_gram_infer.h"
153 #endif
154 #include "nnacl/infer/tile_infer.h"
155 #include "nnacl/infer/topk_infer.h"
156 #include "nnacl/infer/transpose_infer.h"
157 #include "nnacl/infer/uniform_real_infer.h"
158 #include "nnacl/infer/unique_infer.h"
159 #include "nnacl/infer/unsorted_segment_sum_infer.h"
160 #include "nnacl/infer/unsqueeze_infer.h"
161 #include "nnacl/infer/unstack_infer.h"
162 #include "nnacl/infer/where_infer.h"
163 #include "nnacl/infer/isfinite_infer.h"
164 #include "nnacl/infer/fse_decoder_infer.h"
165 #include "nnacl/infer/custom_gru_infer.h"
166 
167 InferShape g_infer_func[PrimType_MAX] = {0};
168 InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0};
RegAllInferFunc1()169 void RegAllInferFunc1() {
170   g_infer_func[PrimType_NONE] = NULL;
171   g_infer_func[PrimType_Abs] = CommonInferShape;
172   g_infer_func[PrimType_AbsGrad] = CommonGradInferShape;
173   g_infer_func[PrimType_Activation] = CommonInferShape;
174   g_infer_func[PrimType_ActivationGrad] = ActivationGradInferShape;
175   g_infer_func[PrimType_Adam] = AdamInferShape;
176   g_infer_func[PrimType_AdamWeightDecay] = AdamWeightDecayInferShape;
177   g_infer_func[PrimType_AdderFusion] = Conv2dInferShape;
178   g_infer_func[PrimType_AddFusion] = ArithmeticInferShape;
179   g_infer_func[PrimType_AddGrad] = AddSubGradInferShape;
180   g_infer_func[PrimType_AddN] = AddnInferShape;
181   g_infer_func[PrimType_Affine] = AffineInferShape;
182   g_infer_func[PrimType_All] = NULL;
183   g_infer_func[PrimType_AllGather] = AllGatherInferShape;
184   g_infer_func[PrimType_ApplyMomentum] = ApplyMomentumInferShape;
185   g_infer_func[PrimType_ArgMaxFusion] = ArgMinMaxInferShape;
186   g_infer_func[PrimType_ArgMinFusion] = ArgMinMaxInferShape;
187   g_infer_func[PrimType_Assert] = AssertOpInferShape;
188   g_infer_func[PrimType_Assign] = AssignInferShape;
189   g_infer_func[PrimType_AssignAdd] = AssignAddInferShape;
190   g_infer_func[PrimType_Attention] = AttentionInferShape;
191   g_infer_func[PrimType_AudioSpectrogram] = AudioSpectrogramInferShape;
192   g_infer_func[PrimType_AvgPoolFusion] = PoolingInferShape;
193   g_infer_func[PrimType_AvgPoolGrad] = PoolingGradInferShape;
194   g_infer_func[PrimType_BatchNorm] = CommonInferShape;
195   g_infer_func[PrimType_BatchNormGrad] = BnGradInferShape;
196   g_infer_func[PrimType_BatchToSpace] = BatchToSpaceInferShape;
197   g_infer_func[PrimType_BatchToSpaceND] = NULL;
198   g_infer_func[PrimType_BiasAdd] = ArithmeticInferShape;
199   g_infer_func[PrimType_BiasAddGrad] = BiasGradInferShape;
200   g_infer_func[PrimType_BinaryCrossEntropy] = BinaryCrossEntropyInferShape;
201   g_infer_func[PrimType_BinaryCrossEntropyGrad] = CommonInferShape;
202   g_infer_func[PrimType_BroadcastTo] = BroadcastToInferShape;
203   g_infer_func[PrimType_Call] = InvalidInferShape;
204   g_infer_func[PrimType_Cast] = CastInferShape;
205   g_infer_func[PrimType_Ceil] = CommonInferShape;
206   g_infer_func[PrimType_Clip] = CommonInferShape;
207   g_infer_func[PrimType_Concat] = ConcatInferShape;
208   g_infer_func[PrimType_ConstantOfShape] = ConstantOfShapeInferShape;
209   g_infer_func[PrimType_Conv2DBackpropFilterFusion] = Conv2dGradFilterInferShape;
210   g_infer_func[PrimType_Conv2DBackpropInputFusion] = Conv2dGradInputInferShape;
211   g_infer_func[PrimType_Conv2DFusion] = Conv2dInferShape;
212   g_infer_func[PrimType_Conv2dTransposeFusion] = Deconv2dInferShape;
213   g_infer_func[PrimType_Cos] = CommonInferShape;
214   g_infer_func[PrimType_Crop] = CropInferShape;
215   g_infer_func[PrimType_CropAndResize] = CropAndResizeInferShape;
216   g_infer_func[PrimType_CumSum] = CumsumInferShape;
217   g_infer_func[PrimType_Custom] = NULL;
218 #ifdef MSLITE_ENABLE_STRING_KERNEL
219   g_infer_func[PrimType_CustomExtractFeatures] = CustomExtractFeaturesInferShape;
220 #endif
221 }
222 
RegAllInferFunc2()223 void RegAllInferFunc2() {
224 #ifdef MSLITE_ENABLE_STRING_KERNEL
225   g_infer_func[PrimType_CustomNormalize] = CustomNormalizeInferShape;
226   g_infer_func[PrimType_CustomPredict] = CustomPredictInferShape;
227 #endif
228   g_infer_func[PrimType_DeConv2DGradFilter] = NULL;
229   g_infer_func[PrimType_Depend] = CommonInferShape;
230   g_infer_func[PrimType_DepthToSpace] = DepthToSpaceInferShape;
231   g_infer_func[PrimType_DetectionPostProcess] = DetectionPostProcessInferShape;
232   g_infer_func[PrimType_DivFusion] = ArithmeticInferShape;
233   g_infer_func[PrimType_DivGrad] = ArithmeticGradInferShape;
234   g_infer_func[PrimType_Dropout] = DropoutInferShape;
235   g_infer_func[PrimType_DropoutGrad] = DropoutGradInferShape;
236   g_infer_func[PrimType_DynamicQuant] = DynamicQuantInferShape;
237   g_infer_func[PrimType_Eltwise] = ArithmeticInferShape;
238   g_infer_func[PrimType_Elu] = CommonInferShape;
239   g_infer_func[PrimType_EmbeddingLookupFusion] = EmbeddingLookupInferShape;
240   g_infer_func[PrimType_Equal] = ArithmeticCompareInferShape;
241   g_infer_func[PrimType_Erf] = CommonInferShape;
242   g_infer_func[PrimType_ExpandDims] = ExpandDimsInferShape;
243   g_infer_func[PrimType_ExpFusion] = CommonInferShape;
244   g_infer_func[PrimType_FakeQuantWithMinMaxVars] = CommonInferShape;
245   g_infer_func[PrimType_FakeQuantWithMinMaxVarsPerChannel] = NULL;
246   g_infer_func[PrimType_FftImag] = FftImagInferShape;
247   g_infer_func[PrimType_FftReal] = FftRealInferShape;
248   g_infer_func[PrimType_Fill] = FillInferShape;
249   g_infer_func[PrimType_FillV2] = FillInferShape;
250   g_infer_func[PrimType_Flatten] = FlattenInferShape;
251   g_infer_func[PrimType_FlattenGrad] = FlattenGradInferShape;
252   g_infer_func[PrimType_Floor] = CommonInferShapeWithOneInput;
253   g_infer_func[PrimType_FloorDiv] = ArithmeticInferShape;
254   g_infer_func[PrimType_FloorMod] = ArithmeticInferShape;
255   g_infer_func[PrimType_FullConnection] = FullConnectionInferShape;
256   g_infer_func[PrimType_FusedBatchNorm] = FusedBatchNormInferShape;
257   g_infer_func[PrimType_Gather] = GatherInferShape;
258   g_infer_func[PrimType_GatherNd] = GatherNdInferShape;
259   g_infer_func[PrimType_GenOP] = NULL;
260   g_infer_func[PrimType_GLU] = GluInferShape;
261   g_infer_func[PrimType_Greater] = ArithmeticCompareInferShape;
262   g_infer_func[PrimType_GreaterEqual] = ArithmeticCompareInferShape;
263   g_infer_func[PrimType_GRU] = GruInferShape;
264 #ifdef MSLITE_ENABLE_STRING_KERNEL
265   g_infer_func[PrimType_HashtableLookup] = HashtableLoopupInferShape;
266 #endif
267   g_infer_func[PrimType_InstanceNorm] = InstanceNormInferShape;
268   g_infer_func[PrimType_InvertPermutation] = InvertPermutationInferShape;
269   g_infer_func[PrimType_IsFinite] = IsFiniteInferShape;
270   g_infer_func[PrimType_L2NormalizeFusion] = CommonInferShape;
271   g_infer_func[PrimType_LayerNormFusion] = LayerNormInferShape;
272   g_infer_func[PrimType_LayerNormGrad] = LayerNormGradInferShape;
273   g_infer_func[PrimType_LeakyRelu] = CommonInferShape;
274   g_infer_func[PrimType_Less] = ArithmeticCompareInferShape;
275   g_infer_func[PrimType_LessEqual] = ArithmeticCompareInferShape;
276   g_infer_func[PrimType_LinSpace] = LinSpaceInferShape;
277 }
278 
RegAllInferFunc3()279 void RegAllInferFunc3() {
280   g_infer_func[PrimType_Log] = CommonInferShape;
281   g_infer_func[PrimType_LogGrad] = CommonGradInferShape;
282   g_infer_func[PrimType_LogicalAnd] = ArithmeticInferShape;
283   g_infer_func[PrimType_LogicalNot] = CommonInferShape;
284   g_infer_func[PrimType_LogicalOr] = ArithmeticInferShape;
285   g_infer_func[PrimType_LogSoftmax] = LogSoftmaxInferShape;
286   g_infer_func[PrimType_LpNormalization] = NULL;
287   g_infer_func[PrimType_LRN] = CommonInferShapeWithNHWC;
288 #ifdef MSLITE_ENABLE_STRING_KERNEL
289   g_infer_func[PrimType_LshProjection] = LshProjectionInferShape;
290 #endif
291   g_infer_func[PrimType_LSTM] = LstmInferShape;
292   g_infer_func[PrimType_LSTMGrad] = LstmGradInferShape;
293   g_infer_func[PrimType_LSTMGradData] = LstmGradDataInferShape;
294   g_infer_func[PrimType_LSTMGradWeight] = LstmGradWeightInferShape;
295   g_infer_func[PrimType_MatMulFusion] = MatmulInferShape;
296   g_infer_func[PrimType_Maximum] = ArithmeticInferShape;
297   g_infer_func[PrimType_MaximumGrad] = MaxMinGradInferShape;
298   g_infer_func[PrimType_MaxPoolFusion] = PoolingInferShape;
299   g_infer_func[PrimType_MaxPoolGrad] = PoolingGradInferShape;
300   g_infer_func[PrimType_SwitchLayer] = InvalidInferShape;
301   g_infer_func[PrimType_Mfcc] = MfccInferShape;
302   g_infer_func[PrimType_MIN] = NULL;
303   g_infer_func[PrimType_Minimum] = ArithmeticInferShape;
304   g_infer_func[PrimType_MinimumGrad] = MaxMinGradInferShape;
305   g_infer_func[PrimType_Mod] = ArithmeticInferShape;
306   g_infer_func[PrimType_MulFusion] = ArithmeticInferShape;
307   g_infer_func[PrimType_MulGrad] = ArithmeticGradInferShape;
308   g_infer_func[PrimType_Neg] = CommonInferShape;
309   g_infer_func[PrimType_NegGrad] = CommonGradInferShape;
310   g_infer_func[PrimType_NLLLoss] = NLLLossInferShape;
311   g_infer_func[PrimType_NLLLossGrad] = NLLLossGradInferShape;
312   g_infer_func[PrimType_NonMaxSuppression] = NonMaxSuppressionInferShape;
313   g_infer_func[PrimType_NonZero] = NULL;
314   g_infer_func[PrimType_NotEqual] = ArithmeticCompareInferShape;
315   g_infer_func[PrimType_OneHot] = OneHotInferShape;
316   g_infer_func[PrimType_OnesLike] = NULL;
317   g_infer_func[PrimType_PadFusion] = PadInferShape;
318   g_infer_func[PrimType_PartialFusion] = InvalidInferShape;
319   g_infer_func[PrimType_PowerGrad] = CommonGradInferShape;
320   g_infer_func[PrimType_PowFusion] = PowerInferShape;
321   g_infer_func[PrimType_PReLUFusion] = CommonInferShape;
322   g_infer_func[PrimType_PriorBox] = PriorBoxInferShape;
323   g_infer_func[PrimType_QuantDTypeCast] = QuantDtypeCastInferShape;
324   g_infer_func[PrimType_RaggedRange] = RaggedRangeInferShape;
325   g_infer_func[PrimType_RandomNormal] = RandomNormalInferShape;
326   g_infer_func[PrimType_RandomStandardNormal] = RandomStandardNormalInferShape;
327   g_infer_func[PrimType_Range] = RangeInferShape;
328   g_infer_func[PrimType_Rank] = RankInferShape;
329 }
330 
RegAllInferFunc4()331 void RegAllInferFunc4() {
332   g_infer_func[PrimType_RealDiv] = ArithmeticInferShape;
333   g_infer_func[PrimType_Reciprocal] = CommonInferShape;
334   g_infer_func[PrimType_ReduceFusion] = ReduceInferShape;
335   g_infer_func[PrimType_ReduceScatter] = ReduceScatterInferShape;
336   g_infer_func[PrimType_Reshape] = ReshapeInferShape;
337   g_infer_func[PrimType_Resize] = ResizeInferShape;
338   g_infer_func[PrimType_ResizeGrad] = ResizeGradInferShape;
339   g_infer_func[PrimType_ReverseSequence] = CommonInferShape;
340   g_infer_func[PrimType_ReverseV2] = CommonInferShape;
341   g_infer_func[PrimType_Rfft] = RfftInferShape;
342   g_infer_func[PrimType_ROIPooling] = ROIPoolingInferShape;
343   g_infer_func[PrimType_Round] = CommonInferShape;
344   g_infer_func[PrimType_Rsqrt] = CommonInferShape;
345   g_infer_func[PrimType_RsqrtGrad] = NULL;
346   g_infer_func[PrimType_ScaleFusion] = CommonInferShape;
347   g_infer_func[PrimType_ScatterNd] = ScatterNdInferShape;
348   g_infer_func[PrimType_ScatterNdUpdate] = ScatterNdUpdateInferShape;
349   g_infer_func[PrimType_TensorScatterAdd] = ScatterNdUpdateInferShape;
350   g_infer_func[PrimType_Select] = SelectInferShape;
351   g_infer_func[PrimType_SGD] = SgdInferShape;
352   g_infer_func[PrimType_Shape] = ShapeInferShape;
353   g_infer_func[PrimType_SigmoidCrossEntropyWithLogits] = CommonInferShape;
354   g_infer_func[PrimType_SigmoidCrossEntropyWithLogitsGrad] = CommonInferShape;
355   g_infer_func[PrimType_Sin] = CommonInferShape;
356   g_infer_func[PrimType_Size] = SizeInferShape;
357 #ifdef MSLITE_ENABLE_STRING_KERNEL
358   g_infer_func[PrimType_SkipGram] = SkipGramInferShape;
359 #endif
360   g_infer_func[PrimType_SliceFusion] = SliceInferShape;
361   g_infer_func[PrimType_SmoothL1Loss] = CommonInferShape;
362   g_infer_func[PrimType_SmoothL1LossGrad] = CommonInferShape;
363   g_infer_func[PrimType_Softmax] = SoftMaxInferShape;
364   g_infer_func[PrimType_SoftmaxCrossEntropyWithLogits] = SoftmaxCrossEntropyInferShape;
365   g_infer_func[PrimType_SpaceToBatch] = SpaceToBatchInferShape;
366   g_infer_func[PrimType_SpaceToBatchND] = SpaceToBatchNdInferShape;
367   g_infer_func[PrimType_SpaceToDepth] = SpaceToDepthInferShape;
368   g_infer_func[PrimType_SparseSoftmaxCrossEntropyWithLogits] = SparseSoftmaxCrossEntropyWithLogitsInferShape;
369   g_infer_func[PrimType_SparseToDense] = SparseToDenseInferShape;
370   g_infer_func[PrimType_Splice] = SpliceInferShape;
371   g_infer_func[PrimType_Split] = SplitInferShape;
372   g_infer_func[PrimType_SplitWithOverlap] = SplitWithOverlapInferShape;
373   g_infer_func[PrimType_Sqrt] = CommonInferShape;
374   g_infer_func[PrimType_SqrtGrad] = NULL;
375   g_infer_func[PrimType_Square] = CommonInferShape;
376   g_infer_func[PrimType_SquaredDifference] = ArithmeticInferShape;
377   g_infer_func[PrimType_Squeeze] = SqueezeInferShape;
378   g_infer_func[PrimType_Stack] = StackInferShape;
379   g_infer_func[PrimType_StridedSlice] = StridedSliceInferShape;
380   g_infer_func[PrimType_StridedSliceGrad] = StridedSliceGradInferShape;
381   g_infer_func[PrimType_SubFusion] = ArithmeticInferShape;
382   g_infer_func[PrimType_SubGrad] = AddSubGradInferShape;
383 }
384 
RegAllInferFunc5()385 void RegAllInferFunc5() {
386   g_infer_func[PrimType_Switch] = InvalidInferShape;
387 #ifdef MSLITE_ENABLE_CONTROLFLOW
388   g_infer_func[PrimType_TensorArray] = TensorArrayInferShape;
389   g_infer_func[PrimType_TensorArrayRead] = TensorArrayReadInferShape;
390   g_infer_func[PrimType_TensorArrayWrite] = TensorArrayWriteInferShape;
391   g_infer_func[PrimType_TensorListFromTensor] = TensorListFromTensorInferShape;
392   g_infer_func[PrimType_TensorListGetItem] = TensorListGetItemInferShape;
393   g_infer_func[PrimType_TensorListReserve] = TensorListReserveInferShape;
394   g_infer_func[PrimType_TensorListSetItem] = TensorListSetItemInferShape;
395   g_infer_func[PrimType_TensorListStack] = TensorListStackInferShape;
396 #endif
397   g_infer_func[PrimType_TileFusion] = TileInferShape;
398   g_infer_func[PrimType_TopKFusion] = TopKInferShape;
399   g_infer_func[PrimType_Transpose] = TransposeInferShape;
400   g_infer_func[PrimType_UniformReal] = UniformRealInferShape;
401   g_infer_func[PrimType_Unique] = UniqueInferShape;
402   g_infer_func[PrimType_UnsortedSegmentSum] = UnsortedSegmentSumInferShape;
403   g_infer_func[PrimType_Unsqueeze] = UnsqueezeInferShape;
404   g_infer_func[PrimType_Unstack] = UnstackInferShape;
405   g_infer_func[PrimType_Where] = WhereInferShape;
406   g_infer_func[PrimType_ZerosLike] = CommonInferShape;
407 
408   // fused operators.
409   g_inner_op_infer_func[PrimType_Inner_GltextureToOpencl - PrimType_InnerOpMin] = NULL;
410   g_inner_op_infer_func[PrimType_Inner_Identity - PrimType_InnerOpMin] = NULL;
411 #ifndef RUNTIME_PASS_CLIP
412   g_inner_op_infer_func[PrimType_Inner_ShapeFusion - PrimType_InnerOpMin] = ShapeFusionInferShape;
413   g_inner_op_infer_func[PrimType_Inner_EncoderLayer - PrimType_InnerOpMin] = EncoderLayerInferShape;
414   g_inner_op_infer_func[PrimType_Inner_DecoderLayer - PrimType_InnerOpMin] = DecoderLayerInferShape;
415   g_inner_op_infer_func[PrimType_Inner_FseDecode - PrimType_InnerOpMin] = FseDecoderInferShape;
416 #endif
417   g_inner_op_infer_func[PrimType_Inner_CustomGru - PrimType_InnerOpMin] = CustomGruInferShape;
418   g_inner_op_infer_func[PrimType_Inner_ToFormat - PrimType_InnerOpMin] = NULL;
419 }
420 
421 #else
422 InferShape g_infer_func[PrimType_MAX] = {0};
423 InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0};
424 #endif  // _MSC_VER
425 
GetInferFunc(int prim_type)426 InferShape GetInferFunc(int prim_type) {
427 #ifdef _MSC_VER
428   if (g_infer_func[PrimType_Abs] == NULL) {
429     RegAllInferFunc1();
430     RegAllInferFunc2();
431     RegAllInferFunc3();
432     RegAllInferFunc4();
433     RegAllInferFunc5();
434   }
435 #endif
436   if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) {
437     return g_infer_func[prim_type];
438   } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) {
439     return g_inner_op_infer_func[prim_type - PrimType_InnerOpMin];
440   }
441   return NULL;
442 }
443 
RegInfer(int prim_type,InferShape func)444 void RegInfer(int prim_type, InferShape func) {
445   if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) {
446     g_infer_func[prim_type] = func;
447   } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) {
448     g_inner_op_infer_func[prim_type - PrimType_InnerOpMin] = func;
449   }
450 }
451