1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_OPS_OP_UTILS_H 18 #define MINDSPORE_CORE_OPS_OP_UTILS_H 19 #include <string> 20 #include <set> 21 #include <vector> 22 #include <algorithm> 23 #include <memory> 24 #include "abstract/primitive_infer_map.h" 25 #include "utils/check_convert_utils.h" 26 27 namespace mindspore::ops { 28 constexpr auto kAlpha = "alpha"; 29 constexpr auto kActivation = "activation"; 30 constexpr auto kActivationType = "activation_type"; 31 constexpr auto kAttentionQActType = "attention_q_act_type"; 32 constexpr auto kAttentionKActType = "attention_k_act_type"; 33 constexpr auto kAttentionVActType = "attention_v_act_type"; 34 constexpr auto kAddress = "address"; 35 constexpr auto kAlignCorners = "align_corners"; 36 constexpr auto kAttr = "attr"; 37 constexpr auto kAspectRatios = "aspect_ratios"; 38 constexpr auto kAxes = "axes"; 39 constexpr auto kAxis = "axis"; 40 constexpr auto kAxisType = "axis_type"; 41 constexpr auto kBaseSize = "base_size"; 42 constexpr auto kBatchDim = "batch_dim"; 43 constexpr auto kBeginMask = "begin_mask"; 44 constexpr auto kBeginNormAxis = "begin_norm_axis"; 45 constexpr auto kBeginParamsAxis = "begin_params_axis"; 46 constexpr auto kBeta = "beta"; 47 constexpr auto kBias = "bias"; 48 constexpr auto kBidirectional = "bidirectional"; 49 constexpr auto kBlockSize = "block_size"; 50 constexpr auto kBlockShape = "block_shape"; 51 constexpr auto kCellClip = "cell_clip"; 52 constexpr auto kCellDepth = "cell_depth"; 53 constexpr auto kCenterPointBox = "center_point_box"; 54 constexpr auto kClip = "clip"; 55 constexpr auto kCondition = "condition"; 56 constexpr auto kCrops = "crops"; 57 constexpr auto kCustom = "custom"; 58 constexpr auto kDampening = "dampening"; 59 constexpr auto kDataType = "data_type"; 60 constexpr auto kDctCoeffNum = "dct_coeff_num"; 61 constexpr auto kDelta = "delta"; 62 constexpr auto kDependMode = "depend_mode"; 63 constexpr auto kDepthRadius = "depth_radius"; 64 constexpr auto kDetectionsPerClass = "detections_per_class"; 65 constexpr auto kDilation = "dilation"; 66 constexpr auto kDropout = "dropout"; 67 constexpr auto kDstT = "dst_t"; 68 constexpr auto kDType = "d_type"; 69 constexpr auto kEllipsisMask = "ellipsis_mask"; 70 constexpr auto kEndMask = "end_mask"; 71 constexpr auto kEps = "eps"; 72 constexpr auto kEpsilon = "epsilon"; 73 constexpr auto kElement_dtype = "element_dtype"; 74 constexpr auto kFeatStride = "feat_stride"; 75 constexpr auto kFftLength = "fft_length"; 76 constexpr auto kFilterBankChannelNum = "filter_bank_channel_num"; 77 constexpr auto kFlip = "flip"; 78 constexpr auto kFormat = "format"; 79 constexpr auto kOriginalFormat = "OriginalFormat"; 80 constexpr auto kFreqLowerLimit = "freq_lower_limit"; 81 constexpr auto kFreqUpperLimit = "freq_upper_limit"; 82 constexpr auto kFreezeBn = "freeze_bn"; 83 constexpr auto kGateOrder = "gate_order"; 84 constexpr auto kGlobal = "global"; 85 constexpr auto kGrad = "grad"; 86 constexpr auto kIsGrad = "is_grad"; 87 constexpr auto kGradientScale = "gradient_scale"; 88 constexpr auto kGradX = "grad_x"; 89 constexpr auto kGradY = "grad_y"; 90 constexpr auto kGroup = "group"; 91 constexpr auto kHasBias = "has_bias"; 92 constexpr auto kAttentionHasMask = "attention_has_mask"; 93 constexpr auto kHiddenSize = "hidden_size"; 94 constexpr auto kId = "id"; 95 constexpr auto kImageSizeH = "image_size_h"; 96 constexpr auto kImageSizeW = "image_size_w"; 97 constexpr auto kIncludeALLGrams = "include_all_grams"; 98 constexpr auto kInputSize = "input_size"; 99 constexpr auto kInChannel = "in_channel"; 100 constexpr auto kInputShape = "input_shape"; 101 constexpr auto kIoFormat = "io_format"; 102 constexpr auto kIsScale = "is_scale"; 103 constexpr auto kIsTraining = "is_training"; 104 constexpr auto kKeepDims = "keep_dims"; 105 constexpr auto kKeepProb = "keep_prob"; 106 constexpr auto kKernelSize = "kernel_size"; 107 constexpr auto kLimit = "limit"; 108 constexpr auto kMagSquare = "mag_square"; 109 constexpr auto kMax = "max"; 110 constexpr auto kMaxSizes = "max_sizes"; 111 constexpr auto kMaxSkipSize = "max_skip_size"; 112 constexpr auto kMaxClassesPerDetection = "max_classes_per_detection"; 113 constexpr auto kMaxDetections = "max_detections"; 114 constexpr auto kMaxNorm = "max_norm"; 115 constexpr auto kMin = "min"; 116 constexpr auto kMinSize = "min_size"; 117 constexpr auto kMinSizes = "min_sizes"; 118 constexpr auto kMode = "mode"; 119 constexpr auto kMomentum = "momentum"; 120 constexpr auto kN = "n"; 121 constexpr auto kNarrowRange = "narrow_range"; 122 constexpr auto kNesterov = "nesterov"; 123 constexpr auto kNewAxisMask = "new_axis_mask"; 124 constexpr auto kNgramSize = "ngram_size"; 125 constexpr auto kNmsThresh = "nms_thresh"; 126 constexpr auto kNormRegion = "norm_region"; 127 constexpr auto kNumLayers = "num_layers"; 128 constexpr auto kNumElements = "num_elements"; 129 constexpr auto kNumBits = "num_bits"; 130 constexpr auto kNumDirections = "num_directions"; 131 constexpr auto kNumProj = "num_proj"; 132 constexpr auto kAttentionNumHeads = "attention_num_heads"; 133 constexpr auto kAttentionSizePerHead = "attention_size_per_head"; 134 constexpr auto kAttentionFromSeqLen = "attention_from_seq_len"; 135 constexpr auto kAttentionToSeqLen = "attention_to_seq_len"; 136 constexpr auto kOffset = "offset"; 137 constexpr auto kNmsIouThreshold = "nms_iou_threshold"; 138 constexpr auto kNmsScoreThreshold = "nms_score_threshold"; 139 constexpr auto kNumClasses = "num_classes"; 140 constexpr auto kOffsets = "offsets"; 141 constexpr auto kOffsetA = "offset_a"; 142 constexpr auto kOrder = "order"; 143 constexpr auto kOutChannel = "out_channel"; 144 constexpr auto kOutMaxValue = "out_max_value"; 145 constexpr auto kOutputChannel = "output_channel"; 146 constexpr auto kOutputNum = "output_num"; 147 constexpr auto kOutputPaddings = "output_paddings"; 148 constexpr auto kOutputType = "output_type"; 149 constexpr auto kOutQuantized = "out_quantized"; 150 constexpr auto kP = "p"; 151 constexpr auto kPad = "pad"; 152 constexpr auto kPadding = "padding"; 153 constexpr auto kPaddingsElementSize = "paddings_element_size"; 154 constexpr auto kPaddingsSize = "paddings_size"; 155 constexpr auto kPadItem = "pad_item"; 156 constexpr auto kPadList = "pad_list"; 157 constexpr auto kPadMode = "pad_mode"; 158 constexpr auto kPads = "pads"; 159 constexpr auto kPadSize = "pad_size"; 160 constexpr auto kPooledH = "pooled_h"; 161 constexpr auto kPooledW = "pooled_w"; 162 constexpr auto kPoolMode = "pool_mode"; 163 constexpr auto kCeilMode = "ceil_mode"; 164 constexpr auto kCountIncludePad = "count_include_pad"; 165 constexpr auto kDivisorOverride = "divisor_override"; 166 constexpr auto kPostNmsTopn = "post_nms_topn"; 167 constexpr auto kPower = "power"; 168 constexpr auto kPreNmsTopn = "pre_nms_topn"; 169 constexpr auto kRatio = "ratio"; 170 constexpr auto kReduction = "reduction"; 171 constexpr auto kRootRank = "root_rank"; 172 constexpr auto kRoundMode = "round_mode"; 173 constexpr auto kSame = "same"; 174 constexpr auto kScale = "scale"; 175 constexpr auto kSeed = "seed"; 176 constexpr auto kSeed2 = "seed2"; 177 constexpr auto kSeqDim = "seq_dim"; 178 constexpr auto kSetattrFlag = "setattr_flag"; 179 constexpr auto kShape = "shape"; 180 constexpr auto kShapeGamma = "shape_gamma"; 181 constexpr auto kShapeSize = "shape_size"; 182 constexpr auto kShift = "shift"; 183 constexpr auto kShrinkAxisMask = "shrink_axis_mask"; 184 constexpr auto kSize = "size"; 185 constexpr auto kSorted = "sorted"; 186 constexpr auto kSrcT = "src_t"; 187 constexpr auto kStart = "start"; 188 constexpr auto kStepH = "step_h"; 189 constexpr auto kStepW = "step_w"; 190 constexpr auto kStride = "stride"; 191 constexpr auto kStrides = "strides"; 192 constexpr auto kShapeType = "shape_type"; 193 constexpr auto kSubGraphIndex = "sub_graph_index"; 194 constexpr auto kSummarize = "summarize"; 195 constexpr auto kTimeMajor = "time_major"; 196 constexpr auto kTopK = "top_k"; 197 constexpr auto kTransposeA = "transpose_a"; 198 constexpr auto kTransposeB = "transpose_b"; 199 constexpr auto kNegativeSlope = "negative_slope"; 200 constexpr auto kType = "type"; 201 constexpr auto kUseAxis = "use_axis"; 202 constexpr auto kUseLocking = "use_locking"; 203 constexpr auto kUseNesterov = "use_nesterov"; 204 constexpr auto kUseNesteroy = "use_nesteroy"; 205 constexpr auto kUseRegularNms = "use_regular_nms"; 206 constexpr auto kValid = "valid"; 207 constexpr auto kValue = "value"; 208 constexpr auto kVariances = "variances"; 209 constexpr auto kWeightDecay = "weight_decay"; 210 constexpr auto kWeightThreshold = "weight_threshold"; 211 constexpr auto kWindow = "window"; 212 constexpr auto kWindowSize = "window_size"; 213 constexpr auto kPaddings = "paddings"; 214 constexpr auto kInput_size = "input_size"; 215 constexpr auto kHidden_size = "hidden_size"; 216 constexpr auto kChannelShared = "channel_shared"; 217 constexpr auto kSlope = "slope"; 218 constexpr auto kBase = "base"; 219 constexpr auto kConstantValue = "constant_value"; 220 constexpr auto kSizeSplits = "size_splits"; 221 constexpr auto kDims = "dims"; 222 constexpr auto kPaddingMode = "padding_mode"; 223 constexpr auto kLargest = "largest"; 224 constexpr auto kElementwiseAffine = "elementwise_affine"; 225 constexpr auto kMinVal = "min_val"; 226 constexpr auto kMaxVal = "max_val"; 227 constexpr auto kMethod = "method"; 228 constexpr auto kNewHeight = "new_height"; 229 constexpr auto kNewWidth = "new_width"; 230 constexpr auto kPreserveAspectRatio = "preserve_aspect_ratio"; 231 constexpr auto kCoordinateTransformMode = "coordinate_transform_mode"; 232 constexpr auto kCubicCoeff = "cubic_coeff"; 233 constexpr auto kExcludeOutside = "exclude_outside"; 234 constexpr auto kExtrapolationValue = "extrapolation_value"; 235 constexpr auto kNearestMode = "nearest_mode"; 236 constexpr auto kReduceToEnd = "reduce_to_end"; 237 constexpr auto kResetAfter = "reset_after"; 238 constexpr auto kCoeff = "coeff"; 239 constexpr auto kIsDepthWise = "is_depth_wise"; 240 constexpr auto kZoneoutCell = "zoneout_cell"; 241 constexpr auto kZoneoutHidden = "zoneout_hidden"; 242 constexpr auto kSpliceContext = "context"; 243 constexpr auto kSpliceForwardIndexes = "forward_indexes"; 244 constexpr auto kSpliceOutputDims = "output_dim"; 245 constexpr auto kSideEffectIO = "side_effect_io"; 246 constexpr auto kDeviceType = "device_type"; 247 constexpr auto kExclusive = "exclusive"; 248 constexpr auto kReverse = "reverse"; 249 constexpr auto kSplitStride = "split_stride"; 250 constexpr auto kExtendTop = "extend_top"; 251 constexpr auto kExtendBottom = "extend_bottom"; 252 constexpr auto kNumberSplit = "number_split"; 253 constexpr auto kSplitDim = "split_dim"; 254 constexpr auto kPadTop = "pad_top"; 255 constexpr auto kTransFormat = "trans_format"; 256 constexpr auto kApproximate = "approximate"; 257 constexpr auto kNumOutput = "num_output"; 258 constexpr auto kUseGlobalStats = "use_global_stats"; 259 constexpr auto kFmkType = "fmk_type"; 260 261 enum Index : size_t { 262 kInputIndex0 = 0, 263 kInputIndex1, 264 kInputIndex2, 265 kInputIndex3, 266 kInputIndex4, 267 kInputIndex5, 268 kInputIndex6, 269 kInputIndex7, 270 kInputIndex8, 271 kInputIndex9, 272 kInputIndex10, 273 kInputIndex11, 274 kInputIndex12, 275 kInputIndex13, 276 kInputIndex14, 277 kInputIndex15, 278 kInputIndex16, 279 }; 280 281 const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, 282 kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; 283 284 const std::set<TypePtr> all_types = { 285 kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, 286 kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64, 287 }; 288 289 std::vector<int64_t> CalBroadCastShape(std::vector<int64_t> x_shape, std::vector<int64_t> y_shape, 290 const std::string &op_name, const std::string &op_x_name = "input1", 291 const std::string &op_y_name = "input2"); 292 abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args); 293 } // namespace mindspore::ops 294 #endif // MINDSPORE_CORE_OPS_OP_UTILS_H 295