1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/core/tensor_row.h"
18
19 #include <utility>
20
21 namespace mindspore {
22 namespace dataset {
23
TensorRow()24 TensorRow::TensorRow() noexcept : id_(kDefaultRowId), path_({}), tensor_row_flag_(kFlagNone) {}
25
TensorRow(size_type n,const TensorRow::value_type & t)26 TensorRow::TensorRow(size_type n, const TensorRow::value_type &t) noexcept
27 : id_(kDefaultRowId), path_({}), row_(n, t), tensor_row_flag_(kFlagNone) {}
28
TensorRow(const TensorRow::vector_type & v)29 TensorRow::TensorRow(const TensorRow::vector_type &v)
30 : id_(kDefaultRowId), path_({}), row_(v), tensor_row_flag_(kFlagNone) {}
31
TensorRow(row_id_type id,const std::initializer_list<value_type> & lst)32 TensorRow::TensorRow(row_id_type id, const std::initializer_list<value_type> &lst)
33 : id_(id), path_({}), row_(lst), tensor_row_flag_(kFlagNone) {}
34
TensorRow(const TensorRow & tr)35 TensorRow::TensorRow(const TensorRow &tr)
36 : id_(tr.id_), path_(tr.path_), row_(tr.row_), tensor_row_flag_(tr.tensor_row_flag_) {}
37
TensorRow(TensorRow::TensorRowFlags flag)38 TensorRow::TensorRow(TensorRow::TensorRowFlags flag) : id_(kDefaultRowId), path_({}), tensor_row_flag_(flag) {}
39
operator =(const TensorRow & tr)40 TensorRow &TensorRow::operator=(const TensorRow &tr) {
41 if (this == &tr) {
42 return *this;
43 }
44 row_ = tr.row_;
45 id_ = tr.id_;
46 path_ = tr.path_;
47 tensor_row_flag_ = tr.tensor_row_flag_;
48 return *this;
49 }
50
Clone(TensorRow * new_tr) const51 Status TensorRow::Clone(TensorRow *new_tr) const {
52 RETURN_UNEXPECTED_IF_NULL(new_tr);
53 new_tr->row_.clear();
54 for (const std::shared_ptr<Tensor> &s : row_) {
55 std::shared_ptr<Tensor> d;
56 RETURN_IF_NOT_OK(Tensor::CreateFromTensor(s, &d));
57 (void)new_tr->row_.emplace_back(std::move(d));
58 }
59 new_tr->id_ = id_;
60 new_tr->path_ = path_;
61 new_tr->tensor_row_flag_ = tensor_row_flag_;
62 return Status::OK();
63 }
64
operator =(const std::initializer_list<TensorRow::value_type> & lst)65 TensorRow &TensorRow::operator=(const std::initializer_list<TensorRow::value_type> &lst) {
66 row_ = lst;
67 tensor_row_flag_ = kFlagNone;
68 return *this;
69 }
70
TensorRow(TensorRow::vector_type && v)71 TensorRow::TensorRow(TensorRow::vector_type &&v) noexcept
72 : id_(kDefaultRowId), path_({}), row_(std::move(v)), tensor_row_flag_(kFlagNone) {}
73
TensorRow(row_id_type id,std::initializer_list<value_type> && lst)74 TensorRow::TensorRow(row_id_type id, std::initializer_list<value_type> &&lst) noexcept
75 : id_(id), path_({}), row_(std::move(lst)), tensor_row_flag_(kFlagNone) {}
76
TensorRow(TensorRow && tr)77 TensorRow::TensorRow(TensorRow &&tr) noexcept {
78 id_ = tr.id_;
79 path_ = std::move(tr.path_);
80 row_ = std::move(tr.row_);
81 tensor_row_flag_ = tr.tensor_row_flag_;
82 }
83
operator =(TensorRow && tr)84 TensorRow &TensorRow::operator=(TensorRow &&tr) noexcept {
85 if (this == &tr) {
86 return *this;
87 }
88 row_ = std::move(tr.row_);
89 id_ = tr.id_;
90 tr.id_ = kDefaultRowId;
91 path_ = std::move(tr.path_);
92 tensor_row_flag_ = tr.tensor_row_flag_;
93 return *this;
94 }
95
operator =(std::initializer_list<TensorRow::value_type> && lst)96 TensorRow &TensorRow::operator=(std::initializer_list<TensorRow::value_type> &&lst) noexcept {
97 row_ = std::move(lst);
98 tensor_row_flag_ = kFlagNone;
99 return *this;
100 }
101
ValidateTensorRow(const TensorRow & input,const DataType & data_type)102 Status TensorRow::ValidateTensorRow(const TensorRow &input, const DataType &data_type) {
103 if (data_type == DataType::DE_UNKNOWN) {
104 RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized.");
105 }
106 if (data_type.IsString()) {
107 RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type string and bytes are not supported.");
108 }
109 if (input.size() != 1) {
110 RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input TensorRow must have exactly one tensor.");
111 }
112 return Status::OK();
113 }
114
115 } // namespace dataset
116 } // namespace mindspore
117