1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_UTILS_CONVERT_UTILS_BASE_H_
18 #define MINDSPORE_CORE_UTILS_CONVERT_UTILS_BASE_H_
19
20 #include <limits>
21 #include <memory>
22 #include <vector>
23 #include <algorithm>
24 #include <string>
25
26 #include "utils/log_adapter.h"
27
28 namespace mindspore {
29 const size_t kGBToByte = 1024 << 20;
30 const size_t kMBToByte = 1024 << 10;
31
SizeToInt(size_t u)32 inline int SizeToInt(size_t u) {
33 if (u > static_cast<size_t>((std::numeric_limits<int>::max)())) {
34 MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int.";
35 }
36 return static_cast<int>(u);
37 }
38
SizeToUint(size_t u)39 inline uint32_t SizeToUint(size_t u) {
40 if (u > static_cast<size_t>((std::numeric_limits<uint32_t>::max)())) {
41 MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of uint32_t.";
42 }
43 return static_cast<uint32_t>(u);
44 }
45
SizeToLong(size_t u)46 inline int64_t SizeToLong(size_t u) {
47 if (u > static_cast<size_t>((std::numeric_limits<int64_t>::max)())) {
48 MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int64_t.";
49 }
50 return static_cast<int64_t>(u);
51 }
52
SizeToUlong(size_t u)53 inline uint64_t SizeToUlong(size_t u) { return static_cast<uint64_t>(u); }
54
IntToSize(int u)55 inline size_t IntToSize(int u) {
56 if (u < 0) {
57 MS_LOG(INTERNAL_EXCEPTION) << "The int value(" << u << ") is less than 0.";
58 }
59 return static_cast<size_t>(u);
60 }
61
LongToSizeClipNeg(int64_t u)62 inline size_t LongToSizeClipNeg(int64_t u) { return u < 0 ? 0 : static_cast<size_t>(u); }
63
LongToSize(int64_t u)64 inline size_t LongToSize(int64_t u) {
65 if (u < 0) {
66 MS_LOG(INTERNAL_EXCEPTION) << "The int64_t value(" << u << ") is less than 0.";
67 }
68 return static_cast<size_t>(u);
69 }
70
LongVecToSizeVec(const std::vector<int64_t> & vec)71 inline std::vector<size_t> LongVecToSizeVec(const std::vector<int64_t> &vec) {
72 std::vector<size_t> result;
73 result.reserve(vec.size());
74 (void)std::transform(vec.begin(), vec.end(), std::back_inserter(result), LongToSize);
75 return result;
76 }
77
LongToUint(int64_t u)78 inline uint32_t LongToUint(int64_t u) {
79 if (u < 0) {
80 MS_LOG(INTERNAL_EXCEPTION) << "The int64_t value(" << u << ") is less than 0.";
81 }
82 if (u > static_cast<int64_t>((std::numeric_limits<uint32_t>::max)())) {
83 MS_LOG(INTERNAL_EXCEPTION) << "The int64_t value(" << u << ") exceeds the maximum value of uint32_t.";
84 }
85 return static_cast<uint32_t>(u);
86 }
87
FloatToSize(float u)88 inline size_t FloatToSize(float u) {
89 if (u < 0) {
90 MS_LOG(INTERNAL_EXCEPTION) << "The float value(" << u << ") is less than 0.";
91 }
92
93 if (u > static_cast<float>((std::numeric_limits<size_t>::max)())) {
94 MS_LOG(INTERNAL_EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of size_t.";
95 }
96 return static_cast<size_t>(u);
97 }
IntToFloat(int32_t v)98 inline float IntToFloat(int32_t v) { return static_cast<float>(v); }
99
FloatToInt(float u)100 inline int FloatToInt(float u) {
101 if (u > static_cast<float>((std::numeric_limits<int>::max)())) {
102 MS_LOG(INTERNAL_EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of int.";
103 }
104 return static_cast<int>(u);
105 }
106
FloatToLong(float u)107 inline int FloatToLong(float u) {
108 if (u > static_cast<float>((std::numeric_limits<int64_t>::max)())) {
109 MS_LOG(INTERNAL_EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of int64_t.";
110 }
111 return static_cast<int64_t>(u);
112 }
113
DoubleToLong(double u)114 inline int64_t DoubleToLong(double u) {
115 if (u > static_cast<double>((std::numeric_limits<int64_t>::max)())) {
116 MS_LOG(INTERNAL_EXCEPTION) << "The double value(" << u << ") exceeds the maximum value of int64_t.";
117 }
118 return static_cast<int64_t>(u);
119 }
120
SizeToFloat(size_t v)121 inline float SizeToFloat(size_t v) { return static_cast<float>(v); }
122
LongToDouble(int64_t v)123 inline double LongToDouble(int64_t v) { return static_cast<double>(v); }
124
LongToFloat(int64_t v)125 inline float LongToFloat(int64_t v) { return static_cast<float>(v); }
126
FloatToDouble(float v)127 inline double FloatToDouble(float v) { return static_cast<double>(v); }
128
IntToUint(int32_t u)129 inline uint32_t IntToUint(int32_t u) {
130 if (u < 0) {
131 MS_LOG(INTERNAL_EXCEPTION) << "The int32_t value(" << u << ") is less than 0.";
132 }
133 return static_cast<uint32_t>(u);
134 }
135
UintToInt(uint32_t u)136 inline int32_t UintToInt(uint32_t u) {
137 if (u > static_cast<uint32_t>((std::numeric_limits<int32_t>::max)())) {
138 MS_LOG(INTERNAL_EXCEPTION) << "The uint32_t value(" << u << ") exceeds the maximum value of int32_t.";
139 }
140 return static_cast<int32_t>(u);
141 }
142
LongToUlong(int64_t u)143 inline uint64_t LongToUlong(int64_t u) {
144 if (u < 0) {
145 MS_LOG(INTERNAL_EXCEPTION) << "The int64_t value(" << u << ") is less than 0.";
146 }
147 return static_cast<uint64_t>(u);
148 }
149
LongToInt(int64_t u)150 inline int32_t LongToInt(int64_t u) {
151 if (u > static_cast<int64_t>((std::numeric_limits<int32_t>::max)())) {
152 MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int.";
153 }
154 return static_cast<int32_t>(u);
155 }
156
IntToLong(int32_t v)157 inline int64_t IntToLong(int32_t v) { return static_cast<int64_t>(v); }
158
UlongToLong(uint64_t u)159 inline int64_t UlongToLong(uint64_t u) {
160 if (u > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) {
161 MS_LOG(INTERNAL_EXCEPTION) << "The uint64_t value(" << u << ") exceeds the maximum value of int64_t.";
162 }
163 return static_cast<int64_t>(u);
164 }
165
UlongToUint(uint64_t u)166 inline unsigned int UlongToUint(uint64_t u) {
167 if (u > static_cast<uint64_t>((std::numeric_limits<unsigned int>::max)())) {
168 MS_LOG(INTERNAL_EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of unsigned int.";
169 }
170 return static_cast<unsigned int>(u);
171 }
172
IntMulWithOverflowCheck(int a,int b)173 inline int IntMulWithOverflowCheck(int a, int b) {
174 int out = a * b;
175 if (a != 0) {
176 bool overflow = ((out / a) != b);
177 if (overflow) {
178 MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
179 }
180 }
181 return out;
182 }
183
LongMulWithOverflowCheck(int64_t a,int64_t b)184 inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) {
185 int64_t out = a * b;
186 if (a != 0) {
187 bool overflow = ((out / a) != b);
188 if (overflow) {
189 MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
190 }
191 }
192 return out;
193 }
194
SizetMulWithOverflowCheck(size_t a,size_t b)195 inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {
196 size_t out = a * b;
197 if (a != 0) {
198 if ((out / a) != b) {
199 MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
200 }
201 }
202 return out;
203 }
204
Uint32tMulWithOverflowCheck(uint32_t a,uint32_t b)205 inline uint32_t Uint32tMulWithOverflowCheck(uint32_t a, uint32_t b) {
206 uint32_t out = a * b;
207 if (a != 0) {
208 if ((out / a) != b) {
209 MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
210 }
211 }
212 return out;
213 }
214
SizetAddWithOverflowCheck(size_t x,size_t y)215 inline size_t SizetAddWithOverflowCheck(size_t x, size_t y) {
216 size_t sum = x + y;
217 if (sum < x || sum < y) {
218 MS_LOG(EXCEPTION) << "Add: a(" << x << ") + b(" << y << ") result is overflow";
219 }
220 return sum;
221 }
222
Uint32tAddWithOverflowCheck(uint32_t x,uint32_t y)223 inline uint32_t Uint32tAddWithOverflowCheck(uint32_t x, uint32_t y) {
224 uint32_t sum = x + y;
225 if (sum < x || sum < y) {
226 MS_LOG(EXCEPTION) << "Add: a(" << x << ") + b(" << y << ") result is overflow";
227 }
228 return sum;
229 }
230
AddressOffset(void * address,size_t offset)231 inline uint8_t *AddressOffset(void *address, size_t offset) {
232 MS_EXCEPTION_IF_NULL(address);
233 return static_cast<uint8_t *>(address) + offset;
234 }
235
CalAddressOffset(void * dst_address,void * ori_address)236 inline size_t CalAddressOffset(void *dst_address, void *ori_address) {
237 MS_EXCEPTION_IF_NULL(dst_address);
238 MS_EXCEPTION_IF_NULL(ori_address);
239 return static_cast<uint8_t *>(dst_address) - static_cast<uint8_t *>(ori_address);
240 }
241
Convert2Int(const std::vector<size_t> & v)242 inline std::vector<int64_t> Convert2Int(const std::vector<size_t> &v) {
243 std::vector<int64_t> result;
244 (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
245 return result;
246 }
247
Convert2Long(const std::vector<size_t> & v)248 inline std::vector<int64_t> Convert2Long(const std::vector<size_t> &v) {
249 std::vector<int64_t> result;
250 (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToLong);
251 return result;
252 }
253
Convert2SizeT(const std::vector<int64_t> & v)254 inline std::vector<size_t> Convert2SizeT(const std::vector<int64_t> &v) {
255 std::vector<size_t> result;
256 (void)std::transform(v.begin(), v.end(), std::back_inserter(result), LongToSize);
257 return result;
258 }
259
Convert2SizeTClipNeg(const std::vector<int64_t> & v)260 inline std::vector<size_t> Convert2SizeTClipNeg(const std::vector<int64_t> &v) {
261 std::vector<size_t> result;
262 auto ConvertFunc = [](int64_t v) -> size_t { return v < 0 ? 0 : static_cast<int64_t>(v); };
263 (void)std::transform(v.begin(), v.end(), std::back_inserter(result), ConvertFunc);
264 return result;
265 }
266
ShapeVectorIsSame(const std::vector<int64_t> & shape,const std::vector<int64_t> & check_shape)267 inline bool ShapeVectorIsSame(const std::vector<int64_t> &shape, const std::vector<int64_t> &check_shape) {
268 if (shape.size() != check_shape.size()) {
269 return false;
270 } else {
271 for (size_t idx = 0; idx < shape.size(); ++idx) {
272 if (shape[idx] != check_shape[idx]) {
273 return false;
274 }
275 }
276 }
277 return true;
278 }
279
ShapeVectorToStr(const std::vector<int64_t> & shp)280 inline std::string ShapeVectorToStr(const std::vector<int64_t> &shp) {
281 std::ostringstream buffer;
282 bool f_begin = true;
283 buffer << "(";
284 for (auto &x : shp) {
285 if (!f_begin) {
286 buffer << ", ";
287 } else {
288 f_begin = false;
289 }
290 buffer << x;
291 }
292 buffer << ")";
293 return buffer.str();
294 }
295 } // namespace mindspore
296
297 #endif // MINDSPORE_CORE_UTILS_CONVERT_UTILS_BASE_H_
298