• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <map>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 #include "minddata/dataset/engine/consumers/tree_consumer.h"
24 #include "minddata/dataset/engine/datasetops/device_queue_op.h"
25 #include "minddata/dataset/engine/opt/pre/getter_pass.h"
26 #include "minddata/dataset/engine/tree_adapter.h"
27 
28 #ifndef ENABLE_ANDROID
29 #include "minddata/mindrecord/include/shard_index_generator.h"
30 #include "minddata/mindrecord/include/shard_header.h"
31 #include "minddata/mindrecord/include/shard_writer.h"
32 #endif
33 
34 namespace mindspore {
35 namespace dataset {
36 // TreeConsumer
TreeConsumer()37 TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
38 
Init(std::shared_ptr<DatasetNode> d)39 Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); }
Terminate()40 Status TreeConsumer::Terminate() {
41   CHECK_FAIL_RETURN_UNEXPECTED(tree_adapter_->AllTasks() != nullptr, " Execution tree has not been built");
42   return tree_adapter_->AllTasks()->ServiceStop();
43 }
44 
45 // IteratorConsumer
Init(std::shared_ptr<DatasetNode> d)46 Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
47   return tree_adapter_->Compile(std::move(d), num_epochs_);
48 }
49 
GetNextAsVector(std::vector<TensorPtr> * out)50 Status IteratorConsumer::GetNextAsVector(std::vector<TensorPtr> *out) {
51   RETURN_UNEXPECTED_IF_NULL(out);
52   out->clear();
53 
54   TensorRow res;
55   RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
56 
57   // Return empty vector if there's no data
58   RETURN_OK_IF_TRUE(res.empty());
59 
60   // Filter meta column
61   std::vector<size_t> to_keep_indices;
62   for (const auto &colMap : tree_adapter_->GetColumnNameMap()) {
63     std::string column_name = colMap.first;
64     // Need to filter meta column start with kDftMetaColumnPrefix
65     size_t pos = column_name.find(kDftMetaColumnPrefix);
66     if (pos != std::string::npos && pos == 0) {
67       continue;
68     }
69     to_keep_indices.push_back(colMap.second);
70   }
71   if (to_keep_indices.size() == 0) {
72     std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
73     err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
74     err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
75     RETURN_STATUS_UNEXPECTED(err_msg);
76   }
77   std::sort(to_keep_indices.begin(), to_keep_indices.end());
78   (void)std::transform(to_keep_indices.begin(), to_keep_indices.end(), std::back_inserter(*out),
79                        [&res](const auto &it) { return std::move(res[it]); });
80 
81   return Status::OK();
82 }
83 
GetNextAsMap(std::unordered_map<std::string,TensorPtr> * const out_map)84 Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> *const out_map) {
85   RETURN_UNEXPECTED_IF_NULL(out_map);
86   out_map->clear();
87 
88   TensorRow res;
89   RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
90 
91   // Return empty map if there's no data
92   RETURN_OK_IF_TRUE(res.empty());
93 
94   // Populate the out map from the row and return it
95   for (const auto &colMap : tree_adapter_->GetColumnNameMap()) {
96     std::string column_name = colMap.first;
97     // Need to filter meta column start with kDftMetaColumnPrefix
98     size_t pos = column_name.find(kDftMetaColumnPrefix);
99     if (pos != std::string::npos && pos == 0) {
100       continue;
101     }
102     (*out_map)[colMap.first] = std::move(res[colMap.second]);
103   }
104   if (out_map->size() == 0) {
105     std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
106     err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
107     err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
108     RETURN_STATUS_UNEXPECTED(err_msg);
109   }
110   return Status::OK();
111 }
112 
GetNextAsOrderedPair(std::vector<std::pair<std::string,std::shared_ptr<Tensor>>> * const vec)113 Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *const vec) {
114   CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
115 
116   TensorRow curr_row;
117 
118   RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row));
119   RETURN_OK_IF_TRUE(curr_row.empty());
120 
121   size_t num_cols = curr_row.size();  // num_cols is non-empty.
122   // order the column names according to their ids
123   if (column_order_.empty()) {
124     for (const auto &itr : tree_adapter_->GetColumnNameMap()) {
125       int32_t ind = itr.second;
126       CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds.");
127       // Need to filter meta column start with kDftMetaColumnPrefix
128       size_t pos = itr.first.find(kDftMetaColumnPrefix);
129       if (pos != std::string::npos && pos == 0) {
130         continue;
131       }
132       column_order_[ind] = itr.first;
133     }
134   }
135 
136   if (column_order_.size() == 0) {
137     std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
138     err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
139     err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
140     RETURN_STATUS_UNEXPECTED(err_msg);
141   }
142   vec->reserve(column_order_.size());
143 
144   std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec),
145                  [curr_row](const auto &col) { return std::make_pair(col.second, curr_row[col.first]); });
146 
147   return Status::OK();
148 }
149 
150 // ToDevice
Init(std::shared_ptr<DatasetNode> d)151 Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); }
152 
Send()153 Status ToDevice::Send() {
154   RETURN_IF_NOT_OK(tree_adapter_->Launch());
155   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
156   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
157   return Status::OK();
158 }
159 
Continue()160 Status ToDevice::Continue() {
161   // tree_.root() must be DeviceQueueOp
162   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
163   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
164   DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
165   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DeviceQueueOp");
166   op->ContinueSend();
167   return Status::OK();
168 }
169 
Stop()170 Status ToDevice::Stop() {
171   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
172   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
173   DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
174   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
175   op->StopSend();
176 
177   return Status::OK();
178 }
179 
GetDataInfo(std::vector<DataType> * const types,std::vector<TensorShape> * const shapes)180 Status ToDevice::GetDataInfo(std::vector<DataType> *const types, std::vector<TensorShape> *const shapes) {
181   RETURN_UNEXPECTED_IF_NULL(types);
182   RETURN_UNEXPECTED_IF_NULL(shapes);
183   // tree_.root() must be DeviceQueueOp
184   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
185   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
186   DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
187   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DeviceQueueOp");
188   DATA_INFO data_info;
189   RETURN_IF_NOT_OK(op->GetDataInfo(&data_info));
190   for (auto el : data_info) {
191     types->push_back(el.first);
192     shapes->push_back(el.second);
193   }
194   return Status::OK();
195 }
196 
Terminate()197 Status ToDevice::Terminate() {
198 #ifdef ENABLE_TDTQUE
199   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
200   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
201   DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
202   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
203   op->StopWaiting();
204 #endif
205   return TreeConsumer::Terminate();
206 }
207 
208 #ifndef ENABLE_ANDROID
209 // SaveToDisk
ValidateParams()210 Status SaveToDisk::ValidateParams() {
211   if (dataset_path_.empty()) {
212     std::string err = "SaveToDisk failed, dataset_path must not be empty";
213     MS_LOG(ERROR) << err;
214     RETURN_STATUS_SYNTAX_ERROR(err);
215   }
216   Path dir(dataset_path_);
217   if (dir.IsDirectory()) {
218     std::string err = "SaveToDisk failed, dataset_path must not be a directory";
219     MS_LOG(ERROR) << err;
220     RETURN_STATUS_SYNTAX_ERROR(err);
221   }
222   std::string real_path;
223   if (Path::RealPath(dir.ParentPath(), real_path).IsError()) {
224     std::string err_msg = "SaveToDisk failed, can not get real dataset path: " + dir.ParentPath();
225     MS_LOG(ERROR) << err_msg;
226     RETURN_STATUS_SYNTAX_ERROR(err_msg);
227   }
228   if (access(dir.ParentPath().c_str(), R_OK) == -1) {
229     std::string err_msg = "SaveToDisk failed, no access to specified dataset path: " + dataset_path_;
230     MS_LOG(ERROR) << err_msg;
231     RETURN_STATUS_SYNTAX_ERROR(err_msg);
232   }
233   if (num_files_ <= 0 || num_files_ > 1000) {
234     std::string err = "SaveToDisk failed, num_files must between 1 and 1000, but got " + std::to_string(num_files_);
235     MS_LOG(ERROR) << err;
236     RETURN_STATUS_SYNTAX_ERROR(err);
237   }
238   if (dataset_type_ != "mindrecord") {
239     std::string err = "SaveToDisk failed, only \"mindrecord\" dataset format is supported, but got " + dataset_type_;
240     MS_LOG(ERROR) << err;
241     RETURN_STATUS_SYNTAX_ERROR(err);
242   }
243   return Status::OK();
244 }
245 
Save()246 Status SaveToDisk::Save() {
247   std::vector<std::string> file_names;
248   if (num_files_ == 1) {
249     file_names.push_back(dataset_path_);
250   } else {
251     for (int32_t i = 0; i < num_files_; i++) {
252       file_names.push_back(dataset_path_ + std::to_string(i));
253     }
254   }
255 
256   auto mr_header = std::make_shared<mindrecord::ShardHeader>();
257   auto mr_writer = std::make_unique<mindrecord::ShardWriter>();
258   std::vector<std::string> blob_fields;
259   RETURN_IF_NOT_OK(mindrecord::ShardWriter::Initialize(&mr_writer, file_names));
260 
261   std::unordered_map<std::string, int32_t> column_name_id_map;
262   for (auto el : tree_adapter_->GetColumnNameMap()) {
263     std::string column_name = el.first;
264     (void)std::transform(column_name.begin(), column_name.end(), column_name.begin(),
265                          [](unsigned char c) { return ispunct(c) ? '_' : c; });
266     column_name_id_map[column_name] = el.second;
267   }
268 
269   TensorRow row;
270   uint64_t mr_schema_id = 0;
271   bool first_loop = true;  // build schema in first loop
272   auto PreTensorRowShapes = std::map<std::string, std::vector<int>>();
273 
274   do {
275     nlohmann::json row_raw_data;
276     std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
277     RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
278     if (row.empty()) {
279       break;
280     }
281     RETURN_IF_NOT_OK(CheckTensorRowShapes(column_name_id_map, row, &PreTensorRowShapes));
282     if (first_loop) {
283       nlohmann::json mr_json;
284       std::vector<std::string> index_fields;
285       RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields));
286       MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump();
287       RETURN_IF_NOT_OK(
288         mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id));
289       RETURN_IF_NOT_OK(mr_writer->SetShardHeader(mr_header));
290       first_loop = false;
291     }
292     // construct data
293     if (!row.empty()) {  // write data
294       RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data));
295       std::shared_ptr<std::vector<uint8_t>> output_bin_data;
296       RETURN_IF_NOT_OK(mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data));
297       std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data;
298       raw_data.insert(
299         std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data}));
300       std::vector<std::vector<uint8_t>> bin_data;
301       if (output_bin_data != nullptr) {
302         bin_data.emplace_back(*output_bin_data);
303       }
304       RETURN_IF_NOT_OK(mr_writer->WriteRawData(raw_data, bin_data));
305     }
306   } while (!row.empty());
307 
308   RETURN_IF_NOT_OK(mr_writer->Commit());
309   RETURN_IF_NOT_OK(mindrecord::ShardIndexGenerator::Finalize(file_names));
310   return Status::OK();
311 }
312 
313 template <typename T>
map_compare(T const & lhs,T const & rhs)314 bool SaveToDisk::map_compare(T const &lhs, T const &rhs) {
315   return lhs.size() == rhs.size() && std::equal(lhs.begin(), lhs.end(), rhs.begin());
316 }
317 
CheckTensorRowShapes(const std::unordered_map<std::string,int32_t> & column_name_id_map,const TensorRow & row,std::map<std::string,std::vector<int>> * PreTensorRowShapes_ptr)318 Status SaveToDisk::CheckTensorRowShapes(const std::unordered_map<std::string, int32_t> &column_name_id_map,
319                                         const TensorRow &row,
320                                         std::map<std::string, std::vector<int>> *PreTensorRowShapes_ptr) {
321   std::map<std::string, std::vector<int>> CurrTensorRowShapes;
322   for (auto &col : column_name_id_map) {
323     auto idx = col.second;
324     auto column_name = col.first;
325     auto &tensor = row[idx];
326     auto column_type = tensor->type();
327     auto column_shape = tensor->shape();
328 
329     auto shapes = column_shape.AsVector();
330     std::vector<int> mr_shape(shapes.begin(), shapes.end());
331 
332     if (mr_shape.empty() || mr_shape.size() == 1) continue;  // ignore scalar and one dimension tensor
333     std::string mr_type;
334     std::string el = column_type.ToString();
335     if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
336       std::string err_msg("Invalid type, unsupported data type: " + el);
337       RETURN_STATUS_UNEXPECTED(err_msg);
338     } else {
339       mr_type = mindrecord::kTypesMap.at(el);
340     }
341     if (mr_type == "bytes" || mr_type == "string") continue;
342     mr_shape.erase(mr_shape.begin());  // ignore the first dimension
343     CurrTensorRowShapes[column_name] = mr_shape;
344   }
345   if (PreTensorRowShapes_ptr->empty()) {
346     *PreTensorRowShapes_ptr = CurrTensorRowShapes;
347     return Status::OK();
348   }
349   auto res = map_compare(*PreTensorRowShapes_ptr, CurrTensorRowShapes);
350   CHECK_FAIL_RETURN_UNEXPECTED(res,
351                                "Error: besides dimension 0, other dimension shape is different from the previous's.");
352   return Status::OK();
353 }
354 
FetchMetaFromTensorRow(const std::unordered_map<std::string,int32_t> & column_name_id_map,const TensorRow & row,nlohmann::json * schema,std::vector<std::string> * index_fields)355 Status SaveToDisk::FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
356                                           const TensorRow &row, nlohmann::json *schema,
357                                           std::vector<std::string> *index_fields) {
358   if (schema == nullptr) {
359     RETURN_STATUS_UNEXPECTED("schema can not be nullptr.");
360   }
361   if (index_fields == nullptr) {
362     RETURN_STATUS_UNEXPECTED("index_fields can not be nullptr.");
363   }
364   if (column_name_id_map.empty()) {
365     RETURN_STATUS_UNEXPECTED("column_name_id_map can not be nullptr..");
366   }
367   nlohmann::json dataset_schema;
368   for (auto &col : column_name_id_map) {
369     auto idx = col.second;
370     auto column_name = col.first;
371     auto &tensor = row[idx];
372     auto column_type = tensor->type();
373     auto column_shape = tensor->shape();
374 
375     std::string mr_type;
376     auto shapes = column_shape.AsVector();
377     std::vector<int> mr_shape(shapes.begin(), shapes.end());
378     std::string el = column_type.ToString();
379     dataset_schema[column_name] = el;
380     if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
381       std::string err_msg("Invalid type, unsupported data type: " + el);
382       RETURN_STATUS_UNEXPECTED(err_msg);
383     } else {
384       mr_type = mindrecord::kTypesMap.at(el);
385     }
386     if (mr_shape.empty()) {
387       if (mr_type == "bytes") {  // map to int32 when bytes without shape.
388         mr_type = "int32";
389       }
390       (*schema)[column_name] = {{"type", mr_type}};
391     } else {
392       if (mr_type == "string") {  // mindrecord can not support string with shape.
393         std::string err_msg("Invalid data, mindrecord can not support multi-dimensional string tensor.");
394         RETURN_STATUS_UNEXPECTED(err_msg);
395       }
396       if (mr_type == "bytes") {  // ignore shape of bytes in minrecord
397         (*schema)[column_name] = {{"type", mr_type}};
398       } else {
399         mr_shape[0] = -1;  // make first dimension -1
400         (*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}};
401       }
402     }
403     if (mr_type == "bytes" || !mr_shape.empty()) continue;
404     index_fields->emplace_back(column_name);  // candidate of index fields
405   }
406   MS_LOG(DEBUG) << "Schema of dataset: " << dataset_schema.dump();
407   return Status::OK();
408 }
409 
ValidateInputParams(nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data,const std::unordered_map<std::string,int32_t> & column_name_id_map)410 inline Status ValidateInputParams(nlohmann::json *row_raw_data,
411                                   std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data,
412                                   const std::unordered_map<std::string, int32_t> &column_name_id_map) {
413   if (row_raw_data == nullptr) {
414     RETURN_STATUS_UNEXPECTED("row_raw_data can not be nullptr.");
415   }
416   if (row_bin_data == nullptr) {
417     RETURN_STATUS_UNEXPECTED("row_bin_data can not be nullptr.");
418   }
419   if (column_name_id_map.empty()) {
420     RETURN_STATUS_UNEXPECTED("column_name_id_map can not be nullptr.");
421   }
422   return Status::OK();
423 }
424 
FetchFloatData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::unique_ptr<std::vector<uint8_t>> * data_ptr)425 Status SaveToDisk::FetchFloatData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
426                                   std::unique_ptr<std::vector<uint8_t>> *data_ptr) {
427   RETURN_UNEXPECTED_IF_NULL(row_raw_data);
428   RETURN_UNEXPECTED_IF_NULL(data_ptr);
429   auto column_type = tensor->type();
430   Status s;
431   if (column_type == DataType::DE_FLOAT32) {
432     std::unique_ptr<float> data, dummy;
433     s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy);
434     RETURN_IF_NOT_OK(s);
435     if (data != nullptr) {
436       (*row_raw_data)[column_name] = std::move(*data);
437     }
438   } else if (column_type == DataType::DE_FLOAT64) {
439     std::unique_ptr<double> data, dummy;
440     s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy);
441     RETURN_IF_NOT_OK(s);
442     if (data != nullptr) {
443       (*row_raw_data)[column_name] = std::move(*data);
444     }
445   }
446   return Status::OK();
447 }
448 
FetchItemData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data)449 Status SaveToDisk::FetchItemData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
450                                  std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
451   RETURN_UNEXPECTED_IF_NULL(tensor);
452   RETURN_UNEXPECTED_IF_NULL(row_raw_data);
453   RETURN_UNEXPECTED_IF_NULL(row_bin_data);
454   auto column_type = tensor->type();
455   Status s;
456   std::unique_ptr<std::vector<uint8_t>> data_ptr;
457   if (column_type == DataType::DE_INT8) {
458     std::unique_ptr<int32_t> data;
459     std::unique_ptr<int8_t> dummy;
460     s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
461     RETURN_IF_NOT_OK(s);
462     if (data != nullptr) {
463       (*row_raw_data)[column_name] = std::move(*data);
464     }
465   } else if (column_type == DataType::DE_INT16) {
466     std::unique_ptr<int32_t> data;
467     std::unique_ptr<int16_t> dummy;
468     s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
469     RETURN_IF_NOT_OK(s);
470     if (data != nullptr) {
471       (*row_raw_data)[column_name] = std::move(*data);
472     }
473   } else if (column_type == DataType::DE_UINT16) {
474     std::unique_ptr<int32_t> data;
475     std::unique_ptr<uint16_t> dummy;
476     s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
477     RETURN_IF_NOT_OK(s);
478     if (data != nullptr) {
479       (*row_raw_data)[column_name] = std::move(*data);
480     }
481   } else if (column_type == DataType::DE_UINT8) {
482     std::unique_ptr<uint8_t> data, dummy;
483     s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
484     RETURN_IF_NOT_OK(s);
485     if (data != nullptr) {
486       (*row_raw_data)[column_name] = std::move(*data);
487     }
488   } else if (column_type == DataType::DE_INT32) {
489     std::unique_ptr<int32_t> data, dummy;
490     s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
491     RETURN_IF_NOT_OK(s);
492     if (data != nullptr) {
493       (*row_raw_data)[column_name] = std::move(*data);
494     }
495   } else if (column_type == DataType::DE_UINT32) {
496     std::unique_ptr<int64_t> data;
497     std::unique_ptr<uint32_t> dummy;
498     s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
499     RETURN_IF_NOT_OK(s);
500     if (data != nullptr) {
501       (*row_raw_data)[column_name] = std::move(*data);
502     }
503   } else if (column_type == DataType::DE_INT64) {
504     std::unique_ptr<int64_t> data, dummy;
505     s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
506     RETURN_IF_NOT_OK(s);
507     if (data != nullptr) {
508       (*row_raw_data)[column_name] = std::move(*data);
509     }
510   } else if (column_type == DataType::DE_FLOAT32 || column_type == DataType::DE_FLOAT64) {
511     s = FetchFloatData(tensor, column_name, row_raw_data, &data_ptr);
512     RETURN_IF_NOT_OK(s);
513   } else if (column_type == DataType::DE_STRING) {
514     std::string_view sv;
515     RETURN_IF_NOT_OK(tensor->GetItemAt(&sv, {}));  // assume scalar string tensor
516     std::string ss(sv);
517     (*row_raw_data)[column_name] = std::move(ss);
518   } else {
519     RETURN_STATUS_UNEXPECTED("Invalid dtype, got unexpected type when casting data: " + column_type.ToString());
520   }
521   if (data_ptr != nullptr) {
522     (*row_bin_data)[column_name] = std::move(data_ptr);
523   }
524   return Status::OK();
525 }
526 
FetchDataFromTensorRow(const TensorRow & row,const std::unordered_map<std::string,int32_t> & column_name_id_map,nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data)527 Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
528                                           const std::unordered_map<std::string, int32_t> &column_name_id_map,
529                                           nlohmann::json *row_raw_data,
530                                           std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
531   RETURN_UNEXPECTED_IF_NULL(row_raw_data);
532   RETURN_UNEXPECTED_IF_NULL(row_bin_data);
533   Status s;
534   s = ValidateInputParams(row_raw_data, row_bin_data, column_name_id_map);
535   if (s.IsError()) {
536     return s;
537   }
538   for (auto &col : column_name_id_map) {
539     auto idx = col.second;
540     auto column_name = col.first;
541     auto &tensor = row[idx];
542     s = FetchItemData(tensor, column_name, row_raw_data, row_bin_data);
543     RETURN_IF_NOT_OK(s);
544   }
545   return Status::OK();
546 }
547 
548 template <typename T, typename S>
TransformTensor(const unsigned char * src,const TensorShape & shape,const int64_t num_of_elements,std::unique_ptr<T> * data,std::unique_ptr<std::vector<uint8_t>> * data_ptr,std::unique_ptr<S> * s,bool need_convert)549 Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
550                                    std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
551                                    std::unique_ptr<S> *s, bool need_convert) {
552   // No need to check src since we support some scenarios that src is nullptr and num_of_elements is 0.
553   RETURN_UNEXPECTED_IF_NULL(data);
554   RETURN_UNEXPECTED_IF_NULL(data_ptr);
555   RETURN_UNEXPECTED_IF_NULL(s);
556 
557   *data_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(T));
558   if (need_convert) {
559     auto tmp_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(S));
560     std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin());
561     auto s_ptr = reinterpret_cast<S *>(&(*(tmp_ptr->begin())));
562     auto el = std::make_unique<T>();
563     for (uint32_t i = 0; i < num_of_elements; ++i) {
564       *el = *(s_ptr + i);
565       auto t_ptr = reinterpret_cast<uint8_t *>(el.get());
566       for (uint32_t j = 0; j < sizeof(T); ++j) {
567         *((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j);
568       }
569     }
570   } else {
571     std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin());
572   }
573   if (shape.empty()) {
574     *data = std::make_unique<T>();
575     auto t_ptr = reinterpret_cast<uint8_t *>((*data).get());
576     for (uint32_t i = 0; i < sizeof(T); ++i) {
577       *(t_ptr + i) = *((*data_ptr)->begin() + i);
578     }
579   }
580   return Status::OK();
581 }
582 #endif
583 
TreeGetters()584 TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), first_row_obtained_(false) {
585   tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
586 }
587 
Init(std::shared_ptr<DatasetNode> d)588 Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
589   RETURN_UNEXPECTED_IF_NULL(d);
590   root_ = std::move(d);
591   return Status::OK();
592 }
593 
GetRow(TensorRow * row)594 Status TreeGetters::GetRow(TensorRow *row) {
595   RETURN_UNEXPECTED_IF_NULL(row);
596   return tree_adapter_->GetNext(row);
597 }
598 
GetOutputTypes(std::vector<DataType> * types)599 Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
600   RETURN_UNEXPECTED_IF_NULL(types);
601   RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
602   *types = first_row_type_;
603   return Status::OK();
604 }
605 
GetOutputShapes(std::vector<TensorShape> * shapes)606 Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
607   RETURN_UNEXPECTED_IF_NULL(shapes);
608   RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
609   *shapes = first_row_shape_;
610   return Status::OK();
611 }
612 
GetBatchSize(int64_t * batch_size)613 Status TreeGetters::GetBatchSize(int64_t *batch_size) {
614   RETURN_UNEXPECTED_IF_NULL(batch_size);
615   RETURN_IF_NOT_OK(InternalInit());
616   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
617   RETURN_UNEXPECTED_IF_NULL(root);
618   *batch_size = root->GetTreeBatchSize();
619   CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "GetBatchSize: Failed to find the batch size in Dataset pipeline.");
620   return Status::OK();
621 }
622 
GetRepeatCount(int64_t * repeat_count)623 Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
624   RETURN_UNEXPECTED_IF_NULL(repeat_count);
625   RETURN_IF_NOT_OK(InternalInit());
626   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
627   RETURN_UNEXPECTED_IF_NULL(root);
628   *repeat_count = root->GetTreeRepeatCount();
629   return Status::OK();
630 }
631 
GetNumClasses(int64_t * num_classes)632 Status TreeGetters::GetNumClasses(int64_t *num_classes) {
633   RETURN_UNEXPECTED_IF_NULL(num_classes);
634   RETURN_IF_NOT_OK(InternalInit());
635   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
636   RETURN_UNEXPECTED_IF_NULL(root);
637   RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
638   return Status::OK();
639 }
640 
GetColumnNames(std::vector<std::string> * output)641 Status TreeGetters::GetColumnNames(std::vector<std::string> *output) {
642   RETURN_UNEXPECTED_IF_NULL(output);
643   RETURN_IF_NOT_OK(InternalInit());
644   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
645   RETURN_UNEXPECTED_IF_NULL(root);
646   std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map();
647   CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map.empty(), "GetColumnNames: column_name_id map can not be empty.");
648   std::vector<std::pair<std::string, int32_t>> col_name_id_vec(column_name_id_map.begin(), column_name_id_map.end());
649   std::sort(col_name_id_vec.begin(), col_name_id_vec.end(),
650             [](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
651               return a.second < b.second;
652             });
653   std::transform(col_name_id_vec.begin(), col_name_id_vec.end(), std::back_inserter(*output),
654                  [](const std::pair<std::string, int32_t> &p) { return p.first; });
655   return Status::OK();
656 }
657 
GetClassIndexing(std::vector<std::pair<std::string,std::vector<int32_t>>> * output_class_indexing)658 Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
659   RETURN_UNEXPECTED_IF_NULL(output_class_indexing);
660   RETURN_IF_NOT_OK(InternalInit());
661   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
662   RETURN_UNEXPECTED_IF_NULL(root);
663   RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing));
664   return Status::OK();
665 }
666 
InternalInit()667 Status TreeGetters::InternalInit() {
668   if (init_flag_) {
669     return Status::OK();
670   }
671 
672   Status s = tree_adapter_->Compile(std::move(root_), 1);
673   if (s.IsOk()) {
674     init_flag_ = true;
675   }
676   return s;
677 }
678 
GetFirstRowShapeAndType()679 Status TreeGetters::GetFirstRowShapeAndType() {
680   RETURN_OK_IF_TRUE(first_row_obtained_);
681   RETURN_IF_NOT_OK(InternalInit());
682   TensorRow first_row;
683   RETURN_IF_NOT_OK(GetRow(&first_row));
684   std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_),
685                  [](const TensorPtr &t) { return t->type(); });
686   std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_shape_),
687                  [](const TensorPtr &t) { return t->shape(); });
688   first_row_obtained_ = true;
689   return Status::OK();
690 }
691 
Init(std::shared_ptr<DatasetNode> d)692 Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }
693 
Start()694 Status BuildVocabConsumer::Start() {
695   // Getting one row would trigger building the vocab
696   TensorRow row;
697   RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
698   // The returned row would EOE which is an empty row
699   CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "BuildVocab: The fetched row from BuildVocab should be an EOE.");
700   return Status::OK();
701 }
GetDatasetSize(int64_t * size,bool estimate)702 Status DatasetSizeGetter::GetDatasetSize(int64_t *size, bool estimate) {
703   if (dataset_size_ == -1) {
704     RETURN_IF_NOT_OK(root_->GetDatasetSize(shared_from_this(), estimate, size));
705     dataset_size_ = *size;  // save the previous result
706   }
707 
708   *size = dataset_size_;
709   return Status::OK();
710 }
Init(std::shared_ptr<DatasetNode> d)711 Status DatasetSizeGetter::Init(std::shared_ptr<DatasetNode> d) {
712   root_ = std::move(d);
713   return Status::OK();
714 }
DryRun(std::shared_ptr<DatasetNode> ir_node,int64_t * dataset_size)715 Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) {
716   RETURN_UNEXPECTED_IF_NULL(dataset_size);
717   std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
718   tree_adapters_.push_back(tree_adapter);
719   RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1));
720   TensorRow row;
721   RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
722   int64_t row_cnt = 0;
723   while (!row.empty()) {
724     ++row_cnt;
725     RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
726   }
727   *dataset_size = row_cnt;
728   return Status::OK();
729 }
GetRow(const std::shared_ptr<TreeAdapter> & tree_adapter,TensorRow * row)730 Status DatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row) {
731   RETURN_UNEXPECTED_IF_NULL(row);
732   return tree_adapter->GetNext(row);
733 }
Terminate()734 Status DatasetSizeGetter::Terminate() {
735   for (const auto &tree : tree_adapters_) {
736     RETURN_UNEXPECTED_IF_NULL(tree);
737     RETURN_UNEXPECTED_IF_NULL(tree->AllTasks());
738     RETURN_IF_NOT_OK(tree->AllTasks()->ServiceStop());
739   }
740   return Status::OK();
741 }
742 }  // namespace dataset
743 }  // namespace mindspore
744