• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_OPS_OP_UTILS_H
18 #define MINDSPORE_CORE_OPS_OP_UTILS_H
19 #include <string>
20 #include <set>
21 #include <vector>
22 #include <algorithm>
23 #include <memory>
24 #include "abstract/primitive_infer_map.h"
25 #include "utils/check_convert_utils.h"
26 
27 namespace mindspore::ops {
28 constexpr auto kAlpha = "alpha";
29 constexpr auto kActivation = "activation";
30 constexpr auto kActivationType = "activation_type";
31 constexpr auto kAttentionQActType = "attention_q_act_type";
32 constexpr auto kAttentionKActType = "attention_k_act_type";
33 constexpr auto kAttentionVActType = "attention_v_act_type";
34 constexpr auto kAddress = "address";
35 constexpr auto kAlignCorners = "align_corners";
36 constexpr auto kAttr = "attr";
37 constexpr auto kAspectRatios = "aspect_ratios";
38 constexpr auto kAxes = "axes";
39 constexpr auto kAxis = "axis";
40 constexpr auto kAxisType = "axis_type";
41 constexpr auto kBaseSize = "base_size";
42 constexpr auto kBatchDim = "batch_dim";
43 constexpr auto kBeginMask = "begin_mask";
44 constexpr auto kBeginNormAxis = "begin_norm_axis";
45 constexpr auto kBeginParamsAxis = "begin_params_axis";
46 constexpr auto kBeta = "beta";
47 constexpr auto kBias = "bias";
48 constexpr auto kBidirectional = "bidirectional";
49 constexpr auto kBlockSize = "block_size";
50 constexpr auto kBlockShape = "block_shape";
51 constexpr auto kCellClip = "cell_clip";
52 constexpr auto kCellDepth = "cell_depth";
53 constexpr auto kCenterPointBox = "center_point_box";
54 constexpr auto kClip = "clip";
55 constexpr auto kCondition = "condition";
56 constexpr auto kCrops = "crops";
57 constexpr auto kCustom = "custom";
58 constexpr auto kDampening = "dampening";
59 constexpr auto kDataType = "data_type";
60 constexpr auto kDctCoeffNum = "dct_coeff_num";
61 constexpr auto kDelta = "delta";
62 constexpr auto kDependMode = "depend_mode";
63 constexpr auto kDepthRadius = "depth_radius";
64 constexpr auto kDetectionsPerClass = "detections_per_class";
65 constexpr auto kDilation = "dilation";
66 constexpr auto kDropout = "dropout";
67 constexpr auto kDstT = "dst_t";
68 constexpr auto kDType = "d_type";
69 constexpr auto kEllipsisMask = "ellipsis_mask";
70 constexpr auto kEndMask = "end_mask";
71 constexpr auto kEps = "eps";
72 constexpr auto kEpsilon = "epsilon";
73 constexpr auto kElement_dtype = "element_dtype";
74 constexpr auto kFeatStride = "feat_stride";
75 constexpr auto kFftLength = "fft_length";
76 constexpr auto kFilterBankChannelNum = "filter_bank_channel_num";
77 constexpr auto kFlip = "flip";
78 constexpr auto kFormat = "format";
79 constexpr auto kOriginalFormat = "OriginalFormat";
80 constexpr auto kFreqLowerLimit = "freq_lower_limit";
81 constexpr auto kFreqUpperLimit = "freq_upper_limit";
82 constexpr auto kFreezeBn = "freeze_bn";
83 constexpr auto kGateOrder = "gate_order";
84 constexpr auto kGlobal = "global";
85 constexpr auto kGrad = "grad";
86 constexpr auto kIsGrad = "is_grad";
87 constexpr auto kGradientScale = "gradient_scale";
88 constexpr auto kGradX = "grad_x";
89 constexpr auto kGradY = "grad_y";
90 constexpr auto kGroup = "group";
91 constexpr auto kHasBias = "has_bias";
92 constexpr auto kAttentionHasMask = "attention_has_mask";
93 constexpr auto kHiddenSize = "hidden_size";
94 constexpr auto kId = "id";
95 constexpr auto kImageSizeH = "image_size_h";
96 constexpr auto kImageSizeW = "image_size_w";
97 constexpr auto kIncludeALLGrams = "include_all_grams";
98 constexpr auto kInputSize = "input_size";
99 constexpr auto kInChannel = "in_channel";
100 constexpr auto kInputShape = "input_shape";
101 constexpr auto kIoFormat = "io_format";
102 constexpr auto kIsScale = "is_scale";
103 constexpr auto kIsTraining = "is_training";
104 constexpr auto kKeepDims = "keep_dims";
105 constexpr auto kKeepProb = "keep_prob";
106 constexpr auto kKernelSize = "kernel_size";
107 constexpr auto kLimit = "limit";
108 constexpr auto kMagSquare = "mag_square";
109 constexpr auto kMax = "max";
110 constexpr auto kMaxSizes = "max_sizes";
111 constexpr auto kMaxSkipSize = "max_skip_size";
112 constexpr auto kMaxClassesPerDetection = "max_classes_per_detection";
113 constexpr auto kMaxDetections = "max_detections";
114 constexpr auto kMaxNorm = "max_norm";
115 constexpr auto kMin = "min";
116 constexpr auto kMinSize = "min_size";
117 constexpr auto kMinSizes = "min_sizes";
118 constexpr auto kMode = "mode";
119 constexpr auto kMomentum = "momentum";
120 constexpr auto kN = "n";
121 constexpr auto kNarrowRange = "narrow_range";
122 constexpr auto kNesterov = "nesterov";
123 constexpr auto kNewAxisMask = "new_axis_mask";
124 constexpr auto kNgramSize = "ngram_size";
125 constexpr auto kNmsThresh = "nms_thresh";
126 constexpr auto kNormRegion = "norm_region";
127 constexpr auto kNumLayers = "num_layers";
128 constexpr auto kNumElements = "num_elements";
129 constexpr auto kNumBits = "num_bits";
130 constexpr auto kNumDirections = "num_directions";
131 constexpr auto kNumProj = "num_proj";
132 constexpr auto kAttentionNumHeads = "attention_num_heads";
133 constexpr auto kAttentionSizePerHead = "attention_size_per_head";
134 constexpr auto kAttentionFromSeqLen = "attention_from_seq_len";
135 constexpr auto kAttentionToSeqLen = "attention_to_seq_len";
136 constexpr auto kOffset = "offset";
137 constexpr auto kNmsIouThreshold = "nms_iou_threshold";
138 constexpr auto kNmsScoreThreshold = "nms_score_threshold";
139 constexpr auto kNumClasses = "num_classes";
140 constexpr auto kOffsets = "offsets";
141 constexpr auto kOffsetA = "offset_a";
142 constexpr auto kOrder = "order";
143 constexpr auto kOutChannel = "out_channel";
144 constexpr auto kOutMaxValue = "out_max_value";
145 constexpr auto kOutputChannel = "output_channel";
146 constexpr auto kOutputNum = "output_num";
147 constexpr auto kOutputPaddings = "output_paddings";
148 constexpr auto kOutputType = "output_type";
149 constexpr auto kOutQuantized = "out_quantized";
150 constexpr auto kP = "p";
151 constexpr auto kPad = "pad";
152 constexpr auto kPadding = "padding";
153 constexpr auto kPaddingsElementSize = "paddings_element_size";
154 constexpr auto kPaddingsSize = "paddings_size";
155 constexpr auto kPadItem = "pad_item";
156 constexpr auto kPadList = "pad_list";
157 constexpr auto kPadMode = "pad_mode";
158 constexpr auto kPads = "pads";
159 constexpr auto kPadSize = "pad_size";
160 constexpr auto kPooledH = "pooled_h";
161 constexpr auto kPooledW = "pooled_w";
162 constexpr auto kPoolMode = "pool_mode";
163 constexpr auto kCeilMode = "ceil_mode";
164 constexpr auto kCountIncludePad = "count_include_pad";
165 constexpr auto kDivisorOverride = "divisor_override";
166 constexpr auto kPostNmsTopn = "post_nms_topn";
167 constexpr auto kPower = "power";
168 constexpr auto kPreNmsTopn = "pre_nms_topn";
169 constexpr auto kRatio = "ratio";
170 constexpr auto kReduction = "reduction";
171 constexpr auto kRootRank = "root_rank";
172 constexpr auto kRoundMode = "round_mode";
173 constexpr auto kSame = "same";
174 constexpr auto kScale = "scale";
175 constexpr auto kSeed = "seed";
176 constexpr auto kSeed2 = "seed2";
177 constexpr auto kSeqDim = "seq_dim";
178 constexpr auto kSetattrFlag = "setattr_flag";
179 constexpr auto kShape = "shape";
180 constexpr auto kShapeGamma = "shape_gamma";
181 constexpr auto kShapeSize = "shape_size";
182 constexpr auto kShift = "shift";
183 constexpr auto kShrinkAxisMask = "shrink_axis_mask";
184 constexpr auto kSize = "size";
185 constexpr auto kSorted = "sorted";
186 constexpr auto kSrcT = "src_t";
187 constexpr auto kStart = "start";
188 constexpr auto kStepH = "step_h";
189 constexpr auto kStepW = "step_w";
190 constexpr auto kStride = "stride";
191 constexpr auto kStrides = "strides";
192 constexpr auto kShapeType = "shape_type";
193 constexpr auto kSubGraphIndex = "sub_graph_index";
194 constexpr auto kSummarize = "summarize";
195 constexpr auto kTimeMajor = "time_major";
196 constexpr auto kTopK = "top_k";
197 constexpr auto kTransposeA = "transpose_a";
198 constexpr auto kTransposeB = "transpose_b";
199 constexpr auto kNegativeSlope = "negative_slope";
200 constexpr auto kType = "type";
201 constexpr auto kUseAxis = "use_axis";
202 constexpr auto kUseLocking = "use_locking";
203 constexpr auto kUseNesterov = "use_nesterov";
204 constexpr auto kUseNesteroy = "use_nesteroy";
205 constexpr auto kUseRegularNms = "use_regular_nms";
206 constexpr auto kValid = "valid";
207 constexpr auto kValue = "value";
208 constexpr auto kVariances = "variances";
209 constexpr auto kWeightDecay = "weight_decay";
210 constexpr auto kWeightThreshold = "weight_threshold";
211 constexpr auto kWindow = "window";
212 constexpr auto kWindowSize = "window_size";
213 constexpr auto kPaddings = "paddings";
214 constexpr auto kInput_size = "input_size";
215 constexpr auto kHidden_size = "hidden_size";
216 constexpr auto kChannelShared = "channel_shared";
217 constexpr auto kSlope = "slope";
218 constexpr auto kBase = "base";
219 constexpr auto kConstantValue = "constant_value";
220 constexpr auto kSizeSplits = "size_splits";
221 constexpr auto kDims = "dims";
222 constexpr auto kPaddingMode = "padding_mode";
223 constexpr auto kLargest = "largest";
224 constexpr auto kElementwiseAffine = "elementwise_affine";
225 constexpr auto kMinVal = "min_val";
226 constexpr auto kMaxVal = "max_val";
227 constexpr auto kMethod = "method";
228 constexpr auto kNewHeight = "new_height";
229 constexpr auto kNewWidth = "new_width";
230 constexpr auto kPreserveAspectRatio = "preserve_aspect_ratio";
231 constexpr auto kCoordinateTransformMode = "coordinate_transform_mode";
232 constexpr auto kCubicCoeff = "cubic_coeff";
233 constexpr auto kExcludeOutside = "exclude_outside";
234 constexpr auto kExtrapolationValue = "extrapolation_value";
235 constexpr auto kNearestMode = "nearest_mode";
236 constexpr auto kReduceToEnd = "reduce_to_end";
237 constexpr auto kResetAfter = "reset_after";
238 constexpr auto kCoeff = "coeff";
239 constexpr auto kIsDepthWise = "is_depth_wise";
240 constexpr auto kZoneoutCell = "zoneout_cell";
241 constexpr auto kZoneoutHidden = "zoneout_hidden";
242 constexpr auto kSpliceContext = "context";
243 constexpr auto kSpliceForwardIndexes = "forward_indexes";
244 constexpr auto kSpliceOutputDims = "output_dim";
245 constexpr auto kSideEffectIO = "side_effect_io";
246 constexpr auto kDeviceType = "device_type";
247 constexpr auto kExclusive = "exclusive";
248 constexpr auto kReverse = "reverse";
249 constexpr auto kSplitStride = "split_stride";
250 constexpr auto kExtendTop = "extend_top";
251 constexpr auto kExtendBottom = "extend_bottom";
252 constexpr auto kNumberSplit = "number_split";
253 constexpr auto kSplitDim = "split_dim";
254 constexpr auto kPadTop = "pad_top";
255 constexpr auto kTransFormat = "trans_format";
256 constexpr auto kApproximate = "approximate";
257 constexpr auto kNumOutput = "num_output";
258 constexpr auto kUseGlobalStats = "use_global_stats";
259 constexpr auto kFmkType = "fmk_type";
260 
261 enum Index : size_t {
262   kInputIndex0 = 0,
263   kInputIndex1,
264   kInputIndex2,
265   kInputIndex3,
266   kInputIndex4,
267   kInputIndex5,
268   kInputIndex6,
269   kInputIndex7,
270   kInputIndex8,
271   kInputIndex9,
272   kInputIndex10,
273   kInputIndex11,
274   kInputIndex12,
275   kInputIndex13,
276   kInputIndex14,
277   kInputIndex15,
278   kInputIndex16,
279 };
280 
281 const std::set<TypePtr> common_valid_types = {kInt8,   kInt16,  kInt32,   kInt64,   kUInt8,  kUInt16,
282                                               kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
283 
284 const std::set<TypePtr> all_types = {
285   kBool,   kInt,    kInt8,   kInt16, kInt32,   kInt64,   kUInt,    kUInt8,
286   kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64,
287 };
288 
289 std::vector<int64_t> CalBroadCastShape(std::vector<int64_t> x_shape, std::vector<int64_t> y_shape,
290                                        const std::string &op_name, const std::string &op_x_name = "input1",
291                                        const std::string &op_y_name = "input2");
292 abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args);
293 }  // namespace mindspore::ops
294 #endif  // MINDSPORE_CORE_OPS_OP_UTILS_H
295