1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23
24 #include "minddata/dataset/engine/consumers/tree_consumer.h"
25 #include "minddata/dataset/engine/datasetops/data_queue_op.h"
26 #include "minddata/dataset/engine/opt/pre/getter_pass.h"
27 #ifndef ENABLE_SECURITY
28 #include "minddata/dataset/engine/perf/auto_tune.h"
29 #endif
30 #include "minddata/dataset/engine/perf/info_collector.h"
31 #ifndef ENABLE_SECURITY
32 #include "minddata/dataset/engine/perf/profiling.h"
33 #endif
34 #include "minddata/dataset/engine/tree_adapter.h"
35 #ifndef ENABLE_ANDROID
36 #include "minddata/dataset/kernels/data/data_utils.h"
37 #include "minddata/mindrecord/include/shard_header.h"
38 #include "minddata/mindrecord/include/shard_index_generator.h"
39 #include "minddata/mindrecord/include/shard_writer.h"
40 #endif
41 #ifdef WITH_BACKEND
42 #include "utils/ms_context.h"
43 #endif
44
45 namespace mindspore {
46 namespace dataset {
47 #ifndef ENABLE_SECURITY
48 using ProfilingRegistrationState = ProfilingManager::ProfilingRegistrationState;
49 #endif
50 // TreeConsumer
TreeConsumer()51 TreeConsumer::TreeConsumer() : TreeConsumer(1) {}
52
TreeConsumer(int32_t num_epochs)53 TreeConsumer::TreeConsumer(int32_t num_epochs) : num_epochs_(num_epochs) {
54 tree_adapter_ = std::make_unique<TreeAdapter>();
55 }
56
Init(const std::shared_ptr<DatasetNode> & root)57 Status TreeConsumer::Init(const std::shared_ptr<DatasetNode> &root) {
58 RETURN_IF_NOT_OK(tree_adapter_->Compile(root));
59 #ifndef ENABLE_SECURITY
60 profiling_manager_ = GlobalContext::profiling_manager();
61 RETURN_IF_NOT_OK(RegisterProfilingManager());
62 #endif
63 return Status::OK();
64 }
65
Init(const std::shared_ptr<DatasetNode> & root,int64_t global_step,int64_t dataset_size)66 Status TreeConsumer::Init(const std::shared_ptr<DatasetNode> &root, int64_t global_step, int64_t dataset_size) {
67 MS_LOG(WARNING) << "TreeConsumer does not support initializing from intermediate epoch or step, change to "
68 "initialize from the beginning.";
69 return Init(root);
70 }
71
Terminate()72 Status TreeConsumer::Terminate() {
73 if (tree_adapter_->AllTasks() != nullptr) {
74 return tree_adapter_->AllTasks()->ServiceStop();
75 }
76 return Status::OK();
77 }
78
79 #ifndef ENABLE_SECURITY
RegisterProfilingManager()80 Status IteratorConsumer::RegisterProfilingManager() {
81 auto profiler_state = profiling_manager_->GetProfilerTreeState(tree_adapter_->tree_.get());
82 // This should never happen
83 CHECK_FAIL_RETURN_UNEXPECTED(profiler_state != ProfilingManager::kEnabledTreeRegistered,
84 "Something went wrong. Current tree is already registered with the MD Profiler");
85 if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered && profiling_manager_->IsProfiling()) {
86 MS_LOG(WARNING) << "Dataset Profiling is already enabled for a different data pipeline.";
87 } else if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered &&
88 profiling_manager_->IsAutotuning()) {
89 MS_LOG(WARNING) << "AutoTune for dataset is already enabled for a different data pipeline.";
90 } else if (profiler_state == ProfilingRegistrationState::kEnabledTreeNotRegistered) {
91 // Profiling infrastructures need to be initialized before Op launching
92 // Setup profiling manager
93 RETURN_IF_NOT_OK(profiling_manager_->RegisterTree(this->tree_adapter_.get()));
94 // dataset_iterator node is used for graph mode
95 std::shared_ptr<Tracing> iterator_tracing = std::make_shared<DatasetIteratorTracing>();
96 RETURN_IF_NOT_OK(profiling_manager_->RegisterTracingNode(iterator_tracing));
97 RETURN_IF_NOT_OK(tree_adapter_->SetProfilingManagerPtr(profiling_manager_, iterator_tracing));
98 // Launch Monitor Thread
99 RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor());
100 } else {
101 MS_LOG(INFO) << "Unable to register this tree with ProfilingManager.";
102 }
103 return Status::OK();
104 }
105
RegisterProfilingManager()106 Status ToDevice::RegisterProfilingManager() {
107 auto profiler_state = profiling_manager_->GetProfilerTreeState(tree_adapter_->tree_.get());
108 // This should never happen
109 CHECK_FAIL_RETURN_UNEXPECTED(profiler_state != ProfilingManager::kEnabledTreeRegistered,
110 "Something went wrong. Current tree is already registered with the MD Profiler");
111 if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered && profiling_manager_->IsProfiling()) {
112 MS_LOG(WARNING) << "Dataset Profiling is already enabled for a different data pipeline.";
113 } else if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered &&
114 profiling_manager_->IsAutotuning()) {
115 MS_LOG(WARNING) << "AutoTune for dataset is already enabled for a different data pipeline.";
116 } else if (profiler_state == ProfilingRegistrationState::kEnabledTreeNotRegistered) {
117 // Profiling infrastructures need to be initialized before Op launching
118 // Setup profiling manager
119 RETURN_IF_NOT_OK(profiling_manager_->RegisterTree(this->tree_adapter_.get()));
120 // device_queue node is used for graph mode
121 std::shared_ptr<Tracing> device_queue_tracing = std::make_shared<DeviceQueueTracing>();
122 RETURN_IF_NOT_OK(profiling_manager_->RegisterTracingNode(device_queue_tracing));
123 RETURN_IF_NOT_OK(tree_adapter_->SetProfilingManagerPtr(profiling_manager_));
124 // Launch Monitor Thread
125 RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor());
126 } else {
127 MS_LOG(INFO) << "Unable to register this tree with ProfilingManager.";
128 }
129 return Status::OK();
130 }
131
RegisterProfilingManager()132 Status TreeConsumer::RegisterProfilingManager() {
133 if (profiling_manager_->IsProfiling()) {
134 return {StatusCode::kMDUnexpectedError, "Dataset Profiling is not supported for this kind of dataset."};
135 }
136 return Status::OK();
137 }
138
InitAutoTune()139 Status TreeConsumer::InitAutoTune() {
140 auto profiler_state = profiling_manager_->GetProfilerTreeState(tree_adapter_->tree_.get());
141 if (profiler_state == ProfilingRegistrationState::kNotEnabled) {
142 // Init ProfilingManager to `Enable` it.
143 RETURN_IF_NOT_OK(profiling_manager_->Init(true));
144 // Register this tree
145 RETURN_IF_NOT_OK(RegisterProfilingManager());
146 // Start Profiler
147 RETURN_IF_NOT_OK(profiling_manager_->Start());
148 // AutoTune object and thread init
149 autotune_ = std::make_unique<AutoTune>(this->tree_adapter_.get(), GetProfilingManager());
150 RETURN_IF_NOT_OK(autotune_->LaunchThread());
151 } else if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered && profiling_manager_->IsProfiling()) {
152 MS_LOG(WARNING) << "Cannot enable AutoTune for the current data pipeline as Dataset Profiling is enabled for "
153 "another data pipeline.";
154 } else if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered &&
155 profiling_manager_->IsAutotuning()) {
156 MS_LOG(WARNING)
157 << "Cannot enable AutoTune for the current data pipeline as it is already enabled for another data pipeline.";
158 } else if (profiler_state == ProfilingManager::kEnabledTreeRegistered && profiling_manager_->IsProfiling()) {
159 MS_LOG(WARNING)
160 << "Cannot enable AutoTune for the current data pipeline as Dataset Profiling is already enabled for the "
161 "current data pipeline.";
162 } else {
163 MS_LOG(WARNING) << "Cannot enable AutoTune for the current data pipeline.";
164 }
165 return Status::OK();
166 }
167 #endif
168
GetOffload()169 std::string TreeConsumer::GetOffload() { return (tree_adapter_->GetOffloadJson()).dump(); }
170
171 // IteratorConsumer
Init(const std::shared_ptr<DatasetNode> & root,int64_t global_step,int64_t dataset_size)172 Status IteratorConsumer::Init(const std::shared_ptr<DatasetNode> &root, int64_t global_step, int64_t dataset_size) {
173 if (global_step != 0) {
174 tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
175 }
176 RETURN_IF_NOT_OK(tree_adapter_->Compile(root, num_epochs_, global_step, dataset_size));
177 #ifndef ENABLE_SECURITY
178 profiling_manager_ = GlobalContext::profiling_manager();
179 if (profiling_manager_->IsProfiling()) {
180 // Init has been called already
181 RETURN_IF_NOT_OK(RegisterProfilingManager());
182 }
183 if (GlobalContext::config_manager()->enable_autotune()) {
184 RETURN_IF_NOT_OK(InitAutoTune());
185 }
186 #endif
187 return Status::OK();
188 }
189
GetNextAsVector(std::vector<TensorPtr> * const out)190 Status IteratorConsumer::GetNextAsVector(std::vector<TensorPtr> *const out) {
191 RETURN_UNEXPECTED_IF_NULL(out);
192 RETURN_IF_NOT_OK(CollectPipelineInfoStart("IteratorConsumer", "GetNextAsVector"));
193 out->clear();
194 TensorRow res;
195 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
196
197 // Return empty vector if there's no data
198 if (res.empty()) {
199 RETURN_IF_NOT_OK(
200 CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsVector", {{"TensorRowFlags", res.FlagName()}}));
201 return Status::OK();
202 }
203
204 // Filter meta column
205 std::vector<size_t> to_keep_indices;
206 for (const auto &colMap : tree_adapter_->GetColumnNameMap()) {
207 std::string column_name = colMap.first;
208 // Need to filter meta column start with kDftMetaColumnPrefix
209 size_t pos = column_name.find(kDftMetaColumnPrefix);
210 if (pos != std::string::npos && pos == 0) {
211 continue;
212 }
213 to_keep_indices.push_back(colMap.second);
214 }
215 if (to_keep_indices.empty()) {
216 std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
217 err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
218 err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
219 RETURN_STATUS_UNEXPECTED(err_msg);
220 }
221 std::sort(to_keep_indices.begin(), to_keep_indices.end());
222 (void)std::transform(to_keep_indices.begin(), to_keep_indices.end(), std::back_inserter(*out),
223 [&res](const auto &it) { return std::move(res[it]); });
224 RETURN_IF_NOT_OK(CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsVector"));
225 return Status::OK();
226 }
227
GetNextAsMap(std::unordered_map<std::string,TensorPtr> * const out_map)228 Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> *const out_map) {
229 RETURN_UNEXPECTED_IF_NULL(out_map);
230 RETURN_IF_NOT_OK(CollectPipelineInfoStart("IteratorConsumer", "GetNextAsMap"));
231
232 out_map->clear();
233 TensorRow res;
234 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
235
236 // Return empty map if there's no data
237 if (res.empty()) {
238 RETURN_IF_NOT_OK(CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsMap", {{"TensorRowFlags", res.FlagName()}}));
239 return Status::OK();
240 }
241
242 // Populate the out map from the row and return it
243 for (const auto &colMap : tree_adapter_->GetColumnNameMap()) {
244 std::string column_name = colMap.first;
245 // Need to filter meta column start with kDftMetaColumnPrefix
246 size_t pos = column_name.find(kDftMetaColumnPrefix);
247 if (pos != std::string::npos && pos == 0) {
248 continue;
249 }
250 (*out_map)[colMap.first] = std::move(res[colMap.second]);
251 }
252 if (out_map->empty()) {
253 std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
254 err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
255 err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
256 RETURN_STATUS_UNEXPECTED(err_msg);
257 }
258 RETURN_IF_NOT_OK(CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsMap"));
259 return Status::OK();
260 }
261
GetNextAsOrderedPair(std::vector<std::pair<std::string,std::shared_ptr<Tensor>>> * const vec)262 Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *const vec) {
263 CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
264 RETURN_IF_NOT_OK(CollectPipelineInfoStart("IteratorConsumer", "GetNextAsOrderedPair"));
265
266 TensorRow curr_row;
267
268 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row));
269
270 // Return empty pair if there's no data
271 if (curr_row.empty()) {
272 RETURN_IF_NOT_OK(
273 CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsOrderedPair", {{"TensorRowFlags", curr_row.FlagName()}}));
274 return Status::OK();
275 }
276
277 size_t num_cols = curr_row.size(); // num_cols is non-empty.
278 // order the column names according to their ids
279 if (column_order_.empty()) {
280 for (const auto &itr : tree_adapter_->GetColumnNameMap()) {
281 int32_t ind = itr.second;
282 CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds.");
283 // Need to filter meta column start with kDftMetaColumnPrefix
284 size_t pos = itr.first.find(kDftMetaColumnPrefix);
285 if (pos != std::string::npos && pos == 0) {
286 continue;
287 }
288 column_order_[ind] = itr.first;
289 }
290 }
291
292 if (column_order_.empty()) {
293 std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
294 err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
295 err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
296 RETURN_STATUS_UNEXPECTED(err_msg);
297 }
298 vec->reserve(column_order_.size());
299
300 std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec),
301 [curr_row](const auto &col) { return std::make_pair(col.second, curr_row[col.first]); });
302 RETURN_IF_NOT_OK(CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsOrderedPair"));
303 return Status::OK();
304 }
305
306 // ToDevice
Init(const std::shared_ptr<DatasetNode> & root,int64_t global_step,int64_t dataset_size)307 Status ToDevice::Init(const std::shared_ptr<DatasetNode> &root, int64_t global_step, int64_t dataset_size) {
308 if (global_step != 0) {
309 tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
310 }
311 RETURN_IF_NOT_OK(tree_adapter_->Compile(root, num_epochs_, global_step, dataset_size));
312 #ifndef ENABLE_SECURITY
313 profiling_manager_ = GlobalContext::profiling_manager();
314 if (profiling_manager_->IsProfiling()) {
315 // Init has been called already
316 RETURN_IF_NOT_OK(RegisterProfilingManager());
317 }
318 if (GlobalContext::config_manager()->enable_autotune()) {
319 RETURN_IF_NOT_OK(InitAutoTune());
320 }
321 #endif
322 return Status::OK();
323 }
324
Send()325 Status ToDevice::Send() {
326 RETURN_IF_NOT_OK(tree_adapter_->Launch());
327 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
328 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
329 return Status::OK();
330 }
331
Continue()332 Status ToDevice::Continue() {
333 // tree_.root() must be DataQueueOp
334 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
335 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
336 DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
337 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DataQueueOp");
338 op->ContinueSend();
339 return Status::OK();
340 }
341
Stop()342 Status ToDevice::Stop() {
343 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
344 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
345 DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
346 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DataQueueOp");
347 op->StopSend();
348
349 return Status::OK();
350 }
351
GetDataInfo(std::vector<DataType> * const types,std::vector<TensorShape> * const shapes)352 Status ToDevice::GetDataInfo(std::vector<DataType> *const types, std::vector<TensorShape> *const shapes) {
353 RETURN_UNEXPECTED_IF_NULL(types);
354 RETURN_UNEXPECTED_IF_NULL(shapes);
355 // tree_.root() must be DataQueueOp
356 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
357 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
358 DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
359 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DataQueueOp");
360 DATA_INFO data_info;
361 RETURN_IF_NOT_OK(op->GetDataInfo(&data_info));
362 for (auto el : data_info) {
363 types->push_back(el.first);
364 shapes->push_back(el.second);
365 }
366 return Status::OK();
367 }
368
GetMbufQueueSize(size_t * queue_size)369 Status ToDevice::GetMbufQueueSize(size_t *queue_size) {
370 RETURN_UNEXPECTED_IF_NULL(queue_size);
371 // tree_.root() must be DataQueueOp
372 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
373 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
374 DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
375 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetMbufQueueSize only supported by DataQueueOp");
376 RETURN_IF_NOT_OK(op->GetMbufQueueSize(queue_size));
377 return Status::OK();
378 }
379
GetSendInfo(std::vector<std::vector<double>> * send_info)380 Status ToDevice::GetSendInfo(std::vector<std::vector<double>> *send_info) {
381 RETURN_UNEXPECTED_IF_NULL(send_info);
382 // tree_.root() must be DataQueueOp
383 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
384 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
385 DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
386 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetSendInfo only supported by DataQueueOp");
387 DATA_INFO data_info;
388 *send_info = op->GetSendInfo();
389 return Status::OK();
390 }
391
Terminate()392 Status ToDevice::Terminate() {
393 #ifdef WITH_BACKEND
394 RETURN_UNEXPECTED_IF_NULL(MsContext::GetInstance());
395 if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
396 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
397 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
398 DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
399 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DataQueueOp");
400 op->StopWaiting();
401 }
402 #endif
403 return TreeConsumer::Terminate();
404 }
405
Reset(int64_t step,const int64_t dataset_size)406 Status TreeConsumer::Reset(int64_t step, const int64_t dataset_size) {
407 MS_LOG(INFO) << "Resetting TreeConsumer";
408
409 MS_LOG(INFO) << "Terminating pipeline with UUID:" << tree_adapter_->tree_->GetUniqueId();
410 std::shared_ptr<DatasetNode> old_root = tree_adapter_->input_ir_;
411 RETURN_IF_NOT_OK(this->Stop());
412 RETURN_IF_NOT_OK(this->Terminate());
413 #ifdef WITH_BACKEND
414 RETURN_UNEXPECTED_IF_NULL(MsContext::GetInstance());
415 if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice) {
416 // clear the device if GPU is used.
417 std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
418 CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
419 DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
420 if (op != nullptr) {
421 MS_LOG(INFO) << "Clearing the GPU device";
422 RETURN_IF_NOT_OK(op->ClearDevice());
423 }
424 }
425 #endif
426 tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
427 RETURN_IF_NOT_OK(tree_adapter_->Compile(old_root, num_epochs_, step, dataset_size));
428 RETURN_IF_NOT_OK(tree_adapter_->Launch());
429 MS_LOG(INFO) << "Launched a new pipeline after reset. UUID: " << tree_adapter_->tree_->GetUniqueId();
430 std::shared_ptr<DatasetOp> root2 = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
431 CHECK_FAIL_RETURN_UNEXPECTED(root2 != nullptr, "Root is a nullptr.");
432 return Status::OK();
433 }
434
435 #ifndef ENABLE_ANDROID
436 // SaveToDisk
ValidateParams()437 Status SaveToDisk::ValidateParams() {
438 if (dataset_path_.empty()) {
439 std::string err = "SaveToDisk failed, dataset_path must not be empty";
440 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err);
441 }
442 Path dir(dataset_path_);
443 if (dir.IsDirectory()) {
444 std::string err = "SaveToDisk failed, dataset_path must not be a directory";
445 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err);
446 }
447 std::string real_path;
448 if (dir.ParentPath().empty()) {
449 dir = Path(".") / dir;
450 }
451 if (Path::RealPath(dir.ParentPath(), real_path).IsError()) {
452 std::string err_msg = "SaveToDisk failed, can not get real dataset path: " + dir.ParentPath();
453 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
454 }
455 if (access(dir.ParentPath().c_str(), R_OK) == -1) {
456 std::string err_msg = "SaveToDisk failed, no access to specified dataset path: " + dataset_path_;
457 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
458 }
459 if (num_files_ <= 0 || num_files_ > 1000) {
460 std::string err = "SaveToDisk failed, num_files must between 1 and 1000, but got " + std::to_string(num_files_);
461 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err);
462 }
463 if (dataset_type_ != "mindrecord") {
464 std::string err = "SaveToDisk failed, only \"mindrecord\" dataset format is supported, but got " + dataset_type_;
465 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err);
466 }
467 return Status::OK();
468 }
469
Save()470 Status SaveToDisk::Save() {
471 RETURN_IF_NOT_OK(CollectPipelineInfoStart("SaveToDisk", "Save"));
472 std::vector<std::string> file_names;
473 if (num_files_ == 1) {
474 file_names.push_back(dataset_path_);
475 } else {
476 for (int32_t i = 0; i < num_files_; i++) {
477 file_names.push_back(dataset_path_ + std::to_string(i));
478 }
479 }
480
481 auto mr_header = std::make_shared<mindrecord::ShardHeader>();
482 auto mr_writer = std::make_unique<mindrecord::ShardWriter>();
483 std::vector<std::string> blob_fields;
484 RETURN_IF_NOT_OK(mindrecord::ShardWriter::Initialize(&mr_writer, file_names));
485
486 std::unordered_map<std::string, int32_t> column_name_id_map;
487 for (auto el : tree_adapter_->GetColumnNameMap()) {
488 std::string column_name = el.first;
489 (void)std::transform(column_name.begin(), column_name.end(), column_name.begin(),
490 [](unsigned char c) { return ispunct(c) ? '_' : c; });
491 column_name_id_map[column_name] = el.second;
492 }
493
494 TensorRow row;
495 uint64_t mr_schema_id = 0;
496 bool first_loop = true; // build schema in first loop
497 auto PreTensorRowShapes = std::map<std::string, std::vector<int>>();
498
499 do {
500 nlohmann::json row_raw_data;
501 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
502 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
503 if (row.empty()) {
504 break;
505 }
506 RETURN_IF_NOT_OK(CheckTensorRowShapes(column_name_id_map, row, &PreTensorRowShapes));
507 if (first_loop) {
508 nlohmann::json mr_json;
509 std::vector<std::string> index_fields;
510 RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields));
511 MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump();
512 RETURN_IF_NOT_OK(
513 mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id));
514 RETURN_IF_NOT_OK(mr_writer->SetShardHeader(mr_header));
515 first_loop = false;
516 }
517 // construct data
518 if (!row.empty()) { // write data
519 RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data));
520 std::shared_ptr<std::vector<uint8_t>> output_bin_data;
521 RETURN_IF_NOT_OK(mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data));
522 std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data;
523 raw_data.insert(
524 std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data}));
525 std::vector<std::vector<uint8_t>> bin_data;
526 if (output_bin_data != nullptr) {
527 bin_data.emplace_back(*output_bin_data);
528 }
529 RETURN_IF_NOT_OK(mr_writer->WriteRawData(raw_data, bin_data));
530 }
531 } while (!row.empty());
532
533 RETURN_IF_NOT_OK(mr_writer->Commit());
534 RETURN_IF_NOT_OK(mindrecord::ShardIndexGenerator::Finalize(file_names));
535 RETURN_IF_NOT_OK(CollectPipelineInfoEnd("SaveToDisk", "Save"));
536 return Status::OK();
537 }
538
539 template <typename T>
map_compare(T const & lhs,T const & rhs)540 bool SaveToDisk::map_compare(T const &lhs, T const &rhs) {
541 return lhs.size() == rhs.size() && std::equal(lhs.begin(), lhs.end(), rhs.begin());
542 }
543
CheckTensorRowShapes(const std::unordered_map<std::string,int32_t> & column_name_id_map,const TensorRow & row,std::map<std::string,std::vector<int>> * PreTensorRowShapes_ptr)544 Status SaveToDisk::CheckTensorRowShapes(const std::unordered_map<std::string, int32_t> &column_name_id_map,
545 const TensorRow &row,
546 std::map<std::string, std::vector<int>> *PreTensorRowShapes_ptr) {
547 std::map<std::string, std::vector<int>> CurrTensorRowShapes;
548 for (auto &col : column_name_id_map) {
549 auto idx = col.second;
550 auto column_name = col.first;
551 auto &tensor = row[idx];
552 auto column_type = tensor->type();
553 auto column_shape = tensor->shape();
554
555 auto shapes = column_shape.AsVector();
556 std::vector<int> mr_shape(shapes.begin(), shapes.end());
557
558 if (mr_shape.empty() || mr_shape.size() == 1) {
559 continue; // ignore scalar and one dimension tensor
560 }
561 std::string mr_type;
562 std::string el = column_type.ToString();
563 if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
564 std::string err_msg("Invalid type, unsupported data type: " + el);
565 RETURN_STATUS_UNEXPECTED(err_msg);
566 } else {
567 mr_type = mindrecord::kTypesMap.at(el);
568 }
569 if (mr_type == "bytes" || mr_type == "string") {
570 continue;
571 }
572 mr_shape.erase(mr_shape.begin()); // ignore the first dimension
573 CurrTensorRowShapes[column_name] = mr_shape;
574 }
575 if (PreTensorRowShapes_ptr->empty()) {
576 *PreTensorRowShapes_ptr = CurrTensorRowShapes;
577 return Status::OK();
578 }
579 auto res = map_compare(*PreTensorRowShapes_ptr, CurrTensorRowShapes);
580 CHECK_FAIL_RETURN_UNEXPECTED(res,
581 "Tensor with dynamic shape do not currently support saving. Except for the shape of "
582 "dimension 0, the other dimension shapes must be fixed. "
583 "You can reshape the Tensor to a fixed shape before saving.");
584 return Status::OK();
585 }
586
FetchMetaFromTensorRow(const std::unordered_map<std::string,int32_t> & column_name_id_map,const TensorRow & row,nlohmann::json * schema,std::vector<std::string> * index_fields)587 Status SaveToDisk::FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
588 const TensorRow &row, nlohmann::json *schema,
589 std::vector<std::string> *index_fields) {
590 if (schema == nullptr) {
591 RETURN_STATUS_UNEXPECTED("schema can not be nullptr.");
592 }
593 if (index_fields == nullptr) {
594 RETURN_STATUS_UNEXPECTED("index_fields can not be nullptr.");
595 }
596 if (column_name_id_map.empty()) {
597 RETURN_STATUS_UNEXPECTED("column_name_id_map can not be nullptr..");
598 }
599 nlohmann::json dataset_schema;
600 for (auto &col : column_name_id_map) {
601 auto idx = col.second;
602 auto column_name = col.first;
603 auto &tensor = row[idx];
604 auto column_type = tensor->type();
605 auto column_shape = tensor->shape();
606
607 std::string mr_type;
608 auto shapes = column_shape.AsVector();
609 std::vector<int> mr_shape(shapes.begin(), shapes.end());
610 std::string el = column_type.ToString();
611 dataset_schema[column_name] = el;
612 if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
613 std::string err_msg("Invalid type, unsupported data type: " + el);
614 RETURN_STATUS_UNEXPECTED(err_msg);
615 } else {
616 mr_type = mindrecord::kTypesMap.at(el);
617 }
618 if (mr_shape.empty()) {
619 (*schema)[column_name] = {{"type", mr_type}};
620 } else {
621 if (mr_type == "string") { // mindrecord can not support string with shape.
622 std::string err_msg("Invalid data, mindrecord can not support multi-dimensional string tensor.");
623 RETURN_STATUS_UNEXPECTED(err_msg);
624 }
625 if (mr_type == "bytes") { // ignore shape of bytes in minrecord
626 (*schema)[column_name] = {{"type", mr_type}};
627 } else {
628 mr_shape[0] = -1; // make first dimension -1
629 (*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}};
630 }
631 }
632 if (mr_type == "bytes" || !mr_shape.empty()) {
633 continue;
634 }
635 index_fields->emplace_back(column_name); // candidate of index fields
636 }
637 MS_LOG(DEBUG) << "Schema of dataset: " << dataset_schema.dump();
638 return Status::OK();
639 }
640
ValidateInputParams(nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data,const std::unordered_map<std::string,int32_t> & column_name_id_map)641 inline Status ValidateInputParams(nlohmann::json *row_raw_data,
642 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data,
643 const std::unordered_map<std::string, int32_t> &column_name_id_map) {
644 if (row_raw_data == nullptr) {
645 RETURN_STATUS_UNEXPECTED("row_raw_data can not be nullptr.");
646 }
647 if (row_bin_data == nullptr) {
648 RETURN_STATUS_UNEXPECTED("row_bin_data can not be nullptr.");
649 }
650 if (column_name_id_map.empty()) {
651 RETURN_STATUS_UNEXPECTED("column_name_id_map can not be nullptr.");
652 }
653 return Status::OK();
654 }
655
FetchIntData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::unique_ptr<std::vector<uint8_t>> * data_ptr)656 Status SaveToDisk::FetchIntData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
657 std::unique_ptr<std::vector<uint8_t>> *data_ptr) {
658 RETURN_UNEXPECTED_IF_NULL(row_raw_data);
659 RETURN_UNEXPECTED_IF_NULL(data_ptr);
660 auto column_type = tensor->type();
661 Status s;
662 if (column_type == DataType::DE_INT8) {
663 std::unique_ptr<int32_t> data;
664 std::unique_ptr<int8_t> dummy;
665 RETURN_IF_NOT_OK(
666 TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
667 if (data != nullptr) {
668 (*row_raw_data)[column_name] = std::move(*data);
669 }
670 } else if (column_type == DataType::DE_UINT8) {
671 std::unique_ptr<int32_t> data;
672 std::unique_ptr<uint8_t> dummy;
673 RETURN_IF_NOT_OK(
674 TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
675 if (data != nullptr) {
676 (*row_raw_data)[column_name] = std::move(*data);
677 }
678 } else if (column_type == DataType::DE_INT16) {
679 std::unique_ptr<int32_t> data;
680 std::unique_ptr<int16_t> dummy;
681 RETURN_IF_NOT_OK(
682 TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
683 if (data != nullptr) {
684 (*row_raw_data)[column_name] = std::move(*data);
685 }
686 } else if (column_type == DataType::DE_UINT16) {
687 std::unique_ptr<int32_t> data;
688 std::unique_ptr<uint16_t> dummy;
689 RETURN_IF_NOT_OK(
690 TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
691 if (data != nullptr) {
692 (*row_raw_data)[column_name] = std::move(*data);
693 }
694 } else if (column_type == DataType::DE_INT32) {
695 std::unique_ptr<int32_t> data, dummy;
696 RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy));
697 if (data != nullptr) {
698 (*row_raw_data)[column_name] = std::move(*data);
699 }
700 } else if (column_type == DataType::DE_UINT32) {
701 std::unique_ptr<int64_t> data;
702 std::unique_ptr<uint32_t> dummy;
703 RETURN_IF_NOT_OK(
704 TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
705 if (data != nullptr) {
706 (*row_raw_data)[column_name] = std::move(*data);
707 }
708 } else if (column_type == DataType::DE_INT64 || column_type == DataType::DE_UINT64) {
709 std::unique_ptr<int64_t> data, dummy;
710 RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy));
711 if (data != nullptr) {
712 (*row_raw_data)[column_name] = std::move(*data);
713 }
714 }
715
716 return Status::OK();
717 }
718
FetchFloatData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::unique_ptr<std::vector<uint8_t>> * data_ptr)719 Status SaveToDisk::FetchFloatData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
720 std::unique_ptr<std::vector<uint8_t>> *data_ptr) {
721 RETURN_UNEXPECTED_IF_NULL(row_raw_data);
722 RETURN_UNEXPECTED_IF_NULL(data_ptr);
723 auto column_type = tensor->type();
724 Status s;
725 if (column_type == DataType::DE_FLOAT16) {
726 std::unique_ptr<float> data, dummy;
727 std::shared_ptr<Tensor> out_tensor;
728 RETURN_IF_NOT_OK(TypeCast(tensor, &out_tensor, DataType("float32")));
729 RETURN_IF_NOT_OK(
730 TransformTensor(out_tensor->GetBuffer(), out_tensor->shape(), out_tensor->Size(), &data, data_ptr, &dummy));
731 if (data != nullptr) {
732 (*row_raw_data)[column_name] = std::move(*data);
733 }
734 } else if (column_type == DataType::DE_FLOAT32) {
735 std::unique_ptr<float> data, dummy;
736 RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy));
737 if (data != nullptr) {
738 (*row_raw_data)[column_name] = std::move(*data);
739 }
740 } else if (column_type == DataType::DE_FLOAT64) {
741 std::unique_ptr<double> data, dummy;
742 RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy));
743 if (data != nullptr) {
744 (*row_raw_data)[column_name] = std::move(*data);
745 }
746 }
747 return Status::OK();
748 }
749
FetchItemData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data)750 Status SaveToDisk::FetchItemData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
751 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
752 RETURN_UNEXPECTED_IF_NULL(tensor);
753 RETURN_UNEXPECTED_IF_NULL(row_raw_data);
754 RETURN_UNEXPECTED_IF_NULL(row_bin_data);
755 auto column_type = tensor->type();
756 Status s;
757 std::unique_ptr<std::vector<uint8_t>> data_ptr;
758 if (column_type == DataType::DE_BOOL) {
759 std::unique_ptr<int32_t> data;
760 std::unique_ptr<int8_t> dummy;
761 RETURN_IF_NOT_OK(
762 TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true));
763 if (data != nullptr) {
764 (*row_raw_data)[column_name] = std::move(*data);
765 }
766 } else if (column_type == DataType::DE_INT8 || column_type == DataType::DE_UINT8 ||
767 column_type == DataType::DE_INT16 || column_type == DataType::DE_UINT16 ||
768 column_type == DataType::DE_INT32 || column_type == DataType::DE_UINT32 ||
769 column_type == DataType::DE_INT64 || column_type == DataType::DE_UINT64) {
770 s = FetchIntData(tensor, column_name, row_raw_data, &data_ptr);
771 RETURN_IF_NOT_OK(s);
772 } else if (column_type == DataType::DE_FLOAT16 || column_type == DataType::DE_FLOAT32 ||
773 column_type == DataType::DE_FLOAT64) {
774 s = FetchFloatData(tensor, column_name, row_raw_data, &data_ptr);
775 RETURN_IF_NOT_OK(s);
776 } else if (column_type == DataType::DE_BYTES) {
777 std::unique_ptr<char> data;
778 std::unique_ptr<char> dummy;
779 CHECK_FAIL_RETURN_UNEXPECTED(tensor->shape().Rank() == 1 || tensor->shape().Rank() == 0,
780 "Currently, multi-dimensional bytes cannot be converted to MindRecord.");
781 if (tensor->shape().Rank() == 1) {
782 CHECK_FAIL_RETURN_UNEXPECTED(tensor->shape().AsVector()[0] == 1,
783 "Currently, multi-dimensional bytes cannot be converted to MindRecord.");
784 }
785 // current only support one bytes to mindrecord field
786 uint32_t string_length = 0;
787 RETURN_IF_NOT_OK(tensor->GetStringLength(&string_length));
788 RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer() + kOffsetSize * 2, TensorShape({1}), string_length, &data,
789 &data_ptr, &dummy));
790 } else if (column_type.IsString()) {
791 std::string_view sv;
792 RETURN_IF_NOT_OK(tensor->GetItemAt(&sv, {})); // assume scalar string tensor
793 std::string ss(sv);
794 (*row_raw_data)[column_name] = std::move(ss);
795 } else {
796 RETURN_STATUS_UNEXPECTED("Invalid dtype, got unexpected type when casting data: " + column_type.ToString());
797 }
798 if (data_ptr != nullptr) {
799 (*row_bin_data)[column_name] = std::move(data_ptr);
800 }
801 return Status::OK();
802 }
803
FetchDataFromTensorRow(const TensorRow & row,const std::unordered_map<std::string,int32_t> & column_name_id_map,nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data)804 Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
805 const std::unordered_map<std::string, int32_t> &column_name_id_map,
806 nlohmann::json *row_raw_data,
807 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
808 RETURN_UNEXPECTED_IF_NULL(row_raw_data);
809 RETURN_UNEXPECTED_IF_NULL(row_bin_data);
810 Status s;
811 s = ValidateInputParams(row_raw_data, row_bin_data, column_name_id_map);
812 if (s.IsError()) {
813 return s;
814 }
815 for (auto &col : column_name_id_map) {
816 auto idx = col.second;
817 auto column_name = col.first;
818 auto &tensor = row[idx];
819 s = FetchItemData(tensor, column_name, row_raw_data, row_bin_data);
820 RETURN_IF_NOT_OK(s);
821 }
822 return Status::OK();
823 }
824
825 template <typename T, typename S>
TransformTensor(const unsigned char * src,const TensorShape & shape,const int64_t num_of_elements,std::unique_ptr<T> * data,std::unique_ptr<std::vector<uint8_t>> * data_ptr,std::unique_ptr<S> * s,bool need_convert)826 Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
827 std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
828 std::unique_ptr<S> *s, bool need_convert) {
829 // No need to check src since we support some scenarios that src is nullptr and num_of_elements is 0.
830 RETURN_UNEXPECTED_IF_NULL(data);
831 RETURN_UNEXPECTED_IF_NULL(data_ptr);
832 RETURN_UNEXPECTED_IF_NULL(s);
833
834 *data_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(T));
835 if (need_convert) {
836 auto tmp_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(S));
837 (void)std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin());
838 auto s_ptr = reinterpret_cast<S *>(&(*(tmp_ptr->begin())));
839 auto el = std::make_unique<T>();
840 for (uint32_t i = 0; i < num_of_elements; ++i) {
841 *el = *(s_ptr + i);
842 auto t_ptr = reinterpret_cast<uint8_t *>(el.get());
843 for (uint32_t j = 0; j < sizeof(T); ++j) {
844 *((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j);
845 }
846 }
847 } else {
848 (void)std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin());
849 }
850 if (shape.empty()) {
851 *data = std::make_unique<T>();
852 auto t_ptr = reinterpret_cast<uint8_t *>((*data).get());
853 for (uint32_t i = 0; i < sizeof(T); ++i) {
854 *(t_ptr + i) = *((*data_ptr)->begin() + i);
855 }
856 }
857 return Status::OK();
858 }
859 #endif
860
Init(const std::shared_ptr<DatasetNode> & root)861 Status BuildVocabConsumer::Init(const std::shared_ptr<DatasetNode> &root) { return tree_adapter_->Compile(root, 1); }
862
Start()863 Status BuildVocabConsumer::Start() {
864 RETURN_IF_NOT_OK(CollectPipelineInfoStart("BuildVocabConsumer", "Start"));
865 // Getting one row would trigger building the vocab
866 TensorRow row;
867 RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
868 // The returned row would EOE which is an empty row
869 CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "BuildVocab: The fetched row from BuildVocab should be an EOE.");
870 RETURN_IF_NOT_OK(CollectPipelineInfoEnd("BuildVocabConsumer", "Start"));
871 return Status::OK();
872 }
GetDatasetSize(int64_t * size,bool estimate)873 Status DatasetSizeGetter::GetDatasetSize(int64_t *size, bool estimate) {
874 if (dataset_size_ == -1) {
875 RETURN_IF_NOT_OK(root_->GetDatasetSize(shared_from_this(), estimate, size));
876 dataset_size_ = *size; // save the previous result
877 }
878
879 *size = dataset_size_;
880 return Status::OK();
881 }
882
Init(const std::shared_ptr<DatasetNode> & root)883 Status DatasetSizeGetter::Init(const std::shared_ptr<DatasetNode> &root) {
884 root_ = root;
885 return Status::OK();
886 }
887
DryRun(const std::shared_ptr<DatasetNode> & ir_node,int64_t * dataset_size)888 Status DatasetSizeGetter::DryRun(const std::shared_ptr<DatasetNode> &ir_node, int64_t *dataset_size) {
889 RETURN_UNEXPECTED_IF_NULL(dataset_size);
890 std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
891 tree_adapters_.push_back(tree_adapter);
892 RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1));
893 TensorRow row;
894 RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
895 int64_t row_cnt = 0;
896 while (!row.empty()) {
897 ++row_cnt;
898 RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
899 }
900 *dataset_size = row_cnt;
901 return Status::OK();
902 }
903
GetRow(const std::shared_ptr<TreeAdapter> & tree_adapter,TensorRow * row)904 Status DatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row) {
905 RETURN_UNEXPECTED_IF_NULL(row);
906 return tree_adapter->GetNext(row);
907 }
908
Terminate()909 Status DatasetSizeGetter::Terminate() {
910 for (const auto &tree : tree_adapters_) {
911 RETURN_UNEXPECTED_IF_NULL(tree);
912 RETURN_UNEXPECTED_IF_NULL(tree->AllTasks());
913 RETURN_IF_NOT_OK(tree->AllTasks()->ServiceStop());
914 }
915 return Status::OK();
916 }
917 } // namespace dataset
918 } // namespace mindspore
919