1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_column.h"
18
19 #include "utils/ms_utils.h"
20 #include "minddata/mindrecord/include/common/shard_utils.h"
21 #include "minddata/mindrecord/include/shard_error.h"
22
23 namespace mindspore {
24 namespace mindrecord {
ShardColumn(const std::shared_ptr<ShardHeader> & shard_header,bool compress_integer)25 ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) {
26 auto first_schema = shard_header->GetSchemas()[0];
27 json schema_json = first_schema->GetSchema();
28 Init(schema_json, compress_integer);
29 }
30
ShardColumn(const json & schema_json,bool compress_integer)31 ShardColumn::ShardColumn(const json &schema_json, bool compress_integer) { Init(schema_json, compress_integer); }
32
Init(const json & schema_json,bool compress_integer)33 void ShardColumn::Init(const json &schema_json, bool compress_integer) {
34 auto schema = schema_json["schema"];
35 auto blob_fields = schema_json["blob_fields"];
36
37 bool has_integer_array = false;
38 for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
39 const std::string &column_name = it.key();
40 column_name_.push_back(column_name);
41
42 json it_value = it.value();
43
44 std::string str_type = it_value["type"];
45 column_data_type_.push_back(ColumnDataTypeMap.at(str_type));
46 if (it_value.find("shape") != it_value.end()) {
47 std::vector<int64_t> vec(it_value["shape"].size());
48 std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin());
49 column_shape_.push_back(vec);
50 if (str_type == "int32" || str_type == "int64") {
51 has_integer_array = true;
52 }
53 } else {
54 std::vector<int64_t> vec = {};
55 column_shape_.push_back(vec);
56 }
57 }
58
59 for (uint64_t i = 0; i < column_name_.size(); i++) {
60 column_name_id_[column_name_[i]] = i;
61 }
62
63 for (const auto &field : blob_fields) {
64 blob_column_.push_back(field);
65 }
66
67 for (uint64_t i = 0; i < blob_column_.size(); i++) {
68 blob_column_id_[blob_column_[i]] = i;
69 }
70
71 has_compress_blob_ = (compress_integer && has_integer_array);
72 num_blob_column_ = blob_column_.size();
73 }
74
GetColumnTypeByName(const std::string & column_name,ColumnDataType * column_data_type,uint64_t * column_data_type_size,std::vector<int64_t> * column_shape,ColumnCategory * column_category)75 Status ShardColumn::GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type,
76 uint64_t *column_data_type_size, std::vector<int64_t> *column_shape,
77 ColumnCategory *column_category) {
78 RETURN_UNEXPECTED_IF_NULL(column_data_type);
79 RETURN_UNEXPECTED_IF_NULL(column_data_type_size);
80 RETURN_UNEXPECTED_IF_NULL(column_shape);
81 RETURN_UNEXPECTED_IF_NULL(column_category);
82 // Skip if column not found
83 *column_category = CheckColumnName(column_name);
84 CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound, "Invalid data, column category is not found.");
85
86 // Get data type and size
87 auto column_id = column_name_id_[column_name];
88 *column_data_type = column_data_type_[column_id];
89 *column_data_type_size = ColumnDataTypeSize[*column_data_type];
90 *column_shape = column_shape_[column_id];
91 return Status::OK();
92 }
93
GetColumnValueByName(const std::string & column_name,const std::vector<uint8_t> & columns_blob,const json & columns_json,const unsigned char ** data,std::unique_ptr<unsigned char[]> * data_ptr,uint64_t * const n_bytes,ColumnDataType * column_data_type,uint64_t * column_data_type_size,std::vector<int64_t> * column_shape)94 Status ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
95 const json &columns_json, const unsigned char **data,
96 std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes,
97 ColumnDataType *column_data_type, uint64_t *column_data_type_size,
98 std::vector<int64_t> *column_shape) {
99 RETURN_UNEXPECTED_IF_NULL(column_data_type);
100 RETURN_UNEXPECTED_IF_NULL(column_data_type_size);
101 RETURN_UNEXPECTED_IF_NULL(column_shape);
102 // Skip if column not found
103 auto column_category = CheckColumnName(column_name);
104 CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound, "Invalid data, column category is not found.");
105 // Get data type and size
106 auto column_id = column_name_id_[column_name];
107 *column_data_type = column_data_type_[column_id];
108 *column_data_type_size = ColumnDataTypeSize[*column_data_type];
109 *column_shape = column_shape_[column_id];
110
111 // Retrieve value from json
112 if (column_category == ColumnInRaw) {
113 RETURN_IF_NOT_OK(GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes));
114 *data = reinterpret_cast<const unsigned char *>(data_ptr->get());
115 return Status::OK();
116 }
117
118 // Retrieve value from blob
119 RETURN_IF_NOT_OK(GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes));
120 if (*data == nullptr) {
121 *data = reinterpret_cast<const unsigned char *>(data_ptr->get());
122 }
123 return Status::OK();
124 }
125
GetColumnFromJson(const std::string & column_name,const json & columns_json,std::unique_ptr<unsigned char[]> * data_ptr,uint64_t * n_bytes)126 Status ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json,
127 std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) {
128 RETURN_UNEXPECTED_IF_NULL(n_bytes);
129 RETURN_UNEXPECTED_IF_NULL(data_ptr);
130 auto column_id = column_name_id_[column_name];
131 auto column_data_type = column_data_type_[column_id];
132
133 // Initialize num bytes
134 *n_bytes = ColumnDataTypeSize[column_data_type];
135 auto json_column_value = columns_json[column_name];
136 CHECK_FAIL_RETURN_UNEXPECTED(
137 json_column_value.is_string() || json_column_value.is_number(),
138 "Invalid data, column value [" + json_column_value.dump() + "] is not string or number.");
139 switch (column_data_type) {
140 case ColumnFloat32: {
141 return GetFloat<float>(data_ptr, json_column_value, false);
142 }
143 case ColumnFloat64: {
144 return GetFloat<double>(data_ptr, json_column_value, true);
145 }
146 case ColumnInt32: {
147 return GetInt<int32_t>(data_ptr, json_column_value);
148 }
149 case ColumnInt64: {
150 return GetInt<int64_t>(data_ptr, json_column_value);
151 }
152 default: {
153 // Convert string to c_str
154 std::string tmp_string;
155 if (json_column_value.is_string()) {
156 tmp_string = json_column_value.get<string>();
157 } else {
158 tmp_string = json_column_value.dump();
159 }
160 *n_bytes = tmp_string.size();
161 auto data = reinterpret_cast<const unsigned char *>(common::SafeCStr(tmp_string));
162 *data_ptr = std::make_unique<unsigned char[]>(*n_bytes);
163 for (uint32_t i = 0; i < *n_bytes; i++) {
164 (*data_ptr)[i] = *(data + i);
165 }
166 break;
167 }
168 }
169 return Status::OK();
170 }
171
172 template <typename T>
GetFloat(std::unique_ptr<unsigned char[]> * data_ptr,const json & json_column_value,bool use_double)173 Status ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value,
174 bool use_double) {
175 RETURN_UNEXPECTED_IF_NULL(data_ptr);
176 std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
177 if (json_column_value.is_number()) {
178 array_data[0] = json_column_value;
179 } else {
180 // Convert string to float
181 try {
182 if (use_double) {
183 array_data[0] = json_column_value.get<double>();
184 } else {
185 array_data[0] = json_column_value.get<float>();
186 }
187 } catch (json::exception &e) {
188 RETURN_STATUS_UNEXPECTED("Failed to convert [" + json_column_value.dump() + "] to float, " +
189 std::string(e.what()));
190 }
191 }
192
193 auto data = reinterpret_cast<const unsigned char *>(array_data.get());
194 *data_ptr = std::make_unique<unsigned char[]>(sizeof(T));
195 for (uint32_t i = 0; i < sizeof(T); i++) {
196 (*data_ptr)[i] = *(data + i);
197 }
198 return Status::OK();
199 }
200
201 template <typename T>
GetInt(std::unique_ptr<unsigned char[]> * data_ptr,const json & json_column_value)202 Status ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) {
203 RETURN_UNEXPECTED_IF_NULL(data_ptr);
204 std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
205 int64_t temp_value;
206 bool less_than_zero = false;
207
208 if (json_column_value.is_number_integer()) {
209 const json json_zero = 0;
210 if (json_column_value < json_zero) {
211 less_than_zero = true;
212 }
213 temp_value = json_column_value;
214 } else if (json_column_value.is_string()) {
215 std::string string_value = json_column_value;
216 try {
217 if (!string_value.empty() && string_value[0] == '-') {
218 temp_value = std::stoll(string_value);
219 less_than_zero = true;
220 } else {
221 temp_value = static_cast<int64_t>(std::stoull(string_value));
222 }
223 } catch (std::invalid_argument &e) {
224 RETURN_STATUS_UNEXPECTED("Failed to convert [" + string_value + "] to int, " + std::string(e.what()));
225 } catch (std::out_of_range &e) {
226 RETURN_STATUS_UNEXPECTED("Failed to convert [" + string_value + "] to int, " + std::string(e.what()));
227 }
228 } else {
229 RETURN_STATUS_UNEXPECTED("Invalid data, column value [" + json_column_value.dump() + "] is not string or number.");
230 }
231
232 if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
233 (!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
234 RETURN_STATUS_UNEXPECTED("Invalid data, column value [" + std::to_string(temp_value) + "] is out of range.");
235 }
236 array_data[0] = static_cast<T>(temp_value);
237
238 auto data = reinterpret_cast<const unsigned char *>(array_data.get());
239 *data_ptr = std::make_unique<unsigned char[]>(sizeof(T));
240 for (uint32_t i = 0; i < sizeof(T); i++) {
241 (*data_ptr)[i] = *(data + i);
242 }
243 return Status::OK();
244 }
245
GetColumnFromBlob(const std::string & column_name,const std::vector<uint8_t> & columns_blob,const unsigned char ** data,std::unique_ptr<unsigned char[]> * data_ptr,uint64_t * const n_bytes)246 Status ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
247 const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
248 uint64_t *const n_bytes) {
249 RETURN_UNEXPECTED_IF_NULL(data);
250 uint64_t offset_address = 0;
251 auto column_id = column_name_id_[column_name];
252 RETURN_IF_NOT_OK(GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address));
253 auto column_data_type = column_data_type_[column_id];
254 if (has_compress_blob_ && column_data_type == ColumnInt32) {
255 RETURN_IF_NOT_OK(UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address));
256 } else if (has_compress_blob_ && column_data_type == ColumnInt64) {
257 RETURN_IF_NOT_OK(UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address));
258 } else {
259 *data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address]));
260 }
261
262 return Status::OK();
263 }
264
CheckColumnName(const std::string & column_name)265 ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
266 auto it_column = column_name_id_.find(column_name);
267 if (it_column == column_name_id_.end()) {
268 return ColumnNotFound;
269 }
270 auto it_blob = blob_column_id_.find(column_name);
271 return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob;
272 }
273
CompressBlob(const std::vector<uint8_t> & blob,int64_t * compression_size)274 std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) {
275 // Skip if no compress columns
276 *compression_size = 0;
277 if (!CheckCompressBlob()) {
278 return blob;
279 }
280
281 std::vector<uint8_t> dst_blob;
282 uint64_t i_src = 0;
283 for (int64_t i = 0; i < num_blob_column_; i++) {
284 // Get column data type
285 auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]];
286 auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type;
287
288 // Compress and return is blob has 1 column only
289 if (num_blob_column_ == 1) {
290 dst_blob = CompressInt(blob, int_type);
291 *compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size());
292 return dst_blob;
293 }
294
295 // Just copy and continue if column dat type is not int32/int64
296 uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type);
297 if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) {
298 dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes);
299 i_src += kInt64Len + num_bytes;
300 continue;
301 }
302
303 // Get column slice in source blob
304 std::vector<uint8_t> blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes);
305 // Compress column
306 auto dst_blob_slice = CompressInt(blob_slice, int_type);
307 // Get new column size
308 auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type);
309 // Append new column size
310 dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end());
311 // Append new column data
312 dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end());
313 i_src += kInt64Len + num_bytes;
314 }
315 MS_LOG(DEBUG) << "Compress blob data from " << blob.size() << " to " << dst_blob.size() << ".";
316 *compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size());
317 return dst_blob;
318 }
319
CompressInt(const vector<uint8_t> & src_bytes,const IntegerType & int_type)320 vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) {
321 uint64_t i_size = kUnsignedOne << static_cast<uint8_t>(int_type);
322 // Get number of elements
323 uint64_t src_n_int = src_bytes.size() / i_size;
324 // Calculate bitmap size (bytes)
325 uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte;
326
327 // Initialize destination blob, more space than needed, will be resized
328 vector<uint8_t> dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0);
329
330 // Write number of elements to destination blob
331 vector<uint8_t> size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type);
332 for (uint64_t n = 0; n < kBytesOfColumnLen; n++) {
333 dst_bytes[n] = size_by_bytes[n];
334 }
335
336 // Write compressed int
337 uint64_t i_dst = kBytesOfColumnLen + bitmap_size;
338 for (uint64_t i = 0; i < src_n_int; i++) {
339 // Initialize destination data type
340 IntegerType dst_int_type = kInt8Type;
341 // Shift to next int position
342 uint64_t pos = i * (kUnsignedOne << static_cast<uint8_t>(int_type));
343 // Narrow down this int
344 int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type);
345
346 // Write this int to destination blob
347 uint64_t u_n = *reinterpret_cast<uint64_t *>(&i_n);
348 auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type);
349 for (uint64_t j = 0; j < (kUnsignedOne << static_cast<uint8_t>(dst_int_type)); j++) {
350 dst_bytes[i_dst++] = temp_bytes[j];
351 }
352
353 // Update date type in bit map
354 dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |=
355 (static_cast<uint8_t>(dst_int_type) << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte))));
356 }
357 // Resize destination blob
358 dst_bytes.resize(i_dst);
359 MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << ".";
360 return dst_bytes;
361 }
362
GetColumnAddressInBlock(const uint64_t & column_id,const std::vector<uint8_t> & columns_blob,uint64_t * num_bytes,uint64_t * shift_idx)363 Status ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
364 uint64_t *num_bytes, uint64_t *shift_idx) {
365 RETURN_UNEXPECTED_IF_NULL(num_bytes);
366 RETURN_UNEXPECTED_IF_NULL(shift_idx);
367 if (num_blob_column_ == 1) {
368 *num_bytes = columns_blob.size();
369 *shift_idx = 0;
370 return Status::OK();
371 }
372 auto blob_id = blob_column_id_[column_name_[column_id]];
373
374 for (int32_t i = 0; i < blob_id; i++) {
375 *shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type);
376 }
377 *num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type);
378
379 (*shift_idx) += kInt64Len;
380
381 return Status::OK();
382 }
383
384 template <typename T>
UncompressInt(const uint64_t & column_id,std::unique_ptr<unsigned char[]> * const data_ptr,const std::vector<uint8_t> & columns_blob,uint64_t * num_bytes,uint64_t shift_idx)385 Status ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr,
386 const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx) {
387 RETURN_UNEXPECTED_IF_NULL(data_ptr);
388 RETURN_UNEXPECTED_IF_NULL(num_bytes);
389 auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type);
390 *num_bytes = sizeof(T) * num_elements;
391
392 // Parse integer array
393 uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte;
394 auto array_data = std::make_unique<T[]>(num_elements);
395
396 for (uint64_t i = 0; i < num_elements; i++) {
397 uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte];
398 uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask;
399 auto mr_int_type = static_cast<IntegerType>(i_type);
400 int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type);
401 i_source += (kUnsignedOne << i_type);
402 array_data[i] = static_cast<T>(i64);
403 }
404
405 auto data = reinterpret_cast<const unsigned char *>(array_data.get());
406 *data_ptr = std::make_unique<unsigned char[]>(*num_bytes);
407 // field is none. for example: numpy is null
408 if (*num_bytes == 0) {
409 return Status::OK();
410 }
411 CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0, "Failed to copy data.");
412 return Status::OK();
413 }
414
BytesBigToUInt64(const std::vector<uint8_t> & bytes_array,const uint64_t & pos,const IntegerType & i_type)415 uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
416 const IntegerType &i_type) {
417 uint64_t result = 0;
418 for (uint64_t i = 0; i < (kUnsignedOne << static_cast<uint8_t>(i_type)); i++) {
419 result = (result << kBitsOfByte) + bytes_array[pos + i];
420 }
421 return result;
422 }
423
UIntToBytesBig(uint64_t value,const IntegerType & i_type)424 std::vector<uint8_t> ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) {
425 uint64_t n_bytes = kUnsignedOne << static_cast<uint8_t>(i_type);
426 std::vector<uint8_t> result(n_bytes, 0);
427 for (uint64_t i = 0; i < n_bytes; i++) {
428 result[n_bytes - 1 - i] = value & std::numeric_limits<uint8_t>::max();
429 value >>= kBitsOfByte;
430 }
431 return result;
432 }
433
UIntToBytesLittle(uint64_t value,const IntegerType & i_type)434 std::vector<uint8_t> ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) {
435 uint64_t n_bytes = kUnsignedOne << static_cast<uint8_t>(i_type);
436 std::vector<uint8_t> result(n_bytes, 0);
437 for (uint64_t i = 0; i < n_bytes; i++) {
438 result[i] = value & std::numeric_limits<uint8_t>::max();
439 value >>= kBitsOfByte;
440 }
441 return result;
442 }
443
BytesLittleToMinIntType(const std::vector<uint8_t> & bytes_array,const uint64_t & pos,const IntegerType & src_i_type,IntegerType * dst_i_type)444 int64_t ShardColumn::BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
445 const IntegerType &src_i_type, IntegerType *dst_i_type) {
446 uint64_t u_temp = 0;
447 for (uint64_t i = 0; i < (kUnsignedOne << static_cast<uint8_t>(src_i_type)); i++) {
448 u_temp = (u_temp << kBitsOfByte) +
449 bytes_array[pos + (kUnsignedOne << static_cast<uint8_t>(src_i_type)) - kUnsignedOne - i];
450 }
451
452 int64_t i_out;
453 switch (src_i_type) {
454 case kInt8Type: {
455 i_out = (int8_t)(u_temp & std::numeric_limits<uint8_t>::max());
456 break;
457 }
458 case kInt16Type: {
459 i_out = (int16_t)(u_temp & std::numeric_limits<uint16_t>::max());
460 break;
461 }
462 case kInt32Type: {
463 i_out = (int32_t)(u_temp & std::numeric_limits<uint32_t>::max());
464 break;
465 }
466 case kInt64Type: {
467 i_out = (int64_t)(u_temp & std::numeric_limits<uint64_t>::max());
468 break;
469 }
470 default: {
471 i_out = 0;
472 }
473 }
474
475 if (!dst_i_type) {
476 return i_out;
477 }
478
479 if (i_out >= static_cast<int64_t>(std::numeric_limits<int8_t>::min()) &&
480 i_out <= static_cast<int64_t>(std::numeric_limits<int8_t>::max())) {
481 *dst_i_type = kInt8Type;
482 } else if (i_out >= static_cast<int64_t>(std::numeric_limits<int16_t>::min()) &&
483 i_out <= static_cast<int64_t>(std::numeric_limits<int16_t>::max())) {
484 *dst_i_type = kInt16Type;
485 } else if (i_out >= static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
486 i_out <= static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
487 *dst_i_type = kInt32Type;
488 } else {
489 *dst_i_type = kInt64Type;
490 }
491 return i_out;
492 }
493 } // namespace mindrecord
494 } // namespace mindspore
495