• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_UTIL_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_UTIL_H_
19 
20 #include <string>
21 #include <vector>
22 #include <algorithm>
23 #include <memory>
24 #include "securec/include/securec.h"
25 #include "ir/anf.h"
26 #include "ir/dtype.h"
27 #include "ir/tensor.h"
28 #include "transform/graph_ir/types.h"
29 #include "graph/tensor.h"
30 #include "utils/shape_utils.h"
31 
32 namespace mindspore {
33 namespace transform {
34 class TransformUtil {
35  public:
36   /*
37    * Parameters:
38    *     type: [MeDataType] the data type for ME tensor
39    * Return:
40    *     [GeDataType] the data type for ge tensor
41    * */
42   static std::vector<int64_t> ConvertIntToList(int64_t data, int size);
43 
44   /*
45    * Parameters:
46    *     type: [MeDataType] the data type for ME tensor
47    * Return:
48    *     [GeDataType] the data type for ge tensor
49    * */
50   static GeDataType ConvertDataType(const MeDataType &type);
51 
52   /*
53    * Parameters:
54    *     type: [string] the data format in ME op
55    * Return:
56    *     [GeFormat] the data format for ge tensor
57    * */
58   static GeFormat ConvertFormat(const std::string &format);
59 
60   /*
61    * Parameters:
62    *     type: [MeDataType] the data type for ME tensor
63    * Return:
64    *     [size_t] the buff size for the type in ME
65    * */
66   static size_t GetDataTypeSize(const MeDataType &type);
67 
68   /*
69    * Parameters:
70    *     tensor: [MeTensorPtr] the me tensor to get description from
71    *     format: [string] the data format in ME
72    *     is_input: [bool] whether the tensor is used as input, default:false
73    * Return:
74    *     [shared_ptr<GeTensorDesc>] the shared pointer of ge tensor description
75    * */
76   static std::shared_ptr<GeTensorDesc> GetGeTensorDesc(const ShapeVector &shape, const MeDataType &me_type,
77                                                        const std::string &format);
78 
79   /*
80    * Parameters:
81    *     tensor: [MeTensor] the data tensor in ME
82    *     format: [string] the data format in ME op
83    *     is_input: [bool] whether the tensor is used as input, default:false
84    * Return:
85    *     [GeTensor] the data tensor in GE
86    * */
87   static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format);
88 
89   /*
90    * Parameters:
91    *     me_tensors: [vector<MeTensorPtr>] the data tensors in ME
92    *     format: [string] the data format in ME op
93    * Return:
94    *     [std::vector<GeTensorPtr>] the data tensors in GE
95    * */
96   static std::vector<GeTensorPtr> ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
97                                                       const std::string &format);
98 
99   /*
100    * Parameters:
101    *     tensor: [GeTensor] the data tensor in GE
102    * Return:
103    *     [MeTensor] the data tensor in ME
104    * */
105   static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor);
106 
107   /*
108    * Parameters:
109    *     tensor: [GeTensor] the data tensor in GE
110    *     request_dims [ShapeVector] the output Me tensors must adjust to this shapes
111    * Return:
112    *     [MeTensor] the data tensor in ME
113    * */
114   static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const ShapeVector &request_dims);
115   /*
116    * Parameters:
117    *     ge_tensors: [std::vector<GeTensorPtr>] the data tensor in GE
118    *     request_dims [std::vector<ShapeVector>] the output Me tensors must adjust to this shapes
119    * Return:
120    *     [std::vector<MeTensorPtr>] the data tensor in ME
121    * */
122   static std::vector<MeTensorPtr> ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
123                                                    const std::vector<ShapeVector> &request_dims);
124   /*
125    * Parameters:
126    *     ge_tensors: [std::vector<GeTensorPtr>] the data tensor in GE
127    * Return:
128    *     [std::vector<MeTensorPtr>] the data tensor in ME
129    * */
130   static std::vector<MeTensorPtr> ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors);
131   /*
132    * Parameters:
133    *     ge_tensor: [GeTensor] the data tensor in GE
134    *     me_dims: [ShapeVector] the shape of created Me tensor
135    *     me_type: [TypeId] the type of created Me tensor
136    * Return:
137    *     [MeTensor] the data tensor in ME
138    * */
139   static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims, const TypeId &me_type);
140   /*
141    * Parameters:
142    *     type: [GeDataType] the ge tensor data type
143    * Return:
144    *     [MeDataType] the me tensor data type
145    * */
146   static MeDataType ConvertGeDataType(const GeDataType &type);
147 
148   /*
149    * Parameters:
150    *     me_dims: [ShapeVector] the me shape
151    * Return:
152    *     [GeShape] the ge shape
153    * */
154   static GeShape ConvertMeShape(const ShapeVector &me_dims);
155 
156   /*
157    * Parameters:
158    *     ge_shape: [GeShape] the ge shape
159    * Return:
160    *     [vector<int>] the me shape
161    * */
162   static ShapeVector ConvertGeShape(const GeShape &ge_shape);
163 
164   /* Function:
165    *     Convert GeShape to Me request shape, Support pattern:
166    *         {1, x, 1, 1} --> {x}
167    *         {x, 1, 1, 1} --> {x}
168    *         {x, x, 1, 1} --> {x, x}
169    *         {x, x, x, 1} --> {x, x, x}
170    *         {x, x, x, x} --> {x, x, x, x}
171    *      If unmatch upon patterns, return original ge dims
172    * Parameters:
173    *     ge_shape: [GeShape] the ge shape
174    *     request_dims: [vector<int>] request dims
175    * Return:
176    *     [vector<int>] the me shape
177    * */
178   static ShapeVector ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims);
179 
180   /*
181    * Parameters:
182    *     vec: [ShapeVector] the vector to print
183    * Return:
184    *     [string] value string
185    * */
186   template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
PrintVector(const std::vector<T> & vec)187   static std::string PrintVector(const std::vector<T> &vec) {
188     const int MAX_PRINT_NUM = 100;
189     std::stringstream ss;
190     ss << "{ ";
191     int i = 0;
192     for (auto it = vec.begin(); it != vec.end(); ++it) {
193       ss << std::to_string(*it) << ", ";
194       i++;
195       if (i >= MAX_PRINT_NUM) {
196         break;
197       }
198     }
199 
200     if (i >= MAX_PRINT_NUM) {
201       ss << "... to be continue}";
202     } else {
203       ss << "}";
204     }
205     return ss.str();
206   }
207 
208   /*
209    * Parameters:
210    *     ge_tensor: [GeTensorPtr] the ge tensor
211    * Return:
212    *     [stringstream] value string
213    * */
214   static std::string PrintGeTensor(const GeTensorPtr ge_tensor);
215 
216   /*
217    * Parameters:
218    *     data: [uint8_t *] the ge tensor data pointer
219    *     size: [size_t] the ge tensor data bytes
220    * Return:
221    *     [shared_ptr<std::vector<T>]  vector pointer
222    * */
223   template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
MakeVector(const uint8_t * const data,size_t size)224   static std::vector<T> MakeVector(const uint8_t *const data, size_t size) {
225     auto dest = std::vector<T>(size / sizeof(T));
226     if (data == nullptr) {
227       return dest;
228     }
229 
230     errno_t ret = memcpy_s(dest.data(), dest.size() * sizeof(T), data, size);
231     if (EOK != ret) {
232       return std::vector<T>();
233     }
234     return dest;
235   }
236 };
237 }  // namespace transform
238 }  // namespace mindspore
239 
240 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_UTIL_H_
241