1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/kernels/data/data_utils.h"
18
19 #include <algorithm>
20 #include <limits>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "minddata/dataset/core/data_type.h"
26 #ifdef ENABLE_PYTHON
27 #include "minddata/dataset/core/pybind_support.h"
28 #endif
29 #include "minddata/dataset/core/tensor.h"
30 #include "minddata/dataset/core/tensor_shape.h"
31 #include "minddata/dataset/include/dataset/constants.h"
32 #include "minddata/dataset/kernels/data/type_cast_op.h"
33 #include "minddata/dataset/util/status.h"
34
35 namespace mindspore {
36 namespace dataset {
37 template <typename T>
OneHotEncodingImpl(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,dsize_t num_classes,int64_t index,double smoothing_rate)38 Status OneHotEncodingImpl(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
39 int64_t index, double smoothing_rate) {
40 RETURN_UNEXPECTED_IF_NULL(output);
41 T class_idx;
42 if (input->Rank() == 0) {
43 RETURN_IF_NOT_OK(input->GetItemAt<T>(&class_idx, {}));
44 } else {
45 RETURN_IF_NOT_OK(input->GetItemAt<T>(&class_idx, {index}));
46 }
47 if (class_idx >= static_cast<int64_t>(num_classes)) {
48 RETURN_STATUS_UNEXPECTED("OneHot: index values should not bigger than num classes: " + std::to_string(num_classes) +
49 ", but got: " + std::to_string(class_idx));
50 }
51 if ((*output)->type().IsInt()) {
52 RETURN_IF_NOT_OK((*output)->SetItemAt<T>({index, static_cast<dsize_t>(class_idx)}, 1));
53 } else if ((*output)->type() == DataType::DE_FLOAT64) {
54 auto itr = (*output)->begin<double>();
55 auto end = (*output)->end<double>();
56 double val = smoothing_rate / static_cast<double>(num_classes);
57 for (; itr != end; itr++) {
58 *itr = val;
59 }
60 double value = (1. - smoothing_rate) + val;
61 RETURN_IF_NOT_OK((*output)->SetItemAt<double>({index, static_cast<dsize_t>(class_idx)}, value));
62 } else {
63 RETURN_STATUS_UNEXPECTED("OneHot: signed input case only supports signed int as input but got:" +
64 input->type().ToString());
65 }
66 return Status::OK();
67 }
68
OneHotEncoding(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,dsize_t num_classes,double smoothing_rate)69 Status OneHotEncoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
70 double smoothing_rate) {
71 RETURN_UNEXPECTED_IF_NULL(output);
72 input->Squeeze();
73
74 CHECK_FAIL_RETURN_UNEXPECTED(input->Rank() <= 1,
75 "OneHot: Only support scalar or 1D input, got rank: " + std::to_string(input->Rank()));
76 CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsInt(),
77 "OneHot: Only support input of int type, but got: " + input->type().ToString());
78 try {
79 dsize_t num_elements = 1;
80 if (input->Rank() == 1) {
81 num_elements = input->shape()[0];
82 }
83 TensorShape out_shape({num_elements, num_classes});
84 std::shared_ptr<Tensor> out;
85 mindspore::dataset::DataType type = input->type();
86 if (abs(smoothing_rate) > std::numeric_limits<double>::epsilon()) {
87 type = DataType(DataType::DE_FLOAT64);
88 }
89 RETURN_IF_NOT_OK(Tensor::CreateEmpty(out_shape, type, &out));
90 RETURN_IF_NOT_OK(out->Zero());
91 for (dsize_t i = 0; i < num_elements; ++i) {
92 if (input->type() == DataType::DE_INT8) {
93 RETURN_IF_NOT_OK(OneHotEncodingImpl<int8_t>(input, &out, num_classes, i, smoothing_rate));
94 } else if (input->type() == DataType::DE_INT16) {
95 RETURN_IF_NOT_OK(OneHotEncodingImpl<int16_t>(input, &out, num_classes, i, smoothing_rate));
96 } else if (input->type() == DataType::DE_INT32) {
97 RETURN_IF_NOT_OK(OneHotEncodingImpl<int32_t>(input, &out, num_classes, i, smoothing_rate));
98 } else if (input->type() == DataType::DE_INT64) {
99 RETURN_IF_NOT_OK(OneHotEncodingImpl<int64_t>(input, &out, num_classes, i, smoothing_rate));
100 } else if (input->type() == DataType::DE_UINT8) {
101 RETURN_IF_NOT_OK(OneHotEncodingImpl<uint8_t>(input, &out, num_classes, i, smoothing_rate));
102 } else if (input->type() == DataType::DE_UINT16) {
103 RETURN_IF_NOT_OK(OneHotEncodingImpl<uint16_t>(input, &out, num_classes, i, smoothing_rate));
104 } else if (input->type() == DataType::DE_UINT32) {
105 RETURN_IF_NOT_OK(OneHotEncodingImpl<uint32_t>(input, &out, num_classes, i, smoothing_rate));
106 } else if (input->type() == DataType::DE_UINT64) {
107 RETURN_IF_NOT_OK(OneHotEncodingImpl<uint64_t>(input, &out, num_classes, i, smoothing_rate));
108 } else {
109 RETURN_STATUS_UNEXPECTED("OneHot: OneHot only supports input of int type, but got:" + input->type().ToString());
110 }
111 }
112 out->Squeeze();
113 *output = out;
114 return Status::OK();
115 } catch (const std::exception &e) {
116 RETURN_STATUS_UNEXPECTED("OneHot: Unexpected errors occurred: " + std::string(e.what()));
117 }
118 }
119
FillHelper(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * out,const std::shared_ptr<Tensor> & fill_output)120 Status FillHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out,
121 const std::shared_ptr<Tensor> &fill_output) {
122 RETURN_UNEXPECTED_IF_NULL(out);
123 const DataType &input_type = input->type();
124 const TensorShape &input_shape = input->shape();
125 switch (input_type.value()) {
126 case DataType::DE_BOOL: {
127 bool value = false;
128 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
129 RETURN_IF_NOT_OK((*out)->Fill<bool>(value));
130 break;
131 }
132 case DataType::DE_INT8: {
133 int8_t value = 0;
134 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
135 RETURN_IF_NOT_OK((*out)->Fill<int8_t>(value));
136 break;
137 }
138 case DataType::DE_UINT8: {
139 uint8_t value = 0;
140 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
141 RETURN_IF_NOT_OK((*out)->Fill<uint8_t>(value));
142 break;
143 }
144 case DataType::DE_UINT16: {
145 uint16_t value = 0;
146 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
147 RETURN_IF_NOT_OK((*out)->Fill<uint16_t>(value));
148 break;
149 }
150 case DataType::DE_INT16: {
151 int16_t value = 0;
152 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
153 RETURN_IF_NOT_OK((*out)->Fill<int16_t>(value));
154 break;
155 }
156 case DataType::DE_UINT32: {
157 uint32_t value = 0;
158 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
159 RETURN_IF_NOT_OK((*out)->Fill<uint32_t>(value));
160 break;
161 }
162 case DataType::DE_INT32: {
163 int32_t value = 0;
164 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
165 RETURN_IF_NOT_OK((*out)->Fill<int32_t>(value));
166 break;
167 }
168 case DataType::DE_UINT64: {
169 uint64_t value = 0;
170 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
171 RETURN_IF_NOT_OK((*out)->Fill<uint64_t>(value));
172 break;
173 }
174 case DataType::DE_INT64: {
175 int64_t value = 0;
176 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
177 RETURN_IF_NOT_OK((*out)->Fill<int64_t>(value));
178 break;
179 }
180 case DataType::DE_FLOAT16: {
181 int64_t value = 0;
182 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
183 RETURN_IF_NOT_OK((*out)->Fill<float>(value));
184 break;
185 }
186 case DataType::DE_FLOAT32: {
187 float value = 0.;
188 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
189 RETURN_IF_NOT_OK((*out)->Fill<float>(value));
190 break;
191 }
192 case DataType::DE_FLOAT64: {
193 double value = 0.;
194 RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
195 RETURN_IF_NOT_OK((*out)->Fill<double>(value));
196 break;
197 }
198 case DataType::DE_STRING:
199 case DataType::DE_BYTES: {
200 std::vector<std::string> strings;
201 std::string_view fill_string_view;
202 RETURN_IF_NOT_OK(fill_output->GetItemAt(&fill_string_view, {}));
203 std::string fill_string = std::string(fill_string_view);
204 for (int i = 0; i < input_shape.NumOfElements(); i++) {
205 (void)strings.emplace_back(fill_string);
206 }
207 RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, input_shape, DataType(input_type.value()), out));
208 break;
209 }
210 default:
211 RETURN_STATUS_UNEXPECTED("Fill: invalid data type, filling values into tensors of type " + input_type.ToString() +
212 " is not supported.");
213 break;
214 }
215 return Status::OK();
216 }
217
Fill(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,const std::shared_ptr<Tensor> & fill_value)218 Status Fill(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
219 const std::shared_ptr<Tensor> &fill_value) {
220 RETURN_UNEXPECTED_IF_NULL(output);
221 const DataType &fill_type = fill_value->type();
222 const DataType &input_type = input->type();
223 const TensorShape &input_shape = input->shape();
224
225 if (fill_type.IsString() || input_type.IsString()) {
226 CHECK_FAIL_RETURN_UNEXPECTED(
227 fill_type == input_type,
228 "Fill: fill_value and the input tensor must be of the same data type when involving strings "
229 "or bytes, but got fill_value data type " +
230 fill_type.ToString() + " and input tensor data type " + input_type.ToString());
231 }
232
233 CHECK_FAIL_RETURN_UNEXPECTED(
234 fill_value->shape() == TensorShape({}),
235 "Fill: the shape of fill_value is not a scalar, got shape:" + fill_value->shape().ToString());
236
237 std::shared_ptr<Tensor> out, fill_output;
238
239 if (input_type.IsNumeric() && fill_type.IsNumeric() && input_type != fill_type) {
240 RETURN_IF_NOT_OK(TypeCast(fill_value, &fill_output, input_type));
241 } else {
242 fill_output = fill_value;
243 }
244
245 if (input_type.IsNumeric()) {
246 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input_shape, input_type, &out));
247 }
248 RETURN_IF_NOT_OK(FillHelper(input, &out, fill_output));
249 *output = out;
250 return Status::OK();
251 }
252
253 template <typename FROM, typename TO>
Cast(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output)254 void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
255 auto in_itr = input->begin<FROM>();
256 auto out_itr = (*output)->begin<TO>();
257 auto out_end = (*output)->end<TO>();
258
259 for (; out_itr != out_end; ++in_itr, ++out_itr) {
260 *out_itr = static_cast<TO>(*in_itr);
261 }
262 }
263
264 template <typename T>
CastFrom(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output)265 void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
266 switch ((*output)->type().value()) {
267 case DataType::DE_BOOL:
268 Cast<T, bool>(input, output);
269 break;
270 case DataType::DE_INT8:
271 Cast<T, int8_t>(input, output);
272 break;
273 case DataType::DE_UINT8:
274 Cast<T, uint8_t>(input, output);
275 break;
276 case DataType::DE_INT16:
277 Cast<T, int16_t>(input, output);
278 break;
279 case DataType::DE_UINT16:
280 Cast<T, uint16_t>(input, output);
281 break;
282 case DataType::DE_INT32:
283 Cast<T, int32_t>(input, output);
284 break;
285 case DataType::DE_UINT32:
286 Cast<T, uint32_t>(input, output);
287 break;
288 case DataType::DE_INT64:
289 Cast<T, int64_t>(input, output);
290 break;
291 case DataType::DE_UINT64:
292 Cast<T, uint64_t>(input, output);
293 break;
294 case DataType::DE_FLOAT16:
295 Cast<T, float16>(input, output);
296 break;
297 case DataType::DE_FLOAT32:
298 Cast<T, float>(input, output);
299 break;
300 case DataType::DE_FLOAT64:
301 Cast<T, double>(input, output);
302 break;
303 case DataType::DE_UNKNOWN:
304 default:
305 MS_LOG(ERROR) << "TypeCast: Casting to type " + (*output)->type().ToString() + " is valid, supported datatype: " +
306 "[bool, int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64].";
307 break;
308 }
309 }
310
311 // Type cast operator
TypeCast(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,const DataType & data_type)312 Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const DataType &data_type) {
313 RETURN_UNEXPECTED_IF_NULL(output);
314 switch (input->type().value()) {
315 case DataType::DE_BOOL:
316 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
317 CastFrom<bool>(input, output);
318 break;
319 case DataType::DE_INT8:
320 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
321 CastFrom<int8_t>(input, output);
322 break;
323 case DataType::DE_UINT8:
324 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
325 CastFrom<uint8_t>(input, output);
326 break;
327 case DataType::DE_INT16:
328 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
329 CastFrom<int16_t>(input, output);
330 break;
331 case DataType::DE_UINT16:
332 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
333 CastFrom<uint16_t>(input, output);
334 break;
335 case DataType::DE_INT32:
336 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
337 CastFrom<int32_t>(input, output);
338 break;
339 case DataType::DE_UINT32:
340 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
341 CastFrom<uint32_t>(input, output);
342 break;
343 case DataType::DE_INT64:
344 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
345 CastFrom<int64_t>(input, output);
346 break;
347 case DataType::DE_UINT64:
348 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
349 CastFrom<uint64_t>(input, output);
350 break;
351 case DataType::DE_FLOAT16:
352 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
353 CastFrom<float16>(input, output);
354 break;
355 case DataType::DE_FLOAT32:
356 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
357 CastFrom<float>(input, output);
358 break;
359 case DataType::DE_FLOAT64:
360 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
361 CastFrom<double>(input, output);
362 break;
363 case DataType::DE_STRING:
364 if (data_type == DataType::DE_STRING) {
365 *output = input;
366 break;
367 } else {
368 RETURN_STATUS_UNEXPECTED("TypeCast: TypeCast does not support cast from string to " + data_type.ToString());
369 }
370 case DataType::DE_BYTES:
371 if (data_type == DataType::DE_BYTES) {
372 *output = input;
373 break;
374 } else {
375 RETURN_STATUS_UNEXPECTED("TypeCast: TypeCast does not support cast from bytes to " + data_type.ToString());
376 }
377 case DataType::DE_UNKNOWN:
378 default:
379 // sanity check, unreachable code.
380 RETURN_STATUS_UNEXPECTED(
381 "TypeCast: Typecast does not support input with type " + input->type().ToString() + ", supported datatype: " +
382 "[bool, int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64].");
383 }
384 return Status::OK();
385 }
386
ToFloat16(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output)387 Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
388 // initiate new tensor for type cast
389 DataType new_type = DataType("float16");
390 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), new_type, output));
391
392 auto in_itr = input->begin<float>();
393 auto in_end = input->end<float>();
394 auto out_itr = (*output)->begin<float16>();
395 auto out_end = (*output)->end<float16>();
396
397 for (; (in_itr != in_end) && (out_itr != out_end); ++in_itr, ++out_itr) {
398 float element = *in_itr;
399 float float16_max = static_cast<float>(std::numeric_limits<float16>::max());
400 float float16_min = static_cast<float>(std::numeric_limits<float16>::lowest());
401 if (element > float16_max || element < float16_min) {
402 RETURN_STATUS_UNEXPECTED("ToFloat16: value " + std::to_string(element) +
403 "in input data is outside of valid float16 range [" + std::to_string(float16_max) +
404 ", " + std::to_string(float16_min) + "].");
405 }
406
407 *out_itr = float16(*in_itr);
408 }
409
410 return Status::OK();
411 }
412
PadEnd(const std::shared_ptr<Tensor> & src,std::shared_ptr<Tensor> * dst,const std::vector<dsize_t> & pad_shape,const std::shared_ptr<Tensor> & pad_val)413 Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
414 const std::shared_ptr<Tensor> &pad_val) {
415 if (pad_val == nullptr) {
416 if (src->type().IsNumeric()) {
417 return PadEndNumeric(src, dst, pad_shape, 0.0);
418 } else {
419 return PadEndString(src, dst, pad_shape, "");
420 }
421 }
422 CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
423 "PadEnd: can not pad numeric and string tensors together, but got: " +
424 pad_val->type().ToString() + " and " + src->type().ToString() + ".");
425 if (pad_val->type().IsNumeric()) {
426 std::shared_ptr<Tensor> float_pad_value;
427 RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32)));
428 float val = 0.;
429 RETURN_IF_NOT_OK(float_pad_value->GetItemAt<float>(&val, {}));
430 return PadEndNumeric(src, dst, pad_shape, val);
431 } else {
432 CHECK_FAIL_RETURN_UNEXPECTED(src->type() == pad_val->type(),
433 "PadEnd: can not pad string and byte tensors together, but got: " +
434 pad_val->type().ToString() + " and " + src->type().ToString() + ".");
435 std::string_view val;
436 RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {}));
437 return PadEndString(src, dst, pad_shape, std::string(val));
438 }
439 }
440
PadEndNumeric(const std::shared_ptr<Tensor> & src,std::shared_ptr<Tensor> * dst,const std::vector<dsize_t> & pad_shape,float pad_val)441 Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
442 const std::vector<dsize_t> &pad_shape, float pad_val) {
443 CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "PadEnd: input or output can't be nullptr");
444 if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
445 (*dst) = src; // if no padding, copy the pointer
446 } else {
447 CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(),
448 "PadEnd: invalid pad shape, as rank of input is: " + std::to_string(src->Rank()) +
449 ", and rank of pad value: " + std::to_string(pad_shape.size()));
450 RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(pad_shape), src->type(), dst));
451 auto tensor_type = src->type().value();
452 if (std::fabs(pad_val) <= std::numeric_limits<float>::epsilon()) { // if pad with zero, don't care what type it is
453 RETURN_IF_NOT_OK((*dst)->Zero());
454 } else if (tensor_type == DataType::DE_INT8) {
455 RETURN_IF_NOT_OK((*dst)->Fill<int8_t>(static_cast<int8_t>(pad_val)));
456 } else if (tensor_type == DataType::DE_BOOL) {
457 RETURN_IF_NOT_OK((*dst)->Fill<bool>(static_cast<bool>(pad_val)));
458 } else if (tensor_type == DataType::DE_UINT8) {
459 RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(static_cast<uint8_t>(pad_val)));
460 } else if (tensor_type == DataType::DE_INT16) {
461 RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(static_cast<int16_t>(pad_val)));
462 } else if (tensor_type == DataType::DE_FLOAT16) {
463 RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
464 } else if (tensor_type == DataType::DE_UINT16) {
465 RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(static_cast<uint16_t>(pad_val)));
466 } else if (tensor_type == DataType::DE_INT32) {
467 RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(static_cast<int32_t>(pad_val)));
468 } else if (tensor_type == DataType::DE_UINT32) {
469 RETURN_IF_NOT_OK((*dst)->Fill<uint32_t>(static_cast<uint32_t>(pad_val)));
470 } else if (tensor_type == DataType::DE_INT64) {
471 RETURN_IF_NOT_OK((*dst)->Fill<int64_t>(static_cast<int64_t>(pad_val)));
472 } else if (tensor_type == DataType::DE_UINT64) {
473 RETURN_IF_NOT_OK((*dst)->Fill<uint64_t>(static_cast<uint64_t>(pad_val)));
474 } else if (tensor_type == DataType::DE_FLOAT32) {
475 RETURN_IF_NOT_OK((*dst)->Fill<float>(static_cast<float>(pad_val)));
476 } else if (tensor_type == DataType::DE_FLOAT64) {
477 RETURN_IF_NOT_OK((*dst)->Fill<double>(static_cast<double>(pad_val)));
478 } else {
479 RETURN_STATUS_UNEXPECTED(
480 "PadEnd: Incorrect/Unknown datatype, supported datatype is: [bool, int8, uint8, int16, uint16, int32, uint32, "
481 "int64, uint64, float16, float32, float64].");
482 }
483 std::vector<dsize_t> cur_ind(src->Rank(), 0);
484 RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0));
485 }
486 return Status::OK();
487 }
488
PadEndNumericHelper(const std::shared_ptr<Tensor> & src,const std::shared_ptr<Tensor> & dst,std::vector<dsize_t> cur_ind,size_t cur_dim)489 Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, const std::shared_ptr<Tensor> &dst,
490 std::vector<dsize_t> cur_ind, size_t cur_dim) {
491 if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
492 RETURN_IF_NOT_OK(dst->CopyLastDimAt(src, cur_ind));
493 } else { // not the last dimension, keep doing recursion
494 dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
495 for (dsize_t i = 0; i < min_ind; i++) {
496 cur_ind[cur_dim] = i;
497 RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1));
498 }
499 }
500 return Status::OK();
501 }
502
PadEndString(const std::shared_ptr<Tensor> & src,std::shared_ptr<Tensor> * dst,const std::vector<dsize_t> & pad_shape,const std::string & pad_val)503 Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
504 const std::vector<dsize_t> &pad_shape, const std::string &pad_val) {
505 CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
506 if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
507 (*dst) = src; // if no padding, copy the pointer
508 } else {
509 CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(),
510 "PadEnd: invalid pad shape, as rank of input is: " + std::to_string(src->Rank()) +
511 ", and rank of pad value: " + std::to_string(pad_shape.size()));
512 std::vector<dsize_t> cur_ind(src->Rank(), 0);
513 std::vector<std::string> strings;
514 RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val));
515 RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, TensorShape(pad_shape), src->type(), dst));
516 }
517 return Status::OK();
518 }
519
PadEndStringHelper(const std::shared_ptr<Tensor> & src,std::vector<std::string> * dst,const TensorShape & dst_shape,std::vector<dsize_t> cur_ind,size_t cur_dim,const std::string & pad_value)520 Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
521 const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
522 const std::string &pad_value) {
523 if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
524 dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
525 for (dsize_t i = 0; i < min_ind; i++) {
526 cur_ind[cur_dim] = i;
527 std::string_view item;
528 RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind));
529 (void)dst->emplace_back(item);
530 }
531 for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) {
532 (void)dst->emplace_back(pad_value);
533 }
534
535 } else { // not the last dimension, keep doing recursion
536 dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
537 for (dsize_t i = 0; i < min_ind; i++) {
538 cur_ind[cur_dim] = i;
539 RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value));
540 }
541 dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim];
542 for (dsize_t i = 0; i < count; i++) {
543 (void)dst->emplace_back(pad_value);
544 }
545 }
546 return Status::OK();
547 }
548
549 template <typename T>
MaskHelper(const std::shared_ptr<Tensor> & input,const std::shared_ptr<Tensor> & output,const std::shared_ptr<Tensor> & value_tensor,RelationalOp op)550 Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output,
551 const std::shared_ptr<Tensor> &value_tensor, RelationalOp op) {
552 T value;
553 RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {}));
554 auto in_itr = input->begin<T>();
555 auto out_itr = output->begin<bool>();
556 for (; in_itr != input->end<T>(); ++in_itr, ++out_itr) {
557 switch (op) {
558 case RelationalOp::kEqual:
559 *out_itr = (*in_itr == value);
560 break;
561 case RelationalOp::kNotEqual:
562 *out_itr = (*in_itr != value);
563 break;
564 case RelationalOp::kGreater:
565 *out_itr = (*in_itr > value);
566 break;
567 case RelationalOp::kGreaterEqual:
568 *out_itr = (*in_itr >= value);
569 break;
570 case RelationalOp::kLess:
571 *out_itr = (*in_itr < value);
572 break;
573 case RelationalOp::kLessEqual:
574 *out_itr = (*in_itr <= value);
575 break;
576 default:
577 RETURN_STATUS_UNEXPECTED(
578 "Mask: unknown relational operator, supported operator is: equal, notEqual, greater, less, lessEqual.");
579 }
580 }
581 return Status::OK();
582 }
583
Mask(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,const std::shared_ptr<Tensor> & value,RelationalOp op)584 Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value,
585 RelationalOp op) {
586 CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(),
587 "Mask: input datatype does not match the value datatype, both should be numeric or "
588 "non-numerical in the same time.");
589 CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(),
590 "Mask: value shape should be a scalar, got shape:" + value->shape().ToString());
591
592 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_BOOL), output));
593
594 std::unique_ptr<TypeCastOp> value_cast_op = std::make_unique<TypeCastOp>(input->type());
595 std::shared_ptr<Tensor> casted_value;
596 if (input->type().IsNumeric()) {
597 RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value));
598 } else {
599 casted_value = value;
600 }
601
602 switch (input->type().value()) {
603 case DataType::DE_BOOL:
604 RETURN_IF_NOT_OK(MaskHelper<bool>(input, *output, casted_value, op));
605 break;
606 case DataType::DE_INT8:
607 RETURN_IF_NOT_OK(MaskHelper<int8_t>(input, *output, casted_value, op));
608 break;
609 case DataType::DE_UINT8:
610 RETURN_IF_NOT_OK(MaskHelper<uint8_t>(input, *output, casted_value, op));
611 break;
612 case DataType::DE_UINT16:
613 RETURN_IF_NOT_OK(MaskHelper<uint16_t>(input, *output, casted_value, op));
614 break;
615 case DataType::DE_INT16:
616 RETURN_IF_NOT_OK(MaskHelper<int16_t>(input, *output, casted_value, op));
617 break;
618 case DataType::DE_UINT32:
619 RETURN_IF_NOT_OK(MaskHelper<uint32_t>(input, *output, casted_value, op));
620 break;
621 case DataType::DE_INT32:
622 RETURN_IF_NOT_OK(MaskHelper<int32_t>(input, *output, casted_value, op));
623 break;
624 case DataType::DE_UINT64:
625 RETURN_IF_NOT_OK(MaskHelper<uint64_t>(input, *output, casted_value, op));
626 break;
627 case DataType::DE_INT64:
628 RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op));
629 break;
630 case DataType::DE_FLOAT16:
631 RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op));
632 break;
633 case DataType::DE_FLOAT32:
634 RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op));
635 break;
636 case DataType::DE_FLOAT64:
637 RETURN_IF_NOT_OK(MaskHelper<double>(input, *output, casted_value, op));
638 break;
639 case DataType::DE_STRING:
640 RETURN_IF_NOT_OK(MaskHelper<std::string_view>(input, *output, casted_value, op));
641 break;
642 case DataType::DE_BYTES:
643 RETURN_IF_NOT_OK(MaskHelper<std::string_view>(input, *output, casted_value, op));
644 break;
645 case DataType::DE_UNKNOWN:
646 default:
647 RETURN_STATUS_UNEXPECTED(
648 "Mask: unsupported input datatype, support datatype is:[bool, int8, uint8, int16, uint16, int32, uint32, "
649 "int64, uint64, float16, float32, float64, string, bytes].");
650 break;
651 }
652 return Status::OK();
653 }
654
Concatenate(const TensorRow & input,TensorRow * output,int8_t axis,const std::shared_ptr<Tensor> & prepend,const std::shared_ptr<Tensor> & append)655 Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, const std::shared_ptr<Tensor> &prepend,
656 const std::shared_ptr<Tensor> &append) {
657 CHECK_FAIL_RETURN_UNEXPECTED(input.size() > 0, "Concatenate: input cannot be empty.");
658 axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
659 CHECK_FAIL_RETURN_UNEXPECTED(
660 axis == 0, "Concatenate: only 1D input supported, input 'axis' should be 0, but got: " + std::to_string(axis));
661
662 TensorShape t = TensorShape::CreateScalar();
663
664 DataType first_dtype = input[0]->type();
665
666 TensorRow tensor_list;
667
668 if (prepend != nullptr) {
669 CHECK_FAIL_RETURN_UNEXPECTED(
670 first_dtype == prepend->type(),
671 "Concatenate: input datatype does not match the prepend datatype, got input datatype: " + first_dtype.ToString() +
672 ", prepend datatype:" + prepend->type().ToString());
673 CHECK_FAIL_RETURN_UNEXPECTED(
674 prepend->shape().Rank() == 1,
675 "Concatenate: only 1D input supported, got rank of prepend: " + std::to_string(prepend->shape().Rank()));
676 tensor_list.emplace_back(prepend);
677 }
678
679 for (const auto &tensor : input) {
680 CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == tensor->type(), "Concatenate: inconsistent datatype of input.");
681 CHECK_FAIL_RETURN_UNEXPECTED(
682 tensor->shape().Rank() == 1,
683 "Concatenate: only 1D input supported, got rank of input: " + std::to_string(tensor->shape().Rank()));
684 tensor_list.emplace_back(tensor);
685 }
686
687 if (append != nullptr) {
688 CHECK_FAIL_RETURN_UNEXPECTED(
689 first_dtype == append->type(),
690 "Concatenate: input datatype does not match the append datatype, got input datatype: " + first_dtype.ToString() +
691 ", append datatype:" + append->type().ToString());
692 CHECK_FAIL_RETURN_UNEXPECTED(
693 append->shape().Rank() == 1,
694 "Concatenate: only 1D append supported, got rank of append:" + std::to_string(append->shape().Rank()));
695 tensor_list.emplace_back(append);
696 }
697
698 // create final shape
699 for (dsize_t i = 0; i < tensor_list[0]->shape().Rank(); i++) {
700 if (i != axis) {
701 t = t.AppendDim(tensor_list[0]->shape()[i]);
702 } else {
703 dsize_t new_shape = 0;
704 for (auto &tensor : tensor_list) {
705 new_shape = tensor->shape()[i] + new_shape;
706 }
707 t = t.AppendDim(new_shape);
708 }
709 }
710
711 std::shared_ptr<Tensor> out;
712
713 if (input[0]->type().IsNumeric()) {
714 RETURN_IF_NOT_OK(Tensor::CreateEmpty(t, tensor_list[0]->type(), &out));
715 std::vector<dsize_t> index(axis + 1, 0);
716
717 auto n = index.size() - 1;
718 for (auto &tensor : tensor_list) {
719 RETURN_IF_NOT_OK(out->InsertTensor({index}, tensor, true));
720 index[n] = index[n] + tensor->shape()[axis];
721 }
722 } else {
723 std::vector<std::string> strings;
724
725 for (auto &i : tensor_list) {
726 auto itr = i->begin<std::string_view>();
727 for (; itr != i->end<std::string_view>(); ++itr) {
728 (void)strings.emplace_back(*itr);
729 }
730 }
731 RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, t, input[0]->type(), &out));
732 }
733
734 output->push_back(out);
735
736 return Status::OK();
737 }
738
739 #ifndef ENABLE_ANDROID
BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> & input,std::vector<std::shared_ptr<CVTensor>> * output)740 Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
741 std::vector<std::shared_ptr<CVTensor>> *output) {
742 RETURN_UNEXPECTED_IF_NULL(output);
743 std::vector<int64_t> tensor_shape = input->shape().AsVector();
744 TensorShape remaining({-1});
745 std::vector<int64_t> index(tensor_shape.size(), 0);
746 CHECK_FAIL_RETURN_UNEXPECTED(
747 tensor_shape.size() > 1,
748 "MixUpBatch: input must be at least 2-D in order to unpack, but got rank: " + std::to_string(tensor_shape.size()));
749 TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end()));
750
751 for (; index[0] < tensor_shape[0]; index[0]++) {
752 uchar *start_addr_of_index = nullptr;
753 std::shared_ptr<Tensor> out;
754
755 RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining));
756 RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out));
757 std::shared_ptr<CVTensor> cv_out = CVTensor::AsCVTensor(std::move(out));
758 if (!cv_out->mat().data) {
759 RETURN_STATUS_UNEXPECTED("[Internal ERROR] MixUpBatch: allocate memory failed.");
760 }
761 output->push_back(cv_out);
762 }
763 return Status::OK();
764 }
765 #endif
766
BatchTensorToTensorVector(const std::shared_ptr<Tensor> & input,std::vector<std::shared_ptr<Tensor>> * output)767 Status BatchTensorToTensorVector(const std::shared_ptr<Tensor> &input, std::vector<std::shared_ptr<Tensor>> *output) {
768 RETURN_UNEXPECTED_IF_NULL(output);
769 std::vector<int64_t> tensor_shape = input->shape().AsVector();
770 TensorShape remaining({-1});
771 std::vector<int64_t> index(tensor_shape.size(), 0);
772 CHECK_FAIL_RETURN_UNEXPECTED(
773 tensor_shape.size() > 1,
774 "CutMixBatch: input must be at least 2-D in order to unpack, but got rank:" + std::to_string(tensor_shape.size()));
775 TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end()));
776
777 for (; index[0] < tensor_shape[0]; index[0]++) {
778 uchar *start_addr_of_index = nullptr;
779 std::shared_ptr<Tensor> out;
780
781 RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining));
782 RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out));
783 output->push_back(out);
784 }
785 return Status::OK();
786 }
787
TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> & input,std::shared_ptr<Tensor> * output)788 Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &input, std::shared_ptr<Tensor> *output) {
789 RETURN_UNEXPECTED_IF_NULL(output);
790 CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "CutMixBatch: the input is empty.");
791 std::vector<int64_t> tensor_shape = input.front()->shape().AsVector();
792 (void)tensor_shape.insert(tensor_shape.begin(), input.size());
793 RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(tensor_shape), input.at(0)->type(), output));
794 for (int i = 0; i < input.size(); i++) {
795 RETURN_IF_NOT_OK((*output)->InsertTensor({i}, input[i]));
796 }
797 return Status::OK();
798 }
799
800 template <typename T>
801 struct UniqueOpHashMap {
802 using map_type = std::unordered_map<T, int32_t>;
803 };
804 #ifndef ENABLE_ANDROID
805 template <>
806 struct UniqueOpHashMap<float16> {
807 using map_type = std::unordered_map<float16, int32_t>;
808 };
809
810 #else
811 struct gn_hash {
operator ()mindspore::dataset::gn_hash812 size_t operator()(const float16 &f) const { return static_cast<std::size_t>(f); }
813 };
814
815 template <>
816 struct UniqueOpHashMap<float16> {
817 using map_type = std::unordered_map<float16, int32_t, gn_hash>;
818 };
819 #endif
820
821 template <>
822 struct UniqueOpHashMap<float> {
823 using map_type = std::unordered_map<float, int32_t>;
824 };
825
826 template <>
827 struct UniqueOpHashMap<double> {
828 using map_type = std::unordered_map<double, int32_t>;
829 };
830
831 template <typename T>
UniqueHelper(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,std::shared_ptr<Tensor> * output_idx,std::shared_ptr<Tensor> * output_cnt)832 Status UniqueHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
833 std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt) {
834 const dsize_t N = input->Size();
835 RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_INT32), output_idx));
836
837 typename UniqueOpHashMap<T>::map_type uniq;
838 uniq.reserve(2 * N);
839 auto in_iter = input->begin<T>();
840 auto out_idx_iter = (*output_idx)->begin<int32_t>();
841 int32_t i = 0;
842 for (; in_iter != input->end<T>(); ++in_iter, ++out_idx_iter) {
843 auto it = uniq.emplace(*in_iter, i);
844 *out_idx_iter = it.first->second;
845 if (it.second) {
846 ++i;
847 }
848 }
849 auto uniq_size = uniq.size();
850 RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), input->type(), output));
851 auto out_iter = (*output)->begin<T>();
852 for (const auto &item : uniq) {
853 *(out_iter + static_cast<ptrdiff_t>(item.second)) = item.first;
854 }
855 RETURN_IF_NOT_OK(
856 Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), DataType(DataType::DE_INT32), output_cnt));
857 RETURN_IF_NOT_OK((*output_cnt)->Zero());
858
859 auto out_cnt_iter = (*output_cnt)->begin<int32_t>();
860 out_idx_iter = (*output_idx)->begin<int32_t>();
861 for (int32_t j = 0; j < N; ++j) {
862 auto idx = *(out_idx_iter + static_cast<ptrdiff_t>(j));
863 ++*(out_cnt_iter + static_cast<ptrdiff_t>(idx));
864 }
865 return Status::OK();
866 }
867
Unique(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,std::shared_ptr<Tensor> * output_idx,std::shared_ptr<Tensor> * output_cnt)868 Status Unique(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
869 std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt) {
870 CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "Unique: only 1D input supported, but got rank: " +
871 std::to_string(input->shape().Rank()));
872 if (input->type() == DataType::DE_INT64) {
873 RETURN_IF_NOT_OK(UniqueHelper<int64_t>(input, output, output_idx, output_cnt));
874 } else if (input->type() == DataType::DE_INT32) {
875 RETURN_IF_NOT_OK(UniqueHelper<int32_t>(input, output, output_idx, output_cnt));
876 } else if (input->type() == DataType::DE_INT16) {
877 RETURN_IF_NOT_OK(UniqueHelper<int16_t>(input, output, output_idx, output_cnt));
878 } else if (input->type() == DataType::DE_INT8) {
879 RETURN_IF_NOT_OK(UniqueHelper<int8_t>(input, output, output_idx, output_cnt));
880 } else if (input->type() == DataType::DE_UINT64) {
881 RETURN_IF_NOT_OK(UniqueHelper<uint64_t>(input, output, output_idx, output_cnt));
882 } else if (input->type() == DataType::DE_UINT32) {
883 RETURN_IF_NOT_OK(UniqueHelper<uint32_t>(input, output, output_idx, output_cnt));
884 } else if (input->type() == DataType::DE_UINT16) {
885 RETURN_IF_NOT_OK(UniqueHelper<uint16_t>(input, output, output_idx, output_cnt));
886 } else if (input->type() == DataType::DE_UINT8) {
887 RETURN_IF_NOT_OK(UniqueHelper<uint8_t>(input, output, output_idx, output_cnt));
888 } else if (input->type() == DataType::DE_FLOAT16) {
889 RETURN_IF_NOT_OK(UniqueHelper<float16>(input, output, output_idx, output_cnt));
890 } else if (input->type() == DataType::DE_FLOAT32) {
891 RETURN_IF_NOT_OK(UniqueHelper<float>(input, output, output_idx, output_cnt));
892 } else if (input->type() == DataType::DE_FLOAT64) {
893 RETURN_IF_NOT_OK(UniqueHelper<double>(input, output, output_idx, output_cnt));
894 } else {
895 RETURN_STATUS_UNEXPECTED("Unique: Unique op only supports numeric input.");
896 }
897 return Status::OK();
898 }
899 } // namespace dataset
900 } // namespace mindspore
901