• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/kernels/data/data_utils.h"
18 
19 #include <algorithm>
20 #include <limits>
21 #include <string>
22 #include <vector>
23 #include <utility>
24 
25 #include "minddata/dataset/include/dataset/constants.h"
26 #include "minddata/dataset/core/data_type.h"
27 #ifdef ENABLE_PYTHON
28 #include "minddata/dataset/core/pybind_support.h"
29 #endif
30 #include "minddata/dataset/core/tensor.h"
31 #include "minddata/dataset/core/tensor_shape.h"
32 #include "minddata/dataset/kernels/data/type_cast_op.h"
33 #include "minddata/dataset/util/status.h"
34 
35 namespace mindspore {
36 namespace dataset {
OneHotEncodingUnsigned(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,dsize_t num_classes,int64_t index)37 Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
38                               dsize_t num_classes, int64_t index) {
39   uint64_t class_idx;
40   if (input->Rank() == 0) {
41     RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {}));
42   } else {
43     RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {index}));
44   }
45   if (class_idx >= static_cast<uint64_t>(num_classes)) {
46     RETURN_STATUS_UNEXPECTED("OneHot: index values should not bigger than num classes: " + std::to_string(num_classes) +
47                              ", but got: " + std::to_string(class_idx));
48   }
49   if (input->type() == DataType::DE_UINT64) {
50     RETURN_IF_NOT_OK((*output)->SetItemAt<uint64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
51   } else if (input->type() == DataType::DE_UINT32) {
52     RETURN_IF_NOT_OK((*output)->SetItemAt<uint32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
53   } else if (input->type() == DataType::DE_UINT16) {
54     RETURN_IF_NOT_OK((*output)->SetItemAt<uint16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
55   } else if (input->type() == DataType::DE_UINT8) {
56     RETURN_IF_NOT_OK((*output)->SetItemAt<uint8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
57   } else {
58     RETURN_STATUS_UNEXPECTED("OneHot: OneHot unsigned only supports unsigned int as input.");
59   }
60   return Status::OK();
61 }
62 
OneHotEncodingSigned(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,dsize_t num_classes,int64_t index)63 Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
64                             int64_t index) {
65   int64_t class_idx;
66   if (input->Rank() == 0) {
67     RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {}));
68   } else {
69     RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {index}));
70   }
71   if (class_idx >= static_cast<int64_t>(num_classes)) {
72     RETURN_STATUS_UNEXPECTED("OneHot: index values should not bigger than num classes: " + std::to_string(num_classes) +
73                              ", but got: " + std::to_string(class_idx));
74   }
75   if (input->type() == DataType::DE_INT64) {
76     RETURN_IF_NOT_OK((*output)->SetItemAt<int64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
77   } else if (input->type() == DataType::DE_INT32) {
78     RETURN_IF_NOT_OK((*output)->SetItemAt<int32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
79   } else if (input->type() == DataType::DE_INT16) {
80     RETURN_IF_NOT_OK((*output)->SetItemAt<int16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
81   } else if (input->type() == DataType::DE_INT8) {
82     RETURN_IF_NOT_OK((*output)->SetItemAt<int8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
83   } else {
84     RETURN_STATUS_UNEXPECTED("OneHot: OneHot signed only supports signed int as input.");
85   }
86   return Status::OK();
87 }
88 
OneHotEncoding(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,dsize_t num_classes)89 Status OneHotEncoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes) {
90   input->Squeeze();
91 
92   if (input->Rank() > 1) {  // We expect the input to be int he first dimension
93     RETURN_STATUS_UNEXPECTED("OneHot: OneHot only supports scalars or 1D input, got rank: " +
94                              std::to_string(input->Rank()));
95   }
96   if (!input->type().IsInt()) {
97     RETURN_STATUS_UNEXPECTED("OneHot: OneHot only not support input of int type.");
98   }
99   try {
100     dsize_t num_elements = 1;
101     if (input->Rank() == 1) num_elements = input->shape()[0];
102     TensorShape out_shape({num_elements, num_classes});
103     std::shared_ptr<Tensor> out;
104     RETURN_IF_NOT_OK(Tensor::CreateEmpty(out_shape, input->type(), &out));
105     RETURN_IF_NOT_OK(out->Zero());
106     for (dsize_t i = 0; i < num_elements; ++i) {
107       if (input->type().IsUnsignedInt()) {
108         RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i));
109       } else {
110         RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i));
111       }
112     }
113     out->Squeeze();
114     *output = out;
115     return Status::OK();
116   } catch (const std::exception &e) {
117     std::string err_msg = "Error raised in OneHot operation: ";
118     err_msg += e.what();
119     RETURN_STATUS_UNEXPECTED(err_msg);
120   }
121 }
122 
FillHelper(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * out,std::shared_ptr<Tensor> fill_output,std::shared_ptr<Tensor> fill_value)123 Status FillHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out,
124                   std::shared_ptr<Tensor> fill_output, std::shared_ptr<Tensor> fill_value) {
125   const DataType &input_type = input->type();
126   const TensorShape &input_shape = input->shape();
127   switch (input_type.value()) {
128     case DataType::DE_BOOL: {
129       bool value = 0;
130       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
131       RETURN_IF_NOT_OK((*out)->Fill<bool>(value));
132       break;
133     }
134     case DataType::DE_INT8: {
135       int8_t value = 0;
136       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
137       RETURN_IF_NOT_OK((*out)->Fill<int8_t>(value));
138       break;
139     }
140     case DataType::DE_UINT8: {
141       uint8_t value = 0;
142       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
143       RETURN_IF_NOT_OK((*out)->Fill<uint8_t>(value));
144       break;
145     }
146     case DataType::DE_UINT16: {
147       uint16_t value = 0;
148       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
149       RETURN_IF_NOT_OK((*out)->Fill<uint16_t>(value));
150       break;
151     }
152     case DataType::DE_INT16: {
153       int16_t value = 0;
154       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
155       RETURN_IF_NOT_OK((*out)->Fill<int16_t>(value));
156       break;
157     }
158     case DataType::DE_UINT32: {
159       uint32_t value = 0;
160       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
161       RETURN_IF_NOT_OK((*out)->Fill<uint32_t>(value));
162       break;
163     }
164     case DataType::DE_INT32: {
165       int32_t value = 0;
166       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
167       RETURN_IF_NOT_OK((*out)->Fill<int32_t>(value));
168       break;
169     }
170     case DataType::DE_UINT64: {
171       uint64_t value = 0;
172       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
173       RETURN_IF_NOT_OK((*out)->Fill<uint64_t>(value));
174       break;
175     }
176     case DataType::DE_INT64: {
177       int64_t value = 0;
178       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
179       RETURN_IF_NOT_OK((*out)->Fill<int64_t>(value));
180       break;
181     }
182     case DataType::DE_FLOAT16: {
183       int64_t value = 0;
184       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
185       RETURN_IF_NOT_OK((*out)->Fill<float>(value));
186       break;
187     }
188     case DataType::DE_FLOAT32: {
189       float value = 0;
190       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
191       RETURN_IF_NOT_OK((*out)->Fill<float>(value));
192       break;
193     }
194     case DataType::DE_FLOAT64: {
195       double value = 0;
196       RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
197       RETURN_IF_NOT_OK((*out)->Fill<double>(value));
198       break;
199     }
200     case DataType::DE_STRING: {
201       std::vector<std::string> strings;
202       std::string_view fill_string_view;
203       RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {}));
204       std::string fill_string = std::string(fill_string_view);
205       for (int i = 0; i < input_shape.NumOfElements(); i++) {
206         strings.emplace_back(fill_string);
207       }
208       RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, input_shape, out));
209       break;
210     }
211     case DataType::DE_UNKNOWN: {
212       RETURN_STATUS_UNEXPECTED("Fill: unknown input datatype.");
213       break;
214     }
215   }
216   return Status::OK();
217 }
218 
Fill(const std::shared_ptr<Tensor> input,std::shared_ptr<Tensor> * output,std::shared_ptr<Tensor> fill_value)219 Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) {
220   const DataType &fill_type = fill_value->type();
221   const DataType &input_type = input->type();
222   const TensorShape &input_shape = input->shape();
223 
224   CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)),
225                                "Fill: fill datatype is string but the input datatype is not string.");
226 
227   CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}),
228                                "Fill: the shape of fill_value is not a scalar.");
229 
230   std::shared_ptr<Tensor> out, fill_output;
231 
232   if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) {
233     auto op = std::make_unique<TypeCastOp>(input_type);
234     RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
235   } else {
236     fill_output = fill_value;
237   }
238 
239   if (input_type.IsNumeric()) {
240     RETURN_IF_NOT_OK(Tensor::CreateEmpty(input_shape, input_type, &out));
241   }
242   RETURN_IF_NOT_OK(FillHelper(input, &out, fill_output, fill_value));
243   *output = out;
244   return Status::OK();
245 }
246 
247 template <typename FROM, typename TO>
Cast(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output)248 void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
249   auto in_itr = input->begin<FROM>();
250   auto out_itr = (*output)->begin<TO>();
251   auto out_end = (*output)->end<TO>();
252 
253   for (; out_itr != out_end; ++in_itr, ++out_itr) *out_itr = static_cast<TO>(*in_itr);
254 }
255 
256 template <typename T>
CastFrom(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output)257 void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
258   switch ((*output)->type().value()) {
259     case DataType::DE_BOOL:
260       Cast<T, bool>(input, output);
261       break;
262     case DataType::DE_INT8:
263       Cast<T, int8_t>(input, output);
264       break;
265     case DataType::DE_UINT8:
266       Cast<T, uint8_t>(input, output);
267       break;
268     case DataType::DE_INT16:
269       Cast<T, int16_t>(input, output);
270       break;
271     case DataType::DE_UINT16:
272       Cast<T, uint16_t>(input, output);
273       break;
274     case DataType::DE_INT32:
275       Cast<T, int32_t>(input, output);
276       break;
277     case DataType::DE_UINT32:
278       Cast<T, uint32_t>(input, output);
279       break;
280     case DataType::DE_INT64:
281       Cast<T, int64_t>(input, output);
282       break;
283     case DataType::DE_UINT64:
284       Cast<T, uint64_t>(input, output);
285       break;
286     case DataType::DE_FLOAT16:
287       Cast<T, float16>(input, output);
288       break;
289     case DataType::DE_FLOAT32:
290       Cast<T, float>(input, output);
291       break;
292     case DataType::DE_FLOAT64:
293       Cast<T, double>(input, output);
294       break;
295     case DataType::DE_UNKNOWN:
296       MS_LOG(ERROR) << "TypeCast: unknown datatype.";
297       break;
298   }
299 }
300 
301 // Type cast operator
TypeCast(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,const DataType & data_type)302 Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const DataType &data_type) {
303   RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output));
304 
305   switch (input->type().value()) {
306     case DataType::DE_BOOL:
307       CastFrom<bool>(input, output);
308       break;
309     case DataType::DE_INT8:
310       CastFrom<int8_t>(input, output);
311       break;
312     case DataType::DE_UINT8:
313       CastFrom<uint8_t>(input, output);
314       break;
315     case DataType::DE_INT16:
316       CastFrom<int16_t>(input, output);
317       break;
318     case DataType::DE_UINT16:
319       CastFrom<uint16_t>(input, output);
320       break;
321     case DataType::DE_INT32:
322       CastFrom<int32_t>(input, output);
323       break;
324     case DataType::DE_UINT32:
325       CastFrom<uint32_t>(input, output);
326       break;
327     case DataType::DE_INT64:
328       CastFrom<int64_t>(input, output);
329       break;
330     case DataType::DE_UINT64:
331       CastFrom<uint64_t>(input, output);
332       break;
333     case DataType::DE_FLOAT16:
334       CastFrom<float16>(input, output);
335       break;
336     case DataType::DE_FLOAT32:
337       CastFrom<float>(input, output);
338       break;
339     case DataType::DE_FLOAT64:
340       CastFrom<double>(input, output);
341       break;
342     case DataType::DE_UNKNOWN:
343       // sanity check, unreachable code.
344       RETURN_STATUS_UNEXPECTED(
345         "TypeCast: TypeCast does not support input of this type, supported is: [bool, int8, int16, int32, int64, uint8,"
346         " uint16, uint32, uint64, float16, float32, float64]");
347   }
348   return Status::OK();
349 }
350 
ToFloat16(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output)351 Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
352   // initiate new tensor for type cast
353   DataType new_type = DataType("float16");
354   RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), new_type, output));
355 
356   auto in_itr = input->begin<float>();
357   auto in_end = input->end<float>();
358   auto out_itr = (*output)->begin<float16>();
359   auto out_end = (*output)->end<float16>();
360 
361   for (; (in_itr != in_end) && (out_itr != out_end); ++in_itr, ++out_itr) {
362     float element = *in_itr;
363     float float16_max = static_cast<float>(std::numeric_limits<float16>::max());
364     float float16_min = static_cast<float>(std::numeric_limits<float16>::lowest());
365     if (element > float16_max || element < float16_min) {
366       RETURN_STATUS_UNEXPECTED("ToFloat16: value " + std::to_string(element) + " is outside of valid float16 range [" +
367                                std::to_string(float16_max) + ", " + std::to_string(float16_min) + "].");
368     }
369 
370     *out_itr = float16(*in_itr);
371   }
372 
373   return Status::OK();
374 }
375 
PadEnd(const std::shared_ptr<Tensor> & src,std::shared_ptr<Tensor> * dst,const std::vector<dsize_t> & pad_shape,const std::shared_ptr<Tensor> & pad_val)376 Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
377               const std::shared_ptr<Tensor> &pad_val) {
378   if (pad_val == nullptr) {
379     if (src->type().IsNumeric()) {
380       return PadEndNumeric(src, dst, pad_shape, 0);
381     } else {
382       return PadEndString(src, dst, pad_shape, "");
383     }
384   }
385   CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
386                                "PadEnd: pad_value and item of dataset are not of the same type, type of pad_value is:" +
387                                  pad_val->type().ToString() +
388                                  ", and type of dataset item is:" + src->type().ToString() + ".");
389   if (pad_val->type().IsNumeric()) {
390     std::shared_ptr<Tensor> float_pad_value;
391     RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32)));
392     float val = 0;
393     RETURN_IF_NOT_OK(float_pad_value->GetItemAt<float>(&val, {}));
394     return PadEndNumeric(src, dst, pad_shape, val);
395   }
396   std::string_view val;
397   RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {}));
398   return PadEndString(src, dst, pad_shape, std::string(val));
399 }
400 
PadEndNumeric(const std::shared_ptr<Tensor> & src,std::shared_ptr<Tensor> * dst,const std::vector<dsize_t> & pad_shape,float pad_val)401 Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
402                      const std::vector<dsize_t> &pad_shape, float pad_val) {
403   CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "PadEnd: input or output can't be nullptr");
404   if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
405     (*dst) = src;  // if no padding, copy the pointer
406   } else {
407     CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(),
408                                  "PadEnd: invalid pad shape, as rank of input is: " + std::to_string(src->Rank()) +
409                                    ", and rank of pad value: " + std::to_string(pad_shape.size()));
410     RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(pad_shape), src->type(), dst));
411     auto tensor_type = src->type().value();
412     if (pad_val == 0) {  // if pad with zero, don't care what type it is
413       RETURN_IF_NOT_OK((*dst)->Zero());
414     } else if (tensor_type == DataType::DE_INT8) {
415       RETURN_IF_NOT_OK((*dst)->Fill<int8_t>(static_cast<int8_t>(pad_val)));
416     } else if (tensor_type == DataType::DE_BOOL) {
417       RETURN_IF_NOT_OK((*dst)->Fill<bool>(static_cast<bool>(pad_val)));
418     } else if (tensor_type == DataType::DE_UINT8) {
419       RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(static_cast<uint8_t>(pad_val)));
420     } else if (tensor_type == DataType::DE_INT16) {
421       RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(static_cast<int16_t>(pad_val)));
422     } else if (tensor_type == DataType::DE_FLOAT16) {
423       RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
424     } else if (tensor_type == DataType::DE_UINT16) {
425       RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(static_cast<uint16_t>(pad_val)));
426     } else if (tensor_type == DataType::DE_INT32) {
427       RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(static_cast<int32_t>(pad_val)));
428     } else if (tensor_type == DataType::DE_UINT32) {
429       RETURN_IF_NOT_OK((*dst)->Fill<uint32_t>(static_cast<uint32_t>(pad_val)));
430     } else if (tensor_type == DataType::DE_INT64) {
431       RETURN_IF_NOT_OK((*dst)->Fill<int64_t>(static_cast<int64_t>(pad_val)));
432     } else if (tensor_type == DataType::DE_UINT64) {
433       RETURN_IF_NOT_OK((*dst)->Fill<uint64_t>(static_cast<uint64_t>(pad_val)));
434     } else if (tensor_type == DataType::DE_FLOAT32) {
435       RETURN_IF_NOT_OK((*dst)->Fill<float>(static_cast<float>(pad_val)));
436     } else if (tensor_type == DataType::DE_FLOAT64) {
437       RETURN_IF_NOT_OK((*dst)->Fill<double>(static_cast<double>(pad_val)));
438     } else {
439       RETURN_STATUS_UNEXPECTED("PadEnd: Incorrect/Unknown datatype");
440     }
441     std::vector<dsize_t> cur_ind(src->Rank(), 0);
442     RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0));
443   }
444   return Status::OK();
445 }
PadEndNumericHelper(const std::shared_ptr<Tensor> & src,std::shared_ptr<Tensor> dst,std::vector<dsize_t> cur_ind,size_t cur_dim)446 Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
447                            std::vector<dsize_t> cur_ind, size_t cur_dim) {
448   if (cur_dim == src->Rank() - 1) {  // if this is the last dimension, copy the data
449     RETURN_IF_NOT_OK(dst->CopyLastDimAt(src, cur_ind));
450   } else {  // not the last dimension, keep doing recursion
451     dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
452     for (dsize_t i = 0; i < min_ind; i++) {
453       cur_ind[cur_dim] = i;
454       RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1));
455     }
456   }
457   return Status::OK();
458 }
459 
PadEndString(const std::shared_ptr<Tensor> & src,std::shared_ptr<Tensor> * dst,const std::vector<dsize_t> & pad_shape,const std::string & pad_val)460 Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
461                     const std::vector<dsize_t> &pad_shape, const std::string &pad_val) {
462   CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
463   if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
464     (*dst) = src;  // if no padding, copy the pointer
465   } else {
466     CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(),
467                                  "PadEnd: invalid pad shape, as rank of input is: " + std::to_string(src->Rank()) +
468                                    ", and rank of pad value: " + std::to_string(pad_shape.size()));
469     std::vector<dsize_t> cur_ind(src->Rank(), 0);
470     std::vector<std::string> strings;
471     RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val));
472     RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, TensorShape(pad_shape), dst));
473   }
474   return Status::OK();
475 }
476 
PadEndStringHelper(const std::shared_ptr<Tensor> & src,std::vector<std::string> * dst,const TensorShape & dst_shape,std::vector<dsize_t> cur_ind,size_t cur_dim,const std::string & pad_value)477 Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
478                           const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
479                           const std::string &pad_value) {
480   if (cur_dim == src->Rank() - 1) {  // if this is the last dimension, copy the data
481     dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
482     for (dsize_t i = 0; i < min_ind; i++) {
483       cur_ind[cur_dim] = i;
484       std::string_view item;
485       RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind));
486       dst->emplace_back(item);
487     }
488     for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) {
489       dst->emplace_back(pad_value);
490     }
491 
492   } else {  // not the last dimension, keep doing recursion
493     dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
494     for (dsize_t i = 0; i < min_ind; i++) {
495       cur_ind[cur_dim] = i;
496       RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value));
497     }
498     dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim];
499     for (dsize_t i = 0; i < count; i++) {
500       dst->emplace_back(pad_value);
501     }
502   }
503   return Status::OK();
504 }
505 
506 template <typename T>
MaskHelper(const std::shared_ptr<Tensor> & input,const std::shared_ptr<Tensor> & output,const std::shared_ptr<Tensor> & value_tensor,RelationalOp op)507 Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output,
508                   const std::shared_ptr<Tensor> &value_tensor, RelationalOp op) {
509   T value;
510   RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {}));
511   auto in_itr = input->begin<T>();
512   auto out_itr = output->begin<bool>();
513   for (; in_itr != input->end<T>(); ++in_itr, ++out_itr) {
514     switch (op) {
515       case RelationalOp::kEqual:
516         *out_itr = (*in_itr == value);
517         break;
518       case RelationalOp::kNotEqual:
519         *out_itr = (*in_itr != value);
520         break;
521       case RelationalOp::kGreater:
522         *out_itr = (*in_itr > value);
523         break;
524       case RelationalOp::kGreaterEqual:
525         *out_itr = (*in_itr >= value);
526         break;
527       case RelationalOp::kLess:
528         *out_itr = (*in_itr < value);
529         break;
530       case RelationalOp::kLessEqual:
531         *out_itr = (*in_itr <= value);
532         break;
533       default:
534         RETURN_STATUS_UNEXPECTED("Mask: unknown relational operator.");
535     }
536   }
537   return Status::OK();
538 }
539 
Mask(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,const std::shared_ptr<Tensor> & value,RelationalOp op)540 Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value,
541             RelationalOp op) {
542   CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(),
543                                "Mask: input datatype does not match the value datatype, both should be numeric or "
544                                "non-numerical in the same time.");
545   CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Mask: value shape is not a scalar");
546 
547   RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_BOOL), output));
548 
549   std::unique_ptr<TypeCastOp> value_cast_op = std::make_unique<TypeCastOp>(input->type());
550   std::shared_ptr<Tensor> casted_value;
551   if (input->type().IsNumeric()) {
552     RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value));
553   } else {
554     casted_value = value;
555   }
556 
557   switch (input->type().value()) {
558     case DataType::DE_BOOL:
559       RETURN_IF_NOT_OK(MaskHelper<bool>(input, *output, casted_value, op));
560       break;
561     case DataType::DE_INT8:
562       RETURN_IF_NOT_OK(MaskHelper<int8_t>(input, *output, casted_value, op));
563       break;
564     case DataType::DE_UINT8:
565       RETURN_IF_NOT_OK(MaskHelper<uint8_t>(input, *output, casted_value, op));
566       break;
567     case DataType::DE_UINT16:
568       RETURN_IF_NOT_OK(MaskHelper<uint16_t>(input, *output, casted_value, op));
569       break;
570     case DataType::DE_INT16:
571       RETURN_IF_NOT_OK(MaskHelper<int16_t>(input, *output, casted_value, op));
572       break;
573     case DataType::DE_UINT32:
574       RETURN_IF_NOT_OK(MaskHelper<uint32_t>(input, *output, casted_value, op));
575       break;
576     case DataType::DE_INT32:
577       RETURN_IF_NOT_OK(MaskHelper<int32_t>(input, *output, casted_value, op));
578       break;
579     case DataType::DE_UINT64:
580       RETURN_IF_NOT_OK(MaskHelper<uint64_t>(input, *output, casted_value, op));
581       break;
582     case DataType::DE_INT64:
583       RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op));
584       break;
585     case DataType::DE_FLOAT16:
586       RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op));
587       break;
588     case DataType::DE_FLOAT32:
589       RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op));
590       break;
591     case DataType::DE_FLOAT64:
592       RETURN_IF_NOT_OK(MaskHelper<double>(input, *output, casted_value, op));
593       break;
594     case DataType::DE_STRING:
595       RETURN_IF_NOT_OK(MaskHelper<std::string_view>(input, *output, casted_value, op));
596       break;
597     case DataType::DE_UNKNOWN:
598       RETURN_STATUS_UNEXPECTED("Mask: unsupported input datatype.");
599       break;
600   }
601   return Status::OK();
602 }
603 
Concatenate(const TensorRow & input,TensorRow * output,int8_t axis,std::shared_ptr<Tensor> prepend,std::shared_ptr<Tensor> append)604 Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
605                    std::shared_ptr<Tensor> append) {
606   CHECK_FAIL_RETURN_UNEXPECTED(input.size() > 0, "Concatenate: input is null");
607   axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
608   CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Concatenate: only 1D input supported, got rank: " + std::to_string(axis));
609 
610   TensorShape t = TensorShape::CreateScalar();
611 
612   DataType first_dtype = input[0]->type();
613 
614   TensorRow tensor_list;
615 
616   if (prepend != nullptr) {
617     CHECK_FAIL_RETURN_UNEXPECTED(
618       first_dtype == prepend->type(),
619       "Concatenate: input datatype does not match the prepend datatype: " + prepend->type().ToString());
620     CHECK_FAIL_RETURN_UNEXPECTED(
621       prepend->shape().Rank() == 1,
622       "Concatenate: only 1D input supported, got rank of input: " + std::to_string(prepend->shape().Rank()));
623     tensor_list.emplace_back(prepend);
624   }
625 
626   for (dsize_t i = 0; i < input.size(); i++) {
627     CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == input[i]->type(), "Concatenate: inconsistent datatype of input.");
628     CHECK_FAIL_RETURN_UNEXPECTED(
629       input[i]->shape().Rank() == 1,
630       "Concatenate: only 1D input supported, got rank of input: " + std::to_string(input[i]->shape().Rank()));
631     tensor_list.emplace_back(input[i]);
632   }
633 
634   if (append != nullptr) {
635     CHECK_FAIL_RETURN_UNEXPECTED(
636       first_dtype == append->type(),
637       "Concatenate: input datatype does not match the append datatype: " + append->type().ToString());
638     CHECK_FAIL_RETURN_UNEXPECTED(
639       append->shape().Rank() == 1,
640       "Concatenate: only 1D append supported, got rank of input: " + std::to_string(append->shape().Rank()));
641     tensor_list.emplace_back(append);
642   }
643 
644   //  create final shape
645   for (dsize_t i = 0; i < tensor_list[0]->shape().Rank(); i++) {
646     if (i != axis) {
647       t = t.AppendDim(tensor_list[0]->shape()[i]);
648     } else {
649       dsize_t new_shape = 0;
650       for (dsize_t j = 0; j < tensor_list.size(); j++) {
651         new_shape = tensor_list[j]->shape()[i] + new_shape;
652       }
653       t = t.AppendDim(new_shape);
654     }
655   }
656 
657   std::shared_ptr<Tensor> out;
658 
659   if (input[0]->type().IsNumeric()) {
660     RETURN_IF_NOT_OK(Tensor::CreateEmpty(t, tensor_list[0]->type(), &out));
661     std::vector<dsize_t> index(axis + 1, 0);
662 
663     int n = index.size() - 1;
664     for (dsize_t i = 0; i < tensor_list.size(); i++) {
665       RETURN_IF_NOT_OK(out->InsertTensor({index}, tensor_list[i], true));
666       index[n] = index[n] + tensor_list[i]->shape()[axis];
667     }
668   } else {
669     std::vector<std::string> strings;
670 
671     for (dsize_t i = 0; i < tensor_list.size(); i++) {
672       auto itr = tensor_list[i]->begin<std::string_view>();
673       for (; itr != tensor_list[i]->end<std::string_view>(); ++itr) {
674         strings.emplace_back(*itr);
675       }
676     }
677     RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, t, &out));
678   }
679 
680   output->push_back(out);
681 
682   return Status::OK();
683 }
684 
685 #ifndef ENABLE_ANDROID
BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> & input,std::vector<std::shared_ptr<CVTensor>> * output)686 Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
687                                    std::vector<std::shared_ptr<CVTensor>> *output) {
688   std::vector<int64_t> tensor_shape = input->shape().AsVector();
689   TensorShape remaining({-1});
690   std::vector<int64_t> index(tensor_shape.size(), 0);
691   if (tensor_shape.size() <= 1) {
692     RETURN_STATUS_UNEXPECTED("MixUpBatch: input must be at least 2-D in order to unpack, but got rank: " +
693                              std::to_string(tensor_shape.size()));
694   }
695   TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end()));
696 
697   for (; index[0] < tensor_shape[0]; index[0]++) {
698     uchar *start_addr_of_index = nullptr;
699     std::shared_ptr<Tensor> out;
700 
701     RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining));
702     RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out));
703     std::shared_ptr<CVTensor> cv_out = CVTensor::AsCVTensor(std::move(out));
704     if (!cv_out->mat().data) {
705       RETURN_STATUS_UNEXPECTED("[Internal ERROR] MixUpBatch: allocate memory failed.");
706     }
707     output->push_back(cv_out);
708   }
709   return Status::OK();
710 }
711 #endif
712 
BatchTensorToTensorVector(const std::shared_ptr<Tensor> & input,std::vector<std::shared_ptr<Tensor>> * output)713 Status BatchTensorToTensorVector(const std::shared_ptr<Tensor> &input, std::vector<std::shared_ptr<Tensor>> *output) {
714   std::vector<int64_t> tensor_shape = input->shape().AsVector();
715   TensorShape remaining({-1});
716   std::vector<int64_t> index(tensor_shape.size(), 0);
717   if (tensor_shape.size() <= 1) {
718     RETURN_STATUS_UNEXPECTED("CutMixBatch: input must be at least 2-D in order to unpack, but got rank:" +
719                              std::to_string(tensor_shape.size()));
720   }
721   TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end()));
722 
723   for (; index[0] < tensor_shape[0]; index[0]++) {
724     uchar *start_addr_of_index = nullptr;
725     std::shared_ptr<Tensor> out;
726 
727     RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining));
728     RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out));
729     output->push_back(out);
730   }
731   return Status::OK();
732 }
733 
TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> & input,std::shared_ptr<Tensor> * output)734 Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &input, std::shared_ptr<Tensor> *output) {
735   if (input.empty()) {
736     RETURN_STATUS_UNEXPECTED("CutMixBatch: the input is empty.");
737   }
738   std::vector<int64_t> tensor_shape = input.front()->shape().AsVector();
739   tensor_shape.insert(tensor_shape.begin(), input.size());
740   RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(tensor_shape), input.at(0)->type(), output));
741   for (int i = 0; i < input.size(); i++) {
742     RETURN_IF_NOT_OK((*output)->InsertTensor({i}, input[i]));
743   }
744   return Status::OK();
745 }
746 
747 template <typename T>
748 struct UniqueOpHashMap {
749   using map_type = std::unordered_map<T, int32_t>;
750 };
751 #ifndef ENABLE_ANDROID
752 template <>
753 struct UniqueOpHashMap<float16> {
754   using map_type = std::unordered_map<float16, int32_t>;
755 };
756 
757 #else
758 struct gn_hash {
operator ()mindspore::dataset::gn_hash759   size_t operator()(const float16 &f) const { return static_cast<std::size_t>(f); }
760 };
761 
762 template <>
763 struct UniqueOpHashMap<float16> {
764   using map_type = std::unordered_map<float16, int32_t, gn_hash>;
765 };
766 #endif
767 
768 template <>
769 struct UniqueOpHashMap<float> {
770   using map_type = std::unordered_map<float, int32_t>;
771 };
772 
773 template <>
774 struct UniqueOpHashMap<double> {
775   using map_type = std::unordered_map<double, int32_t>;
776 };
777 
778 template <typename T>
UniqueHelper(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,std::shared_ptr<Tensor> * output_idx,std::shared_ptr<Tensor> * output_cnt)779 Status UniqueHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
780                     std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt) {
781   const dsize_t N = input->Size();
782   RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_INT32), output_idx));
783 
784   typename UniqueOpHashMap<T>::map_type uniq;
785   uniq.reserve(2 * N);
786   auto in_iter = input->begin<T>();
787   auto out_idx_iter = (*output_idx)->begin<int32_t>();
788   int32_t i = 0;
789   for (; in_iter != input->end<T>(); ++in_iter, ++out_idx_iter) {
790     auto it = uniq.emplace(*in_iter, i);
791     *out_idx_iter = it.first->second;
792     if (it.second) {
793       ++i;
794     }
795   }
796   auto uniq_size = uniq.size();
797   RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), input->type(), output));
798   auto out_iter = (*output)->begin<T>();
799   for (const auto &item : uniq) {
800     *(out_iter + static_cast<ptrdiff_t>(item.second)) = item.first;
801   }
802   RETURN_IF_NOT_OK(
803     Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), DataType(DataType::DE_INT32), output_cnt));
804   RETURN_IF_NOT_OK((*output_cnt)->Zero());
805 
806   auto out_cnt_iter = (*output_cnt)->begin<int32_t>();
807   out_idx_iter = (*output_idx)->begin<int32_t>();
808   for (int32_t j = 0; j < N; ++j) {
809     auto idx = *(out_idx_iter + static_cast<ptrdiff_t>(j));
810     ++*(out_cnt_iter + static_cast<ptrdiff_t>(idx));
811   }
812   return Status::OK();
813 }
814 
Unique(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output,std::shared_ptr<Tensor> * output_idx,std::shared_ptr<Tensor> * output_cnt)815 Status Unique(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
816               std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt) {
817   CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "Unique: only 1D input supported, but got rank: " +
818                                                              std::to_string(input->shape().Rank()));
819   if (input->type() == DataType::DE_INT64) {
820     RETURN_IF_NOT_OK(UniqueHelper<int64_t>(input, output, output_idx, output_cnt));
821   } else if (input->type() == DataType::DE_INT32) {
822     RETURN_IF_NOT_OK(UniqueHelper<int32_t>(input, output, output_idx, output_cnt));
823   } else if (input->type() == DataType::DE_INT16) {
824     RETURN_IF_NOT_OK(UniqueHelper<int16_t>(input, output, output_idx, output_cnt));
825   } else if (input->type() == DataType::DE_INT8) {
826     RETURN_IF_NOT_OK(UniqueHelper<int8_t>(input, output, output_idx, output_cnt));
827   } else if (input->type() == DataType::DE_UINT64) {
828     RETURN_IF_NOT_OK(UniqueHelper<uint64_t>(input, output, output_idx, output_cnt));
829   } else if (input->type() == DataType::DE_UINT32) {
830     RETURN_IF_NOT_OK(UniqueHelper<uint32_t>(input, output, output_idx, output_cnt));
831   } else if (input->type() == DataType::DE_UINT16) {
832     RETURN_IF_NOT_OK(UniqueHelper<uint16_t>(input, output, output_idx, output_cnt));
833   } else if (input->type() == DataType::DE_UINT8) {
834     RETURN_IF_NOT_OK(UniqueHelper<uint8_t>(input, output, output_idx, output_cnt));
835   } else if (input->type() == DataType::DE_FLOAT16) {
836     RETURN_IF_NOT_OK(UniqueHelper<float16>(input, output, output_idx, output_cnt));
837   } else if (input->type() == DataType::DE_FLOAT32) {
838     RETURN_IF_NOT_OK(UniqueHelper<float>(input, output, output_idx, output_cnt));
839   } else if (input->type() == DataType::DE_FLOAT64) {
840     RETURN_IF_NOT_OK(UniqueHelper<double>(input, output, output_idx, output_cnt));
841   } else {
842     RETURN_STATUS_UNEXPECTED("Unique: Unique op only supports numeric input.");
843   }
844   return Status::OK();
845 }
846 
847 }  // namespace dataset
848 }  // namespace mindspore
849