1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
18
19 #include <algorithm>
20
21 namespace mindspore::dataset {
Init(const std::shared_ptr<DatasetNode> & root)22 Status PullBasedIteratorConsumer::Init(const std::shared_ptr<DatasetNode> &root) {
23 return tree_adapter_lite_->Compile(root, num_epochs_);
24 }
25
GetRows(int64_t num_rows)26 std::vector<TensorRow> PullBasedIteratorConsumer::GetRows(int64_t num_rows) {
27 std::vector<TensorRow> rows;
28 for (int i = 0; i < num_rows; i++) {
29 TensorRow row;
30 RETURN_SECOND_IF_ERROR(tree_adapter_lite_->GetNextRow(&row), {});
31 if (row.empty()) {
32 break;
33 }
34 rows.push_back(row);
35 }
36
37 return rows;
38 }
39
GetNextAsVector(std::vector<TensorPtr> * const out)40 Status PullBasedIteratorConsumer::GetNextAsVector(std::vector<TensorPtr> *const out) {
41 RETURN_UNEXPECTED_IF_NULL(out);
42 out->clear();
43
44 TensorRow res;
45 RETURN_IF_NOT_OK(tree_adapter_lite_->GetNextRow(&res));
46
47 // Return empty vector if there's no data
48 RETURN_OK_IF_TRUE(res.empty());
49
50 (void)std::copy(res.begin(), res.end(), std::back_inserter(*out));
51 return Status::OK();
52 }
53
GetNextAsMap(std::unordered_map<std::string,TensorPtr> * const out_map)54 Status PullBasedIteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> *const out_map) {
55 RETURN_UNEXPECTED_IF_NULL(out_map);
56 out_map->clear();
57
58 TensorRow res;
59 RETURN_IF_NOT_OK(tree_adapter_lite_->GetNextRow(&res));
60
61 // Return empty map if there's no data
62 RETURN_OK_IF_TRUE(res.empty());
63
64 // Populate the out map from the row and return it
65 for (const auto &colMap : tree_adapter_lite_->GetColumnNameMap()) {
66 (*out_map)[colMap.first] = std::move(res[colMap.second]);
67 }
68 return Status::OK();
69 }
70
GetNextAsOrderedPair(std::vector<std::pair<std::string,std::shared_ptr<Tensor>>> * const vec)71 Status PullBasedIteratorConsumer::GetNextAsOrderedPair(
72 std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *const vec) {
73 CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
74
75 TensorRow curr_row;
76
77 RETURN_IF_NOT_OK(tree_adapter_lite_->GetNextRow(&curr_row));
78 RETURN_OK_IF_TRUE(curr_row.empty());
79 size_t num_cols = curr_row.size(); // num_cols is non-empty.
80 // order the column names according to their ids
81 if (column_order_.empty()) {
82 const int32_t invalid_col_id = -1;
83 column_order_.resize(num_cols, {std::string(), invalid_col_id});
84 for (const auto &itr : tree_adapter_lite_->GetColumnNameMap()) {
85 int32_t ind = itr.second;
86 CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds. Expecting in the range [0," +
87 std::to_string(num_cols) + "), but got " +
88 std::to_string(ind));
89 column_order_[ind] = std::make_pair(itr.first, ind);
90 }
91 // error check, make sure the ids in col_name_id_map are continuous and starts from 0
92 for (const auto &col : column_order_) {
93 if (col.second == invalid_col_id) {
94 std::string err_msg = "Invalid column id encountered.";
95 err_msg += " Note: It is unsupported and ambiguous to reuse the same column name for an output_column name";
96 err_msg += " if it is an input_column name that will already appear as one of the output columns.";
97 err_msg += " Use unique columns names.";
98 MS_LOG(ERROR) << err_msg;
99 RETURN_STATUS_UNEXPECTED(err_msg);
100 }
101 }
102 }
103 vec->reserve(num_cols);
104
105 std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec),
106 [curr_row](const auto &col) { return std::make_pair(col.first, curr_row[col.second]); });
107
108 return Status::OK();
109 }
110
TreeGetters()111 TreeGetters::TreeGetters()
112 : root_(nullptr),
113 first_row_type_({}),
114 first_row_shape_({}),
115 estimated_row_shape_({}),
116 init_flag_(false),
117 first_row_obtained_(false) {
118 tree_adapter_lite_ = std::make_unique<TreeAdapterLite>(TreeAdapterLite::UsageFlag::kDeGetter);
119 }
120
Init(const std::shared_ptr<DatasetNode> & root)121 Status TreeGetters::Init(const std::shared_ptr<DatasetNode> &root) {
122 root_ = root;
123 return Status::OK();
124 }
125
GetRow(TensorRow * row)126 Status TreeGetters::GetRow(TensorRow *row) {
127 RETURN_UNEXPECTED_IF_NULL(row);
128 Status get_next_status = tree_adapter_lite_->GetNextRow(row);
129 return get_next_status;
130 }
131
GetOutputTypes(std::vector<DataType> * types)132 Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
133 RETURN_UNEXPECTED_IF_NULL(types);
134 RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
135 *types = first_row_type_;
136 return Status::OK();
137 }
138
GetOutputShapes(std::vector<TensorShape> * shapes,bool estimate)139 Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes, bool estimate) {
140 RETURN_UNEXPECTED_IF_NULL(shapes);
141 RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
142 *shapes = first_row_shape_;
143
144 if (estimate) {
145 estimated_row_shape_ = first_row_shape_;
146 TensorRow row;
147 RETURN_IF_NOT_OK(GetRow(&row));
148
149 while (!row.empty()) {
150 std::vector<TensorShape> cur_row_shape;
151 (void)std::transform(row.begin(), row.end(), std::back_inserter(cur_row_shape),
152 [=](auto &t) { return t->shape(); });
153
154 // calculate dynamic shape
155 CHECK_FAIL_RETURN_SYNTAX_ERROR(cur_row_shape.size() == estimated_row_shape_.size(),
156 "Inconsistent shapes, expect same shape for each data row, last data row: " +
157 std::to_string(cur_row_shape.size()) +
158 ", current data row: " + std::to_string(estimated_row_shape_.size()));
159 size_t shape_size = cur_row_shape.size();
160 std::vector<TensorShape> dynamic_shapes;
161 for (size_t i = 0; i < shape_size; i++) {
162 CHECK_FAIL_RETURN_SYNTAX_ERROR(
163 cur_row_shape[i].Size() == estimated_row_shape_[i].Size(),
164 "Inconsistent shapes, expect same shape for each data row, last data row: " + cur_row_shape[i].ToString() +
165 ", current data row: " + estimated_row_shape_[i].ToString());
166
167 std::vector<dsize_t> vec;
168 for (size_t j = 0; j < estimated_row_shape_[i].Size(); j++) {
169 dsize_t dim = cur_row_shape[i][j] == estimated_row_shape_[i][j] ? cur_row_shape[i][j] : -1;
170 vec.push_back(dim);
171 }
172 dynamic_shapes.emplace_back(TensorShape(vec));
173 }
174 estimated_row_shape_ = dynamic_shapes;
175 RETURN_IF_NOT_OK(GetRow(&row));
176 }
177
178 *shapes = estimated_row_shape_;
179 }
180 return Status::OK();
181 }
182
GetBatchSize(int64_t * batch_size)183 Status TreeGetters::GetBatchSize(int64_t *batch_size) {
184 RETURN_UNEXPECTED_IF_NULL(batch_size);
185 RETURN_IF_NOT_OK(InternalInit());
186 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_lite_->GetRoot());
187 RETURN_UNEXPECTED_IF_NULL(root);
188 *batch_size = root->GetTreeBatchSize();
189 CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != 0, "GetBatchSize: Failed to find the batch size in Dataset pipeline.");
190 return Status::OK();
191 }
192
GetRepeatCount(int64_t * repeat_count)193 Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
194 RETURN_UNEXPECTED_IF_NULL(repeat_count);
195 RETURN_IF_NOT_OK(InternalInit());
196 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_lite_->GetRoot());
197 RETURN_UNEXPECTED_IF_NULL(root);
198 *repeat_count = root->GetTreeRepeatCount();
199 return Status::OK();
200 }
201
GetNumClasses(int64_t * num_classes)202 Status TreeGetters::GetNumClasses(int64_t *num_classes) {
203 RETURN_UNEXPECTED_IF_NULL(num_classes);
204 RETURN_IF_NOT_OK(InternalInit());
205 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_lite_->GetRoot());
206 RETURN_UNEXPECTED_IF_NULL(root);
207 RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
208 return Status::OK();
209 }
210
GetColumnNames(std::vector<std::string> * output)211 Status TreeGetters::GetColumnNames(std::vector<std::string> *output) {
212 RETURN_UNEXPECTED_IF_NULL(output);
213 RETURN_IF_NOT_OK(InternalInit());
214 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_lite_->GetRoot());
215 RETURN_UNEXPECTED_IF_NULL(root);
216 std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map();
217 CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map.empty(), "GetColumnNames: column_name_id map can not be empty.");
218 std::vector<std::pair<std::string, int32_t>> col_name_id_vec(column_name_id_map.begin(), column_name_id_map.end());
219 std::sort(col_name_id_vec.begin(), col_name_id_vec.end(),
220 [](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
221 return a.second < b.second;
222 });
223 std::transform(col_name_id_vec.begin(), col_name_id_vec.end(), std::back_inserter(*output),
224 [](const std::pair<std::string, int32_t> &p) { return p.first; });
225 return Status::OK();
226 }
227
GetClassIndexing(std::vector<std::pair<std::string,std::vector<int32_t>>> * output_class_indexing)228 Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
229 RETURN_UNEXPECTED_IF_NULL(output_class_indexing);
230 RETURN_IF_NOT_OK(InternalInit());
231 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_lite_->GetRoot());
232 RETURN_UNEXPECTED_IF_NULL(root);
233 RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing));
234 return Status::OK();
235 }
236
InternalInit()237 Status TreeGetters::InternalInit() {
238 if (init_flag_) {
239 return Status::OK();
240 }
241
242 Status s = tree_adapter_lite_->Compile(root_, 1);
243 if (s.IsOk()) {
244 init_flag_ = true;
245 }
246 return s;
247 }
248
GetFirstRowShapeAndType()249 Status TreeGetters::GetFirstRowShapeAndType() {
250 RETURN_OK_IF_TRUE(first_row_obtained_);
251 RETURN_IF_NOT_OK(InternalInit());
252 TensorRow first_row;
253 RETURN_IF_NOT_OK(GetRow(&first_row));
254 std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_),
255 [](const TensorPtr &t) { return t->type(); });
256 std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_shape_),
257 [](const TensorPtr &t) { return t->shape(); });
258 first_row_obtained_ = true;
259 return Status::OK();
260 }
261 } // namespace mindspore::dataset
262