1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "nnacl/infer/infer_register.h" 17 18 #ifdef _MSC_VER 19 #include "nnacl/infer/activation_grad_infer.h" 20 #include "nnacl/infer/adam_infer.h" 21 #include "nnacl/infer/adam_weight_decay_infer.h" 22 #include "nnacl/infer/add_sub_grad_infer.h" 23 #include "nnacl/infer/addn_infer.h" 24 #include "nnacl/infer/affine_infer.h" 25 #include "nnacl/infer/all_gather_infer.h" 26 #include "nnacl/infer/apply_momentum_infer.h" 27 #include "nnacl/infer/argmin_max_infer.h" 28 #include "nnacl/infer/arithmetic_compare_infer.h" 29 #include "nnacl/infer/arithmetic_grad_infer.h" 30 #include "nnacl/infer/arithmetic_infer.h" 31 #include "nnacl/infer/assert_op_infer.h" 32 #include "nnacl/infer/assign_add_infer.h" 33 #include "nnacl/infer/assign_infer.h" 34 #include "nnacl/infer/attention_infer.h" 35 #include "nnacl/infer/encoder_layer_infer.h" 36 #include "nnacl/infer/audio_spectrogram_infer.h" 37 #include "nnacl/infer/batch_to_space_infer.h" 38 #include "nnacl/infer/bias_grad_infer.h" 39 #include "nnacl/infer/binary_cross_entropy_infer.h" 40 #include "nnacl/infer/bn_grad_infer.h" 41 #include "nnacl/infer/broadcast_to_infer.h" 42 #include "nnacl/infer/cast_infer.h" 43 #include "nnacl/infer/common_infer.h" 44 #include "nnacl/infer/concat_infer.h" 45 #include "nnacl/infer/constant_of_shape_infer.h" 46 #include "nnacl/infer/decoder_layer_infer.h" 47 48 #ifdef MSLITE_ENABLE_CONTROLFLOW 49 #include "nnacl/infer/control/tensor_array_infer.h" 50 #include "nnacl/infer/control/tensor_array_read_infer.h" 51 #include "nnacl/infer/control/tensor_array_write_infer.h" 52 #include "nnacl/infer/control/tensorlist_fromtensor_infer.h" 53 #include "nnacl/infer/control/tensorlist_getitem_infer.h" 54 #include "nnacl/infer/control/tensorlist_reserve_infer.h" 55 #include "nnacl/infer/control/tensorlist_setitem_infer.h" 56 #include "nnacl/infer/control/tensorlist_stack_infer.h" 57 #endif 58 #include "nnacl/infer/conv2d_grad_filter_infer.h" 59 #include "nnacl/infer/conv2d_grad_input_infer.h" 60 #include "nnacl/infer/conv2d_infer.h" 61 #include "nnacl/infer/crop_and_resize_infer.h" 62 #include "nnacl/infer/crop_infer.h" 63 #include "nnacl/infer/cumsum_infer.h" 64 #include "nnacl/infer/deconv2d_infer.h" 65 #include "nnacl/infer/depth_to_space_infer.h" 66 #include "nnacl/infer/depthwise_conv2d_infer.h" 67 #include "nnacl/infer/detection_post_process_infer.h" 68 #include "nnacl/infer/dropout_grad_infer.h" 69 #include "nnacl/infer/dropout_infer.h" 70 #include "nnacl/infer/dynamic_quant_infer.h" 71 #include "nnacl/infer/embedding_lookup_infer.h" 72 #include "nnacl/infer/expand_dims_infer.h" 73 #include "nnacl/infer/fft_imag_infer.h" 74 #include "nnacl/infer/fft_real_infer.h" 75 #include "nnacl/infer/fill_infer.h" 76 #include "nnacl/infer/fillv2_infer.h" 77 #include "nnacl/infer/flatten_grad_infer.h" 78 #include "nnacl/infer/flatten_infer.h" 79 #include "nnacl/infer/full_connection_infer.h" 80 #include "nnacl/infer/fused_batchnorm_infer.h" 81 #include "nnacl/infer/gather_infer.h" 82 #include "nnacl/infer/gather_nd_infer.h" 83 #include "nnacl/infer/glu_infer.h" 84 #include "nnacl/infer/group_conv2d_grad_input_infer.h" 85 #include "nnacl/infer/gru_infer.h" 86 #include "nnacl/infer/instance_norm_infer.h" 87 #include "nnacl/infer/invert_permutation_infer.h" 88 #include "nnacl/infer/layer_norm_grad_infer.h" 89 #include "nnacl/infer/layer_norm_infer.h" 90 #include "nnacl/infer/lin_space_infer.h" 91 #include "nnacl/infer/log_softmax_infer.h" 92 #include "nnacl/infer/lstm_grad_data_infer.h" 93 #include "nnacl/infer/lstm_grad_infer.h" 94 #include "nnacl/infer/lstm_grad_weight_infer.h" 95 #include "nnacl/infer/lstm_infer.h" 96 #include "nnacl/infer/matmul_infer.h" 97 #include "nnacl/infer/max_min_grad_infer.h" 98 #include "nnacl/infer/mfcc_infer.h" 99 #include "nnacl/infer/nllloss_grad_infer.h" 100 #include "nnacl/infer/nllloss_infer.h" 101 #include "nnacl/infer/non_max_suppression_infer.h" 102 #include "nnacl/infer/one_hot_infer.h" 103 #include "nnacl/infer/pad_infer.h" 104 #include "nnacl/infer/pooling_grad_infer.h" 105 #include "nnacl/infer/pooling_infer.h" 106 #include "nnacl/infer/power_infer.h" 107 #include "nnacl/infer/prior_box_infer.h" 108 #include "nnacl/infer/quant_dtype_cast_infer.h" 109 #include "nnacl/infer/ragged_range_infer.h" 110 #include "nnacl/infer/random_normal_infer.h" 111 #include "nnacl/infer/random_standard_normal_infer.h" 112 #include "nnacl/infer/range_infer.h" 113 #include "nnacl/infer/rank_infer.h" 114 #include "nnacl/infer/reduce_infer.h" 115 #include "nnacl/infer/reduce_scatter_infer.h" 116 #include "nnacl/infer/reshape_infer.h" 117 #include "nnacl/infer/resize_grad_infer.h" 118 #include "nnacl/infer/resize_infer.h" 119 #include "nnacl/infer/rfft_infer.h" 120 #include "nnacl/infer/roi_pooling_infer.h" 121 #include "nnacl/infer/scatter_nd_infer.h" 122 #include "nnacl/infer/scatter_nd_update_infer.h" 123 #include "nnacl/infer/select_infer.h" 124 #include "nnacl/infer/sgd_infer.h" 125 #include "nnacl/infer/invalid_infer.h" 126 #ifndef RUNTIME_PASS_CLIP 127 #include "nnacl/infer/shape_fusion_infer.h" 128 #endif 129 #include "nnacl/infer/shape_infer.h" 130 #include "nnacl/infer/size_infer.h" 131 #include "nnacl/infer/slice_infer.h" 132 #include "nnacl/infer/softmax_cross_entropy_infer.h" 133 #include "nnacl/infer/softmax_infer.h" 134 #include "nnacl/infer/space_to_batch_infer.h" 135 #include "nnacl/infer/space_to_batch_nd_infer.h" 136 #include "nnacl/infer/space_to_depth_infer.h" 137 #include "nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h" 138 #include "nnacl/infer/sparse_to_dense_infer.h" 139 #include "nnacl/infer/splice_infer.h" 140 #include "nnacl/infer/split_infer.h" 141 #include "nnacl/infer/split_with_over_lap_infer.h" 142 #include "nnacl/infer/squeeze_infer.h" 143 #include "nnacl/infer/stack_infer.h" 144 #include "nnacl/infer/strided_slice_grad_infer.h" 145 #include "nnacl/infer/strided_slice_infer.h" 146 #ifdef MSLITE_ENABLE_STRING_KERNEL 147 #include "nnacl/infer/string/custom_extract_features_infer.h" 148 #include "nnacl/infer/string/custom_normalize_infer.h" 149 #include "nnacl/infer/string/custom_predict_infer.h" 150 #include "nnacl/infer/string/hashtable_lookup_infer.h" 151 #include "nnacl/infer/string/lsh_projection_infer.h" 152 #include "nnacl/infer/string/skip_gram_infer.h" 153 #endif 154 #include "nnacl/infer/tile_infer.h" 155 #include "nnacl/infer/topk_infer.h" 156 #include "nnacl/infer/transpose_infer.h" 157 #include "nnacl/infer/uniform_real_infer.h" 158 #include "nnacl/infer/unique_infer.h" 159 #include "nnacl/infer/unsorted_segment_sum_infer.h" 160 #include "nnacl/infer/unsqueeze_infer.h" 161 #include "nnacl/infer/unstack_infer.h" 162 #include "nnacl/infer/where_infer.h" 163 #include "nnacl/infer/isfinite_infer.h" 164 #include "nnacl/infer/fse_decoder_infer.h" 165 #include "nnacl/infer/custom_gru_infer.h" 166 167 InferShape g_infer_func[PrimType_MAX] = {0}; 168 InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0}; RegAllInferFunc1()169void RegAllInferFunc1() { 170 g_infer_func[PrimType_NONE] = NULL; 171 g_infer_func[PrimType_Abs] = CommonInferShape; 172 g_infer_func[PrimType_AbsGrad] = CommonGradInferShape; 173 g_infer_func[PrimType_Activation] = CommonInferShape; 174 g_infer_func[PrimType_ActivationGrad] = ActivationGradInferShape; 175 g_infer_func[PrimType_Adam] = AdamInferShape; 176 g_infer_func[PrimType_AdamWeightDecay] = AdamWeightDecayInferShape; 177 g_infer_func[PrimType_AdderFusion] = Conv2dInferShape; 178 g_infer_func[PrimType_AddFusion] = ArithmeticInferShape; 179 g_infer_func[PrimType_AddGrad] = AddSubGradInferShape; 180 g_infer_func[PrimType_AddN] = AddnInferShape; 181 g_infer_func[PrimType_Affine] = AffineInferShape; 182 g_infer_func[PrimType_All] = NULL; 183 g_infer_func[PrimType_AllGather] = AllGatherInferShape; 184 g_infer_func[PrimType_ApplyMomentum] = ApplyMomentumInferShape; 185 g_infer_func[PrimType_ArgMaxFusion] = ArgMinMaxInferShape; 186 g_infer_func[PrimType_ArgMinFusion] = ArgMinMaxInferShape; 187 g_infer_func[PrimType_Assert] = AssertOpInferShape; 188 g_infer_func[PrimType_Assign] = AssignInferShape; 189 g_infer_func[PrimType_AssignAdd] = AssignAddInferShape; 190 g_infer_func[PrimType_Attention] = AttentionInferShape; 191 g_infer_func[PrimType_AudioSpectrogram] = AudioSpectrogramInferShape; 192 g_infer_func[PrimType_AvgPoolFusion] = PoolingInferShape; 193 g_infer_func[PrimType_AvgPoolGrad] = PoolingGradInferShape; 194 g_infer_func[PrimType_BatchNorm] = CommonInferShape; 195 g_infer_func[PrimType_BatchNormGrad] = BnGradInferShape; 196 g_infer_func[PrimType_BatchToSpace] = BatchToSpaceInferShape; 197 g_infer_func[PrimType_BatchToSpaceND] = NULL; 198 g_infer_func[PrimType_BiasAdd] = ArithmeticInferShape; 199 g_infer_func[PrimType_BiasAddGrad] = BiasGradInferShape; 200 g_infer_func[PrimType_BinaryCrossEntropy] = BinaryCrossEntropyInferShape; 201 g_infer_func[PrimType_BinaryCrossEntropyGrad] = CommonInferShape; 202 g_infer_func[PrimType_BroadcastTo] = BroadcastToInferShape; 203 g_infer_func[PrimType_Call] = InvalidInferShape; 204 g_infer_func[PrimType_Cast] = CastInferShape; 205 g_infer_func[PrimType_Ceil] = CommonInferShape; 206 g_infer_func[PrimType_Clip] = CommonInferShape; 207 g_infer_func[PrimType_Concat] = ConcatInferShape; 208 g_infer_func[PrimType_ConstantOfShape] = ConstantOfShapeInferShape; 209 g_infer_func[PrimType_Conv2DBackpropFilterFusion] = Conv2dGradFilterInferShape; 210 g_infer_func[PrimType_Conv2DBackpropInputFusion] = Conv2dGradInputInferShape; 211 g_infer_func[PrimType_Conv2DFusion] = Conv2dInferShape; 212 g_infer_func[PrimType_Conv2dTransposeFusion] = Deconv2dInferShape; 213 g_infer_func[PrimType_Cos] = CommonInferShape; 214 g_infer_func[PrimType_Crop] = CropInferShape; 215 g_infer_func[PrimType_CropAndResize] = CropAndResizeInferShape; 216 g_infer_func[PrimType_CumSum] = CumsumInferShape; 217 g_infer_func[PrimType_Custom] = NULL; 218 #ifdef MSLITE_ENABLE_STRING_KERNEL 219 g_infer_func[PrimType_CustomExtractFeatures] = CustomExtractFeaturesInferShape; 220 #endif 221 } 222 RegAllInferFunc2()223void RegAllInferFunc2() { 224 #ifdef MSLITE_ENABLE_STRING_KERNEL 225 g_infer_func[PrimType_CustomNormalize] = CustomNormalizeInferShape; 226 g_infer_func[PrimType_CustomPredict] = CustomPredictInferShape; 227 #endif 228 g_infer_func[PrimType_DeConv2DGradFilter] = NULL; 229 g_infer_func[PrimType_Depend] = CommonInferShape; 230 g_infer_func[PrimType_DepthToSpace] = DepthToSpaceInferShape; 231 g_infer_func[PrimType_DetectionPostProcess] = DetectionPostProcessInferShape; 232 g_infer_func[PrimType_DivFusion] = ArithmeticInferShape; 233 g_infer_func[PrimType_DivGrad] = ArithmeticGradInferShape; 234 g_infer_func[PrimType_Dropout] = DropoutInferShape; 235 g_infer_func[PrimType_DropoutGrad] = DropoutGradInferShape; 236 g_infer_func[PrimType_DynamicQuant] = DynamicQuantInferShape; 237 g_infer_func[PrimType_Eltwise] = ArithmeticInferShape; 238 g_infer_func[PrimType_Elu] = CommonInferShape; 239 g_infer_func[PrimType_EmbeddingLookupFusion] = EmbeddingLookupInferShape; 240 g_infer_func[PrimType_Equal] = ArithmeticCompareInferShape; 241 g_infer_func[PrimType_Erf] = CommonInferShape; 242 g_infer_func[PrimType_ExpandDims] = ExpandDimsInferShape; 243 g_infer_func[PrimType_ExpFusion] = CommonInferShape; 244 g_infer_func[PrimType_FakeQuantWithMinMaxVars] = CommonInferShape; 245 g_infer_func[PrimType_FakeQuantWithMinMaxVarsPerChannel] = NULL; 246 g_infer_func[PrimType_FftImag] = FftImagInferShape; 247 g_infer_func[PrimType_FftReal] = FftRealInferShape; 248 g_infer_func[PrimType_Fill] = FillInferShape; 249 g_infer_func[PrimType_FillV2] = FillInferShape; 250 g_infer_func[PrimType_Flatten] = FlattenInferShape; 251 g_infer_func[PrimType_FlattenGrad] = FlattenGradInferShape; 252 g_infer_func[PrimType_Floor] = CommonInferShapeWithOneInput; 253 g_infer_func[PrimType_FloorDiv] = ArithmeticInferShape; 254 g_infer_func[PrimType_FloorMod] = ArithmeticInferShape; 255 g_infer_func[PrimType_FullConnection] = FullConnectionInferShape; 256 g_infer_func[PrimType_FusedBatchNorm] = FusedBatchNormInferShape; 257 g_infer_func[PrimType_Gather] = GatherInferShape; 258 g_infer_func[PrimType_GatherNd] = GatherNdInferShape; 259 g_infer_func[PrimType_GenOP] = NULL; 260 g_infer_func[PrimType_GLU] = GluInferShape; 261 g_infer_func[PrimType_Greater] = ArithmeticCompareInferShape; 262 g_infer_func[PrimType_GreaterEqual] = ArithmeticCompareInferShape; 263 g_infer_func[PrimType_GRU] = GruInferShape; 264 #ifdef MSLITE_ENABLE_STRING_KERNEL 265 g_infer_func[PrimType_HashtableLookup] = HashtableLoopupInferShape; 266 #endif 267 g_infer_func[PrimType_InstanceNorm] = InstanceNormInferShape; 268 g_infer_func[PrimType_InvertPermutation] = InvertPermutationInferShape; 269 g_infer_func[PrimType_IsFinite] = IsFiniteInferShape; 270 g_infer_func[PrimType_L2NormalizeFusion] = CommonInferShape; 271 g_infer_func[PrimType_LayerNormFusion] = LayerNormInferShape; 272 g_infer_func[PrimType_LayerNormGrad] = LayerNormGradInferShape; 273 g_infer_func[PrimType_LeakyRelu] = CommonInferShape; 274 g_infer_func[PrimType_Less] = ArithmeticCompareInferShape; 275 g_infer_func[PrimType_LessEqual] = ArithmeticCompareInferShape; 276 g_infer_func[PrimType_LinSpace] = LinSpaceInferShape; 277 } 278 RegAllInferFunc3()279void RegAllInferFunc3() { 280 g_infer_func[PrimType_Log] = CommonInferShape; 281 g_infer_func[PrimType_LogGrad] = CommonGradInferShape; 282 g_infer_func[PrimType_LogicalAnd] = ArithmeticInferShape; 283 g_infer_func[PrimType_LogicalNot] = CommonInferShape; 284 g_infer_func[PrimType_LogicalOr] = ArithmeticInferShape; 285 g_infer_func[PrimType_LogSoftmax] = LogSoftmaxInferShape; 286 g_infer_func[PrimType_LpNormalization] = NULL; 287 g_infer_func[PrimType_LRN] = CommonInferShapeWithNHWC; 288 #ifdef MSLITE_ENABLE_STRING_KERNEL 289 g_infer_func[PrimType_LshProjection] = LshProjectionInferShape; 290 #endif 291 g_infer_func[PrimType_LSTM] = LstmInferShape; 292 g_infer_func[PrimType_LSTMGrad] = LstmGradInferShape; 293 g_infer_func[PrimType_LSTMGradData] = LstmGradDataInferShape; 294 g_infer_func[PrimType_LSTMGradWeight] = LstmGradWeightInferShape; 295 g_infer_func[PrimType_MatMulFusion] = MatmulInferShape; 296 g_infer_func[PrimType_Maximum] = ArithmeticInferShape; 297 g_infer_func[PrimType_MaximumGrad] = MaxMinGradInferShape; 298 g_infer_func[PrimType_MaxPoolFusion] = PoolingInferShape; 299 g_infer_func[PrimType_MaxPoolGrad] = PoolingGradInferShape; 300 g_infer_func[PrimType_SwitchLayer] = InvalidInferShape; 301 g_infer_func[PrimType_Mfcc] = MfccInferShape; 302 g_infer_func[PrimType_MIN] = NULL; 303 g_infer_func[PrimType_Minimum] = ArithmeticInferShape; 304 g_infer_func[PrimType_MinimumGrad] = MaxMinGradInferShape; 305 g_infer_func[PrimType_Mod] = ArithmeticInferShape; 306 g_infer_func[PrimType_MulFusion] = ArithmeticInferShape; 307 g_infer_func[PrimType_MulGrad] = ArithmeticGradInferShape; 308 g_infer_func[PrimType_Neg] = CommonInferShape; 309 g_infer_func[PrimType_NegGrad] = CommonGradInferShape; 310 g_infer_func[PrimType_NLLLoss] = NLLLossInferShape; 311 g_infer_func[PrimType_NLLLossGrad] = NLLLossGradInferShape; 312 g_infer_func[PrimType_NonMaxSuppression] = NonMaxSuppressionInferShape; 313 g_infer_func[PrimType_NonZero] = NULL; 314 g_infer_func[PrimType_NotEqual] = ArithmeticCompareInferShape; 315 g_infer_func[PrimType_OneHot] = OneHotInferShape; 316 g_infer_func[PrimType_OnesLike] = NULL; 317 g_infer_func[PrimType_PadFusion] = PadInferShape; 318 g_infer_func[PrimType_PartialFusion] = InvalidInferShape; 319 g_infer_func[PrimType_PowerGrad] = CommonGradInferShape; 320 g_infer_func[PrimType_PowFusion] = PowerInferShape; 321 g_infer_func[PrimType_PReLUFusion] = CommonInferShape; 322 g_infer_func[PrimType_PriorBox] = PriorBoxInferShape; 323 g_infer_func[PrimType_QuantDTypeCast] = QuantDtypeCastInferShape; 324 g_infer_func[PrimType_RaggedRange] = RaggedRangeInferShape; 325 g_infer_func[PrimType_RandomNormal] = RandomNormalInferShape; 326 g_infer_func[PrimType_RandomStandardNormal] = RandomStandardNormalInferShape; 327 g_infer_func[PrimType_Range] = RangeInferShape; 328 g_infer_func[PrimType_Rank] = RankInferShape; 329 } 330 RegAllInferFunc4()331void RegAllInferFunc4() { 332 g_infer_func[PrimType_RealDiv] = ArithmeticInferShape; 333 g_infer_func[PrimType_Reciprocal] = CommonInferShape; 334 g_infer_func[PrimType_ReduceFusion] = ReduceInferShape; 335 g_infer_func[PrimType_ReduceScatter] = ReduceScatterInferShape; 336 g_infer_func[PrimType_Reshape] = ReshapeInferShape; 337 g_infer_func[PrimType_Resize] = ResizeInferShape; 338 g_infer_func[PrimType_ResizeGrad] = ResizeGradInferShape; 339 g_infer_func[PrimType_ReverseSequence] = CommonInferShape; 340 g_infer_func[PrimType_ReverseV2] = CommonInferShape; 341 g_infer_func[PrimType_Rfft] = RfftInferShape; 342 g_infer_func[PrimType_ROIPooling] = ROIPoolingInferShape; 343 g_infer_func[PrimType_Round] = CommonInferShape; 344 g_infer_func[PrimType_Rsqrt] = CommonInferShape; 345 g_infer_func[PrimType_RsqrtGrad] = NULL; 346 g_infer_func[PrimType_ScaleFusion] = CommonInferShape; 347 g_infer_func[PrimType_ScatterNd] = ScatterNdInferShape; 348 g_infer_func[PrimType_ScatterNdUpdate] = ScatterNdUpdateInferShape; 349 g_infer_func[PrimType_TensorScatterAdd] = ScatterNdUpdateInferShape; 350 g_infer_func[PrimType_Select] = SelectInferShape; 351 g_infer_func[PrimType_SGD] = SgdInferShape; 352 g_infer_func[PrimType_Shape] = ShapeInferShape; 353 g_infer_func[PrimType_SigmoidCrossEntropyWithLogits] = CommonInferShape; 354 g_infer_func[PrimType_SigmoidCrossEntropyWithLogitsGrad] = CommonInferShape; 355 g_infer_func[PrimType_Sin] = CommonInferShape; 356 g_infer_func[PrimType_Size] = SizeInferShape; 357 #ifdef MSLITE_ENABLE_STRING_KERNEL 358 g_infer_func[PrimType_SkipGram] = SkipGramInferShape; 359 #endif 360 g_infer_func[PrimType_SliceFusion] = SliceInferShape; 361 g_infer_func[PrimType_SmoothL1Loss] = CommonInferShape; 362 g_infer_func[PrimType_SmoothL1LossGrad] = CommonInferShape; 363 g_infer_func[PrimType_Softmax] = SoftMaxInferShape; 364 g_infer_func[PrimType_SoftmaxCrossEntropyWithLogits] = SoftmaxCrossEntropyInferShape; 365 g_infer_func[PrimType_SpaceToBatch] = SpaceToBatchInferShape; 366 g_infer_func[PrimType_SpaceToBatchND] = SpaceToBatchNdInferShape; 367 g_infer_func[PrimType_SpaceToDepth] = SpaceToDepthInferShape; 368 g_infer_func[PrimType_SparseSoftmaxCrossEntropyWithLogits] = SparseSoftmaxCrossEntropyWithLogitsInferShape; 369 g_infer_func[PrimType_SparseToDense] = SparseToDenseInferShape; 370 g_infer_func[PrimType_Splice] = SpliceInferShape; 371 g_infer_func[PrimType_Split] = SplitInferShape; 372 g_infer_func[PrimType_SplitWithOverlap] = SplitWithOverlapInferShape; 373 g_infer_func[PrimType_Sqrt] = CommonInferShape; 374 g_infer_func[PrimType_SqrtGrad] = NULL; 375 g_infer_func[PrimType_Square] = CommonInferShape; 376 g_infer_func[PrimType_SquaredDifference] = ArithmeticInferShape; 377 g_infer_func[PrimType_Squeeze] = SqueezeInferShape; 378 g_infer_func[PrimType_Stack] = StackInferShape; 379 g_infer_func[PrimType_StridedSlice] = StridedSliceInferShape; 380 g_infer_func[PrimType_StridedSliceGrad] = StridedSliceGradInferShape; 381 g_infer_func[PrimType_SubFusion] = ArithmeticInferShape; 382 g_infer_func[PrimType_SubGrad] = AddSubGradInferShape; 383 } 384 RegAllInferFunc5()385void RegAllInferFunc5() { 386 g_infer_func[PrimType_Switch] = InvalidInferShape; 387 #ifdef MSLITE_ENABLE_CONTROLFLOW 388 g_infer_func[PrimType_TensorArray] = TensorArrayInferShape; 389 g_infer_func[PrimType_TensorArrayRead] = TensorArrayReadInferShape; 390 g_infer_func[PrimType_TensorArrayWrite] = TensorArrayWriteInferShape; 391 g_infer_func[PrimType_TensorListFromTensor] = TensorListFromTensorInferShape; 392 g_infer_func[PrimType_TensorListGetItem] = TensorListGetItemInferShape; 393 g_infer_func[PrimType_TensorListReserve] = TensorListReserveInferShape; 394 g_infer_func[PrimType_TensorListSetItem] = TensorListSetItemInferShape; 395 g_infer_func[PrimType_TensorListStack] = TensorListStackInferShape; 396 #endif 397 g_infer_func[PrimType_TileFusion] = TileInferShape; 398 g_infer_func[PrimType_TopKFusion] = TopKInferShape; 399 g_infer_func[PrimType_Transpose] = TransposeInferShape; 400 g_infer_func[PrimType_UniformReal] = UniformRealInferShape; 401 g_infer_func[PrimType_Unique] = UniqueInferShape; 402 g_infer_func[PrimType_UnsortedSegmentSum] = UnsortedSegmentSumInferShape; 403 g_infer_func[PrimType_Unsqueeze] = UnsqueezeInferShape; 404 g_infer_func[PrimType_Unstack] = UnstackInferShape; 405 g_infer_func[PrimType_Where] = WhereInferShape; 406 g_infer_func[PrimType_ZerosLike] = CommonInferShape; 407 408 // fused operators. 409 g_inner_op_infer_func[PrimType_Inner_GltextureToOpencl - PrimType_InnerOpMin] = NULL; 410 g_inner_op_infer_func[PrimType_Inner_Identity - PrimType_InnerOpMin] = NULL; 411 #ifndef RUNTIME_PASS_CLIP 412 g_inner_op_infer_func[PrimType_Inner_ShapeFusion - PrimType_InnerOpMin] = ShapeFusionInferShape; 413 g_inner_op_infer_func[PrimType_Inner_EncoderLayer - PrimType_InnerOpMin] = EncoderLayerInferShape; 414 g_inner_op_infer_func[PrimType_Inner_DecoderLayer - PrimType_InnerOpMin] = DecoderLayerInferShape; 415 g_inner_op_infer_func[PrimType_Inner_FseDecode - PrimType_InnerOpMin] = FseDecoderInferShape; 416 #endif 417 g_inner_op_infer_func[PrimType_Inner_CustomGru - PrimType_InnerOpMin] = CustomGruInferShape; 418 g_inner_op_infer_func[PrimType_Inner_ToFormat - PrimType_InnerOpMin] = NULL; 419 } 420 421 #else 422 InferShape g_infer_func[PrimType_MAX] = {0}; 423 InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0}; 424 #endif // _MSC_VER 425 GetInferFunc(int prim_type)426InferShape GetInferFunc(int prim_type) { 427 #ifdef _MSC_VER 428 if (g_infer_func[PrimType_Abs] == NULL) { 429 RegAllInferFunc1(); 430 RegAllInferFunc2(); 431 RegAllInferFunc3(); 432 RegAllInferFunc4(); 433 RegAllInferFunc5(); 434 } 435 #endif 436 if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) { 437 return g_infer_func[prim_type]; 438 } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) { 439 return g_inner_op_infer_func[prim_type - PrimType_InnerOpMin]; 440 } 441 return NULL; 442 } 443 RegInfer(int prim_type,InferShape func)444void RegInfer(int prim_type, InferShape func) { 445 if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) { 446 g_infer_func[prim_type] = func; 447 } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) { 448 g_inner_op_infer_func[prim_type - PrimType_InnerOpMin] = func; 449 } 450 } 451