1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <map>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 #include "minddata/dataset/engine/consumers/tree_consumer.h"
24 #include "minddata/dataset/engine/datasetops/device_queue_op.h"
25 #include "minddata/dataset/engine/opt/pre/getter_pass.h"
26 #include "minddata/dataset/engine/tree_adapter.h"
27
28 #ifndef ENABLE_ANDROID
29 #include "minddata/mindrecord/include/shard_index_generator.h"
30 #include "minddata/mindrecord/include/shard_header.h"
31 #include "minddata/mindrecord/include/shard_writer.h"
32 #endif
33
34 namespace mindspore {
35 namespace dataset {
36 // TreeConsumer
TreeConsumer()37 TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
38
Init(std::shared_ptr<DatasetNode> d)39 Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); }
Terminate()40 Status TreeConsumer::Terminate() {
41 CHECK_FAIL_RETURN_UNEXPECTED(tree_adapter_->AllTasks() != nullptr, " Execution tree has not been built");
42 return tree_adapter_->AllTasks()->ServiceStop();
43 }
44
45 // IteratorConsumer
Init(std::shared_ptr<DatasetNode> d)46 Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
47 return tree_adapter_->Compile(std::move(d), num_epochs_);
48 }
49
GetNextAsVector(std::vector<TensorPtr> * out)50 Status IteratorConsumer::GetNextAsVector(std::vector<TensorPtr> *out) {
51 RETURN_UNEXPECTED_IF_NULL(out);
52 out->clear();
53
54 TensorRow res;
55 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
56
57 // Return empty vector if there's no data
58 RETURN_OK_IF_TRUE(res.empty());
59
60 // Filter meta column
61 std::vector<size_t> to_keep_indices;
62 for (const auto &colMap : tree_adapter_->GetColumnNameMap()) {
63 std::string column_name = colMap.first;
64 // Need to filter meta column start with kDftMetaColumnPrefix
65 size_t pos = column_name.find(kDftMetaColumnPrefix);
66 if (pos != std::string::npos && pos == 0) {
67 continue;
68 }
69 to_keep_indices.push_back(colMap.second);
70 }
71 if (to_keep_indices.size() == 0) {
72 std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
73 err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
74 err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
75 RETURN_STATUS_UNEXPECTED(err_msg);
76 }
77 std::sort(to_keep_indices.begin(), to_keep_indices.end());
78 (void)std::transform(to_keep_indices.begin(), to_keep_indices.end(), std::back_inserter(*out),
79 [&res](const auto &it) { return std::move(res[it]); });
80
81 return Status::OK();
82 }
83
GetNextAsMap(std::unordered_map<std::string,TensorPtr> * const out_map)84 Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> *const out_map) {
85 RETURN_UNEXPECTED_IF_NULL(out_map);
86 out_map->clear();
87
88 TensorRow res;
89 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
90
91 // Return empty map if there's no data
92 RETURN_OK_IF_TRUE(res.empty());
93
94 // Populate the out map from the row and return it
95 for (const auto &colMap : tree_adapter_->GetColumnNameMap()) {
96 std::string column_name = colMap.first;
97 // Need to filter meta column start with kDftMetaColumnPrefix
98 size_t pos = column_name.find(kDftMetaColumnPrefix);
99 if (pos != std::string::npos && pos == 0) {
100 continue;
101 }
102 (*out_map)[colMap.first] = std::move(res[colMap.second]);
103 }
104 if (out_map->size() == 0) {
105 std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
106 err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
107 err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
108 RETURN_STATUS_UNEXPECTED(err_msg);
109 }
110 return Status::OK();
111 }
112
GetNextAsOrderedPair(std::vector<std::pair<std::string,std::shared_ptr<Tensor>>> * const vec)113 Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *const vec) {
114 CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
115
116 TensorRow curr_row;
117
118 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row));
119 RETURN_OK_IF_TRUE(curr_row.empty());
120
121 size_t num_cols = curr_row.size(); // num_cols is non-empty.
122 // order the column names according to their ids
123 if (column_order_.empty()) {
124 for (const auto &itr : tree_adapter_->GetColumnNameMap()) {
125 int32_t ind = itr.second;
126 CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds.");
127 // Need to filter meta column start with kDftMetaColumnPrefix
128 size_t pos = itr.first.find(kDftMetaColumnPrefix);
129 if (pos != std::string::npos && pos == 0) {
130 continue;
131 }
132 column_order_[ind] = itr.first;
133 }
134 }
135
136 if (column_order_.size() == 0) {
137 std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
138 err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
139 err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
140 RETURN_STATUS_UNEXPECTED(err_msg);
141 }
142 vec->reserve(column_order_.size());
143
144 std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec),
145 [curr_row](const auto &col) { return std::make_pair(col.second, curr_row[col.first]); });
146
147 return Status::OK();
148 }
149
150 // ToDevice
Init(std::shared_ptr<DatasetNode> d)151 Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); }
152
Send()153 Status ToDevice::Send() {
154 RETURN_IF_NOT_OK(tree_adapter_->Launch());
155 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
156 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
157 return Status::OK();
158 }
159
Continue()160 Status ToDevice::Continue() {
161 // tree_.root() must be DeviceQueueOp
162 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
163 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
164 DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
165 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DeviceQueueOp");
166 op->ContinueSend();
167 return Status::OK();
168 }
169
Stop()170 Status ToDevice::Stop() {
171 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
172 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
173 DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
174 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
175 op->StopSend();
176
177 return Status::OK();
178 }
179
GetDataInfo(std::vector<DataType> * const types,std::vector<TensorShape> * const shapes)180 Status ToDevice::GetDataInfo(std::vector<DataType> *const types, std::vector<TensorShape> *const shapes) {
181 RETURN_UNEXPECTED_IF_NULL(types);
182 RETURN_UNEXPECTED_IF_NULL(shapes);
183 // tree_.root() must be DeviceQueueOp
184 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
185 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
186 DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
187 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DeviceQueueOp");
188 DATA_INFO data_info;
189 RETURN_IF_NOT_OK(op->GetDataInfo(&data_info));
190 for (auto el : data_info) {
191 types->push_back(el.first);
192 shapes->push_back(el.second);
193 }
194 return Status::OK();
195 }
196
Terminate()197 Status ToDevice::Terminate() {
198 #ifdef ENABLE_TDTQUE
199 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
200 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
201 DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
202 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
203 op->StopWaiting();
204 #endif
205 return TreeConsumer::Terminate();
206 }
207
208 #ifndef ENABLE_ANDROID
209 // SaveToDisk
ValidateParams()210 Status SaveToDisk::ValidateParams() {
211 if (dataset_path_.empty()) {
212 std::string err = "SaveToDisk failed, dataset_path must not be empty";
213 MS_LOG(ERROR) << err;
214 RETURN_STATUS_SYNTAX_ERROR(err);
215 }
216 Path dir(dataset_path_);
217 if (dir.IsDirectory()) {
218 std::string err = "SaveToDisk failed, dataset_path must not be a directory";
219 MS_LOG(ERROR) << err;
220 RETURN_STATUS_SYNTAX_ERROR(err);
221 }
222 std::string real_path;
223 if (Path::RealPath(dir.ParentPath(), real_path).IsError()) {
224 std::string err_msg = "SaveToDisk failed, can not get real dataset path: " + dir.ParentPath();
225 MS_LOG(ERROR) << err_msg;
226 RETURN_STATUS_SYNTAX_ERROR(err_msg);
227 }
228 if (access(dir.ParentPath().c_str(), R_OK) == -1) {
229 std::string err_msg = "SaveToDisk failed, no access to specified dataset path: " + dataset_path_;
230 MS_LOG(ERROR) << err_msg;
231 RETURN_STATUS_SYNTAX_ERROR(err_msg);
232 }
233 if (num_files_ <= 0 || num_files_ > 1000) {
234 std::string err = "SaveToDisk failed, num_files must between 1 and 1000, but got " + std::to_string(num_files_);
235 MS_LOG(ERROR) << err;
236 RETURN_STATUS_SYNTAX_ERROR(err);
237 }
238 if (dataset_type_ != "mindrecord") {
239 std::string err = "SaveToDisk failed, only \"mindrecord\" dataset format is supported, but got " + dataset_type_;
240 MS_LOG(ERROR) << err;
241 RETURN_STATUS_SYNTAX_ERROR(err);
242 }
243 return Status::OK();
244 }
245
Save()246 Status SaveToDisk::Save() {
247 std::vector<std::string> file_names;
248 if (num_files_ == 1) {
249 file_names.push_back(dataset_path_);
250 } else {
251 for (int32_t i = 0; i < num_files_; i++) {
252 file_names.push_back(dataset_path_ + std::to_string(i));
253 }
254 }
255
256 auto mr_header = std::make_shared<mindrecord::ShardHeader>();
257 auto mr_writer = std::make_unique<mindrecord::ShardWriter>();
258 std::vector<std::string> blob_fields;
259 RETURN_IF_NOT_OK(mindrecord::ShardWriter::Initialize(&mr_writer, file_names));
260
261 std::unordered_map<std::string, int32_t> column_name_id_map;
262 for (auto el : tree_adapter_->GetColumnNameMap()) {
263 std::string column_name = el.first;
264 (void)std::transform(column_name.begin(), column_name.end(), column_name.begin(),
265 [](unsigned char c) { return ispunct(c) ? '_' : c; });
266 column_name_id_map[column_name] = el.second;
267 }
268
269 TensorRow row;
270 uint64_t mr_schema_id = 0;
271 bool first_loop = true; // build schema in first loop
272 auto PreTensorRowShapes = std::map<std::string, std::vector<int>>();
273
274 do {
275 nlohmann::json row_raw_data;
276 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
277 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
278 if (row.empty()) {
279 break;
280 }
281 RETURN_IF_NOT_OK(CheckTensorRowShapes(column_name_id_map, row, &PreTensorRowShapes));
282 if (first_loop) {
283 nlohmann::json mr_json;
284 std::vector<std::string> index_fields;
285 RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields));
286 MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump();
287 RETURN_IF_NOT_OK(
288 mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id));
289 RETURN_IF_NOT_OK(mr_writer->SetShardHeader(mr_header));
290 first_loop = false;
291 }
292 // construct data
293 if (!row.empty()) { // write data
294 RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data));
295 std::shared_ptr<std::vector<uint8_t>> output_bin_data;
296 RETURN_IF_NOT_OK(mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data));
297 std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data;
298 raw_data.insert(
299 std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data}));
300 std::vector<std::vector<uint8_t>> bin_data;
301 if (output_bin_data != nullptr) {
302 bin_data.emplace_back(*output_bin_data);
303 }
304 RETURN_IF_NOT_OK(mr_writer->WriteRawData(raw_data, bin_data));
305 }
306 } while (!row.empty());
307
308 RETURN_IF_NOT_OK(mr_writer->Commit());
309 RETURN_IF_NOT_OK(mindrecord::ShardIndexGenerator::Finalize(file_names));
310 return Status::OK();
311 }
312
313 template <typename T>
map_compare(T const & lhs,T const & rhs)314 bool SaveToDisk::map_compare(T const &lhs, T const &rhs) {
315 return lhs.size() == rhs.size() && std::equal(lhs.begin(), lhs.end(), rhs.begin());
316 }
317
CheckTensorRowShapes(const std::unordered_map<std::string,int32_t> & column_name_id_map,const TensorRow & row,std::map<std::string,std::vector<int>> * PreTensorRowShapes_ptr)318 Status SaveToDisk::CheckTensorRowShapes(const std::unordered_map<std::string, int32_t> &column_name_id_map,
319 const TensorRow &row,
320 std::map<std::string, std::vector<int>> *PreTensorRowShapes_ptr) {
321 std::map<std::string, std::vector<int>> CurrTensorRowShapes;
322 for (auto &col : column_name_id_map) {
323 auto idx = col.second;
324 auto column_name = col.first;
325 auto &tensor = row[idx];
326 auto column_type = tensor->type();
327 auto column_shape = tensor->shape();
328
329 auto shapes = column_shape.AsVector();
330 std::vector<int> mr_shape(shapes.begin(), shapes.end());
331
332 if (mr_shape.empty() || mr_shape.size() == 1) continue; // ignore scalar and one dimension tensor
333 std::string mr_type;
334 std::string el = column_type.ToString();
335 if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
336 std::string err_msg("Invalid type, unsupported data type: " + el);
337 RETURN_STATUS_UNEXPECTED(err_msg);
338 } else {
339 mr_type = mindrecord::kTypesMap.at(el);
340 }
341 if (mr_type == "bytes" || mr_type == "string") continue;
342 mr_shape.erase(mr_shape.begin()); // ignore the first dimension
343 CurrTensorRowShapes[column_name] = mr_shape;
344 }
345 if (PreTensorRowShapes_ptr->empty()) {
346 *PreTensorRowShapes_ptr = CurrTensorRowShapes;
347 return Status::OK();
348 }
349 auto res = map_compare(*PreTensorRowShapes_ptr, CurrTensorRowShapes);
350 CHECK_FAIL_RETURN_UNEXPECTED(res,
351 "Error: besides dimension 0, other dimension shape is different from the previous's.");
352 return Status::OK();
353 }
354
FetchMetaFromTensorRow(const std::unordered_map<std::string,int32_t> & column_name_id_map,const TensorRow & row,nlohmann::json * schema,std::vector<std::string> * index_fields)355 Status SaveToDisk::FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
356 const TensorRow &row, nlohmann::json *schema,
357 std::vector<std::string> *index_fields) {
358 if (schema == nullptr) {
359 RETURN_STATUS_UNEXPECTED("schema can not be nullptr.");
360 }
361 if (index_fields == nullptr) {
362 RETURN_STATUS_UNEXPECTED("index_fields can not be nullptr.");
363 }
364 if (column_name_id_map.empty()) {
365 RETURN_STATUS_UNEXPECTED("column_name_id_map can not be nullptr..");
366 }
367 nlohmann::json dataset_schema;
368 for (auto &col : column_name_id_map) {
369 auto idx = col.second;
370 auto column_name = col.first;
371 auto &tensor = row[idx];
372 auto column_type = tensor->type();
373 auto column_shape = tensor->shape();
374
375 std::string mr_type;
376 auto shapes = column_shape.AsVector();
377 std::vector<int> mr_shape(shapes.begin(), shapes.end());
378 std::string el = column_type.ToString();
379 dataset_schema[column_name] = el;
380 if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
381 std::string err_msg("Invalid type, unsupported data type: " + el);
382 RETURN_STATUS_UNEXPECTED(err_msg);
383 } else {
384 mr_type = mindrecord::kTypesMap.at(el);
385 }
386 if (mr_shape.empty()) {
387 if (mr_type == "bytes") { // map to int32 when bytes without shape.
388 mr_type = "int32";
389 }
390 (*schema)[column_name] = {{"type", mr_type}};
391 } else {
392 if (mr_type == "string") { // mindrecord can not support string with shape.
393 std::string err_msg("Invalid data, mindrecord can not support multi-dimensional string tensor.");
394 RETURN_STATUS_UNEXPECTED(err_msg);
395 }
396 if (mr_type == "bytes") { // ignore shape of bytes in minrecord
397 (*schema)[column_name] = {{"type", mr_type}};
398 } else {
399 mr_shape[0] = -1; // make first dimension -1
400 (*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}};
401 }
402 }
403 if (mr_type == "bytes" || !mr_shape.empty()) continue;
404 index_fields->emplace_back(column_name); // candidate of index fields
405 }
406 MS_LOG(DEBUG) << "Schema of dataset: " << dataset_schema.dump();
407 return Status::OK();
408 }
409
ValidateInputParams(nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data,const std::unordered_map<std::string,int32_t> & column_name_id_map)410 inline Status ValidateInputParams(nlohmann::json *row_raw_data,
411 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data,
412 const std::unordered_map<std::string, int32_t> &column_name_id_map) {
413 if (row_raw_data == nullptr) {
414 RETURN_STATUS_UNEXPECTED("row_raw_data can not be nullptr.");
415 }
416 if (row_bin_data == nullptr) {
417 RETURN_STATUS_UNEXPECTED("row_bin_data can not be nullptr.");
418 }
419 if (column_name_id_map.empty()) {
420 RETURN_STATUS_UNEXPECTED("column_name_id_map can not be nullptr.");
421 }
422 return Status::OK();
423 }
424
FetchFloatData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::unique_ptr<std::vector<uint8_t>> * data_ptr)425 Status SaveToDisk::FetchFloatData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
426 std::unique_ptr<std::vector<uint8_t>> *data_ptr) {
427 RETURN_UNEXPECTED_IF_NULL(row_raw_data);
428 RETURN_UNEXPECTED_IF_NULL(data_ptr);
429 auto column_type = tensor->type();
430 Status s;
431 if (column_type == DataType::DE_FLOAT32) {
432 std::unique_ptr<float> data, dummy;
433 s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy);
434 RETURN_IF_NOT_OK(s);
435 if (data != nullptr) {
436 (*row_raw_data)[column_name] = std::move(*data);
437 }
438 } else if (column_type == DataType::DE_FLOAT64) {
439 std::unique_ptr<double> data, dummy;
440 s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy);
441 RETURN_IF_NOT_OK(s);
442 if (data != nullptr) {
443 (*row_raw_data)[column_name] = std::move(*data);
444 }
445 }
446 return Status::OK();
447 }
448
FetchItemData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data)449 Status SaveToDisk::FetchItemData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
450 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
451 RETURN_UNEXPECTED_IF_NULL(tensor);
452 RETURN_UNEXPECTED_IF_NULL(row_raw_data);
453 RETURN_UNEXPECTED_IF_NULL(row_bin_data);
454 auto column_type = tensor->type();
455 Status s;
456 std::unique_ptr<std::vector<uint8_t>> data_ptr;
457 if (column_type == DataType::DE_INT8) {
458 std::unique_ptr<int32_t> data;
459 std::unique_ptr<int8_t> dummy;
460 s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
461 RETURN_IF_NOT_OK(s);
462 if (data != nullptr) {
463 (*row_raw_data)[column_name] = std::move(*data);
464 }
465 } else if (column_type == DataType::DE_INT16) {
466 std::unique_ptr<int32_t> data;
467 std::unique_ptr<int16_t> dummy;
468 s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
469 RETURN_IF_NOT_OK(s);
470 if (data != nullptr) {
471 (*row_raw_data)[column_name] = std::move(*data);
472 }
473 } else if (column_type == DataType::DE_UINT16) {
474 std::unique_ptr<int32_t> data;
475 std::unique_ptr<uint16_t> dummy;
476 s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
477 RETURN_IF_NOT_OK(s);
478 if (data != nullptr) {
479 (*row_raw_data)[column_name] = std::move(*data);
480 }
481 } else if (column_type == DataType::DE_UINT8) {
482 std::unique_ptr<uint8_t> data, dummy;
483 s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
484 RETURN_IF_NOT_OK(s);
485 if (data != nullptr) {
486 (*row_raw_data)[column_name] = std::move(*data);
487 }
488 } else if (column_type == DataType::DE_INT32) {
489 std::unique_ptr<int32_t> data, dummy;
490 s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
491 RETURN_IF_NOT_OK(s);
492 if (data != nullptr) {
493 (*row_raw_data)[column_name] = std::move(*data);
494 }
495 } else if (column_type == DataType::DE_UINT32) {
496 std::unique_ptr<int64_t> data;
497 std::unique_ptr<uint32_t> dummy;
498 s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
499 RETURN_IF_NOT_OK(s);
500 if (data != nullptr) {
501 (*row_raw_data)[column_name] = std::move(*data);
502 }
503 } else if (column_type == DataType::DE_INT64) {
504 std::unique_ptr<int64_t> data, dummy;
505 s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
506 RETURN_IF_NOT_OK(s);
507 if (data != nullptr) {
508 (*row_raw_data)[column_name] = std::move(*data);
509 }
510 } else if (column_type == DataType::DE_FLOAT32 || column_type == DataType::DE_FLOAT64) {
511 s = FetchFloatData(tensor, column_name, row_raw_data, &data_ptr);
512 RETURN_IF_NOT_OK(s);
513 } else if (column_type == DataType::DE_STRING) {
514 std::string_view sv;
515 RETURN_IF_NOT_OK(tensor->GetItemAt(&sv, {})); // assume scalar string tensor
516 std::string ss(sv);
517 (*row_raw_data)[column_name] = std::move(ss);
518 } else {
519 RETURN_STATUS_UNEXPECTED("Invalid dtype, got unexpected type when casting data: " + column_type.ToString());
520 }
521 if (data_ptr != nullptr) {
522 (*row_bin_data)[column_name] = std::move(data_ptr);
523 }
524 return Status::OK();
525 }
526
FetchDataFromTensorRow(const TensorRow & row,const std::unordered_map<std::string,int32_t> & column_name_id_map,nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data)527 Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
528 const std::unordered_map<std::string, int32_t> &column_name_id_map,
529 nlohmann::json *row_raw_data,
530 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
531 RETURN_UNEXPECTED_IF_NULL(row_raw_data);
532 RETURN_UNEXPECTED_IF_NULL(row_bin_data);
533 Status s;
534 s = ValidateInputParams(row_raw_data, row_bin_data, column_name_id_map);
535 if (s.IsError()) {
536 return s;
537 }
538 for (auto &col : column_name_id_map) {
539 auto idx = col.second;
540 auto column_name = col.first;
541 auto &tensor = row[idx];
542 s = FetchItemData(tensor, column_name, row_raw_data, row_bin_data);
543 RETURN_IF_NOT_OK(s);
544 }
545 return Status::OK();
546 }
547
548 template <typename T, typename S>
TransformTensor(const unsigned char * src,const TensorShape & shape,const int64_t num_of_elements,std::unique_ptr<T> * data,std::unique_ptr<std::vector<uint8_t>> * data_ptr,std::unique_ptr<S> * s,bool need_convert)549 Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
550 std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
551 std::unique_ptr<S> *s, bool need_convert) {
552 // No need to check src since we support some scenarios that src is nullptr and num_of_elements is 0.
553 RETURN_UNEXPECTED_IF_NULL(data);
554 RETURN_UNEXPECTED_IF_NULL(data_ptr);
555 RETURN_UNEXPECTED_IF_NULL(s);
556
557 *data_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(T));
558 if (need_convert) {
559 auto tmp_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(S));
560 std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin());
561 auto s_ptr = reinterpret_cast<S *>(&(*(tmp_ptr->begin())));
562 auto el = std::make_unique<T>();
563 for (uint32_t i = 0; i < num_of_elements; ++i) {
564 *el = *(s_ptr + i);
565 auto t_ptr = reinterpret_cast<uint8_t *>(el.get());
566 for (uint32_t j = 0; j < sizeof(T); ++j) {
567 *((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j);
568 }
569 }
570 } else {
571 std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin());
572 }
573 if (shape.empty()) {
574 *data = std::make_unique<T>();
575 auto t_ptr = reinterpret_cast<uint8_t *>((*data).get());
576 for (uint32_t i = 0; i < sizeof(T); ++i) {
577 *(t_ptr + i) = *((*data_ptr)->begin() + i);
578 }
579 }
580 return Status::OK();
581 }
582 #endif
583
TreeGetters()584 TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), first_row_obtained_(false) {
585 tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
586 }
587
Init(std::shared_ptr<DatasetNode> d)588 Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
589 RETURN_UNEXPECTED_IF_NULL(d);
590 root_ = std::move(d);
591 return Status::OK();
592 }
593
GetRow(TensorRow * row)594 Status TreeGetters::GetRow(TensorRow *row) {
595 RETURN_UNEXPECTED_IF_NULL(row);
596 return tree_adapter_->GetNext(row);
597 }
598
GetOutputTypes(std::vector<DataType> * types)599 Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
600 RETURN_UNEXPECTED_IF_NULL(types);
601 RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
602 *types = first_row_type_;
603 return Status::OK();
604 }
605
GetOutputShapes(std::vector<TensorShape> * shapes)606 Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
607 RETURN_UNEXPECTED_IF_NULL(shapes);
608 RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
609 *shapes = first_row_shape_;
610 return Status::OK();
611 }
612
GetBatchSize(int64_t * batch_size)613 Status TreeGetters::GetBatchSize(int64_t *batch_size) {
614 RETURN_UNEXPECTED_IF_NULL(batch_size);
615 RETURN_IF_NOT_OK(InternalInit());
616 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
617 RETURN_UNEXPECTED_IF_NULL(root);
618 *batch_size = root->GetTreeBatchSize();
619 CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "GetBatchSize: Failed to find the batch size in Dataset pipeline.");
620 return Status::OK();
621 }
622
GetRepeatCount(int64_t * repeat_count)623 Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
624 RETURN_UNEXPECTED_IF_NULL(repeat_count);
625 RETURN_IF_NOT_OK(InternalInit());
626 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
627 RETURN_UNEXPECTED_IF_NULL(root);
628 *repeat_count = root->GetTreeRepeatCount();
629 return Status::OK();
630 }
631
GetNumClasses(int64_t * num_classes)632 Status TreeGetters::GetNumClasses(int64_t *num_classes) {
633 RETURN_UNEXPECTED_IF_NULL(num_classes);
634 RETURN_IF_NOT_OK(InternalInit());
635 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
636 RETURN_UNEXPECTED_IF_NULL(root);
637 RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
638 return Status::OK();
639 }
640
GetColumnNames(std::vector<std::string> * output)641 Status TreeGetters::GetColumnNames(std::vector<std::string> *output) {
642 RETURN_UNEXPECTED_IF_NULL(output);
643 RETURN_IF_NOT_OK(InternalInit());
644 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
645 RETURN_UNEXPECTED_IF_NULL(root);
646 std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map();
647 CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map.empty(), "GetColumnNames: column_name_id map can not be empty.");
648 std::vector<std::pair<std::string, int32_t>> col_name_id_vec(column_name_id_map.begin(), column_name_id_map.end());
649 std::sort(col_name_id_vec.begin(), col_name_id_vec.end(),
650 [](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
651 return a.second < b.second;
652 });
653 std::transform(col_name_id_vec.begin(), col_name_id_vec.end(), std::back_inserter(*output),
654 [](const std::pair<std::string, int32_t> &p) { return p.first; });
655 return Status::OK();
656 }
657
GetClassIndexing(std::vector<std::pair<std::string,std::vector<int32_t>>> * output_class_indexing)658 Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
659 RETURN_UNEXPECTED_IF_NULL(output_class_indexing);
660 RETURN_IF_NOT_OK(InternalInit());
661 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
662 RETURN_UNEXPECTED_IF_NULL(root);
663 RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing));
664 return Status::OK();
665 }
666
InternalInit()667 Status TreeGetters::InternalInit() {
668 if (init_flag_) {
669 return Status::OK();
670 }
671
672 Status s = tree_adapter_->Compile(std::move(root_), 1);
673 if (s.IsOk()) {
674 init_flag_ = true;
675 }
676 return s;
677 }
678
GetFirstRowShapeAndType()679 Status TreeGetters::GetFirstRowShapeAndType() {
680 RETURN_OK_IF_TRUE(first_row_obtained_);
681 RETURN_IF_NOT_OK(InternalInit());
682 TensorRow first_row;
683 RETURN_IF_NOT_OK(GetRow(&first_row));
684 std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_),
685 [](const TensorPtr &t) { return t->type(); });
686 std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_shape_),
687 [](const TensorPtr &t) { return t->shape(); });
688 first_row_obtained_ = true;
689 return Status::OK();
690 }
691
Init(std::shared_ptr<DatasetNode> d)692 Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }
693
Start()694 Status BuildVocabConsumer::Start() {
695 // Getting one row would trigger building the vocab
696 TensorRow row;
697 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
698 // The returned row would EOE which is an empty row
699 CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "BuildVocab: The fetched row from BuildVocab should be an EOE.");
700 return Status::OK();
701 }
GetDatasetSize(int64_t * size,bool estimate)702 Status DatasetSizeGetter::GetDatasetSize(int64_t *size, bool estimate) {
703 if (dataset_size_ == -1) {
704 RETURN_IF_NOT_OK(root_->GetDatasetSize(shared_from_this(), estimate, size));
705 dataset_size_ = *size; // save the previous result
706 }
707
708 *size = dataset_size_;
709 return Status::OK();
710 }
Init(std::shared_ptr<DatasetNode> d)711 Status DatasetSizeGetter::Init(std::shared_ptr<DatasetNode> d) {
712 root_ = std::move(d);
713 return Status::OK();
714 }
DryRun(std::shared_ptr<DatasetNode> ir_node,int64_t * dataset_size)715 Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) {
716 RETURN_UNEXPECTED_IF_NULL(dataset_size);
717 std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
718 tree_adapters_.push_back(tree_adapter);
719 RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1));
720 TensorRow row;
721 RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
722 int64_t row_cnt = 0;
723 while (!row.empty()) {
724 ++row_cnt;
725 RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
726 }
727 *dataset_size = row_cnt;
728 return Status::OK();
729 }
GetRow(const std::shared_ptr<TreeAdapter> & tree_adapter,TensorRow * row)730 Status DatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row) {
731 RETURN_UNEXPECTED_IF_NULL(row);
732 return tree_adapter->GetNext(row);
733 }
Terminate()734 Status DatasetSizeGetter::Terminate() {
735 for (const auto &tree : tree_adapters_) {
736 RETURN_UNEXPECTED_IF_NULL(tree);
737 RETURN_UNEXPECTED_IF_NULL(tree->AllTasks());
738 RETURN_IF_NOT_OK(tree->AllTasks()->ServiceStop());
739 }
740 return Status::OK();
741 }
742 } // namespace dataset
743 } // namespace mindspore
744