• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_UTILS_CONVERT_UTILS_BASE_H_
18 #define MINDSPORE_CORE_UTILS_CONVERT_UTILS_BASE_H_
19 
20 #include <limits>
21 #include <memory>
22 #include <vector>
23 #include <algorithm>
24 #include <string>
25 
26 #include "utils/log_adapter.h"
27 
28 namespace mindspore {
29 const size_t kGBToByte = 1024 << 20;
30 const size_t kMBToByte = 1024 << 10;
31 
SizeToInt(size_t u)32 inline int SizeToInt(size_t u) {
33   if (u > static_cast<size_t>((std::numeric_limits<int>::max)())) {
34     MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int.";
35   }
36   return static_cast<int>(u);
37 }
38 
SizeToUint(size_t u)39 inline uint32_t SizeToUint(size_t u) {
40   if (u > static_cast<size_t>((std::numeric_limits<uint32_t>::max)())) {
41     MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of uint32_t.";
42   }
43   return static_cast<uint32_t>(u);
44 }
45 
SizeToLong(size_t u)46 inline int64_t SizeToLong(size_t u) {
47   if (u > static_cast<size_t>((std::numeric_limits<int64_t>::max)())) {
48     MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int64_t.";
49   }
50   return static_cast<int64_t>(u);
51 }
52 
SizeToUlong(size_t u)53 inline uint64_t SizeToUlong(size_t u) { return static_cast<uint64_t>(u); }
54 
IntToSize(int u)55 inline size_t IntToSize(int u) {
56   if (u < 0) {
57     MS_LOG(INTERNAL_EXCEPTION) << "The int value(" << u << ") is less than 0.";
58   }
59   return static_cast<size_t>(u);
60 }
61 
LongToSizeClipNeg(int64_t u)62 inline size_t LongToSizeClipNeg(int64_t u) { return u < 0 ? 0 : static_cast<size_t>(u); }
63 
LongToSize(int64_t u)64 inline size_t LongToSize(int64_t u) {
65   if (u < 0) {
66     MS_LOG(INTERNAL_EXCEPTION) << "The int64_t value(" << u << ") is less than 0.";
67   }
68   return static_cast<size_t>(u);
69 }
70 
LongVecToSizeVec(const std::vector<int64_t> & vec)71 inline std::vector<size_t> LongVecToSizeVec(const std::vector<int64_t> &vec) {
72   std::vector<size_t> result;
73   result.reserve(vec.size());
74   (void)std::transform(vec.begin(), vec.end(), std::back_inserter(result), LongToSize);
75   return result;
76 }
77 
LongToUint(int64_t u)78 inline uint32_t LongToUint(int64_t u) {
79   if (u < 0) {
80     MS_LOG(INTERNAL_EXCEPTION) << "The int64_t value(" << u << ") is less than 0.";
81   }
82   if (u > static_cast<int64_t>((std::numeric_limits<uint32_t>::max)())) {
83     MS_LOG(INTERNAL_EXCEPTION) << "The int64_t value(" << u << ") exceeds the maximum value of uint32_t.";
84   }
85   return static_cast<uint32_t>(u);
86 }
87 
FloatToSize(float u)88 inline size_t FloatToSize(float u) {
89   if (u < 0) {
90     MS_LOG(INTERNAL_EXCEPTION) << "The float value(" << u << ") is less than 0.";
91   }
92 
93   if (u > static_cast<float>((std::numeric_limits<size_t>::max)())) {
94     MS_LOG(INTERNAL_EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of size_t.";
95   }
96   return static_cast<size_t>(u);
97 }
IntToFloat(int32_t v)98 inline float IntToFloat(int32_t v) { return static_cast<float>(v); }
99 
FloatToInt(float u)100 inline int FloatToInt(float u) {
101   if (u > static_cast<float>((std::numeric_limits<int>::max)())) {
102     MS_LOG(INTERNAL_EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of int.";
103   }
104   return static_cast<int>(u);
105 }
106 
FloatToLong(float u)107 inline int FloatToLong(float u) {
108   if (u > static_cast<float>((std::numeric_limits<int64_t>::max)())) {
109     MS_LOG(INTERNAL_EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of int64_t.";
110   }
111   return static_cast<int64_t>(u);
112 }
113 
DoubleToLong(double u)114 inline int64_t DoubleToLong(double u) {
115   if (u > static_cast<double>((std::numeric_limits<int64_t>::max)())) {
116     MS_LOG(INTERNAL_EXCEPTION) << "The double value(" << u << ") exceeds the maximum value of int64_t.";
117   }
118   return static_cast<int64_t>(u);
119 }
120 
SizeToFloat(size_t v)121 inline float SizeToFloat(size_t v) { return static_cast<float>(v); }
122 
LongToDouble(int64_t v)123 inline double LongToDouble(int64_t v) { return static_cast<double>(v); }
124 
LongToFloat(int64_t v)125 inline float LongToFloat(int64_t v) { return static_cast<float>(v); }
126 
FloatToDouble(float v)127 inline double FloatToDouble(float v) { return static_cast<double>(v); }
128 
IntToUint(int32_t u)129 inline uint32_t IntToUint(int32_t u) {
130   if (u < 0) {
131     MS_LOG(INTERNAL_EXCEPTION) << "The int32_t value(" << u << ") is less than 0.";
132   }
133   return static_cast<uint32_t>(u);
134 }
135 
UintToInt(uint32_t u)136 inline int32_t UintToInt(uint32_t u) {
137   if (u > static_cast<uint32_t>((std::numeric_limits<int32_t>::max)())) {
138     MS_LOG(INTERNAL_EXCEPTION) << "The uint32_t value(" << u << ") exceeds the maximum value of int32_t.";
139   }
140   return static_cast<int32_t>(u);
141 }
142 
LongToUlong(int64_t u)143 inline uint64_t LongToUlong(int64_t u) {
144   if (u < 0) {
145     MS_LOG(INTERNAL_EXCEPTION) << "The int64_t value(" << u << ") is less than 0.";
146   }
147   return static_cast<uint64_t>(u);
148 }
149 
LongToInt(int64_t u)150 inline int32_t LongToInt(int64_t u) {
151   if (u > static_cast<int64_t>((std::numeric_limits<int32_t>::max)())) {
152     MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int.";
153   }
154   return static_cast<int32_t>(u);
155 }
156 
IntToLong(int32_t v)157 inline int64_t IntToLong(int32_t v) { return static_cast<int64_t>(v); }
158 
UlongToLong(uint64_t u)159 inline int64_t UlongToLong(uint64_t u) {
160   if (u > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) {
161     MS_LOG(INTERNAL_EXCEPTION) << "The uint64_t value(" << u << ") exceeds the maximum value of int64_t.";
162   }
163   return static_cast<int64_t>(u);
164 }
165 
UlongToUint(uint64_t u)166 inline unsigned int UlongToUint(uint64_t u) {
167   if (u > static_cast<uint64_t>((std::numeric_limits<unsigned int>::max)())) {
168     MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of unsigned int.";
169   }
170   return static_cast<unsigned int>(u);
171 }
172 
IntMulWithOverflowCheck(int a,int b)173 inline int IntMulWithOverflowCheck(int a, int b) {
174   int out = a * b;
175   if (a != 0) {
176     bool overflow = ((out / a) != b);
177     if (overflow) {
178       MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
179     }
180   }
181   return out;
182 }
183 
LongMulWithOverflowCheck(int64_t a,int64_t b)184 inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) {
185   int64_t out = a * b;
186   if (a != 0) {
187     bool overflow = ((out / a) != b);
188     if (overflow) {
189       MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
190     }
191   }
192   return out;
193 }
194 
SizetMulWithOverflowCheck(size_t a,size_t b)195 inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {
196   size_t out = a * b;
197   if (a != 0) {
198     if ((out / a) != b) {
199       MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
200     }
201   }
202   return out;
203 }
204 
Uint32tMulWithOverflowCheck(uint32_t a,uint32_t b)205 inline uint32_t Uint32tMulWithOverflowCheck(uint32_t a, uint32_t b) {
206   uint32_t out = a * b;
207   if (a != 0) {
208     if ((out / a) != b) {
209       MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
210     }
211   }
212   return out;
213 }
214 
SizetAddWithOverflowCheck(size_t x,size_t y)215 inline size_t SizetAddWithOverflowCheck(size_t x, size_t y) {
216   size_t sum = x + y;
217   if (sum < x || sum < y) {
218     MS_LOG(EXCEPTION) << "Add: a(" << x << ") + b(" << y << ") result is overflow";
219   }
220   return sum;
221 }
222 
Uint32tAddWithOverflowCheck(uint32_t x,uint32_t y)223 inline uint32_t Uint32tAddWithOverflowCheck(uint32_t x, uint32_t y) {
224   uint32_t sum = x + y;
225   if (sum < x || sum < y) {
226     MS_LOG(EXCEPTION) << "Add: a(" << x << ") + b(" << y << ") result is overflow";
227   }
228   return sum;
229 }
230 
AddressOffset(void * address,size_t offset)231 inline uint8_t *AddressOffset(void *address, size_t offset) {
232   MS_EXCEPTION_IF_NULL(address);
233   return static_cast<uint8_t *>(address) + offset;
234 }
235 
CalAddressOffset(void * dst_address,void * ori_address)236 inline size_t CalAddressOffset(void *dst_address, void *ori_address) {
237   MS_EXCEPTION_IF_NULL(dst_address);
238   MS_EXCEPTION_IF_NULL(ori_address);
239   return static_cast<uint8_t *>(dst_address) - static_cast<uint8_t *>(ori_address);
240 }
241 
Convert2Int(const std::vector<size_t> & v)242 inline std::vector<int64_t> Convert2Int(const std::vector<size_t> &v) {
243   std::vector<int64_t> result;
244   (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
245   return result;
246 }
247 
Convert2Long(const std::vector<size_t> & v)248 inline std::vector<int64_t> Convert2Long(const std::vector<size_t> &v) {
249   std::vector<int64_t> result;
250   (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToLong);
251   return result;
252 }
253 
Convert2SizeT(const std::vector<int64_t> & v)254 inline std::vector<size_t> Convert2SizeT(const std::vector<int64_t> &v) {
255   std::vector<size_t> result;
256   (void)std::transform(v.begin(), v.end(), std::back_inserter(result), LongToSize);
257   return result;
258 }
259 
Convert2SizeTClipNeg(const std::vector<int64_t> & v)260 inline std::vector<size_t> Convert2SizeTClipNeg(const std::vector<int64_t> &v) {
261   std::vector<size_t> result;
262   auto ConvertFunc = [](int64_t v) -> size_t { return v < 0 ? 0 : static_cast<int64_t>(v); };
263   (void)std::transform(v.begin(), v.end(), std::back_inserter(result), ConvertFunc);
264   return result;
265 }
266 
ShapeVectorIsSame(const std::vector<int64_t> & shape,const std::vector<int64_t> & check_shape)267 inline bool ShapeVectorIsSame(const std::vector<int64_t> &shape, const std::vector<int64_t> &check_shape) {
268   if (shape.size() != check_shape.size()) {
269     return false;
270   } else {
271     for (size_t idx = 0; idx < shape.size(); ++idx) {
272       if (shape[idx] != check_shape[idx]) {
273         return false;
274       }
275     }
276   }
277   return true;
278 }
279 
ShapeVectorToStr(const std::vector<int64_t> & shp)280 inline std::string ShapeVectorToStr(const std::vector<int64_t> &shp) {
281   std::ostringstream buffer;
282   bool f_begin = true;
283   buffer << "(";
284   for (auto &x : shp) {
285     if (!f_begin) {
286       buffer << ", ";
287     } else {
288       f_begin = false;
289     }
290     buffer << x;
291   }
292   buffer << ")";
293   return buffer.str();
294 }
295 }  // namespace mindspore
296 
297 #endif  // MINDSPORE_CORE_UTILS_CONVERT_UTILS_BASE_H_
298