1 /**
2 * Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 /*!
18 * \file util.cpp
19 * \brief
20 */
21 #include "util.h"
22 #include <numeric>
23 #include <utility>
24 #include <string>
25 #include <vector>
26 #include <map>
27 #include <functional>
28 #include <algorithm>
29 #include <set>
30 #include "error_util.h"
31 #include "vector_proto_profiling.h"
32 #include "op_common_util.h"
33
34 namespace ge {
35 using namespace std;
36
GetInputDataType(const ge::DataType & data_type,const std::vector<ge::DataType> & supportList)37 bool GetInputDataType(const ge::DataType &data_type, const std::vector<ge::DataType> &supportList) {
38 std::vector<ge::DataType>::const_iterator supportIter = find(supportList.begin(), supportList.end(), data_type);
39 if (supportIter == supportList.end()) {
40 return false;
41 }
42 return true;
43 }
44
CheckInputDtypeAndShape(const Operator & op,const std::map<std::string,std::vector<DataType>> & inputTensorMap)45 bool CheckInputDtypeAndShape(const Operator &op, const std::map<std::string, std::vector<DataType>> &inputTensorMap) {
46 auto iter = inputTensorMap.begin();
47 auto first_name = iter->first;
48 auto first_shape_dims = op.GetInputDescByName(iter->first.c_str()).GetShape().GetDims();
49 auto first_input_dtype = op.GetInputDescByName(iter->first.c_str()).GetDataType();
50 for (; iter != inputTensorMap.end(); ++iter) {
51 const TensorDesc input_desc = op.GetInputDescByName(iter->first.c_str());
52 // check input dtype
53 auto input_type = input_desc.GetDataType();
54 if (input_type != first_input_dtype) {
55 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
56 TbeGetName(op),
57 OtherErrMsg(ConcatString("the op type of param ", iter->first, " must equal with param ", first_name)));
58 return false;
59 }
60 auto dims = input_desc.GetShape().GetDims();
61 if (dims != first_shape_dims) {
62 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
63 TbeGetName(op),
64 OtherErrMsg(ConcatString("the op shape of param ", iter->first, " must equal with param ", first_name)));
65 return false;
66 }
67 }
68 return true;
69 }
70
CheckInputDataType(const Operator & op,const std::string & input_name,const std::vector<ge::DataType> & support_list)71 bool CheckInputDataType(const Operator &op, const std::string &input_name,
72 const std::vector<ge::DataType> &support_list) {
73 bool valid = false;
74 DataType input_type = op.GetInputDescByName(input_name.c_str()).GetDataType();
75 do {
76 const auto &found_list = find(support_list.begin(), support_list.end(), input_type);
77
78 if (found_list == support_list.end()) {
79 break;
80 }
81
82 const auto &found_map = DTYPE_STR_MAP.find(input_type);
83 if (found_map == DTYPE_STR_MAP.end()) {
84 break;
85 }
86
87 valid = true;
88 } while (0);
89
90 if (!valid) {
91 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
92 TbeGetName(op), OtherErrMsg(ConcatString("The op do not support the dtype", GeDataTypeToString(input_type))));
93 return false;
94 }
95
96 return true;
97 }
98
CheckTwoInputDtypeSame(const Operator & op,const string & input_name1,const string & input_name2)99 bool CheckTwoInputDtypeSame(const Operator &op, const string &input_name1, const string &input_name2) {
100 DataType input_type_x1 = op.GetInputDesc(input_name1).GetDataType();
101 DataType input_type_x2 = op.GetInputDesc(input_name2).GetDataType();
102 if (input_type_x1 != input_type_x2) {
103 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
104 TbeGetName(op), OtherErrMsg(ConcatString("The ", TbeGetName(op),
105 " op dtype is not same, type1:", GeDataTypeToString(input_type_x1),
106 ", type2:", GeDataTypeToString(input_type_x2))));
107 return false;
108 }
109
110 return true;
111 }
112
CheckInputDtypeSame(const Operator & op,const std::vector<std::string> & input_names)113 bool CheckInputDtypeSame(const Operator &op, const std::vector<std::string> &input_names) {
114 auto first_name = input_names.begin();
115 auto first_input_dtype = op.GetInputDescByName((*first_name).c_str()).GetDataType();
116 for (const string &input_name : input_names) {
117 const TensorDesc input_desc = op.GetInputDescByName(input_name.c_str());
118 auto input_dtype = input_desc.GetDataType();
119 if (input_dtype != first_input_dtype) {
120 auto error_ms = ConcatString("dtype of inputs must be same, ", input_name, ":", GeDataTypeToString(input_dtype),
121 ", ", (*first_name), ":", GeDataTypeToString(first_input_dtype), ".");
122 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg(error_ms));
123 return false;
124 }
125 }
126 return true;
127 }
128
CheckInputsShapeDtypeSame(const Operator & op,const std::vector<std::string> & input_names)129 bool CheckInputsShapeDtypeSame(const Operator &op, const std::vector<std::string> &input_names) {
130 auto first_input_name = input_names.begin();
131 auto first_input_des = op.GetInputDescByName((*first_input_name).c_str());
132 auto input_name = first_input_name;
133 for (++input_name; input_name != input_names.end(); ++input_name) {
134 auto input_des = op.GetInputDescByName((*first_input_name).c_str());
135 if (input_des.GetDataType() != first_input_des.GetDataType() ||
136 input_des.GetShape().GetDims() != first_input_des.GetShape().GetDims()) {
137 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
138 TbeGetName(op), OtherErrMsg(ConcatString("the dtype and shape of param ", first_input_name->c_str(),
139 " must be same as param ", input_name->c_str())));
140 return false;
141 }
142 }
143
144 return true;
145 }
146
TwoShapeAndRangeBroadcastIntegration(const Operator & op,std::vector<int64_t> & dimVec,std::vector<std::pair<int64_t,int64_t>> & Vec_range,std::vector<int64_t> dims,std::vector<std::pair<int64_t,int64_t>> range,const string & input_name1,const string & input_name2)147 bool TwoShapeAndRangeBroadcastIntegration(const Operator &op, std::vector<int64_t> &dimVec,
148 std::vector<std::pair<int64_t, int64_t>> &Vec_range,
149 std::vector<int64_t> dims, std::vector<std::pair<int64_t, int64_t>> range,
150 const string &input_name1, const string &input_name2) {
151 if (dimVec.size() < dims.size()) {
152 std::vector<int64_t> dimsTmp = dimVec;
153 dimVec = dims;
154 dims = dimsTmp;
155 std::vector<std::pair<int64_t, int64_t>> range_temp = Vec_range;
156 Vec_range = range;
157 range = range_temp;
158 }
159 if (dimVec.size() != dims.size()) {
160 int dec = static_cast<int>(dimVec.size() - dims.size());
161 for (int i = 0; i < dec; i++) {
162 dims.insert(dims.begin(), static_cast<int64_t>(1));
163 }
164 }
165 for (size_t i = 0; i < dimVec.size(); i++) {
166 CHECK((dimVec[i] != dims[i]) && (dimVec[i] != 1) && (dims[i] != 1) && (dimVec[i] != -1) && (dims[i] != -1),
167 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
168 TbeGetName(op),
169 OtherErrMsg(ConcatString("The ", TbeGetName(op), "'s dimensions does not match the broadcast rule(",
170 dimVec[i], dims[i], ")."))),
171 return false);
172 }
173 dimVec = TwoBroadcastShape(dimVec, dims);
174 if (IsUnknown(dimVec)) {
175 MakeUpShapeRange(dims, range);
176 Vec_range = TwoShapeAndRangeBroadcast(dimVec, Vec_range, range);
177 }
178 return true;
179 }
180
TwoBroadcastShape(const std::vector<int64_t> & dimsX,const std::vector<int64_t> & dimsY)181 std::vector<int64_t> TwoBroadcastShape(const std::vector<int64_t> &dimsX, const std::vector<int64_t> &dimsY) {
182 std::vector<int64_t> dimVec;
183 // when not dynamic case, do infer shape only
184 if (!IsUnknown(dimsY) && !IsUnknown(dimsX)) {
185 for (size_t i = 0; i < dimsX.size(); i++) {
186 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
187 dims = (dimsY[i] == 0 || dimsX[i] == 0) ? 0 : dims;
188 dimVec.push_back(dims);
189 }
190 return dimVec;
191 }
192 // dynamic case
193 for (size_t i = 0; i < dimsX.size(); i++) {
194 if ((dimsX[i] == -1) && (dimsY[i] != -1)) {
195 if (dimsY[i] > 1) {
196 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
197 dimVec.push_back(dims);
198 } else if (dimsY[i] == 1) {
199 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
200 dimVec.push_back(dims);
201 dimVec[i] = -1;
202 } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) {
203 dimVec.push_back(0);
204 }
205 } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) {
206 if (dimsX[i] > 1) {
207 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
208 dimVec.push_back(dims);
209 } else if (dimsX[i] == 0) {
210 dimVec.push_back(0);
211 } else if (dimsX[i] == 1) {
212 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
213 dimVec.push_back(dims);
214 dimVec[i] = -1;
215 }
216 } else {
217 if ((dimsX[i] == -1) && (dimsY[i] == -1)) {
218 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
219 dimVec.push_back(dims);
220 dimVec[i] = -1;
221 } else {
222 if (dimsY[i] == 0 || dimsX[i] == 0) {
223 dimVec.push_back(0);
224 } else {
225 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
226 dimVec.push_back(dims);
227 }
228 }
229 }
230 }
231 return dimVec;
232 }
233
TwoShapeAndRangeBroadcast(const std::vector<int64_t> & dims_out,const std::vector<std::pair<int64_t,int64_t>> & shape_range_x,std::vector<std::pair<int64_t,int64_t>> & shape_range_y)234 std::vector<std::pair<int64_t, int64_t>> TwoShapeAndRangeBroadcast(
235 const std::vector<int64_t> &dims_out, const std::vector<std::pair<int64_t, int64_t>> &shape_range_x,
236 std::vector<std::pair<int64_t, int64_t>> &shape_range_y) {
237 size_t size_shape_out = dims_out.size();
238 std::vector<std::pair<int64_t, int64_t>> out_range;
239 if (!IsUnknownRankShape(dims_out)) {
240 while (shape_range_x.size() > shape_range_y.size()) {
241 shape_range_y.insert(shape_range_y.begin(), std::pair<int64_t, int64_t>(1, 1));
242 }
243 for (size_t i = 0; i < size_shape_out; i++) {
244 if (dims_out[i] != -1) {
245 out_range.push_back(std::pair<int64_t, int64_t>(dims_out[i], dims_out[i]));
246 continue;
247 }
248 if (i < shape_range_x.size() && i < shape_range_y.size()) {
249 if (shape_range_x[i].second == -1 && shape_range_y[i].second == 1) {
250 out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
251 } else if (shape_range_x[i].second == 1 && shape_range_y[i].second == -1) {
252 out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
253 } else if (shape_range_x[i].first == 1 || shape_range_y[i].first == 1) {
254 // one shape size maybe 1, so will support broadcast
255 // first_range == max first
256 int64_t first_range = std::max(shape_range_x[i].first, shape_range_y[i].first);
257 int64_t second_range = shape_range_x[i].first == 1 ? shape_range_y[i].second : shape_range_x[i].second;
258 if (shape_range_x[i].first == 1 && shape_range_y[i].first == 1) {
259 second_range = std::max(shape_range_x[i].second, shape_range_y[i].second);
260 second_range = (shape_range_x[i].second == -1 || shape_range_y[i].second == -1) ? -1 : second_range;
261 }
262 out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
263 } else {
264 // no 1 in range.first, mean no broadcast for range
265 // get intersect range
266 int64_t first_range = std::max(shape_range_x[i].first, shape_range_y[i].first);
267 int64_t second_range = std::min(shape_range_x[i].second, shape_range_y[i].second);
268 second_range = (shape_range_x[i].second == -1 || shape_range_y[i].second == -1)
269 ? std::max(shape_range_x[i].second, shape_range_y[i].second)
270 : second_range;
271 out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
272 }
273 }
274 }
275 }
276 return out_range;
277 }
278
InferBroadcastshapeForStatic(const Shape & shape_x,const Shape & shape_y,Shape & shape_output)279 bool InferBroadcastshapeForStatic(const Shape &shape_x, const Shape &shape_y, Shape &shape_output) {
280 auto shape_x_len = shape_x.GetDimNum();
281 auto shape_y_len = shape_y.GetDimNum();
282
283 OP_LOGI("BroadcastInfer", "input1 shape is: %s, input2 shape is: %s.", to_string(shape_x).c_str(),
284 to_string(shape_y).c_str());
285 std::vector<int64_t> output_shape;
286 if (shape_x_len >= shape_y_len) {
287 // when inputx len >= inputy len
288 // input_x = [128, 128, 128] Vs input_y = [128]
289 auto len_sub = shape_x_len - shape_y_len;
290 for (size_t i = 0; i < len_sub; i++) {
291 (void)output_shape.emplace_back(shape_x.GetDim(i));
292 }
293 for (size_t i = 0; i < shape_y_len; i++) {
294 int64_t dim_size = std::max(shape_x.GetDim(len_sub + i), shape_y.GetDim(i));
295 // if one dim is 0, the output dim is 0
296 dim_size = (shape_x.GetDim(len_sub + i) == 0 || shape_y.GetDim(i) == 0) ? 0 : dim_size;
297 (void)output_shape.emplace_back(dim_size);
298 }
299 } else {
300 // when inputx len < inputy len
301 // input_x = [128] Vs input_y = [128, 128, 128]
302 auto len_sub = shape_y_len - shape_x_len;
303 for (size_t i = 0; i < len_sub; i++) {
304 (void)output_shape.emplace_back(shape_y.GetDim(i));
305 }
306 for (size_t i = 0; i < shape_x_len; i++) {
307 int64_t dim_size = std::max(shape_y.GetDim(len_sub + i), shape_x.GetDim(i));
308 // if one dim is 0, the output dim is 0
309 dim_size = (shape_y.GetDim(len_sub + i) == 0 || shape_x.GetDim(i) == 0) ? 0 : dim_size;
310 (void)output_shape.emplace_back(dim_size);
311 }
312 }
313 shape_output = Shape(output_shape);
314 OP_LOGI("BroadcastInfer", "output1 shape is: %s.", to_string(shape_output).c_str());
315 return true;
316 }
317
InferShapeAndTypeTwoInOneOutBroadcast(Operator & op,const string & input_name1,const string & input_name2,const string & output_name,bool & is_dynamic)318 bool InferShapeAndTypeTwoInOneOutBroadcast(Operator &op, const string &input_name1, const string &input_name2,
319 const string &output_name, bool &is_dynamic) {
320 PROFILING_PROTO_INIT(TbeGetName(op).c_str());
321 DataType input_dtype = op.GetInputDesc(input_name1).GetDataType();
322
323 // output Desc
324 auto tensordesc_output = op.GetOutputDesc(output_name);
325 tensordesc_output.SetDataType(input_dtype);
326
327 ge::Shape shapeX = op.GetInputDesc(input_name1).GetShape();
328 ge::Shape shapeY = op.GetInputDesc(input_name2).GetShape();
329 OP_LOGI(TbeGetName(op).c_str(), "shape %s: %s, shape %s: %s.", input_name1.c_str(), to_string(shapeX).c_str(),
330 input_name2.c_str(), to_string(shapeY).c_str());
331 std::vector<int64_t> dimsX = shapeX.GetDims();
332 std::vector<int64_t> dimsY = shapeY.GetDims();
333 PROFILING_PROTO_AFTER_GET_SHAPE_REG();
334 // swap based on shape size
335 if (dimsX.size() < dimsY.size()) {
336 std::vector<int64_t> dimsTmp = dimsX;
337 dimsX = dimsY;
338 dimsY = dimsTmp;
339 }
340
341 // unknown rank
342 if (IsUnknownRankShape(dimsX) || IsUnknownRankShape(dimsY)) {
343 tensordesc_output.SetShape(ge::Shape(UNKNOWN_RANK));
344 OP_LOGI(TbeGetName(op).c_str(), "output shape is: %s, output dtype is:%d.",
345 to_string(ge::Shape(UNKNOWN_RANK)).c_str(), input_dtype);
346 is_dynamic = false;
347 op.UpdateOutputDesc(output_name, tensordesc_output);
348 return true;
349 }
350
351 // pad 1 for small shape
352 if (dimsX.size() != dimsY.size()) {
353 int dec = static_cast<int>(dimsX.size() - dimsY.size());
354 for (int i = 0; i < dec; i++) {
355 dimsY.insert(dimsY.begin(), (int64_t)1);
356 }
357 }
358
359 // when not dynamic case, do infer shape only
360 if (!IsUnKnownShape(dimsY) && !IsUnKnownShape(dimsX)) {
361 std::vector<int64_t> dimVec(dimsX.size(), 0);
362 for (size_t i = 0; i < dimsX.size(); i++) {
363 dimVec[i] = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
364 dimVec[i] = (dimsY[i] == 0 || dimsX[i] == 0) ? 0 : dimVec[i];
365 }
366
367 PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
368 tensordesc_output.SetShape(ge::Shape(dimVec));
369 is_dynamic = false;
370 op.UpdateOutputDesc(output_name, tensordesc_output);
371 PROFILING_PROTO_END();
372 return true;
373 }
374
375 std::vector<int64_t> dimVec;
376 // dynamic case
377 for (size_t i = 0; i < dimsX.size(); i++) {
378 CHECK((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1) && (dimsX[i] != -1) && (dimsY[i] != -1),
379 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
380 TbeGetName(op),
381 OtherErrMsg(ConcatString("The ", TbeGetName(op), "'s dimensions does not match the broadcast rule(",
382 dimsX[i], dimsY[i], ")."))),
383 return false);
384
385 if ((dimsX[i] == -1) && (dimsY[i] != -1)) {
386 if (dimsY[i] > 1) {
387 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
388 dimVec.push_back(dims);
389 } else if (dimsY[i] == 1) {
390 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
391 dimVec.push_back(dims);
392 dimVec[i] = -1;
393 } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) {
394 dimVec.push_back(-1);
395 }
396 } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) {
397 if (dimsX[i] > 1) {
398 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
399 dimVec.push_back(dims);
400 } else if (dimsX[i] == 0) {
401 dimVec.push_back(-1);
402 } else if (dimsX[i] == 1) {
403 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
404 dimVec.push_back(dims);
405 dimVec[i] = -1;
406 }
407 } else {
408 if ((dimsX[i] == -1) && (dimsY[i] == -1)) {
409 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
410 dimVec.push_back(dims);
411 dimVec[i] = -1;
412 } else {
413 if (dimsY[i] == 0 || dimsX[i] == 0) {
414 dimVec.push_back(0);
415 } else {
416 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
417 dimVec.push_back(dims);
418 }
419 }
420 }
421 }
422 ge::Shape outputShape = ge::Shape(dimVec);
423 tensordesc_output.SetShape(outputShape);
424
425 OP_LOGI(TbeGetName(op).c_str(), "output shape is: %s, output dtype is:%s.", to_string(outputShape).c_str(),
426 GeDataTypeToString(input_dtype).c_str());
427 is_dynamic = IsUnknown(dimVec);
428 if (is_dynamic) {
429 if (!InferShapeRangeTwoInOneOutBroadcast(op, input_name1, input_name2, output_name)) {
430 return false;
431 }
432 }
433 op.UpdateOutputDesc(output_name, tensordesc_output);
434 return true;
435 }
436
InferShapeAndTypeTwoInOneOutBroadcast(Operator & op,const string & input_name1,const string & input_name2,const string & output_name)437 bool InferShapeAndTypeTwoInOneOutBroadcast(Operator &op, const string &input_name1, const string &input_name2,
438 const string &output_name) {
439 DataType input_dtype = op.GetInputDesc(input_name1).GetDataType();
440
441 auto tensordesc_output = op.GetOutputDesc(output_name);
442
443 ge::Shape shapeX = op.GetInputDesc(input_name1).GetShape();
444 ge::Shape shapeY = op.GetInputDesc(input_name2).GetShape();
445 OP_LOGI(TbeGetName(op).c_str(), "shape %s: %s, shape %s: %s.", input_name1.c_str(), to_string(shapeX).c_str(),
446 input_name2.c_str(), to_string(shapeY).c_str());
447 std::vector<int64_t> dimsX = shapeX.GetDims();
448 std::vector<int64_t> dimsY = shapeY.GetDims();
449 // swap based on shape size
450 if (dimsX.size() < dimsY.size()) {
451 std::vector<int64_t> dimsTmp = dimsX;
452 dimsX = dimsY;
453 dimsY = dimsTmp;
454 }
455
456 std::vector<int64_t> dimVec;
457
458 // unknown rank
459 if (IsUnknownRankShape(dimsX) || IsUnknownRankShape(dimsY)) {
460 tensordesc_output.SetShape(ge::Shape(UNKNOWN_RANK));
461 tensordesc_output.SetDataType(input_dtype);
462 OP_LOGI(TbeGetName(op).c_str(), "output shape is: %s, output dtype is:%d.",
463 to_string(ge::Shape(UNKNOWN_RANK)).c_str(), input_dtype);
464 op.UpdateOutputDesc(output_name, tensordesc_output);
465 return true;
466 }
467
468 // pad 1 for small shape
469 if (dimsX.size() != dimsY.size()) {
470 int dec = static_cast<int>(dimsX.size() - dimsY.size());
471 for (int i = 0; i < dec; i++) {
472 dimsY.insert(dimsY.begin(), (int64_t)1);
473 }
474 }
475
476 for (size_t i = 0; i < dimsX.size(); i++) {
477 CHECK((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1) && (dimsX[i] != -1) && (dimsY[i] != -1),
478 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
479 TbeGetName(op),
480 OtherErrMsg(ConcatString("The ", TbeGetName(op), "'s dimensions does not match the broadcast rule(",
481 dimsX[i], dimsY[i], ")."))),
482 return false);
483
484 if ((dimsX[i] == -1) && (dimsY[i] != -1)) {
485 if (dimsY[i] > 1) {
486 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
487 dimVec.push_back(dims);
488 } else if (dimsY[i] == 1) {
489 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
490 dimVec.push_back(dims);
491 dimVec[i] = -1;
492 } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) {
493 dimVec.push_back(0);
494 }
495 } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) {
496 if (dimsX[i] > 1) {
497 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
498 dimVec.push_back(dims);
499 } else if (dimsX[i] == 0) {
500 dimVec.push_back(0);
501 } else if (dimsX[i] == 1) {
502 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
503 dimVec.push_back(dims);
504 dimVec[i] = -1;
505 }
506 } else {
507 if ((dimsX[i] == -1) && (dimsY[i] == -1)) {
508 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
509 dimVec.push_back(dims);
510 dimVec[i] = -1;
511 } else {
512 if (dimsY[i] == 0 || dimsX[i] == 0) {
513 dimVec.push_back(0);
514 } else {
515 int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
516 dimVec.push_back(dims);
517 }
518 }
519 }
520 }
521 ge::Shape outputShape = ge::Shape(dimVec);
522
523 tensordesc_output.SetShape(outputShape);
524 tensordesc_output.SetDataType(input_dtype);
525 OP_LOGI(TbeGetName(op).c_str(), "output shape is: %s, output dtype is:%s.", to_string(outputShape).c_str(),
526 GeDataTypeToString(input_dtype).c_str());
527 op.UpdateOutputDesc(output_name, tensordesc_output);
528
529 return true;
530 }
531
ToFormatString(ge::Format format)532 std::string ToFormatString(ge::Format format) { return GeFormatToString(format); }
533
AddToOutputRange(std::vector<std::pair<int64_t,int64_t>> & out_range,const std::pair<int64_t,int64_t> & shape_range_x,const std::pair<int64_t,int64_t> & shape_range_y)534 static void AddToOutputRange(std::vector<std::pair<int64_t, int64_t>> &out_range,
535 const std::pair<int64_t, int64_t> &shape_range_x,
536 const std::pair<int64_t, int64_t> &shape_range_y) {
537 // first_range == max first
538 int64_t first_range =
539 (shape_range_x.first * shape_range_y.first == 0) ? 0 : std::max(shape_range_x.first, shape_range_y.first);
540
541 if (shape_range_x.second * shape_range_y.second == -1) {
542 out_range.push_back(std::pair<int64_t, int64_t>(first_range, -1));
543 } else if (shape_range_x.first == 1 && shape_range_y.first == 1) {
544 int64_t second_range = (shape_range_x.second == -1 || shape_range_y.second == -1)
545 ? -1
546 : std::max(shape_range_x.second, shape_range_y.second);
547 out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
548 } else if (shape_range_x.first == 1 || shape_range_y.first == 1) {
549 // one shape size maybe 1, so will support broadcast
550 int64_t second_range = shape_range_x.first == 1 ? shape_range_y.second : shape_range_x.second;
551 out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
552 } else {
553 // no 1 in range.first, mean no broadcast for range
554 // get intersect range
555 int64_t second_range = std::min(shape_range_x.second, shape_range_y.second);
556 second_range = (shape_range_x.second == -1 || shape_range_y.second == -1)
557 ? std::max(shape_range_x.second, shape_range_y.second)
558 : second_range;
559 out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
560 }
561 }
562
InferShapeRangeTwoInOneOutBroadcast(Operator & op,const string & input_name1,const string & input_name2,const string & output_name)563 bool InferShapeRangeTwoInOneOutBroadcast(Operator &op, const string &input_name1, const string &input_name2,
564 const string &output_name) {
565 ge::Shape shape_x = op.GetInputDesc(input_name1).GetShape();
566 ge::Shape shape_y = op.GetInputDesc(input_name2).GetShape();
567
568 std::vector<int64_t> dims_x = shape_x.GetDims();
569 std::vector<int64_t> dims_y = shape_y.GetDims();
570
571 std::vector<std::pair<int64_t, int64_t>> shape_range_x;
572 op.GetInputDesc(input_name1).GetShapeRange(shape_range_x);
573 std::vector<std::pair<int64_t, int64_t>> shape_range_y;
574 op.GetInputDesc(input_name2).GetShapeRange(shape_range_y);
575
576 MakeUpShapeRange(dims_x, shape_range_x);
577 MakeUpShapeRange(dims_y, shape_range_y);
578
579 ge::Shape shape_out = op.GetOutputDesc(output_name).GetShape();
580 std::vector<int64_t> dims_out = shape_out.GetDims();
581 size_t size_shape_out = dims_out.size();
582
583 std::vector<std::pair<int64_t, int64_t>> out_range;
584
585 if (!IsUnknownRankShape(dims_out)) {
586 // shape switch by shape dim size
587 if (dims_x.size() < dims_y.size()) {
588 std::vector<int64_t> dims_tmp = dims_x;
589 dims_x = dims_y;
590 dims_y = dims_tmp;
591
592 std::vector<std::pair<int64_t, int64_t>> range_temp = shape_range_x;
593 shape_range_x = shape_range_y;
594 shape_range_y = range_temp;
595 }
596
597 while (dims_x.size() > shape_range_y.size()) {
598 shape_range_y.insert(shape_range_y.begin(), std::pair<int64_t, int64_t>(1, 1));
599 }
600
601 for (size_t i = 0; i < size_shape_out; i++) {
602 if (dims_out[i] != -1) {
603 out_range.push_back(std::pair<int64_t, int64_t>(dims_out[i], dims_out[i]));
604 continue;
605 }
606 if (i < shape_range_x.size() && i < shape_range_y.size()) {
607 AddToOutputRange(out_range, shape_range_x[i], shape_range_y[i]);
608 }
609 }
610 }
611 OP_LOGI(TbeGetName(op).c_str(), "elewise out range is %s", to_string(out_range).c_str());
612 auto tensor_out = op.GetOutputDesc(output_name);
613 tensor_out.SetShapeRange(out_range);
614 op.UpdateOutputDesc(output_name, tensor_out);
615
616 return true;
617 }
618
GetInputDataType(const ge::DataType & dataType,const std::vector<ge::DataType> & supportList,std::string & dType)619 bool GetInputDataType(const ge::DataType &dataType, const std::vector<ge::DataType> &supportList, std::string &dType) {
620 std::vector<ge::DataType>::const_iterator supportIter = find(supportList.begin(), supportList.end(), dataType);
621 if (supportIter == supportList.end()) {
622 return false;
623 }
624
625 std::map<ge::DataType, std::string>::const_iterator totalIter = DTYPE_STR_MAP.find(dataType);
626 if (totalIter == DTYPE_STR_MAP.end()) {
627 return false;
628 }
629
630 dType = totalIter->second;
631 return true;
632 }
633
CheckInputDataType(const Operator & op,std::string * data_type,const std::string & input_name,const std::vector<ge::DataType> & supportList)634 bool CheckInputDataType(const Operator &op, std::string *data_type, const std::string &input_name,
635 const std::vector<ge::DataType> &supportList) {
636 DataType input_type = op.GetInputDescByName(input_name.c_str()).GetDataType();
637 if (false == GetInputDataType(input_type, supportList, *data_type)) {
638 LOG_ERROR("[ERROR]op [%s] [%s] do not supported dtype [%s]!\n", TbeGetName(op).c_str(), input_name.c_str(),
639 data_type->c_str());
640 return false;
641 }
642 return true;
643 }
644
GetConstValue(const ge::Operator & op,const std::string & key_name,float & attr_value)645 bool GetConstValue(const ge::Operator &op, const std::string &key_name, float &attr_value) {
646 if (ge::GRAPH_SUCCESS != op.GetAttr(key_name.c_str(), attr_value)) {
647 LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", TbeGetName(op).c_str(), key_name.c_str());
648 return false;
649 }
650 return true;
651 }
652
GetConstValue(const ge::Operator & op,const std::string & key_name,int64_t & attr_value)653 bool GetConstValue(const ge::Operator &op, const std::string &key_name, int64_t &attr_value) {
654 if (ge::GRAPH_SUCCESS != op.GetAttr(key_name.c_str(), attr_value)) {
655 LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", TbeGetName(op).c_str(), key_name.c_str());
656 return false;
657 }
658 return true;
659 }
660
GetConstValue(const ge::Operator & op,const std::string & key_name,bool & attr_value)661 bool GetConstValue(const ge::Operator &op, const std::string &key_name, bool &attr_value) {
662 if (ge::GRAPH_SUCCESS != op.GetAttr(key_name.c_str(), attr_value)) {
663 LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", TbeGetName(op).c_str(), key_name.c_str());
664 return false;
665 }
666 return true;
667 }
668
GetConstValue(const ge::Operator & op,const std::string & key_name,std::vector<int32_t> & attr_value)669 bool GetConstValue(const ge::Operator &op, const std::string &key_name, std::vector<int32_t> &attr_value) {
670 if (ge::GRAPH_SUCCESS != op.GetAttr(key_name.c_str(), attr_value)) {
671 LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", TbeGetName(op).c_str(), key_name.c_str());
672 return false;
673 }
674 return true;
675 }
676
677 template <typename T>
GetConstIntData(const uint8_t * const_data,size_t data_size)678 static std::vector<int64_t> GetConstIntData(const uint8_t *const_data, size_t data_size) {
679 size_t size = data_size / sizeof(T);
680 std::vector<int64_t> result(size);
681 const T *data = reinterpret_cast<const T *>(const_data);
682 for (size_t i = 0; i < size; i++) {
683 result[i] = *(data + i);
684 }
685
686 return result;
687 }
688
GetConstIntData(const Tensor & data,DataType data_type,std::vector<int64_t> & const_values)689 bool GetConstIntData(const Tensor &data, DataType data_type, std::vector<int64_t> &const_values) {
690 using std::placeholders::_1;
691 using std::placeholders::_2;
692 const std::map<DataType, std::function<std::vector<int64_t>(const uint8_t *, size_t)>> type_call_map = {
693 {DT_INT8, std::bind(GetConstIntData<int8_t>, _1, _2)},
694 {DT_INT16, std::bind(GetConstIntData<int16_t>, _1, _2)},
695 {DT_INT32, std::bind(GetConstIntData<int32_t>, _1, _2)},
696 {DT_INT64, std::bind(GetConstIntData<int64_t>, _1, _2)},
697 };
698
699 auto found = type_call_map.find(data_type);
700 if (found == type_call_map.end()) {
701 USER_GE_LOGE("[ERROR]GetConstIntData is not support data_type[%s]!", GeDataTypeToString(data_type).c_str());
702 return false;
703 }
704
705 const_values = found->second(data.GetData(), data.GetSize());
706
707 return true;
708 }
709
GetConstValue(const Operator & op,const Tensor & const_tensor,const DataType & dtype,std::vector<int64_t> & const_data)710 bool GetConstValue(const Operator &op, const Tensor &const_tensor, const DataType &dtype,
711 std::vector<int64_t> &const_data) {
712 CHECK(dtype != ge::DT_INT32 && dtype != ge::DT_INT64,
713 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("not support this type")), return false);
714 if (dtype == ge::DT_INT32) {
715 const int32_t *const_data_ptr = reinterpret_cast<const int32_t *>(const_tensor.GetData());
716 size_t size = const_tensor.GetSize() / sizeof(int32_t);
717 for (size_t i = 0; i < size; ++i) {
718 const_data.push_back(static_cast<int32_t>(*(const_data_ptr + i)));
719 OP_LOGD(TbeGetName(op).c_str(), "const data int32 fusion pass ====== %d",
720 static_cast<int32_t>(*(const_data_ptr + i)));
721 }
722 } else if (dtype == ge::DT_INT64) {
723 const int64_t *const_data_ptr = reinterpret_cast<const int64_t *>(const_tensor.GetData());
724 size_t size = const_tensor.GetSize() / sizeof(int64_t);
725 for (size_t i = 0; i < size; ++i) {
726 const_data.push_back(static_cast<int64_t>(*(const_data_ptr + i)));
727 OP_LOGD(TbeGetName(op).c_str(), "const data int64 fusion pass ====== %ld",
728 static_cast<int64_t>(*(const_data_ptr + i)));
729 }
730 }
731 return true;
732 }
733
GetConstValue(const Operator & op,const Tensor & const_tensor,const DataType & dtype,std::vector<uint64_t> & const_data)734 bool GetConstValue(const Operator &op, const Tensor &const_tensor, const DataType &dtype,
735 std::vector<uint64_t> &const_data) {
736 size_t size = 0;
737 CHECK(dtype != ge::DT_UINT64,
738 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("not support this type")), return false);
739 const uint64_t *const_data_ptr = reinterpret_cast<const uint64_t *>(const_tensor.GetData());
740 size = const_tensor.GetSize() / sizeof(uint64_t);
741 for (size_t i = 0; i < size; ++i) {
742 const_data.push_back(static_cast<uint64_t>(*(const_data_ptr + i)));
743 OP_LOGD(TbeGetName(op).c_str(), "const data uint64 fusion pass, const_data[%lu]",
744 static_cast<uint64_t>(*(const_data_ptr + i)));
745 }
746 return true;
747 }
748
GetScalerValue(const Operator & op,const Tensor & const_tensor,const DataType & dtype,std::int64_t & const_data)749 bool GetScalerValue(const Operator &op, const Tensor &const_tensor, const DataType &dtype, std::int64_t &const_data) {
750 if (dtype == ge::DT_INT32) {
751 const int32_t *const_data_ptr = reinterpret_cast<const int32_t *>(const_tensor.GetData());
752 const_data = static_cast<int32_t>(*const_data_ptr);
753 } else if (dtype == ge::DT_INT64) {
754 const int64_t *const_data_ptr = reinterpret_cast<const int64_t *>(const_tensor.GetData());
755 const_data = static_cast<int64_t>(*const_data_ptr);
756 } else {
757 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg(ConcatString("not support this type:", dtype)));
758 return false;
759 }
760 return true;
761 }
762
to_string(const std::vector<int64_t> & shape)763 std::string to_string(const std::vector<int64_t> &shape) { return ops::to_string(shape); }
764
to_string(const ge::Shape & shape)765 std::string to_string(const ge::Shape &shape) { return to_string(shape.GetDims()); }
766
to_string(const std::vector<std::pair<int64_t,int64_t>> & ranges)767 std::string to_string(const std::vector<std::pair<int64_t, int64_t>> &ranges) { return ops::to_string(ranges); }
768
769 static std::map<ge::DataType, std::string> kDataTypeToStringMap = {{ge::DataType::DT_FLOAT, "float"},
770 {ge::DataType::DT_FLOAT16, "float16"},
771 {ge::DataType::DT_INT8, "int8"},
772 {ge::DataType::DT_INT16, "int16"},
773 {ge::DataType::DT_UINT16, "uint16"},
774 {ge::DataType::DT_UINT8, "uint8"},
775 {ge::DataType::DT_INT32, "int32"},
776 {ge::DataType::DT_INT64, "int64"},
777 {ge::DataType::DT_UINT32, "uint32"},
778 {ge::DataType::DT_UINT64, "uint64"},
779 {ge::DataType::DT_BOOL, "bool"},
780 {ge::DataType::DT_DOUBLE, "double"},
781 {ge::DataType::DT_STRING, "string"},
782 {ge::DataType::DT_DUAL_SUB_INT8, "dual_sub_int8"},
783 {ge::DataType::DT_DUAL_SUB_UINT8, "dual_sub_uint8"},
784 {ge::DataType::DT_COMPLEX64, "complex64"},
785 {ge::DataType::DT_COMPLEX128, "complex128"},
786 {ge::DataType::DT_DUAL, "dual"},
787 {ge::DataType::DT_QINT8, "qint8"},
788 {ge::DataType::DT_QINT16, "qint16"},
789 {ge::DataType::DT_QINT32, "qint32"},
790 {ge::DataType::DT_QUINT8, "quint8"},
791 {ge::DataType::DT_QUINT16, "quint16"},
792 {ge::DataType::DT_RESOURCE, "resource"},
793 {ge::DataType::DT_STRING_REF, "string ref"},
794 {ge::DataType::DT_VARIANT, "dt_variant"},
795 {ge::DataType::DT_UNDEFINED, "undefined"},
796 {ge::DataType::DT_INT4, "int4"},
797 {ge::DataType::DT_UINT1, "uint1"},
798 {ge::DataType::DT_INT2, "int2"},
799 {ge::DataType::DT_UINT2, "uint2"},
800 {ge::DataType::DT_COMPLEX32, "complex32"},
801 {ge::DataType::DT_BF16, "bf16"}};
802
803 static std::map<ge::Format, std::string> kFormatToStringMap = {
804 {ge::Format::FORMAT_NCHW, "NCHW"},
805 {ge::Format::FORMAT_NHWC, "NHWC"},
806 {ge::Format::FORMAT_ND, "Nd"},
807 {ge::Format::FORMAT_NC1HWC0, "NC1HWC0"},
808 {ge::Format::FORMAT_FRACTAL_Z, "FRACTAL_Z"},
809 {ge::Format::FORMAT_NC1C0HWPAD, "NC1C0HWPAD"},
810 {ge::Format::FORMAT_NHWC1C0, "NHWC1C0"},
811 {ge::Format::FORMAT_FSR_NCHW, "FSR_NCHW"},
812 {ge::Format::FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"},
813 {ge::Format::FORMAT_C1HWNC0, "C1HWNC0"},
814 {ge::Format::FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"},
815 {ge::Format::FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"},
816 {ge::Format::FORMAT_NC1HWC0_C04, "NC1HWC0_C04"},
817 {ge::Format::FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"},
818 {ge::Format::FORMAT_CHWN, "CHWN"},
819 {ge::Format::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "FRACTAL_DECONV_SP_STRIDE8_TRANS"},
820 {ge::Format::FORMAT_HWCN, "HWCN"},
821 {ge::Format::FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"},
822 {ge::Format::FORMAT_BN_WEIGHT, "BN_WEIGHT"},
823 {ge::Format::FORMAT_FILTER_HWCK, "FILTER_HWCK"},
824 {ge::Format::FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "HASHTABLE_LOOKUP_LOOKUPS"},
825 {ge::Format::FORMAT_HASHTABLE_LOOKUP_KEYS, "HASHTABLE_LOOKUP_KEYS"},
826 {ge::Format::FORMAT_HASHTABLE_LOOKUP_VALUE, "HASHTABLE_LOOKUP_VALUE"},
827 {ge::Format::FORMAT_HASHTABLE_LOOKUP_OUTPUT, "HASHTABLE_LOOKUP_OUTPUT"},
828 {ge::Format::FORMAT_HASHTABLE_LOOKUP_HITS, "HASHTABLE_LOOKUP_HITS"},
829 {ge::Format::FORMAT_C1HWNCoC0, "C1HWNCoC0"},
830 {ge::Format::FORMAT_MD, "MD"},
831 {ge::Format::FORMAT_NDHWC, "NDHWC"},
832 {ge::Format::FORMAT_FRACTAL_ZZ, "FRACTAL_ZZ"},
833 {ge::Format::FORMAT_FRACTAL_NZ, "FRACTAL_NZ"},
834 {ge::Format::FORMAT_NCDHW, "NCDHW"},
835 {ge::Format::FORMAT_DHWCN, "DHWCN"},
836 {ge::Format::FORMAT_NDC1HWC0, "NDC1HWC0"},
837 {ge::Format::FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"},
838 {ge::Format::FORMAT_CN, "CN"},
839 {ge::Format::FORMAT_NC, "NC"},
840 {ge::Format::FORMAT_DHWNC, "DHWNC"},
841 {ge::Format::FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"},
842 {ge::Format::FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"},
843 {ge::Format::FORMAT_FRACTAL_Z_G, "FRACTAL_Z_G"},
844 {ge::Format::FORMAT_RESERVED, "RESERVED"},
845 {ge::Format::FORMAT_ALL, "ALL"},
846 {ge::Format::FORMAT_NULL, "NULL"},
847 {ge::Format::FORMAT_ND_RNN_BIAS, "ND_RNN_BIAS"},
848 {ge::Format::FORMAT_FRACTAL_ZN_RNN, "FRACTAL_ZN_RNN"},
849 {ge::Format::FORMAT_NYUV, "NYUV"},
850 {ge::Format::FORMAT_NYUV_A, "NYUV_A"},
851 {ge::Format::FORMAT_NCL, "NCL"}};
852
GeDataTypeToString(const ge::DataType datatype)853 std::string GeDataTypeToString(const ge::DataType datatype) {
854 auto iter = kDataTypeToStringMap.find(datatype);
855 if (iter != kDataTypeToStringMap.end()) {
856 return iter->second;
857 }
858 return "";
859 }
860
GeFormatToString(const ge::Format format)861 std::string GeFormatToString(const ge::Format format) {
862 auto iter = kFormatToStringMap.find(format);
863 if (iter != kFormatToStringMap.end()) {
864 return iter->second;
865 }
866 return "";
867 }
868
IsEmptyTensor(const std::vector<int64_t> & dims)869 bool IsEmptyTensor(const std::vector<int64_t> &dims) {
870 if (dims.size() == 1 && dims[0] == 0) {
871 return true;
872 } else {
873 return false;
874 }
875 }
876
IsUnknownRank(const Operator & op,const std::string & tensor_name,const std::string & types)877 bool IsUnknownRank(const Operator &op, const std::string &tensor_name, const std::string &types) {
878 TensorDesc tensor_desc;
879 if (types == "input") {
880 tensor_desc = op.GetInputDesc(tensor_name);
881 } else if (types == "output") {
882 tensor_desc = op.GetOutputDesc(tensor_name);
883 } else {
884 VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
885 OtherErrMsg(ConcatString("invalid params:", types, " of types to judge.")));
886 return false;
887 }
888
889 std::vector<int64_t> shape_vec = tensor_desc.GetShape().GetDims();
890 if (shape_vec.size() == 1 && shape_vec[0] == INPUT_NEGATIVE_NUM2) {
891 return true;
892 }
893 return false;
894 }
895
IsUnknownRankShape(const std::vector<int64_t> & shape_vec)896 bool IsUnknownRankShape(const std::vector<int64_t> &shape_vec) {
897 if (shape_vec.size() == 1 && shape_vec[0] == ge::UNKNOWN_DIM_NUM) {
898 return true;
899 }
900 return false;
901 }
902
IsUnknownRankShape(const Shape & input_shape)903 bool IsUnknownRankShape(const Shape &input_shape) {
904 auto dims = input_shape.GetDims();
905 return (dims.size() == 1UL) && (dims[0UL] == UNKNOWN_DIM_NUM);
906 }
907
IsUnKnownShape(const std::vector<int64_t> & shape_vec)908 bool IsUnKnownShape(const std::vector<int64_t> &shape_vec) {
909 auto found = find(shape_vec.begin(), shape_vec.end(), -1);
910 return found != shape_vec.end();
911 }
912
IsUnknown(const std::vector<int64_t> & shape_vec)913 bool IsUnknown(const std::vector<int64_t> &shape_vec) {
914 return (IsUnKnownShape(shape_vec) || IsUnknownRankShape(shape_vec));
915 }
916
IsUnknownVec(std::vector<int64_t> & shape_vec)917 bool IsUnknownVec(std::vector<int64_t> &shape_vec) {
918 std::vector<int64_t>::iterator it_shape = find(shape_vec.begin(), shape_vec.end(), -1);
919 if (it_shape == shape_vec.end()) {
920 return false;
921 } else {
922 return true;
923 }
924 }
925
MakeUpShapeRange(const std::vector<int64_t> & shape,std::vector<std::pair<int64_t,int64_t>> & range)926 void MakeUpShapeRange(const std::vector<int64_t> &shape, std::vector<std::pair<int64_t, int64_t>> &range) {
927 if (IsUnknownRankShape(shape)) {
928 return;
929 }
930
931 if (range.empty()) {
932 for (size_t i = 0; i < shape.size(); i++) {
933 if (shape[i] == -1) {
934 range.push_back(std::pair<int64_t, int64_t>(0, -1));
935 } else {
936 range.push_back(std::pair<int64_t, int64_t>(shape[i], shape[i]));
937 }
938 }
939 }
940 }
941
MakeUpShapeRange(const ge::Shape & shape,std::vector<std::pair<int64_t,int64_t>> & range)942 void MakeUpShapeRange(const ge::Shape &shape, std::vector<std::pair<int64_t, int64_t>> &range) {
943 if (IsUnknownRankShape(shape)) {
944 return;
945 }
946
947 if (range.empty()) {
948 for (size_t i = 0; i < shape.GetDimNum(); i++) {
949 int64_t dim = shape.GetDim(i);
950 if (dim == -1) {
951 range.push_back(std::pair<int64_t, int64_t>(0, -1));
952 } else {
953 range.push_back(std::pair<int64_t, int64_t>(dim, dim));
954 }
955 }
956 }
957 }
958
DataTypeToStringDesc(const ge::DataType & dataType)959 std::string DataTypeToStringDesc(const ge::DataType &dataType) {
960 std::map<ge::DataType, std::string>::const_iterator totalIter = DTYPE_STR_MAP.find(dataType);
961 if (totalIter == DTYPE_STR_MAP.end()) {
962 return "UNDEFINED";
963 }
964 return totalIter->second;
965 }
966
OneInOneOutDynamicInfer(Operator & op,const std::string & input_name,const std::vector<std::string> & output_name_list)967 bool OneInOneOutDynamicInfer(Operator &op, const std::string &input_name,
968 const std::vector<std::string> &output_name_list) {
969 // get input desc
970 PROFILING_PROTO_INIT(TbeGetName(op).c_str());
971 auto input_desc = op.GetInputDesc(input_name);
972 vector<int64_t> input_shape = input_desc.GetShape().GetDims();
973 DataType input_dtype = input_desc.GetDataType();
974
975 if (IsUnknown(input_shape)) {
976 std::vector<std::pair<int64_t, int64_t>> input_range;
977 input_desc.GetShapeRange(input_range);
978 MakeUpShapeRange(input_shape, input_range);
979
980 auto output_desc = op.GetOutputDesc(0);
981 for (const string &output_name : output_name_list) {
982 output_desc = op.GetOutputDesc(output_name);
983 output_desc.SetShape(Shape(input_shape));
984 output_desc.SetOriginShape(Shape(input_shape));
985 output_desc.SetShapeRange(input_range);
986 output_desc.SetDataType(input_dtype);
987 op.UpdateOutputDesc(output_name, output_desc);
988 }
989 } else {
990 auto output_desc = op.GetOutputDesc(0);
991 PROFILING_PROTO_AFTER_GET_SHAPE_REG();
992 PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
993 for (const string &output_name : output_name_list) {
994 output_desc = op.GetOutputDesc(output_name);
995 output_desc.SetShape(Shape(input_shape));
996 output_desc.SetDataType(input_dtype);
997 op.UpdateOutputDesc(output_name, output_desc);
998 }
999 PROFILING_PROTO_END();
1000 }
1001 return true;
1002 }
1003
FixShapeRangeWithDims(const std::vector<int64_t> & dims,std::vector<int64_t> & shape_1,std::vector<int64_t> & shape_2,std::vector<std::pair<int64_t,int64_t>> & range_1,std::vector<std::pair<int64_t,int64_t>> & range_2)1004 void FixShapeRangeWithDims(const std::vector<int64_t> &dims, std::vector<int64_t> &shape_1,
1005 std::vector<int64_t> &shape_2, std::vector<std::pair<int64_t, int64_t>> &range_1,
1006 std::vector<std::pair<int64_t, int64_t>> &range_2) {
1007 MakeUpShapeRange(shape_1, range_1);
1008 MakeUpShapeRange(shape_2, range_2);
1009 bool is_all_fix = dims.empty();
1010
1011 if (shape_1 == UNKNOWN_RANK && shape_2 == UNKNOWN_RANK) {
1012 return;
1013 }
1014 if (shape_1 == UNKNOWN_RANK) {
1015 shape_1 = shape_2;
1016 range_1 = range_2;
1017 return;
1018 }
1019 if (shape_2 == UNKNOWN_RANK) {
1020 shape_2 = shape_1;
1021 range_2 = range_1;
1022 return;
1023 }
1024 if ((shape_1.size() != shape_2.size()) || (range_1.size() != range_2.size())) {
1025 return;
1026 }
1027 auto loop_size = is_all_fix ? shape_1.size() : dims.size();
1028 for (size_t i = 0; i < loop_size; i++) {
1029 auto dim_num = is_all_fix ? i : dims[i];
1030 if (shape_1[dim_num] != -1) {
1031 shape_2[dim_num] = shape_1[dim_num];
1032 range_1[dim_num] = std::pair<int64_t, int64_t>(shape_1[dim_num], shape_1[dim_num]);
1033 range_2[dim_num] = std::pair<int64_t, int64_t>(shape_1[dim_num], shape_1[dim_num]);
1034 continue;
1035 }
1036 if (shape_2[dim_num] != -1) {
1037 shape_1[dim_num] = shape_2[dim_num];
1038 range_1[dim_num] = std::pair<int64_t, int64_t>(shape_2[dim_num], shape_2[dim_num]);
1039 range_2[dim_num] = std::pair<int64_t, int64_t>(shape_2[dim_num], shape_2[dim_num]);
1040 continue;
1041 }
1042 // both the dim in shape1 and shape2 are -1
1043 auto range_1_min = range_1[dim_num].first;
1044 auto range_2_min = range_2[dim_num].first;
1045 auto range_1_max = range_1[dim_num].second;
1046 auto range_2_max = range_2[dim_num].second;
1047 auto range_fisrt = range_1_min > range_2_min ? range_1_min : range_2_min;
1048 auto range_second_min = range_1_max > range_2_max ? range_2_max : range_1_max;
1049 auto range_second_max = range_1_max > range_2_max ? range_1_max : range_2_max;
1050 range_second_min = range_second_min == -1 ? range_second_max : range_second_min;
1051 range_1[dim_num] = std::pair<int64_t, int64_t>(range_fisrt, range_second_min);
1052 range_2[dim_num] = std::pair<int64_t, int64_t>(range_fisrt, range_second_min);
1053 }
1054 }
1055
TwoInOneOutDynamicInferNoBroadcast(Operator & op,const string & input1_name,const string & input2_name,const std::vector<string> & output_name_list)1056 bool TwoInOneOutDynamicInferNoBroadcast(Operator &op, const string &input1_name, const string &input2_name,
1057 const std::vector<string> &output_name_list) {
1058 // get input1 desc
1059 auto input1_desc = op.GetInputDesc(input1_name);
1060 vector<int64_t> input1_shape = input1_desc.GetShape().GetDims();
1061 DataType input_dtype = input1_desc.GetDataType();
1062
1063 // get input2 desc
1064 auto input2_desc = op.GetInputDesc(input2_name);
1065 vector<int64_t> input2_shape = input2_desc.GetShape().GetDims();
1066
1067 if (IsUnknown(input1_shape) || IsUnknown(input2_shape)) {
1068 std::vector<std::pair<int64_t, int64_t>> input1_range;
1069 input1_desc.GetShapeRange(input1_range);
1070 std::vector<std::pair<int64_t, int64_t>> input2_range;
1071 input2_desc.GetShapeRange(input2_range);
1072
1073 vector<int64_t> dim_size = {};
1074 FixShapeRangeWithDims(dim_size, input1_shape, input2_shape, input1_range, input2_range);
1075
1076 // update output desc
1077 for (const string &output_name : output_name_list) {
1078 auto output_desc = op.GetOutputDesc(output_name);
1079 output_desc.SetShape(Shape(input1_shape));
1080 output_desc.SetOriginShape(Shape(input1_shape));
1081 output_desc.SetShapeRange(input1_range);
1082 output_desc.SetDataType(input_dtype);
1083 op.UpdateOutputDesc(output_name, output_desc);
1084 }
1085 } else {
1086 for (const string &output_name : output_name_list) {
1087 auto output_desc = op.GetOutputDesc(output_name);
1088 output_desc.SetShape(Shape(input1_shape));
1089 output_desc.SetDataType(input_dtype);
1090 op.UpdateOutputDesc(output_name, output_desc);
1091 }
1092 }
1093 return true;
1094 }
1095
IsEmptyTensor(TensorDesc tensor_desc)1096 bool IsEmptyTensor(TensorDesc tensor_desc) { return IsEmptyTensor(tensor_desc.GetShape()); }
1097
IsEmptyTensor(const Shape & ge_shape)1098 bool IsEmptyTensor(const Shape &ge_shape) {
1099 bool is_empty = false;
1100 for (const auto &dim : ge_shape.GetDims()) {
1101 if (dim == 0) {
1102 is_empty = true;
1103 break;
1104 }
1105 }
1106 return is_empty;
1107 }
1108
IsUnknownShape(const ge::Shape & shape)1109 bool IsUnknownShape(const ge::Shape &shape) {
1110 const auto &dims = shape.GetDims();
1111 return std::any_of(dims.begin(), dims.end(),
1112 [](const int64_t &dim) { return (dim == UNKNOWN_DIM) || (dim == UNKNOWN_DIM_NUM); });
1113 }
1114
IsUnknownDimNum(const ge::Shape & shape)1115 bool IsUnknownDimNum(const ge::Shape &shape) {
1116 const auto &dims = shape.GetDims();
1117 return (dims.size() == 1UL) && (dims[0UL] == UNKNOWN_DIM_NUM);
1118 }
1119
IsScalar(const ge::Shape & shape)1120 bool IsScalar(const ge::Shape &shape) {
1121 const auto &dims = shape.GetDims();
1122 return dims.empty();
1123 }
1124
SetOpInferDepends(Operator & op,const std::vector<std::string> & depend_names)1125 void SetOpInferDepends(Operator &op, const std::vector<std::string> &depend_names) {
1126 op.SetAttr(ATTR_NAME_OP_INFER_DEPENDS, depend_names);
1127 }
1128
SetIsUnknownDimNum(ge::Shape & shape)1129 void SetIsUnknownDimNum(ge::Shape &shape) {
1130 std::vector<int64_t> dims(1UL, UNKNOWN_DIM_NUM);
1131 dims[0UL] = UNKNOWN_DIM_NUM;
1132 shape = ge::Shape(dims);
1133 }
1134
1135 namespace array_ops {
1136 // If not overflow return true
CheckInt64MulOverflow(int64_t a,int64_t b)1137 bool CheckInt64MulOverflow(int64_t a, int64_t b) {
1138 if (a > 0) {
1139 if (b > 0) {
1140 if (a > (INT64_MAX / b)) {
1141 return false;
1142 }
1143 } else {
1144 if (b < (INT64_MIN / a)) {
1145 return false;
1146 }
1147 }
1148 } else {
1149 if (b > 0) {
1150 if (a < (INT64_MIN / b)) {
1151 return false;
1152 }
1153 } else {
1154 if ((a != 0) && (b < (INT64_MAX / a))) {
1155 return false;
1156 }
1157 }
1158 }
1159
1160 return true;
1161 }
1162
CalcMaxElementsCount(const Operator & op,const std::vector<std::pair<int64_t,int64_t>> & x_shape_range,const Shape & x_shape)1163 int64_t CalcMaxElementsCount(const Operator &op, const std::vector<std::pair<int64_t, int64_t>> &x_shape_range,
1164 const Shape &x_shape) {
1165 int64_t max_elements_count = 1;
1166 auto x_shape_size = x_shape.GetShapeSize();
1167 if (x_shape_size > 0) {
1168 // when known dim, x_shape_size is max_elements_count
1169 max_elements_count = x_shape_size;
1170 } else {
1171 // unknown dim
1172 if (x_shape_range.empty()) {
1173 max_elements_count = -1;
1174 }
1175 for (const auto &x_range_i : x_shape_range) {
1176 if (x_range_i.second <= 0) {
1177 max_elements_count = -1;
1178 break;
1179 }
1180 if (array_ops::CheckInt64MulOverflow(max_elements_count, x_range_i.second)) {
1181 max_elements_count *= x_range_i.second;
1182 } else {
1183 max_elements_count = -1;
1184 break;
1185 }
1186 }
1187 }
1188
1189 return max_elements_count;
1190 }
1191
GenerateWorstYShapeAndYShapeRange(int64_t y_rank,int64_t max_elements_count,std::vector<std::pair<int64_t,int64_t>> & y_shape_range,Shape & y_shape)1192 void GenerateWorstYShapeAndYShapeRange(int64_t y_rank, int64_t max_elements_count,
1193 std::vector<std::pair<int64_t, int64_t>> &y_shape_range, Shape &y_shape) {
1194 y_shape = Shape(std::vector<int64_t>(y_rank, UNKNOWN_DIM));
1195 y_shape_range.clear();
1196 for (int64_t i = 0; i < y_rank; ++i) {
1197 y_shape_range.emplace_back(std::pair<int64_t, int64_t>(1, max_elements_count));
1198 }
1199 }
1200
RepairAndCheckRange(const std::vector<std::pair<int64_t,int64_t>> & x_shape_range,std::vector<std::pair<int64_t,int64_t>> & value_range)1201 bool RepairAndCheckRange(const std::vector<std::pair<int64_t, int64_t>> &x_shape_range,
1202 std::vector<std::pair<int64_t, int64_t>> &value_range) {
1203 bool has_zero_in_range = false;
1204 for (auto &range_i : value_range) {
1205 if (range_i.first < 0) {
1206 range_i.first = 1;
1207 }
1208 if (range_i.second < 0) {
1209 range_i.second = -1;
1210 }
1211 if (range_i.first == 0) {
1212 has_zero_in_range = true;
1213 }
1214 }
1215
1216 for (auto &range_i : x_shape_range) {
1217 if (range_i.first == 0) {
1218 has_zero_in_range = true;
1219 break;
1220 }
1221 }
1222 return has_zero_in_range;
1223 }
1224
InferShapeRangeForEmptyTensor(int64_t y_rank,int64_t max_elements_count,const std::vector<std::pair<int64_t,int64_t>> & value_range,std::vector<std::pair<int64_t,int64_t>> & y_shape_range,Shape & y_shape)1225 void InferShapeRangeForEmptyTensor(int64_t y_rank, int64_t max_elements_count,
1226 const std::vector<std::pair<int64_t, int64_t>> &value_range,
1227 std::vector<std::pair<int64_t, int64_t>> &y_shape_range, Shape &y_shape) {
1228 y_shape_range = value_range;
1229 int64_t known_dims_product = 1;
1230 std::vector<int64_t> y_dims = y_shape.GetDims();
1231 for (int64_t i = 0; i < y_rank; ++i) {
1232 if (y_shape_range[i].first == y_shape_range[i].second) {
1233 y_dims[i] = y_shape_range[i].first;
1234 if (max_elements_count != -1 && y_dims[i] != 0) {
1235 known_dims_product *= y_dims[i];
1236 }
1237 }
1238 }
1239 y_shape = Shape(y_dims);
1240
1241 if (known_dims_product != 1) {
1242 auto cur_dim_max_elements_count = (max_elements_count - 1) / known_dims_product + 1;
1243 for (int64_t i = 0; i < y_rank; ++i) {
1244 if (y_dims[i] == -1) {
1245 if (y_shape_range[i].second != -1) {
1246 y_shape_range[i].second = std::min(cur_dim_max_elements_count, y_shape_range[i].second);
1247 } else {
1248 y_shape_range[i].second = cur_dim_max_elements_count;
1249 }
1250 }
1251 }
1252 }
1253 }
1254
UpdateDimsAndShapeRange(const Operator & op,int64_t max_elements_count,const std::vector<std::pair<int64_t,int64_t>> & value_range,std::vector<int64_t> & y_dims,std::vector<std::pair<int64_t,int64_t>> & y_shape_range)1255 void UpdateDimsAndShapeRange(const Operator &op, int64_t max_elements_count,
1256 const std::vector<std::pair<int64_t, int64_t>> &value_range, std::vector<int64_t> &y_dims,
1257 std::vector<std::pair<int64_t, int64_t>> &y_shape_range) {
1258 size_t y_rank = y_dims.size();
1259 for (size_t i = 0; i < y_rank; ++i) {
1260 if (value_range[i].first == value_range[i].second) {
1261 y_dims[i] = value_range[i].first;
1262 y_shape_range[i] = std::pair<int64_t, int64_t>(y_dims[i], y_dims[i]);
1263 } else {
1264 if (max_elements_count == -1) {
1265 // while max_elements_count = -1, y shape range i is always value_range[i].second;
1266 y_shape_range[i] = std::pair<int64_t, int64_t>(value_range[i].first, value_range[i].second);
1267 continue;
1268 }
1269 int64_t other_dims_range_lower_boundary = 1;
1270 for (size_t j = 0; j < y_rank; ++j) {
1271 if (i != j) {
1272 other_dims_range_lower_boundary *= value_range[j].first;
1273 }
1274 }
1275 int64_t cur_dim_range_max = (max_elements_count - 1) / other_dims_range_lower_boundary + 1;
1276 if (value_range[i].second > 0) {
1277 cur_dim_range_max = std::min(cur_dim_range_max, value_range[i].second);
1278 }
1279 y_shape_range[i] = std::pair<int64_t, int64_t>(value_range[i].first, cur_dim_range_max);
1280 }
1281 }
1282 }
1283
CalculateMaxInputDims(const std::vector<std::pair<int64_t,int64_t>> & x_range,const Operator & op)1284 int64_t CalculateMaxInputDims(const std::vector<std::pair<int64_t, int64_t>> &x_range, const Operator &op) {
1285 int64_t max_input_dims = 1;
1286 for (const auto &pair : x_range) {
1287 if (pair.second < 0) {
1288 max_input_dims = -1;
1289 break;
1290 }
1291
1292 if (array_ops::CheckInt64MulOverflow(max_input_dims, pair.second)) {
1293 max_input_dims *= pair.second;
1294 } else {
1295 max_input_dims = INT64_MAX;
1296 GE_OP_LOGW(TbeGetName(op).c_str(), "Range Infer out of int64 max!Do set int64max!");
1297 break;
1298 }
1299 }
1300 return max_input_dims;
1301 }
1302 } // namespace array_ops
1303
IsSliceUnknownShape(const std::vector<int64_t> & dim_vec,const int64_t & begin,const int64_t & end)1304 bool IsSliceUnknownShape(const std::vector<int64_t> &dim_vec, const int64_t &begin, const int64_t &end) {
1305 if (begin < 0 || end >= static_cast<int64_t>(dim_vec.size())) {
1306 GE_OP_LOGE("FlattenV2", "index is out of range");
1307 return false;
1308 }
1309 for (int64_t i = begin; i < end + 1; i++) {
1310 if (dim_vec[i] == -1) {
1311 return true;
1312 }
1313 }
1314 return false;
1315 }
1316
1317 void SetOpInferDepends(Operator &op, const std::vector<std::string> &depend_names);
1318 } // namespace ge
1319