• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (c) 2022-2022 Huawei Technologies Co., Ltd.  All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /*!
18  * \file util.cpp
19  * \brief
20  */
21 #include "util.h"
22 #include <numeric>
23 #include <utility>
24 #include <string>
25 #include <vector>
26 #include <map>
27 #include <functional>
28 #include <algorithm>
29 #include <set>
30 #include "error_util.h"
31 #include "vector_proto_profiling.h"
32 #include "op_common_util.h"
33 
34 namespace ge {
35 using namespace std;
36 
GetInputDataType(const ge::DataType & data_type,const std::vector<ge::DataType> & supportList)37 bool GetInputDataType(const ge::DataType &data_type, const std::vector<ge::DataType> &supportList) {
38   std::vector<ge::DataType>::const_iterator supportIter = find(supportList.begin(), supportList.end(), data_type);
39   if (supportIter == supportList.end()) {
40     return false;
41   }
42   return true;
43 }
44 
CheckInputDtypeAndShape(const Operator & op,const std::map<std::string,std::vector<DataType>> & inputTensorMap)45 bool CheckInputDtypeAndShape(const Operator &op, const std::map<std::string, std::vector<DataType>> &inputTensorMap) {
46   auto iter = inputTensorMap.begin();
47   auto first_name = iter->first;
48   auto first_shape_dims = op.GetInputDescByName(iter->first.c_str()).GetShape().GetDims();
49   auto first_input_dtype = op.GetInputDescByName(iter->first.c_str()).GetDataType();
50   for (; iter != inputTensorMap.end(); ++iter) {
51     const TensorDesc input_desc = op.GetInputDescByName(iter->first.c_str());
52     // check input dtype
53     auto input_type = input_desc.GetDataType();
54     if (input_type != first_input_dtype) {
55       VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
56         TbeGetName(op),
57         OtherErrMsg(ConcatString("the op type of param ", iter->first, " must equal with param ", first_name)));
58       return false;
59     }
60     auto dims = input_desc.GetShape().GetDims();
61     if (dims != first_shape_dims) {
62       VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
63         TbeGetName(op),
64         OtherErrMsg(ConcatString("the op shape of param ", iter->first, " must equal with param ", first_name)));
65       return false;
66     }
67   }
68   return true;
69 }
70 
CheckInputDataType(const Operator & op,const std::string & input_name,const std::vector<ge::DataType> & support_list)71 bool CheckInputDataType(const Operator &op, const std::string &input_name,
72                         const std::vector<ge::DataType> &support_list) {
73   bool valid = false;
74   DataType input_type = op.GetInputDescByName(input_name.c_str()).GetDataType();
75   do {
76     const auto &found_list = find(support_list.begin(), support_list.end(), input_type);
77 
78     if (found_list == support_list.end()) {
79       break;
80     }
81 
82     const auto &found_map = DTYPE_STR_MAP.find(input_type);
83     if (found_map == DTYPE_STR_MAP.end()) {
84       break;
85     }
86 
87     valid = true;
88   } while (0);
89 
90   if (!valid) {
91     VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
92       TbeGetName(op), OtherErrMsg(ConcatString("The op do not support the dtype", GeDataTypeToString(input_type))));
93     return false;
94   }
95 
96   return true;
97 }
98 
CheckTwoInputDtypeSame(const Operator & op,const string & input_name1,const string & input_name2)99 bool CheckTwoInputDtypeSame(const Operator &op, const string &input_name1, const string &input_name2) {
100   DataType input_type_x1 = op.GetInputDesc(input_name1).GetDataType();
101   DataType input_type_x2 = op.GetInputDesc(input_name2).GetDataType();
102   if (input_type_x1 != input_type_x2) {
103     VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
104       TbeGetName(op), OtherErrMsg(ConcatString("The ", TbeGetName(op),
105                                                " op dtype is not same, type1:", GeDataTypeToString(input_type_x1),
106                                                ", type2:", GeDataTypeToString(input_type_x2))));
107     return false;
108   }
109 
110   return true;
111 }
112 
CheckInputDtypeSame(const Operator & op,const std::vector<std::string> & input_names)113 bool CheckInputDtypeSame(const Operator &op, const std::vector<std::string> &input_names) {
114   auto first_name = input_names.begin();
115   auto first_input_dtype = op.GetInputDescByName((*first_name).c_str()).GetDataType();
116   for (const string &input_name : input_names) {
117     const TensorDesc input_desc = op.GetInputDescByName(input_name.c_str());
118     auto input_dtype = input_desc.GetDataType();
119     if (input_dtype != first_input_dtype) {
120       auto error_ms = ConcatString("dtype of inputs must be same, ", input_name, ":", GeDataTypeToString(input_dtype),
121                                    ", ", (*first_name), ":", GeDataTypeToString(first_input_dtype), ".");
122       VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg(error_ms));
123       return false;
124     }
125   }
126   return true;
127 }
128 
CheckInputsShapeDtypeSame(const Operator & op,const std::vector<std::string> & input_names)129 bool CheckInputsShapeDtypeSame(const Operator &op, const std::vector<std::string> &input_names) {
130   auto first_input_name = input_names.begin();
131   auto first_input_des = op.GetInputDescByName((*first_input_name).c_str());
132   auto input_name = first_input_name;
133   for (++input_name; input_name != input_names.end(); ++input_name) {
134     auto input_des = op.GetInputDescByName((*first_input_name).c_str());
135     if (input_des.GetDataType() != first_input_des.GetDataType() ||
136         input_des.GetShape().GetDims() != first_input_des.GetShape().GetDims()) {
137       VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
138         TbeGetName(op), OtherErrMsg(ConcatString("the dtype and shape of param ", first_input_name->c_str(),
139                                                  " must be same as param ", input_name->c_str())));
140       return false;
141     }
142   }
143 
144   return true;
145 }
146 
TwoShapeAndRangeBroadcastIntegration(const Operator & op,std::vector<int64_t> & dimVec,std::vector<std::pair<int64_t,int64_t>> & Vec_range,std::vector<int64_t> dims,std::vector<std::pair<int64_t,int64_t>> range,const string & input_name1,const string & input_name2)147 bool TwoShapeAndRangeBroadcastIntegration(const Operator &op, std::vector<int64_t> &dimVec,
148                                           std::vector<std::pair<int64_t, int64_t>> &Vec_range,
149                                           std::vector<int64_t> dims, std::vector<std::pair<int64_t, int64_t>> range,
150                                           const string &input_name1, const string &input_name2) {
151   if (dimVec.size() < dims.size()) {
152     std::vector<int64_t> dimsTmp = dimVec;
153     dimVec = dims;
154     dims = dimsTmp;
155     std::vector<std::pair<int64_t, int64_t>> range_temp = Vec_range;
156     Vec_range = range;
157     range = range_temp;
158   }
159   if (dimVec.size() != dims.size()) {
160     int dec = static_cast<int>(dimVec.size() - dims.size());
161     for (int i = 0; i < dec; i++) {
162       dims.insert(dims.begin(), static_cast<int64_t>(1));
163     }
164   }
165   for (size_t i = 0; i < dimVec.size(); i++) {
166     CHECK((dimVec[i] != dims[i]) && (dimVec[i] != 1) && (dims[i] != 1) && (dimVec[i] != -1) && (dims[i] != -1),
167           VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
168             TbeGetName(op),
169             OtherErrMsg(ConcatString("The ", TbeGetName(op), "'s dimensions does not match the broadcast rule(",
170                                      dimVec[i], dims[i], ")."))),
171           return false);
172   }
173   dimVec = TwoBroadcastShape(dimVec, dims);
174   if (IsUnknown(dimVec)) {
175     MakeUpShapeRange(dims, range);
176     Vec_range = TwoShapeAndRangeBroadcast(dimVec, Vec_range, range);
177   }
178   return true;
179 }
180 
TwoBroadcastShape(const std::vector<int64_t> & dimsX,const std::vector<int64_t> & dimsY)181 std::vector<int64_t> TwoBroadcastShape(const std::vector<int64_t> &dimsX, const std::vector<int64_t> &dimsY) {
182   std::vector<int64_t> dimVec;
183   // when not dynamic case, do infer shape only
184   if (!IsUnknown(dimsY) && !IsUnknown(dimsX)) {
185     for (size_t i = 0; i < dimsX.size(); i++) {
186       int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
187       dims = (dimsY[i] == 0 || dimsX[i] == 0) ? 0 : dims;
188       dimVec.push_back(dims);
189     }
190     return dimVec;
191   }
192   // dynamic case
193   for (size_t i = 0; i < dimsX.size(); i++) {
194     if ((dimsX[i] == -1) && (dimsY[i] != -1)) {
195       if (dimsY[i] > 1) {
196         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
197         dimVec.push_back(dims);
198       } else if (dimsY[i] == 1) {
199         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
200         dimVec.push_back(dims);
201         dimVec[i] = -1;
202       } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) {
203         dimVec.push_back(0);
204       }
205     } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) {
206       if (dimsX[i] > 1) {
207         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
208         dimVec.push_back(dims);
209       } else if (dimsX[i] == 0) {
210         dimVec.push_back(0);
211       } else if (dimsX[i] == 1) {
212         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
213         dimVec.push_back(dims);
214         dimVec[i] = -1;
215       }
216     } else {
217       if ((dimsX[i] == -1) && (dimsY[i] == -1)) {
218         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
219         dimVec.push_back(dims);
220         dimVec[i] = -1;
221       } else {
222         if (dimsY[i] == 0 || dimsX[i] == 0) {
223           dimVec.push_back(0);
224         } else {
225           int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
226           dimVec.push_back(dims);
227         }
228       }
229     }
230   }
231   return dimVec;
232 }
233 
TwoShapeAndRangeBroadcast(const std::vector<int64_t> & dims_out,const std::vector<std::pair<int64_t,int64_t>> & shape_range_x,std::vector<std::pair<int64_t,int64_t>> & shape_range_y)234 std::vector<std::pair<int64_t, int64_t>> TwoShapeAndRangeBroadcast(
235   const std::vector<int64_t> &dims_out, const std::vector<std::pair<int64_t, int64_t>> &shape_range_x,
236   std::vector<std::pair<int64_t, int64_t>> &shape_range_y) {
237   size_t size_shape_out = dims_out.size();
238   std::vector<std::pair<int64_t, int64_t>> out_range;
239   if (!IsUnknownRankShape(dims_out)) {
240     while (shape_range_x.size() > shape_range_y.size()) {
241       shape_range_y.insert(shape_range_y.begin(), std::pair<int64_t, int64_t>(1, 1));
242     }
243     for (size_t i = 0; i < size_shape_out; i++) {
244       if (dims_out[i] != -1) {
245         out_range.push_back(std::pair<int64_t, int64_t>(dims_out[i], dims_out[i]));
246         continue;
247       }
248       if (i < shape_range_x.size() && i < shape_range_y.size()) {
249         if (shape_range_x[i].second == -1 && shape_range_y[i].second == 1) {
250           out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
251         } else if (shape_range_x[i].second == 1 && shape_range_y[i].second == -1) {
252           out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
253         } else if (shape_range_x[i].first == 1 || shape_range_y[i].first == 1) {
254           // one shape size maybe 1, so will support broadcast
255           // first_range == max first
256           int64_t first_range = std::max(shape_range_x[i].first, shape_range_y[i].first);
257           int64_t second_range = shape_range_x[i].first == 1 ? shape_range_y[i].second : shape_range_x[i].second;
258           if (shape_range_x[i].first == 1 && shape_range_y[i].first == 1) {
259             second_range = std::max(shape_range_x[i].second, shape_range_y[i].second);
260             second_range = (shape_range_x[i].second == -1 || shape_range_y[i].second == -1) ? -1 : second_range;
261           }
262           out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
263         } else {
264           // no 1 in range.first, mean no broadcast for range
265           // get intersect range
266           int64_t first_range = std::max(shape_range_x[i].first, shape_range_y[i].first);
267           int64_t second_range = std::min(shape_range_x[i].second, shape_range_y[i].second);
268           second_range = (shape_range_x[i].second == -1 || shape_range_y[i].second == -1)
269                            ? std::max(shape_range_x[i].second, shape_range_y[i].second)
270                            : second_range;
271           out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
272         }
273       }
274     }
275   }
276   return out_range;
277 }
278 
InferBroadcastshapeForStatic(const Shape & shape_x,const Shape & shape_y,Shape & shape_output)279 bool InferBroadcastshapeForStatic(const Shape &shape_x, const Shape &shape_y, Shape &shape_output) {
280   auto shape_x_len = shape_x.GetDimNum();
281   auto shape_y_len = shape_y.GetDimNum();
282 
283   OP_LOGI("BroadcastInfer", "input1 shape is: %s, input2 shape is: %s.", to_string(shape_x).c_str(),
284           to_string(shape_y).c_str());
285   std::vector<int64_t> output_shape;
286   if (shape_x_len >= shape_y_len) {
287     // when inputx len >= inputy len
288     // input_x = [128, 128, 128] Vs input_y = [128]
289     auto len_sub = shape_x_len - shape_y_len;
290     for (size_t i = 0; i < len_sub; i++) {
291       (void)output_shape.emplace_back(shape_x.GetDim(i));
292     }
293     for (size_t i = 0; i < shape_y_len; i++) {
294       int64_t dim_size = std::max(shape_x.GetDim(len_sub + i), shape_y.GetDim(i));
295       // if one dim is 0, the output dim is 0
296       dim_size = (shape_x.GetDim(len_sub + i) == 0 || shape_y.GetDim(i) == 0) ? 0 : dim_size;
297       (void)output_shape.emplace_back(dim_size);
298     }
299   } else {
300     // when inputx len < inputy len
301     // input_x = [128] Vs input_y = [128, 128, 128]
302     auto len_sub = shape_y_len - shape_x_len;
303     for (size_t i = 0; i < len_sub; i++) {
304       (void)output_shape.emplace_back(shape_y.GetDim(i));
305     }
306     for (size_t i = 0; i < shape_x_len; i++) {
307       int64_t dim_size = std::max(shape_y.GetDim(len_sub + i), shape_x.GetDim(i));
308       // if one dim is 0, the output dim is 0
309       dim_size = (shape_y.GetDim(len_sub + i) == 0 || shape_x.GetDim(i) == 0) ? 0 : dim_size;
310       (void)output_shape.emplace_back(dim_size);
311     }
312   }
313   shape_output = Shape(output_shape);
314   OP_LOGI("BroadcastInfer", "output1 shape is: %s.", to_string(shape_output).c_str());
315   return true;
316 }
317 
InferShapeAndTypeTwoInOneOutBroadcast(Operator & op,const string & input_name1,const string & input_name2,const string & output_name,bool & is_dynamic)318 bool InferShapeAndTypeTwoInOneOutBroadcast(Operator &op, const string &input_name1, const string &input_name2,
319                                            const string &output_name, bool &is_dynamic) {
320   PROFILING_PROTO_INIT(TbeGetName(op).c_str());
321   DataType input_dtype = op.GetInputDesc(input_name1).GetDataType();
322 
323   // output Desc
324   auto tensordesc_output = op.GetOutputDesc(output_name);
325   tensordesc_output.SetDataType(input_dtype);
326 
327   ge::Shape shapeX = op.GetInputDesc(input_name1).GetShape();
328   ge::Shape shapeY = op.GetInputDesc(input_name2).GetShape();
329   OP_LOGI(TbeGetName(op).c_str(), "shape %s: %s, shape %s: %s.", input_name1.c_str(), to_string(shapeX).c_str(),
330           input_name2.c_str(), to_string(shapeY).c_str());
331   std::vector<int64_t> dimsX = shapeX.GetDims();
332   std::vector<int64_t> dimsY = shapeY.GetDims();
333   PROFILING_PROTO_AFTER_GET_SHAPE_REG();
334   // swap based on shape size
335   if (dimsX.size() < dimsY.size()) {
336     std::vector<int64_t> dimsTmp = dimsX;
337     dimsX = dimsY;
338     dimsY = dimsTmp;
339   }
340 
341   // unknown rank
342   if (IsUnknownRankShape(dimsX) || IsUnknownRankShape(dimsY)) {
343     tensordesc_output.SetShape(ge::Shape(UNKNOWN_RANK));
344     OP_LOGI(TbeGetName(op).c_str(), "output shape is: %s, output dtype is:%d.",
345             to_string(ge::Shape(UNKNOWN_RANK)).c_str(), input_dtype);
346     is_dynamic = false;
347     op.UpdateOutputDesc(output_name, tensordesc_output);
348     return true;
349   }
350 
351   // pad 1 for small shape
352   if (dimsX.size() != dimsY.size()) {
353     int dec = static_cast<int>(dimsX.size() - dimsY.size());
354     for (int i = 0; i < dec; i++) {
355       dimsY.insert(dimsY.begin(), (int64_t)1);
356     }
357   }
358 
359   // when not dynamic case, do infer shape only
360   if (!IsUnKnownShape(dimsY) && !IsUnKnownShape(dimsX)) {
361     std::vector<int64_t> dimVec(dimsX.size(), 0);
362     for (size_t i = 0; i < dimsX.size(); i++) {
363       dimVec[i] = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
364       dimVec[i] = (dimsY[i] == 0 || dimsX[i] == 0) ? 0 : dimVec[i];
365     }
366 
367     PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
368     tensordesc_output.SetShape(ge::Shape(dimVec));
369     is_dynamic = false;
370     op.UpdateOutputDesc(output_name, tensordesc_output);
371     PROFILING_PROTO_END();
372     return true;
373   }
374 
375   std::vector<int64_t> dimVec;
376   // dynamic case
377   for (size_t i = 0; i < dimsX.size(); i++) {
378     CHECK((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1) && (dimsX[i] != -1) && (dimsY[i] != -1),
379           VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
380             TbeGetName(op),
381             OtherErrMsg(ConcatString("The ", TbeGetName(op), "'s dimensions does not match the broadcast rule(",
382                                      dimsX[i], dimsY[i], ")."))),
383           return false);
384 
385     if ((dimsX[i] == -1) && (dimsY[i] != -1)) {
386       if (dimsY[i] > 1) {
387         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
388         dimVec.push_back(dims);
389       } else if (dimsY[i] == 1) {
390         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
391         dimVec.push_back(dims);
392         dimVec[i] = -1;
393       } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) {
394         dimVec.push_back(-1);
395       }
396     } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) {
397       if (dimsX[i] > 1) {
398         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
399         dimVec.push_back(dims);
400       } else if (dimsX[i] == 0) {
401         dimVec.push_back(-1);
402       } else if (dimsX[i] == 1) {
403         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
404         dimVec.push_back(dims);
405         dimVec[i] = -1;
406       }
407     } else {
408       if ((dimsX[i] == -1) && (dimsY[i] == -1)) {
409         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
410         dimVec.push_back(dims);
411         dimVec[i] = -1;
412       } else {
413         if (dimsY[i] == 0 || dimsX[i] == 0) {
414           dimVec.push_back(0);
415         } else {
416           int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
417           dimVec.push_back(dims);
418         }
419       }
420     }
421   }
422   ge::Shape outputShape = ge::Shape(dimVec);
423   tensordesc_output.SetShape(outputShape);
424 
425   OP_LOGI(TbeGetName(op).c_str(), "output shape is: %s, output dtype is:%s.", to_string(outputShape).c_str(),
426           GeDataTypeToString(input_dtype).c_str());
427   is_dynamic = IsUnknown(dimVec);
428   if (is_dynamic) {
429     if (!InferShapeRangeTwoInOneOutBroadcast(op, input_name1, input_name2, output_name)) {
430       return false;
431     }
432   }
433   op.UpdateOutputDesc(output_name, tensordesc_output);
434   return true;
435 }
436 
InferShapeAndTypeTwoInOneOutBroadcast(Operator & op,const string & input_name1,const string & input_name2,const string & output_name)437 bool InferShapeAndTypeTwoInOneOutBroadcast(Operator &op, const string &input_name1, const string &input_name2,
438                                            const string &output_name) {
439   DataType input_dtype = op.GetInputDesc(input_name1).GetDataType();
440 
441   auto tensordesc_output = op.GetOutputDesc(output_name);
442 
443   ge::Shape shapeX = op.GetInputDesc(input_name1).GetShape();
444   ge::Shape shapeY = op.GetInputDesc(input_name2).GetShape();
445   OP_LOGI(TbeGetName(op).c_str(), "shape %s: %s, shape %s: %s.", input_name1.c_str(), to_string(shapeX).c_str(),
446           input_name2.c_str(), to_string(shapeY).c_str());
447   std::vector<int64_t> dimsX = shapeX.GetDims();
448   std::vector<int64_t> dimsY = shapeY.GetDims();
449   // swap based on shape size
450   if (dimsX.size() < dimsY.size()) {
451     std::vector<int64_t> dimsTmp = dimsX;
452     dimsX = dimsY;
453     dimsY = dimsTmp;
454   }
455 
456   std::vector<int64_t> dimVec;
457 
458   // unknown rank
459   if (IsUnknownRankShape(dimsX) || IsUnknownRankShape(dimsY)) {
460     tensordesc_output.SetShape(ge::Shape(UNKNOWN_RANK));
461     tensordesc_output.SetDataType(input_dtype);
462     OP_LOGI(TbeGetName(op).c_str(), "output shape is: %s, output dtype is:%d.",
463             to_string(ge::Shape(UNKNOWN_RANK)).c_str(), input_dtype);
464     op.UpdateOutputDesc(output_name, tensordesc_output);
465     return true;
466   }
467 
468   // pad 1 for small shape
469   if (dimsX.size() != dimsY.size()) {
470     int dec = static_cast<int>(dimsX.size() - dimsY.size());
471     for (int i = 0; i < dec; i++) {
472       dimsY.insert(dimsY.begin(), (int64_t)1);
473     }
474   }
475 
476   for (size_t i = 0; i < dimsX.size(); i++) {
477     CHECK((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1) && (dimsX[i] != -1) && (dimsY[i] != -1),
478           VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
479             TbeGetName(op),
480             OtherErrMsg(ConcatString("The ", TbeGetName(op), "'s dimensions does not match the broadcast rule(",
481                                      dimsX[i], dimsY[i], ")."))),
482           return false);
483 
484     if ((dimsX[i] == -1) && (dimsY[i] != -1)) {
485       if (dimsY[i] > 1) {
486         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
487         dimVec.push_back(dims);
488       } else if (dimsY[i] == 1) {
489         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
490         dimVec.push_back(dims);
491         dimVec[i] = -1;
492       } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) {
493         dimVec.push_back(0);
494       }
495     } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) {
496       if (dimsX[i] > 1) {
497         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
498         dimVec.push_back(dims);
499       } else if (dimsX[i] == 0) {
500         dimVec.push_back(0);
501       } else if (dimsX[i] == 1) {
502         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
503         dimVec.push_back(dims);
504         dimVec[i] = -1;
505       }
506     } else {
507       if ((dimsX[i] == -1) && (dimsY[i] == -1)) {
508         int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
509         dimVec.push_back(dims);
510         dimVec[i] = -1;
511       } else {
512         if (dimsY[i] == 0 || dimsX[i] == 0) {
513           dimVec.push_back(0);
514         } else {
515           int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
516           dimVec.push_back(dims);
517         }
518       }
519     }
520   }
521   ge::Shape outputShape = ge::Shape(dimVec);
522 
523   tensordesc_output.SetShape(outputShape);
524   tensordesc_output.SetDataType(input_dtype);
525   OP_LOGI(TbeGetName(op).c_str(), "output shape is: %s, output dtype is:%s.", to_string(outputShape).c_str(),
526           GeDataTypeToString(input_dtype).c_str());
527   op.UpdateOutputDesc(output_name, tensordesc_output);
528 
529   return true;
530 }
531 
ToFormatString(ge::Format format)532 std::string ToFormatString(ge::Format format) { return GeFormatToString(format); }
533 
AddToOutputRange(std::vector<std::pair<int64_t,int64_t>> & out_range,const std::pair<int64_t,int64_t> & shape_range_x,const std::pair<int64_t,int64_t> & shape_range_y)534 static void AddToOutputRange(std::vector<std::pair<int64_t, int64_t>> &out_range,
535                              const std::pair<int64_t, int64_t> &shape_range_x,
536                              const std::pair<int64_t, int64_t> &shape_range_y) {
537   // first_range == max first
538   int64_t first_range =
539     (shape_range_x.first * shape_range_y.first == 0) ? 0 : std::max(shape_range_x.first, shape_range_y.first);
540 
541   if (shape_range_x.second * shape_range_y.second == -1) {
542     out_range.push_back(std::pair<int64_t, int64_t>(first_range, -1));
543   } else if (shape_range_x.first == 1 && shape_range_y.first == 1) {
544     int64_t second_range = (shape_range_x.second == -1 || shape_range_y.second == -1)
545                              ? -1
546                              : std::max(shape_range_x.second, shape_range_y.second);
547     out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
548   } else if (shape_range_x.first == 1 || shape_range_y.first == 1) {
549     // one shape size maybe 1, so will support broadcast
550     int64_t second_range = shape_range_x.first == 1 ? shape_range_y.second : shape_range_x.second;
551     out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
552   } else {
553     // no 1 in range.first, mean no broadcast for range
554     // get intersect range
555     int64_t second_range = std::min(shape_range_x.second, shape_range_y.second);
556     second_range = (shape_range_x.second == -1 || shape_range_y.second == -1)
557                      ? std::max(shape_range_x.second, shape_range_y.second)
558                      : second_range;
559     out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
560   }
561 }
562 
InferShapeRangeTwoInOneOutBroadcast(Operator & op,const string & input_name1,const string & input_name2,const string & output_name)563 bool InferShapeRangeTwoInOneOutBroadcast(Operator &op, const string &input_name1, const string &input_name2,
564                                          const string &output_name) {
565   ge::Shape shape_x = op.GetInputDesc(input_name1).GetShape();
566   ge::Shape shape_y = op.GetInputDesc(input_name2).GetShape();
567 
568   std::vector<int64_t> dims_x = shape_x.GetDims();
569   std::vector<int64_t> dims_y = shape_y.GetDims();
570 
571   std::vector<std::pair<int64_t, int64_t>> shape_range_x;
572   op.GetInputDesc(input_name1).GetShapeRange(shape_range_x);
573   std::vector<std::pair<int64_t, int64_t>> shape_range_y;
574   op.GetInputDesc(input_name2).GetShapeRange(shape_range_y);
575 
576   MakeUpShapeRange(dims_x, shape_range_x);
577   MakeUpShapeRange(dims_y, shape_range_y);
578 
579   ge::Shape shape_out = op.GetOutputDesc(output_name).GetShape();
580   std::vector<int64_t> dims_out = shape_out.GetDims();
581   size_t size_shape_out = dims_out.size();
582 
583   std::vector<std::pair<int64_t, int64_t>> out_range;
584 
585   if (!IsUnknownRankShape(dims_out)) {
586     // shape switch by shape dim size
587     if (dims_x.size() < dims_y.size()) {
588       std::vector<int64_t> dims_tmp = dims_x;
589       dims_x = dims_y;
590       dims_y = dims_tmp;
591 
592       std::vector<std::pair<int64_t, int64_t>> range_temp = shape_range_x;
593       shape_range_x = shape_range_y;
594       shape_range_y = range_temp;
595     }
596 
597     while (dims_x.size() > shape_range_y.size()) {
598       shape_range_y.insert(shape_range_y.begin(), std::pair<int64_t, int64_t>(1, 1));
599     }
600 
601     for (size_t i = 0; i < size_shape_out; i++) {
602       if (dims_out[i] != -1) {
603         out_range.push_back(std::pair<int64_t, int64_t>(dims_out[i], dims_out[i]));
604         continue;
605       }
606       if (i < shape_range_x.size() && i < shape_range_y.size()) {
607         AddToOutputRange(out_range, shape_range_x[i], shape_range_y[i]);
608       }
609     }
610   }
611   OP_LOGI(TbeGetName(op).c_str(), "elewise out range is %s", to_string(out_range).c_str());
612   auto tensor_out = op.GetOutputDesc(output_name);
613   tensor_out.SetShapeRange(out_range);
614   op.UpdateOutputDesc(output_name, tensor_out);
615 
616   return true;
617 }
618 
GetInputDataType(const ge::DataType & dataType,const std::vector<ge::DataType> & supportList,std::string & dType)619 bool GetInputDataType(const ge::DataType &dataType, const std::vector<ge::DataType> &supportList, std::string &dType) {
620   std::vector<ge::DataType>::const_iterator supportIter = find(supportList.begin(), supportList.end(), dataType);
621   if (supportIter == supportList.end()) {
622     return false;
623   }
624 
625   std::map<ge::DataType, std::string>::const_iterator totalIter = DTYPE_STR_MAP.find(dataType);
626   if (totalIter == DTYPE_STR_MAP.end()) {
627     return false;
628   }
629 
630   dType = totalIter->second;
631   return true;
632 }
633 
CheckInputDataType(const Operator & op,std::string * data_type,const std::string & input_name,const std::vector<ge::DataType> & supportList)634 bool CheckInputDataType(const Operator &op, std::string *data_type, const std::string &input_name,
635                         const std::vector<ge::DataType> &supportList) {
636   DataType input_type = op.GetInputDescByName(input_name.c_str()).GetDataType();
637   if (false == GetInputDataType(input_type, supportList, *data_type)) {
638     LOG_ERROR("[ERROR]op [%s] [%s] do not supported dtype [%s]!\n", TbeGetName(op).c_str(), input_name.c_str(),
639               data_type->c_str());
640     return false;
641   }
642   return true;
643 }
644 
GetConstValue(const ge::Operator & op,const std::string & key_name,float & attr_value)645 bool GetConstValue(const ge::Operator &op, const std::string &key_name, float &attr_value) {
646   if (ge::GRAPH_SUCCESS != op.GetAttr(key_name.c_str(), attr_value)) {
647     LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", TbeGetName(op).c_str(), key_name.c_str());
648     return false;
649   }
650   return true;
651 }
652 
GetConstValue(const ge::Operator & op,const std::string & key_name,int64_t & attr_value)653 bool GetConstValue(const ge::Operator &op, const std::string &key_name, int64_t &attr_value) {
654   if (ge::GRAPH_SUCCESS != op.GetAttr(key_name.c_str(), attr_value)) {
655     LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", TbeGetName(op).c_str(), key_name.c_str());
656     return false;
657   }
658   return true;
659 }
660 
GetConstValue(const ge::Operator & op,const std::string & key_name,bool & attr_value)661 bool GetConstValue(const ge::Operator &op, const std::string &key_name, bool &attr_value) {
662   if (ge::GRAPH_SUCCESS != op.GetAttr(key_name.c_str(), attr_value)) {
663     LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", TbeGetName(op).c_str(), key_name.c_str());
664     return false;
665   }
666   return true;
667 }
668 
GetConstValue(const ge::Operator & op,const std::string & key_name,std::vector<int32_t> & attr_value)669 bool GetConstValue(const ge::Operator &op, const std::string &key_name, std::vector<int32_t> &attr_value) {
670   if (ge::GRAPH_SUCCESS != op.GetAttr(key_name.c_str(), attr_value)) {
671     LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", TbeGetName(op).c_str(), key_name.c_str());
672     return false;
673   }
674   return true;
675 }
676 
677 template <typename T>
GetConstIntData(const uint8_t * const_data,size_t data_size)678 static std::vector<int64_t> GetConstIntData(const uint8_t *const_data, size_t data_size) {
679   size_t size = data_size / sizeof(T);
680   std::vector<int64_t> result(size);
681   const T *data = reinterpret_cast<const T *>(const_data);
682   for (size_t i = 0; i < size; i++) {
683     result[i] = *(data + i);
684   }
685 
686   return result;
687 }
688 
GetConstIntData(const Tensor & data,DataType data_type,std::vector<int64_t> & const_values)689 bool GetConstIntData(const Tensor &data, DataType data_type, std::vector<int64_t> &const_values) {
690   using std::placeholders::_1;
691   using std::placeholders::_2;
692   const std::map<DataType, std::function<std::vector<int64_t>(const uint8_t *, size_t)>> type_call_map = {
693     {DT_INT8, std::bind(GetConstIntData<int8_t>, _1, _2)},
694     {DT_INT16, std::bind(GetConstIntData<int16_t>, _1, _2)},
695     {DT_INT32, std::bind(GetConstIntData<int32_t>, _1, _2)},
696     {DT_INT64, std::bind(GetConstIntData<int64_t>, _1, _2)},
697   };
698 
699   auto found = type_call_map.find(data_type);
700   if (found == type_call_map.end()) {
701     USER_GE_LOGE("[ERROR]GetConstIntData is not support data_type[%s]!", GeDataTypeToString(data_type).c_str());
702     return false;
703   }
704 
705   const_values = found->second(data.GetData(), data.GetSize());
706 
707   return true;
708 }
709 
GetConstValue(const Operator & op,const Tensor & const_tensor,const DataType & dtype,std::vector<int64_t> & const_data)710 bool GetConstValue(const Operator &op, const Tensor &const_tensor, const DataType &dtype,
711                    std::vector<int64_t> &const_data) {
712   CHECK(dtype != ge::DT_INT32 && dtype != ge::DT_INT64,
713         VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("not support this type")), return false);
714   if (dtype == ge::DT_INT32) {
715     const int32_t *const_data_ptr = reinterpret_cast<const int32_t *>(const_tensor.GetData());
716     size_t size = const_tensor.GetSize() / sizeof(int32_t);
717     for (size_t i = 0; i < size; ++i) {
718       const_data.push_back(static_cast<int32_t>(*(const_data_ptr + i)));
719       OP_LOGD(TbeGetName(op).c_str(), "const data int32 fusion pass ====== %d",
720               static_cast<int32_t>(*(const_data_ptr + i)));
721     }
722   } else if (dtype == ge::DT_INT64) {
723     const int64_t *const_data_ptr = reinterpret_cast<const int64_t *>(const_tensor.GetData());
724     size_t size = const_tensor.GetSize() / sizeof(int64_t);
725     for (size_t i = 0; i < size; ++i) {
726       const_data.push_back(static_cast<int64_t>(*(const_data_ptr + i)));
727       OP_LOGD(TbeGetName(op).c_str(), "const data int64 fusion pass ====== %ld",
728               static_cast<int64_t>(*(const_data_ptr + i)));
729     }
730   }
731   return true;
732 }
733 
GetConstValue(const Operator & op,const Tensor & const_tensor,const DataType & dtype,std::vector<uint64_t> & const_data)734 bool GetConstValue(const Operator &op, const Tensor &const_tensor, const DataType &dtype,
735                    std::vector<uint64_t> &const_data) {
736   size_t size = 0;
737   CHECK(dtype != ge::DT_UINT64,
738         VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("not support this type")), return false);
739   const uint64_t *const_data_ptr = reinterpret_cast<const uint64_t *>(const_tensor.GetData());
740   size = const_tensor.GetSize() / sizeof(uint64_t);
741   for (size_t i = 0; i < size; ++i) {
742     const_data.push_back(static_cast<uint64_t>(*(const_data_ptr + i)));
743     OP_LOGD(TbeGetName(op).c_str(), "const data uint64 fusion pass, const_data[%lu]",
744             static_cast<uint64_t>(*(const_data_ptr + i)));
745   }
746   return true;
747 }
748 
GetScalerValue(const Operator & op,const Tensor & const_tensor,const DataType & dtype,std::int64_t & const_data)749 bool GetScalerValue(const Operator &op, const Tensor &const_tensor, const DataType &dtype, std::int64_t &const_data) {
750   if (dtype == ge::DT_INT32) {
751     const int32_t *const_data_ptr = reinterpret_cast<const int32_t *>(const_tensor.GetData());
752     const_data = static_cast<int32_t>(*const_data_ptr);
753   } else if (dtype == ge::DT_INT64) {
754     const int64_t *const_data_ptr = reinterpret_cast<const int64_t *>(const_tensor.GetData());
755     const_data = static_cast<int64_t>(*const_data_ptr);
756   } else {
757     VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg(ConcatString("not support this type:", dtype)));
758     return false;
759   }
760   return true;
761 }
762 
to_string(const std::vector<int64_t> & shape)763 std::string to_string(const std::vector<int64_t> &shape) { return ops::to_string(shape); }
764 
to_string(const ge::Shape & shape)765 std::string to_string(const ge::Shape &shape) { return to_string(shape.GetDims()); }
766 
to_string(const std::vector<std::pair<int64_t,int64_t>> & ranges)767 std::string to_string(const std::vector<std::pair<int64_t, int64_t>> &ranges) { return ops::to_string(ranges); }
768 
769 static std::map<ge::DataType, std::string> kDataTypeToStringMap = {{ge::DataType::DT_FLOAT, "float"},
770                                                                    {ge::DataType::DT_FLOAT16, "float16"},
771                                                                    {ge::DataType::DT_INT8, "int8"},
772                                                                    {ge::DataType::DT_INT16, "int16"},
773                                                                    {ge::DataType::DT_UINT16, "uint16"},
774                                                                    {ge::DataType::DT_UINT8, "uint8"},
775                                                                    {ge::DataType::DT_INT32, "int32"},
776                                                                    {ge::DataType::DT_INT64, "int64"},
777                                                                    {ge::DataType::DT_UINT32, "uint32"},
778                                                                    {ge::DataType::DT_UINT64, "uint64"},
779                                                                    {ge::DataType::DT_BOOL, "bool"},
780                                                                    {ge::DataType::DT_DOUBLE, "double"},
781                                                                    {ge::DataType::DT_STRING, "string"},
782                                                                    {ge::DataType::DT_DUAL_SUB_INT8, "dual_sub_int8"},
783                                                                    {ge::DataType::DT_DUAL_SUB_UINT8, "dual_sub_uint8"},
784                                                                    {ge::DataType::DT_COMPLEX64, "complex64"},
785                                                                    {ge::DataType::DT_COMPLEX128, "complex128"},
786                                                                    {ge::DataType::DT_DUAL, "dual"},
787                                                                    {ge::DataType::DT_QINT8, "qint8"},
788                                                                    {ge::DataType::DT_QINT16, "qint16"},
789                                                                    {ge::DataType::DT_QINT32, "qint32"},
790                                                                    {ge::DataType::DT_QUINT8, "quint8"},
791                                                                    {ge::DataType::DT_QUINT16, "quint16"},
792                                                                    {ge::DataType::DT_RESOURCE, "resource"},
793                                                                    {ge::DataType::DT_STRING_REF, "string ref"},
794                                                                    {ge::DataType::DT_VARIANT, "dt_variant"},
795                                                                    {ge::DataType::DT_UNDEFINED, "undefined"},
796                                                                    {ge::DataType::DT_INT4, "int4"},
797                                                                    {ge::DataType::DT_UINT1, "uint1"},
798                                                                    {ge::DataType::DT_INT2, "int2"},
799                                                                    {ge::DataType::DT_UINT2, "uint2"},
800                                                                    {ge::DataType::DT_COMPLEX32, "complex32"},
801                                                                    {ge::DataType::DT_BF16, "bf16"}};
802 
803 static std::map<ge::Format, std::string> kFormatToStringMap = {
804   {ge::Format::FORMAT_NCHW, "NCHW"},
805   {ge::Format::FORMAT_NHWC, "NHWC"},
806   {ge::Format::FORMAT_ND, "Nd"},
807   {ge::Format::FORMAT_NC1HWC0, "NC1HWC0"},
808   {ge::Format::FORMAT_FRACTAL_Z, "FRACTAL_Z"},
809   {ge::Format::FORMAT_NC1C0HWPAD, "NC1C0HWPAD"},
810   {ge::Format::FORMAT_NHWC1C0, "NHWC1C0"},
811   {ge::Format::FORMAT_FSR_NCHW, "FSR_NCHW"},
812   {ge::Format::FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"},
813   {ge::Format::FORMAT_C1HWNC0, "C1HWNC0"},
814   {ge::Format::FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"},
815   {ge::Format::FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"},
816   {ge::Format::FORMAT_NC1HWC0_C04, "NC1HWC0_C04"},
817   {ge::Format::FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"},
818   {ge::Format::FORMAT_CHWN, "CHWN"},
819   {ge::Format::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "FRACTAL_DECONV_SP_STRIDE8_TRANS"},
820   {ge::Format::FORMAT_HWCN, "HWCN"},
821   {ge::Format::FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"},
822   {ge::Format::FORMAT_BN_WEIGHT, "BN_WEIGHT"},
823   {ge::Format::FORMAT_FILTER_HWCK, "FILTER_HWCK"},
824   {ge::Format::FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "HASHTABLE_LOOKUP_LOOKUPS"},
825   {ge::Format::FORMAT_HASHTABLE_LOOKUP_KEYS, "HASHTABLE_LOOKUP_KEYS"},
826   {ge::Format::FORMAT_HASHTABLE_LOOKUP_VALUE, "HASHTABLE_LOOKUP_VALUE"},
827   {ge::Format::FORMAT_HASHTABLE_LOOKUP_OUTPUT, "HASHTABLE_LOOKUP_OUTPUT"},
828   {ge::Format::FORMAT_HASHTABLE_LOOKUP_HITS, "HASHTABLE_LOOKUP_HITS"},
829   {ge::Format::FORMAT_C1HWNCoC0, "C1HWNCoC0"},
830   {ge::Format::FORMAT_MD, "MD"},
831   {ge::Format::FORMAT_NDHWC, "NDHWC"},
832   {ge::Format::FORMAT_FRACTAL_ZZ, "FRACTAL_ZZ"},
833   {ge::Format::FORMAT_FRACTAL_NZ, "FRACTAL_NZ"},
834   {ge::Format::FORMAT_NCDHW, "NCDHW"},
835   {ge::Format::FORMAT_DHWCN, "DHWCN"},
836   {ge::Format::FORMAT_NDC1HWC0, "NDC1HWC0"},
837   {ge::Format::FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"},
838   {ge::Format::FORMAT_CN, "CN"},
839   {ge::Format::FORMAT_NC, "NC"},
840   {ge::Format::FORMAT_DHWNC, "DHWNC"},
841   {ge::Format::FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"},
842   {ge::Format::FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"},
843   {ge::Format::FORMAT_FRACTAL_Z_G, "FRACTAL_Z_G"},
844   {ge::Format::FORMAT_RESERVED, "RESERVED"},
845   {ge::Format::FORMAT_ALL, "ALL"},
846   {ge::Format::FORMAT_NULL, "NULL"},
847   {ge::Format::FORMAT_ND_RNN_BIAS, "ND_RNN_BIAS"},
848   {ge::Format::FORMAT_FRACTAL_ZN_RNN, "FRACTAL_ZN_RNN"},
849   {ge::Format::FORMAT_NYUV, "NYUV"},
850   {ge::Format::FORMAT_NYUV_A, "NYUV_A"},
851   {ge::Format::FORMAT_NCL, "NCL"}};
852 
GeDataTypeToString(const ge::DataType datatype)853 std::string GeDataTypeToString(const ge::DataType datatype) {
854   auto iter = kDataTypeToStringMap.find(datatype);
855   if (iter != kDataTypeToStringMap.end()) {
856     return iter->second;
857   }
858   return "";
859 }
860 
GeFormatToString(const ge::Format format)861 std::string GeFormatToString(const ge::Format format) {
862   auto iter = kFormatToStringMap.find(format);
863   if (iter != kFormatToStringMap.end()) {
864     return iter->second;
865   }
866   return "";
867 }
868 
IsEmptyTensor(const std::vector<int64_t> & dims)869 bool IsEmptyTensor(const std::vector<int64_t> &dims) {
870   if (dims.size() == 1 && dims[0] == 0) {
871     return true;
872   } else {
873     return false;
874   }
875 }
876 
IsUnknownRank(const Operator & op,const std::string & tensor_name,const std::string & types)877 bool IsUnknownRank(const Operator &op, const std::string &tensor_name, const std::string &types) {
878   TensorDesc tensor_desc;
879   if (types == "input") {
880     tensor_desc = op.GetInputDesc(tensor_name);
881   } else if (types == "output") {
882     tensor_desc = op.GetOutputDesc(tensor_name);
883   } else {
884     VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
885                                         OtherErrMsg(ConcatString("invalid params:", types, " of types to judge.")));
886     return false;
887   }
888 
889   std::vector<int64_t> shape_vec = tensor_desc.GetShape().GetDims();
890   if (shape_vec.size() == 1 && shape_vec[0] == INPUT_NEGATIVE_NUM2) {
891     return true;
892   }
893   return false;
894 }
895 
IsUnknownRankShape(const std::vector<int64_t> & shape_vec)896 bool IsUnknownRankShape(const std::vector<int64_t> &shape_vec) {
897   if (shape_vec.size() == 1 && shape_vec[0] == ge::UNKNOWN_DIM_NUM) {
898     return true;
899   }
900   return false;
901 }
902 
IsUnknownRankShape(const Shape & input_shape)903 bool IsUnknownRankShape(const Shape &input_shape) {
904   auto dims = input_shape.GetDims();
905   return (dims.size() == 1UL) && (dims[0UL] == UNKNOWN_DIM_NUM);
906 }
907 
IsUnKnownShape(const std::vector<int64_t> & shape_vec)908 bool IsUnKnownShape(const std::vector<int64_t> &shape_vec) {
909   auto found = find(shape_vec.begin(), shape_vec.end(), -1);
910   return found != shape_vec.end();
911 }
912 
IsUnknown(const std::vector<int64_t> & shape_vec)913 bool IsUnknown(const std::vector<int64_t> &shape_vec) {
914   return (IsUnKnownShape(shape_vec) || IsUnknownRankShape(shape_vec));
915 }
916 
IsUnknownVec(std::vector<int64_t> & shape_vec)917 bool IsUnknownVec(std::vector<int64_t> &shape_vec) {
918   std::vector<int64_t>::iterator it_shape = find(shape_vec.begin(), shape_vec.end(), -1);
919   if (it_shape == shape_vec.end()) {
920     return false;
921   } else {
922     return true;
923   }
924 }
925 
MakeUpShapeRange(const std::vector<int64_t> & shape,std::vector<std::pair<int64_t,int64_t>> & range)926 void MakeUpShapeRange(const std::vector<int64_t> &shape, std::vector<std::pair<int64_t, int64_t>> &range) {
927   if (IsUnknownRankShape(shape)) {
928     return;
929   }
930 
931   if (range.empty()) {
932     for (size_t i = 0; i < shape.size(); i++) {
933       if (shape[i] == -1) {
934         range.push_back(std::pair<int64_t, int64_t>(0, -1));
935       } else {
936         range.push_back(std::pair<int64_t, int64_t>(shape[i], shape[i]));
937       }
938     }
939   }
940 }
941 
MakeUpShapeRange(const ge::Shape & shape,std::vector<std::pair<int64_t,int64_t>> & range)942 void MakeUpShapeRange(const ge::Shape &shape, std::vector<std::pair<int64_t, int64_t>> &range) {
943   if (IsUnknownRankShape(shape)) {
944     return;
945   }
946 
947   if (range.empty()) {
948     for (size_t i = 0; i < shape.GetDimNum(); i++) {
949       int64_t dim = shape.GetDim(i);
950       if (dim == -1) {
951         range.push_back(std::pair<int64_t, int64_t>(0, -1));
952       } else {
953         range.push_back(std::pair<int64_t, int64_t>(dim, dim));
954       }
955     }
956   }
957 }
958 
DataTypeToStringDesc(const ge::DataType & dataType)959 std::string DataTypeToStringDesc(const ge::DataType &dataType) {
960   std::map<ge::DataType, std::string>::const_iterator totalIter = DTYPE_STR_MAP.find(dataType);
961   if (totalIter == DTYPE_STR_MAP.end()) {
962     return "UNDEFINED";
963   }
964   return totalIter->second;
965 }
966 
OneInOneOutDynamicInfer(Operator & op,const std::string & input_name,const std::vector<std::string> & output_name_list)967 bool OneInOneOutDynamicInfer(Operator &op, const std::string &input_name,
968                              const std::vector<std::string> &output_name_list) {
969   // get input desc
970   PROFILING_PROTO_INIT(TbeGetName(op).c_str());
971   auto input_desc = op.GetInputDesc(input_name);
972   vector<int64_t> input_shape = input_desc.GetShape().GetDims();
973   DataType input_dtype = input_desc.GetDataType();
974 
975   if (IsUnknown(input_shape)) {
976     std::vector<std::pair<int64_t, int64_t>> input_range;
977     input_desc.GetShapeRange(input_range);
978     MakeUpShapeRange(input_shape, input_range);
979 
980     auto output_desc = op.GetOutputDesc(0);
981     for (const string &output_name : output_name_list) {
982       output_desc = op.GetOutputDesc(output_name);
983       output_desc.SetShape(Shape(input_shape));
984       output_desc.SetOriginShape(Shape(input_shape));
985       output_desc.SetShapeRange(input_range);
986       output_desc.SetDataType(input_dtype);
987       op.UpdateOutputDesc(output_name, output_desc);
988     }
989   } else {
990     auto output_desc = op.GetOutputDesc(0);
991     PROFILING_PROTO_AFTER_GET_SHAPE_REG();
992     PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
993     for (const string &output_name : output_name_list) {
994       output_desc = op.GetOutputDesc(output_name);
995       output_desc.SetShape(Shape(input_shape));
996       output_desc.SetDataType(input_dtype);
997       op.UpdateOutputDesc(output_name, output_desc);
998     }
999     PROFILING_PROTO_END();
1000   }
1001   return true;
1002 }
1003 
FixShapeRangeWithDims(const std::vector<int64_t> & dims,std::vector<int64_t> & shape_1,std::vector<int64_t> & shape_2,std::vector<std::pair<int64_t,int64_t>> & range_1,std::vector<std::pair<int64_t,int64_t>> & range_2)1004 void FixShapeRangeWithDims(const std::vector<int64_t> &dims, std::vector<int64_t> &shape_1,
1005                            std::vector<int64_t> &shape_2, std::vector<std::pair<int64_t, int64_t>> &range_1,
1006                            std::vector<std::pair<int64_t, int64_t>> &range_2) {
1007   MakeUpShapeRange(shape_1, range_1);
1008   MakeUpShapeRange(shape_2, range_2);
1009   bool is_all_fix = dims.empty();
1010 
1011   if (shape_1 == UNKNOWN_RANK && shape_2 == UNKNOWN_RANK) {
1012     return;
1013   }
1014   if (shape_1 == UNKNOWN_RANK) {
1015     shape_1 = shape_2;
1016     range_1 = range_2;
1017     return;
1018   }
1019   if (shape_2 == UNKNOWN_RANK) {
1020     shape_2 = shape_1;
1021     range_2 = range_1;
1022     return;
1023   }
1024   if ((shape_1.size() != shape_2.size()) || (range_1.size() != range_2.size())) {
1025     return;
1026   }
1027   auto loop_size = is_all_fix ? shape_1.size() : dims.size();
1028   for (size_t i = 0; i < loop_size; i++) {
1029     auto dim_num = is_all_fix ? i : dims[i];
1030     if (shape_1[dim_num] != -1) {
1031       shape_2[dim_num] = shape_1[dim_num];
1032       range_1[dim_num] = std::pair<int64_t, int64_t>(shape_1[dim_num], shape_1[dim_num]);
1033       range_2[dim_num] = std::pair<int64_t, int64_t>(shape_1[dim_num], shape_1[dim_num]);
1034       continue;
1035     }
1036     if (shape_2[dim_num] != -1) {
1037       shape_1[dim_num] = shape_2[dim_num];
1038       range_1[dim_num] = std::pair<int64_t, int64_t>(shape_2[dim_num], shape_2[dim_num]);
1039       range_2[dim_num] = std::pair<int64_t, int64_t>(shape_2[dim_num], shape_2[dim_num]);
1040       continue;
1041     }
1042     // both the dim in shape1 and shape2 are -1
1043     auto range_1_min = range_1[dim_num].first;
1044     auto range_2_min = range_2[dim_num].first;
1045     auto range_1_max = range_1[dim_num].second;
1046     auto range_2_max = range_2[dim_num].second;
1047     auto range_fisrt = range_1_min > range_2_min ? range_1_min : range_2_min;
1048     auto range_second_min = range_1_max > range_2_max ? range_2_max : range_1_max;
1049     auto range_second_max = range_1_max > range_2_max ? range_1_max : range_2_max;
1050     range_second_min = range_second_min == -1 ? range_second_max : range_second_min;
1051     range_1[dim_num] = std::pair<int64_t, int64_t>(range_fisrt, range_second_min);
1052     range_2[dim_num] = std::pair<int64_t, int64_t>(range_fisrt, range_second_min);
1053   }
1054 }
1055 
TwoInOneOutDynamicInferNoBroadcast(Operator & op,const string & input1_name,const string & input2_name,const std::vector<string> & output_name_list)1056 bool TwoInOneOutDynamicInferNoBroadcast(Operator &op, const string &input1_name, const string &input2_name,
1057                                         const std::vector<string> &output_name_list) {
1058   // get input1 desc
1059   auto input1_desc = op.GetInputDesc(input1_name);
1060   vector<int64_t> input1_shape = input1_desc.GetShape().GetDims();
1061   DataType input_dtype = input1_desc.GetDataType();
1062 
1063   // get input2 desc
1064   auto input2_desc = op.GetInputDesc(input2_name);
1065   vector<int64_t> input2_shape = input2_desc.GetShape().GetDims();
1066 
1067   if (IsUnknown(input1_shape) || IsUnknown(input2_shape)) {
1068     std::vector<std::pair<int64_t, int64_t>> input1_range;
1069     input1_desc.GetShapeRange(input1_range);
1070     std::vector<std::pair<int64_t, int64_t>> input2_range;
1071     input2_desc.GetShapeRange(input2_range);
1072 
1073     vector<int64_t> dim_size = {};
1074     FixShapeRangeWithDims(dim_size, input1_shape, input2_shape, input1_range, input2_range);
1075 
1076     // update output desc
1077     for (const string &output_name : output_name_list) {
1078       auto output_desc = op.GetOutputDesc(output_name);
1079       output_desc.SetShape(Shape(input1_shape));
1080       output_desc.SetOriginShape(Shape(input1_shape));
1081       output_desc.SetShapeRange(input1_range);
1082       output_desc.SetDataType(input_dtype);
1083       op.UpdateOutputDesc(output_name, output_desc);
1084     }
1085   } else {
1086     for (const string &output_name : output_name_list) {
1087       auto output_desc = op.GetOutputDesc(output_name);
1088       output_desc.SetShape(Shape(input1_shape));
1089       output_desc.SetDataType(input_dtype);
1090       op.UpdateOutputDesc(output_name, output_desc);
1091     }
1092   }
1093   return true;
1094 }
1095 
IsEmptyTensor(TensorDesc tensor_desc)1096 bool IsEmptyTensor(TensorDesc tensor_desc) { return IsEmptyTensor(tensor_desc.GetShape()); }
1097 
IsEmptyTensor(const Shape & ge_shape)1098 bool IsEmptyTensor(const Shape &ge_shape) {
1099   bool is_empty = false;
1100   for (const auto &dim : ge_shape.GetDims()) {
1101     if (dim == 0) {
1102       is_empty = true;
1103       break;
1104     }
1105   }
1106   return is_empty;
1107 }
1108 
IsUnknownShape(const ge::Shape & shape)1109 bool IsUnknownShape(const ge::Shape &shape) {
1110   const auto &dims = shape.GetDims();
1111   return std::any_of(dims.begin(), dims.end(),
1112                      [](const int64_t &dim) { return (dim == UNKNOWN_DIM) || (dim == UNKNOWN_DIM_NUM); });
1113 }
1114 
IsUnknownDimNum(const ge::Shape & shape)1115 bool IsUnknownDimNum(const ge::Shape &shape) {
1116   const auto &dims = shape.GetDims();
1117   return (dims.size() == 1UL) && (dims[0UL] == UNKNOWN_DIM_NUM);
1118 }
1119 
IsScalar(const ge::Shape & shape)1120 bool IsScalar(const ge::Shape &shape) {
1121   const auto &dims = shape.GetDims();
1122   return dims.empty();
1123 }
1124 
SetOpInferDepends(Operator & op,const std::vector<std::string> & depend_names)1125 void SetOpInferDepends(Operator &op, const std::vector<std::string> &depend_names) {
1126   op.SetAttr(ATTR_NAME_OP_INFER_DEPENDS, depend_names);
1127 }
1128 
SetIsUnknownDimNum(ge::Shape & shape)1129 void SetIsUnknownDimNum(ge::Shape &shape) {
1130   std::vector<int64_t> dims(1UL, UNKNOWN_DIM_NUM);
1131   dims[0UL] = UNKNOWN_DIM_NUM;
1132   shape = ge::Shape(dims);
1133 }
1134 
1135 namespace array_ops {
1136 // If not overflow return true
CheckInt64MulOverflow(int64_t a,int64_t b)1137 bool CheckInt64MulOverflow(int64_t a, int64_t b) {
1138   if (a > 0) {
1139     if (b > 0) {
1140       if (a > (INT64_MAX / b)) {
1141         return false;
1142       }
1143     } else {
1144       if (b < (INT64_MIN / a)) {
1145         return false;
1146       }
1147     }
1148   } else {
1149     if (b > 0) {
1150       if (a < (INT64_MIN / b)) {
1151         return false;
1152       }
1153     } else {
1154       if ((a != 0) && (b < (INT64_MAX / a))) {
1155         return false;
1156       }
1157     }
1158   }
1159 
1160   return true;
1161 }
1162 
CalcMaxElementsCount(const Operator & op,const std::vector<std::pair<int64_t,int64_t>> & x_shape_range,const Shape & x_shape)1163 int64_t CalcMaxElementsCount(const Operator &op, const std::vector<std::pair<int64_t, int64_t>> &x_shape_range,
1164                              const Shape &x_shape) {
1165   int64_t max_elements_count = 1;
1166   auto x_shape_size = x_shape.GetShapeSize();
1167   if (x_shape_size > 0) {
1168     // when known dim, x_shape_size is max_elements_count
1169     max_elements_count = x_shape_size;
1170   } else {
1171     // unknown dim
1172     if (x_shape_range.empty()) {
1173       max_elements_count = -1;
1174     }
1175     for (const auto &x_range_i : x_shape_range) {
1176       if (x_range_i.second <= 0) {
1177         max_elements_count = -1;
1178         break;
1179       }
1180       if (array_ops::CheckInt64MulOverflow(max_elements_count, x_range_i.second)) {
1181         max_elements_count *= x_range_i.second;
1182       } else {
1183         max_elements_count = -1;
1184         break;
1185       }
1186     }
1187   }
1188 
1189   return max_elements_count;
1190 }
1191 
GenerateWorstYShapeAndYShapeRange(int64_t y_rank,int64_t max_elements_count,std::vector<std::pair<int64_t,int64_t>> & y_shape_range,Shape & y_shape)1192 void GenerateWorstYShapeAndYShapeRange(int64_t y_rank, int64_t max_elements_count,
1193                                        std::vector<std::pair<int64_t, int64_t>> &y_shape_range, Shape &y_shape) {
1194   y_shape = Shape(std::vector<int64_t>(y_rank, UNKNOWN_DIM));
1195   y_shape_range.clear();
1196   for (int64_t i = 0; i < y_rank; ++i) {
1197     y_shape_range.emplace_back(std::pair<int64_t, int64_t>(1, max_elements_count));
1198   }
1199 }
1200 
RepairAndCheckRange(const std::vector<std::pair<int64_t,int64_t>> & x_shape_range,std::vector<std::pair<int64_t,int64_t>> & value_range)1201 bool RepairAndCheckRange(const std::vector<std::pair<int64_t, int64_t>> &x_shape_range,
1202                          std::vector<std::pair<int64_t, int64_t>> &value_range) {
1203   bool has_zero_in_range = false;
1204   for (auto &range_i : value_range) {
1205     if (range_i.first < 0) {
1206       range_i.first = 1;
1207     }
1208     if (range_i.second < 0) {
1209       range_i.second = -1;
1210     }
1211     if (range_i.first == 0) {
1212       has_zero_in_range = true;
1213     }
1214   }
1215 
1216   for (auto &range_i : x_shape_range) {
1217     if (range_i.first == 0) {
1218       has_zero_in_range = true;
1219       break;
1220     }
1221   }
1222   return has_zero_in_range;
1223 }
1224 
InferShapeRangeForEmptyTensor(int64_t y_rank,int64_t max_elements_count,const std::vector<std::pair<int64_t,int64_t>> & value_range,std::vector<std::pair<int64_t,int64_t>> & y_shape_range,Shape & y_shape)1225 void InferShapeRangeForEmptyTensor(int64_t y_rank, int64_t max_elements_count,
1226                                    const std::vector<std::pair<int64_t, int64_t>> &value_range,
1227                                    std::vector<std::pair<int64_t, int64_t>> &y_shape_range, Shape &y_shape) {
1228   y_shape_range = value_range;
1229   int64_t known_dims_product = 1;
1230   std::vector<int64_t> y_dims = y_shape.GetDims();
1231   for (int64_t i = 0; i < y_rank; ++i) {
1232     if (y_shape_range[i].first == y_shape_range[i].second) {
1233       y_dims[i] = y_shape_range[i].first;
1234       if (max_elements_count != -1 && y_dims[i] != 0) {
1235         known_dims_product *= y_dims[i];
1236       }
1237     }
1238   }
1239   y_shape = Shape(y_dims);
1240 
1241   if (known_dims_product != 1) {
1242     auto cur_dim_max_elements_count = (max_elements_count - 1) / known_dims_product + 1;
1243     for (int64_t i = 0; i < y_rank; ++i) {
1244       if (y_dims[i] == -1) {
1245         if (y_shape_range[i].second != -1) {
1246           y_shape_range[i].second = std::min(cur_dim_max_elements_count, y_shape_range[i].second);
1247         } else {
1248           y_shape_range[i].second = cur_dim_max_elements_count;
1249         }
1250       }
1251     }
1252   }
1253 }
1254 
UpdateDimsAndShapeRange(const Operator & op,int64_t max_elements_count,const std::vector<std::pair<int64_t,int64_t>> & value_range,std::vector<int64_t> & y_dims,std::vector<std::pair<int64_t,int64_t>> & y_shape_range)1255 void UpdateDimsAndShapeRange(const Operator &op, int64_t max_elements_count,
1256                              const std::vector<std::pair<int64_t, int64_t>> &value_range, std::vector<int64_t> &y_dims,
1257                              std::vector<std::pair<int64_t, int64_t>> &y_shape_range) {
1258   size_t y_rank = y_dims.size();
1259   for (size_t i = 0; i < y_rank; ++i) {
1260     if (value_range[i].first == value_range[i].second) {
1261       y_dims[i] = value_range[i].first;
1262       y_shape_range[i] = std::pair<int64_t, int64_t>(y_dims[i], y_dims[i]);
1263     } else {
1264       if (max_elements_count == -1) {
1265         // while max_elements_count = -1, y shape range i is always value_range[i].second;
1266         y_shape_range[i] = std::pair<int64_t, int64_t>(value_range[i].first, value_range[i].second);
1267         continue;
1268       }
1269       int64_t other_dims_range_lower_boundary = 1;
1270       for (size_t j = 0; j < y_rank; ++j) {
1271         if (i != j) {
1272           other_dims_range_lower_boundary *= value_range[j].first;
1273         }
1274       }
1275       int64_t cur_dim_range_max = (max_elements_count - 1) / other_dims_range_lower_boundary + 1;
1276       if (value_range[i].second > 0) {
1277         cur_dim_range_max = std::min(cur_dim_range_max, value_range[i].second);
1278       }
1279       y_shape_range[i] = std::pair<int64_t, int64_t>(value_range[i].first, cur_dim_range_max);
1280     }
1281   }
1282 }
1283 
CalculateMaxInputDims(const std::vector<std::pair<int64_t,int64_t>> & x_range,const Operator & op)1284 int64_t CalculateMaxInputDims(const std::vector<std::pair<int64_t, int64_t>> &x_range, const Operator &op) {
1285   int64_t max_input_dims = 1;
1286   for (const auto &pair : x_range) {
1287     if (pair.second < 0) {
1288       max_input_dims = -1;
1289       break;
1290     }
1291 
1292     if (array_ops::CheckInt64MulOverflow(max_input_dims, pair.second)) {
1293       max_input_dims *= pair.second;
1294     } else {
1295       max_input_dims = INT64_MAX;
1296       GE_OP_LOGW(TbeGetName(op).c_str(), "Range Infer out of int64 max!Do set int64max!");
1297       break;
1298     }
1299   }
1300   return max_input_dims;
1301 }
1302 }  // namespace array_ops
1303 
IsSliceUnknownShape(const std::vector<int64_t> & dim_vec,const int64_t & begin,const int64_t & end)1304 bool IsSliceUnknownShape(const std::vector<int64_t> &dim_vec, const int64_t &begin, const int64_t &end) {
1305   if (begin < 0 || end >= static_cast<int64_t>(dim_vec.size())) {
1306     GE_OP_LOGE("FlattenV2", "index is out of range");
1307     return false;
1308   }
1309   for (int64_t i = begin; i < end + 1; i++) {
1310     if (dim_vec[i] == -1) {
1311       return true;
1312     }
1313   }
1314   return false;
1315 }
1316 
1317 void SetOpInferDepends(Operator &op, const std::vector<std::string> &depend_names);
1318 }  // namespace ge
1319