• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "minddata/dataset/engine/consumers/tree_consumer.h"
25 #include "minddata/dataset/engine/datasetops/data_queue_op.h"
26 #include "minddata/dataset/engine/opt/pre/getter_pass.h"
27 #ifndef ENABLE_SECURITY
28 #include "minddata/dataset/engine/perf/auto_tune.h"
29 #endif
30 #include "minddata/dataset/engine/perf/info_collector.h"
31 #ifndef ENABLE_SECURITY
32 #include "minddata/dataset/engine/perf/profiling.h"
33 #endif
34 #include "minddata/dataset/engine/tree_adapter.h"
35 #ifndef ENABLE_ANDROID
36 #include "minddata/dataset/kernels/data/data_utils.h"
37 #include "minddata/mindrecord/include/shard_header.h"
38 #include "minddata/mindrecord/include/shard_index_generator.h"
39 #include "minddata/mindrecord/include/shard_writer.h"
40 #endif
41 #ifdef WITH_BACKEND
42 #include "utils/ms_context.h"
43 #endif
44 
45 namespace mindspore {
46 namespace dataset {
47 #ifndef ENABLE_SECURITY
48 using ProfilingRegistrationState = ProfilingManager::ProfilingRegistrationState;
49 #endif
50 // TreeConsumer
TreeConsumer()51 TreeConsumer::TreeConsumer() : TreeConsumer(1) {}
52 
TreeConsumer(int32_t num_epochs)53 TreeConsumer::TreeConsumer(int32_t num_epochs) : num_epochs_(num_epochs) {
54   tree_adapter_ = std::make_unique<TreeAdapter>();
55 }
56 
Init(const std::shared_ptr<DatasetNode> & root)57 Status TreeConsumer::Init(const std::shared_ptr<DatasetNode> &root) {
58   RETURN_IF_NOT_OK(tree_adapter_->Compile(root));
59 #ifndef ENABLE_SECURITY
60   profiling_manager_ = GlobalContext::profiling_manager();
61   RETURN_IF_NOT_OK(RegisterProfilingManager());
62 #endif
63   return Status::OK();
64 }
65 
Init(const std::shared_ptr<DatasetNode> & root,int64_t global_step,int64_t dataset_size)66 Status TreeConsumer::Init(const std::shared_ptr<DatasetNode> &root, int64_t global_step, int64_t dataset_size) {
67   MS_LOG(WARNING) << "TreeConsumer does not support initializing from intermediate epoch or step, change to "
68                      "initialize from the beginning.";
69   return Init(root);
70 }
71 
Terminate()72 Status TreeConsumer::Terminate() {
73   if (tree_adapter_->AllTasks() != nullptr) {
74     return tree_adapter_->AllTasks()->ServiceStop();
75   }
76   return Status::OK();
77 }
78 
79 #ifndef ENABLE_SECURITY
RegisterProfilingManager()80 Status IteratorConsumer::RegisterProfilingManager() {
81   auto profiler_state = profiling_manager_->GetProfilerTreeState(tree_adapter_->tree_.get());
82   // This should never happen
83   CHECK_FAIL_RETURN_UNEXPECTED(profiler_state != ProfilingManager::kEnabledTreeRegistered,
84                                "Something went wrong. Current tree is already registered with the MD Profiler");
85   if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered && profiling_manager_->IsProfiling()) {
86     MS_LOG(WARNING) << "Dataset Profiling is already enabled for a different data pipeline.";
87   } else if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered &&
88              profiling_manager_->IsAutotuning()) {
89     MS_LOG(WARNING) << "AutoTune for dataset is already enabled for a different data pipeline.";
90   } else if (profiler_state == ProfilingRegistrationState::kEnabledTreeNotRegistered) {
91     // Profiling infrastructures need to be initialized before Op launching
92     // Setup profiling manager
93     RETURN_IF_NOT_OK(profiling_manager_->RegisterTree(this->tree_adapter_.get()));
94     // dataset_iterator node is used for graph mode
95     std::shared_ptr<Tracing> iterator_tracing = std::make_shared<DatasetIteratorTracing>();
96     RETURN_IF_NOT_OK(profiling_manager_->RegisterTracingNode(iterator_tracing));
97     RETURN_IF_NOT_OK(tree_adapter_->SetProfilingManagerPtr(profiling_manager_, iterator_tracing));
98     // Launch Monitor Thread
99     RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor());
100   } else {
101     MS_LOG(INFO) << "Unable to register this tree with ProfilingManager.";
102   }
103   return Status::OK();
104 }
105 
RegisterProfilingManager()106 Status ToDevice::RegisterProfilingManager() {
107   auto profiler_state = profiling_manager_->GetProfilerTreeState(tree_adapter_->tree_.get());
108   // This should never happen
109   CHECK_FAIL_RETURN_UNEXPECTED(profiler_state != ProfilingManager::kEnabledTreeRegistered,
110                                "Something went wrong. Current tree is already registered with the MD Profiler");
111   if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered && profiling_manager_->IsProfiling()) {
112     MS_LOG(WARNING) << "Dataset Profiling is already enabled for a different data pipeline.";
113   } else if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered &&
114              profiling_manager_->IsAutotuning()) {
115     MS_LOG(WARNING) << "AutoTune for dataset is already enabled for a different data pipeline.";
116   } else if (profiler_state == ProfilingRegistrationState::kEnabledTreeNotRegistered) {
117     // Profiling infrastructures need to be initialized before Op launching
118     // Setup profiling manager
119     RETURN_IF_NOT_OK(profiling_manager_->RegisterTree(this->tree_adapter_.get()));
120     // device_queue node is used for graph mode
121     std::shared_ptr<Tracing> device_queue_tracing = std::make_shared<DeviceQueueTracing>();
122     RETURN_IF_NOT_OK(profiling_manager_->RegisterTracingNode(device_queue_tracing));
123     RETURN_IF_NOT_OK(tree_adapter_->SetProfilingManagerPtr(profiling_manager_));
124     // Launch Monitor Thread
125     RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor());
126   } else {
127     MS_LOG(INFO) << "Unable to register this tree with ProfilingManager.";
128   }
129   return Status::OK();
130 }
131 
RegisterProfilingManager()132 Status TreeConsumer::RegisterProfilingManager() {
133   if (profiling_manager_->IsProfiling()) {
134     return {StatusCode::kMDUnexpectedError, "Dataset Profiling is not supported for this kind of dataset."};
135   }
136   return Status::OK();
137 }
138 
InitAutoTune()139 Status TreeConsumer::InitAutoTune() {
140   auto profiler_state = profiling_manager_->GetProfilerTreeState(tree_adapter_->tree_.get());
141   if (profiler_state == ProfilingRegistrationState::kNotEnabled) {
142     // Init ProfilingManager to `Enable` it.
143     RETURN_IF_NOT_OK(profiling_manager_->Init(true));
144     // Register this tree
145     RETURN_IF_NOT_OK(RegisterProfilingManager());
146     // Start Profiler
147     RETURN_IF_NOT_OK(profiling_manager_->Start());
148     // AutoTune object and thread init
149     autotune_ = std::make_unique<AutoTune>(this->tree_adapter_.get(), GetProfilingManager());
150     RETURN_IF_NOT_OK(autotune_->LaunchThread());
151   } else if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered && profiling_manager_->IsProfiling()) {
152     MS_LOG(WARNING) << "Cannot enable AutoTune for the current data pipeline as Dataset Profiling is enabled for "
153                        "another data pipeline.";
154   } else if (profiler_state == ProfilingManager::kEnabledDifferentTreeRegistered &&
155              profiling_manager_->IsAutotuning()) {
156     MS_LOG(WARNING)
157       << "Cannot enable AutoTune for the current data pipeline as it is already enabled for another data pipeline.";
158   } else if (profiler_state == ProfilingManager::kEnabledTreeRegistered && profiling_manager_->IsProfiling()) {
159     MS_LOG(WARNING)
160       << "Cannot enable AutoTune for the current data pipeline as Dataset Profiling is already enabled for the "
161          "current data pipeline.";
162   } else {
163     MS_LOG(WARNING) << "Cannot enable AutoTune for the current data pipeline.";
164   }
165   return Status::OK();
166 }
167 #endif
168 
GetOffload()169 std::string TreeConsumer::GetOffload() { return (tree_adapter_->GetOffloadJson()).dump(); }
170 
171 // IteratorConsumer
Init(const std::shared_ptr<DatasetNode> & root,int64_t global_step,int64_t dataset_size)172 Status IteratorConsumer::Init(const std::shared_ptr<DatasetNode> &root, int64_t global_step, int64_t dataset_size) {
173   if (global_step != 0) {
174     tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
175   }
176   RETURN_IF_NOT_OK(tree_adapter_->Compile(root, num_epochs_, global_step, dataset_size));
177 #ifndef ENABLE_SECURITY
178   profiling_manager_ = GlobalContext::profiling_manager();
179   if (profiling_manager_->IsProfiling()) {
180     // Init has been called already
181     RETURN_IF_NOT_OK(RegisterProfilingManager());
182   }
183   if (GlobalContext::config_manager()->enable_autotune()) {
184     RETURN_IF_NOT_OK(InitAutoTune());
185   }
186 #endif
187   return Status::OK();
188 }
189 
GetNextAsVector(std::vector<TensorPtr> * const out)190 Status IteratorConsumer::GetNextAsVector(std::vector<TensorPtr> *const out) {
191   RETURN_UNEXPECTED_IF_NULL(out);
192   RETURN_IF_NOT_OK(CollectPipelineInfoStart("IteratorConsumer", "GetNextAsVector"));
193   out->clear();
194   TensorRow res;
195   RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
196 
197   // Return empty vector if there's no data
198   if (res.empty()) {
199     RETURN_IF_NOT_OK(
200       CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsVector", {{"TensorRowFlags", res.FlagName()}}));
201     return Status::OK();
202   }
203 
204   // Filter meta column
205   std::vector<size_t> to_keep_indices;
206   for (const auto &colMap : tree_adapter_->GetColumnNameMap()) {
207     std::string column_name = colMap.first;
208     // Need to filter meta column start with kDftMetaColumnPrefix
209     size_t pos = column_name.find(kDftMetaColumnPrefix);
210     if (pos != std::string::npos && pos == 0) {
211       continue;
212     }
213     to_keep_indices.push_back(colMap.second);
214   }
215   if (to_keep_indices.empty()) {
216     std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
217     err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
218     err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
219     RETURN_STATUS_UNEXPECTED(err_msg);
220   }
221   std::sort(to_keep_indices.begin(), to_keep_indices.end());
222   (void)std::transform(to_keep_indices.begin(), to_keep_indices.end(), std::back_inserter(*out),
223                        [&res](const auto &it) { return std::move(res[it]); });
224   RETURN_IF_NOT_OK(CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsVector"));
225   return Status::OK();
226 }
227 
GetNextAsMap(std::unordered_map<std::string,TensorPtr> * const out_map)228 Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> *const out_map) {
229   RETURN_UNEXPECTED_IF_NULL(out_map);
230   RETURN_IF_NOT_OK(CollectPipelineInfoStart("IteratorConsumer", "GetNextAsMap"));
231 
232   out_map->clear();
233   TensorRow res;
234   RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
235 
236   // Return empty map if there's no data
237   if (res.empty()) {
238     RETURN_IF_NOT_OK(CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsMap", {{"TensorRowFlags", res.FlagName()}}));
239     return Status::OK();
240   }
241 
242   // Populate the out map from the row and return it
243   for (const auto &colMap : tree_adapter_->GetColumnNameMap()) {
244     std::string column_name = colMap.first;
245     // Need to filter meta column start with kDftMetaColumnPrefix
246     size_t pos = column_name.find(kDftMetaColumnPrefix);
247     if (pos != std::string::npos && pos == 0) {
248       continue;
249     }
250     (*out_map)[colMap.first] = std::move(res[colMap.second]);
251   }
252   if (out_map->empty()) {
253     std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
254     err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
255     err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
256     RETURN_STATUS_UNEXPECTED(err_msg);
257   }
258   RETURN_IF_NOT_OK(CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsMap"));
259   return Status::OK();
260 }
261 
GetNextAsOrderedPair(std::vector<std::pair<std::string,std::shared_ptr<Tensor>>> * const vec)262 Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *const vec) {
263   CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
264   RETURN_IF_NOT_OK(CollectPipelineInfoStart("IteratorConsumer", "GetNextAsOrderedPair"));
265 
266   TensorRow curr_row;
267 
268   RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row));
269 
270   // Return empty pair if there's no data
271   if (curr_row.empty()) {
272     RETURN_IF_NOT_OK(
273       CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsOrderedPair", {{"TensorRowFlags", curr_row.FlagName()}}));
274     return Status::OK();
275   }
276 
277   size_t num_cols = curr_row.size();  // num_cols is non-empty.
278   // order the column names according to their ids
279   if (column_order_.empty()) {
280     for (const auto &itr : tree_adapter_->GetColumnNameMap()) {
281       int32_t ind = itr.second;
282       CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds.");
283       // Need to filter meta column start with kDftMetaColumnPrefix
284       size_t pos = itr.first.find(kDftMetaColumnPrefix);
285       if (pos != std::string::npos && pos == 0) {
286         continue;
287       }
288       column_order_[ind] = itr.first;
289     }
290   }
291 
292   if (column_order_.empty()) {
293     std::string err_msg = "No effective column found, maybe all columns are meta column and will be filtered. ";
294     err_msg += "If you want to output meta column please rename column name to a new one which is not start with ";
295     err_msg += "\"" + std::string(kDftMetaColumnPrefix) + "\"";
296     RETURN_STATUS_UNEXPECTED(err_msg);
297   }
298   vec->reserve(column_order_.size());
299 
300   std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec),
301                  [curr_row](const auto &col) { return std::make_pair(col.second, curr_row[col.first]); });
302   RETURN_IF_NOT_OK(CollectPipelineInfoEnd("IteratorConsumer", "GetNextAsOrderedPair"));
303   return Status::OK();
304 }
305 
306 // ToDevice
Init(const std::shared_ptr<DatasetNode> & root,int64_t global_step,int64_t dataset_size)307 Status ToDevice::Init(const std::shared_ptr<DatasetNode> &root, int64_t global_step, int64_t dataset_size) {
308   if (global_step != 0) {
309     tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
310   }
311   RETURN_IF_NOT_OK(tree_adapter_->Compile(root, num_epochs_, global_step, dataset_size));
312 #ifndef ENABLE_SECURITY
313   profiling_manager_ = GlobalContext::profiling_manager();
314   if (profiling_manager_->IsProfiling()) {
315     // Init has been called already
316     RETURN_IF_NOT_OK(RegisterProfilingManager());
317   }
318   if (GlobalContext::config_manager()->enable_autotune()) {
319     RETURN_IF_NOT_OK(InitAutoTune());
320   }
321 #endif
322   return Status::OK();
323 }
324 
Send()325 Status ToDevice::Send() {
326   RETURN_IF_NOT_OK(tree_adapter_->Launch());
327   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
328   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
329   return Status::OK();
330 }
331 
Continue()332 Status ToDevice::Continue() {
333   // tree_.root() must be DataQueueOp
334   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
335   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
336   DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
337   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DataQueueOp");
338   op->ContinueSend();
339   return Status::OK();
340 }
341 
Stop()342 Status ToDevice::Stop() {
343   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
344   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
345   DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
346   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DataQueueOp");
347   op->StopSend();
348 
349   return Status::OK();
350 }
351 
GetDataInfo(std::vector<DataType> * const types,std::vector<TensorShape> * const shapes)352 Status ToDevice::GetDataInfo(std::vector<DataType> *const types, std::vector<TensorShape> *const shapes) {
353   RETURN_UNEXPECTED_IF_NULL(types);
354   RETURN_UNEXPECTED_IF_NULL(shapes);
355   // tree_.root() must be DataQueueOp
356   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
357   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
358   DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
359   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DataQueueOp");
360   DATA_INFO data_info;
361   RETURN_IF_NOT_OK(op->GetDataInfo(&data_info));
362   for (auto el : data_info) {
363     types->push_back(el.first);
364     shapes->push_back(el.second);
365   }
366   return Status::OK();
367 }
368 
GetMbufQueueSize(size_t * queue_size)369 Status ToDevice::GetMbufQueueSize(size_t *queue_size) {
370   RETURN_UNEXPECTED_IF_NULL(queue_size);
371   // tree_.root() must be DataQueueOp
372   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
373   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
374   DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
375   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetMbufQueueSize only supported by DataQueueOp");
376   RETURN_IF_NOT_OK(op->GetMbufQueueSize(queue_size));
377   return Status::OK();
378 }
379 
GetSendInfo(std::vector<std::vector<double>> * send_info)380 Status ToDevice::GetSendInfo(std::vector<std::vector<double>> *send_info) {
381   RETURN_UNEXPECTED_IF_NULL(send_info);
382   // tree_.root() must be DataQueueOp
383   std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
384   CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
385   DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
386   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetSendInfo only supported by DataQueueOp");
387   DATA_INFO data_info;
388   *send_info = op->GetSendInfo();
389   return Status::OK();
390 }
391 
Terminate()392 Status ToDevice::Terminate() {
393 #ifdef WITH_BACKEND
394   RETURN_UNEXPECTED_IF_NULL(MsContext::GetInstance());
395   if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
396     std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
397     CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
398     DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
399     CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DataQueueOp");
400     op->StopWaiting();
401   }
402 #endif
403   return TreeConsumer::Terminate();
404 }
405 
Reset(int64_t step,const int64_t dataset_size)406 Status TreeConsumer::Reset(int64_t step, const int64_t dataset_size) {
407   MS_LOG(INFO) << "Resetting TreeConsumer";
408 
409   MS_LOG(INFO) << "Terminating pipeline with UUID:" << tree_adapter_->tree_->GetUniqueId();
410   std::shared_ptr<DatasetNode> old_root = tree_adapter_->input_ir_;
411   RETURN_IF_NOT_OK(this->Stop());
412   RETURN_IF_NOT_OK(this->Terminate());
413 #ifdef WITH_BACKEND
414   RETURN_UNEXPECTED_IF_NULL(MsContext::GetInstance());
415   if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice) {
416     // clear the device if GPU is used.
417     std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
418     CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
419     DataQueueOp *op = dynamic_cast<DataQueueOp *>(root.get());
420     if (op != nullptr) {
421       MS_LOG(INFO) << "Clearing the GPU device";
422       RETURN_IF_NOT_OK(op->ClearDevice());
423     }
424   }
425 #endif
426   tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
427   RETURN_IF_NOT_OK(tree_adapter_->Compile(old_root, num_epochs_, step, dataset_size));
428   RETURN_IF_NOT_OK(tree_adapter_->Launch());
429   MS_LOG(INFO) << "Launched a new pipeline after reset. UUID: " << tree_adapter_->tree_->GetUniqueId();
430   std::shared_ptr<DatasetOp> root2 = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
431   CHECK_FAIL_RETURN_UNEXPECTED(root2 != nullptr, "Root is a nullptr.");
432   return Status::OK();
433 }
434 
435 #ifndef ENABLE_ANDROID
436 // SaveToDisk
ValidateParams()437 Status SaveToDisk::ValidateParams() {
438   if (dataset_path_.empty()) {
439     std::string err = "SaveToDisk failed, dataset_path must not be empty";
440     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err);
441   }
442   Path dir(dataset_path_);
443   if (dir.IsDirectory()) {
444     std::string err = "SaveToDisk failed, dataset_path must not be a directory";
445     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err);
446   }
447   std::string real_path;
448   if (dir.ParentPath().empty()) {
449     dir = Path(".") / dir;
450   }
451   if (Path::RealPath(dir.ParentPath(), real_path).IsError()) {
452     std::string err_msg = "SaveToDisk failed, can not get real dataset path: " + dir.ParentPath();
453     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
454   }
455   if (access(dir.ParentPath().c_str(), R_OK) == -1) {
456     std::string err_msg = "SaveToDisk failed, no access to specified dataset path: " + dataset_path_;
457     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
458   }
459   if (num_files_ <= 0 || num_files_ > 1000) {
460     std::string err = "SaveToDisk failed, num_files must between 1 and 1000, but got " + std::to_string(num_files_);
461     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err);
462   }
463   if (dataset_type_ != "mindrecord") {
464     std::string err = "SaveToDisk failed, only \"mindrecord\" dataset format is supported, but got " + dataset_type_;
465     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err);
466   }
467   return Status::OK();
468 }
469 
Save()470 Status SaveToDisk::Save() {
471   RETURN_IF_NOT_OK(CollectPipelineInfoStart("SaveToDisk", "Save"));
472   std::vector<std::string> file_names;
473   if (num_files_ == 1) {
474     file_names.push_back(dataset_path_);
475   } else {
476     for (int32_t i = 0; i < num_files_; i++) {
477       file_names.push_back(dataset_path_ + std::to_string(i));
478     }
479   }
480 
481   auto mr_header = std::make_shared<mindrecord::ShardHeader>();
482   auto mr_writer = std::make_unique<mindrecord::ShardWriter>();
483   std::vector<std::string> blob_fields;
484   RETURN_IF_NOT_OK(mindrecord::ShardWriter::Initialize(&mr_writer, file_names));
485 
486   std::unordered_map<std::string, int32_t> column_name_id_map;
487   for (auto el : tree_adapter_->GetColumnNameMap()) {
488     std::string column_name = el.first;
489     (void)std::transform(column_name.begin(), column_name.end(), column_name.begin(),
490                          [](unsigned char c) { return ispunct(c) ? '_' : c; });
491     column_name_id_map[column_name] = el.second;
492   }
493 
494   TensorRow row;
495   uint64_t mr_schema_id = 0;
496   bool first_loop = true;  // build schema in first loop
497   auto PreTensorRowShapes = std::map<std::string, std::vector<int>>();
498 
499   do {
500     nlohmann::json row_raw_data;
501     std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
502     RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
503     if (row.empty()) {
504       break;
505     }
506     RETURN_IF_NOT_OK(CheckTensorRowShapes(column_name_id_map, row, &PreTensorRowShapes));
507     if (first_loop) {
508       nlohmann::json mr_json;
509       std::vector<std::string> index_fields;
510       RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields));
511       MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump();
512       RETURN_IF_NOT_OK(
513         mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id));
514       RETURN_IF_NOT_OK(mr_writer->SetShardHeader(mr_header));
515       first_loop = false;
516     }
517     // construct data
518     if (!row.empty()) {  // write data
519       RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data));
520       std::shared_ptr<std::vector<uint8_t>> output_bin_data;
521       RETURN_IF_NOT_OK(mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data));
522       std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data;
523       raw_data.insert(
524         std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data}));
525       std::vector<std::vector<uint8_t>> bin_data;
526       if (output_bin_data != nullptr) {
527         bin_data.emplace_back(*output_bin_data);
528       }
529       RETURN_IF_NOT_OK(mr_writer->WriteRawData(raw_data, bin_data));
530     }
531   } while (!row.empty());
532 
533   RETURN_IF_NOT_OK(mr_writer->Commit());
534   RETURN_IF_NOT_OK(mindrecord::ShardIndexGenerator::Finalize(file_names));
535   RETURN_IF_NOT_OK(CollectPipelineInfoEnd("SaveToDisk", "Save"));
536   return Status::OK();
537 }
538 
539 template <typename T>
map_compare(T const & lhs,T const & rhs)540 bool SaveToDisk::map_compare(T const &lhs, T const &rhs) {
541   return lhs.size() == rhs.size() && std::equal(lhs.begin(), lhs.end(), rhs.begin());
542 }
543 
CheckTensorRowShapes(const std::unordered_map<std::string,int32_t> & column_name_id_map,const TensorRow & row,std::map<std::string,std::vector<int>> * PreTensorRowShapes_ptr)544 Status SaveToDisk::CheckTensorRowShapes(const std::unordered_map<std::string, int32_t> &column_name_id_map,
545                                         const TensorRow &row,
546                                         std::map<std::string, std::vector<int>> *PreTensorRowShapes_ptr) {
547   std::map<std::string, std::vector<int>> CurrTensorRowShapes;
548   for (auto &col : column_name_id_map) {
549     auto idx = col.second;
550     auto column_name = col.first;
551     auto &tensor = row[idx];
552     auto column_type = tensor->type();
553     auto column_shape = tensor->shape();
554 
555     auto shapes = column_shape.AsVector();
556     std::vector<int> mr_shape(shapes.begin(), shapes.end());
557 
558     if (mr_shape.empty() || mr_shape.size() == 1) {
559       continue;  // ignore scalar and one dimension tensor
560     }
561     std::string mr_type;
562     std::string el = column_type.ToString();
563     if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
564       std::string err_msg("Invalid type, unsupported data type: " + el);
565       RETURN_STATUS_UNEXPECTED(err_msg);
566     } else {
567       mr_type = mindrecord::kTypesMap.at(el);
568     }
569     if (mr_type == "bytes" || mr_type == "string") {
570       continue;
571     }
572     mr_shape.erase(mr_shape.begin());  // ignore the first dimension
573     CurrTensorRowShapes[column_name] = mr_shape;
574   }
575   if (PreTensorRowShapes_ptr->empty()) {
576     *PreTensorRowShapes_ptr = CurrTensorRowShapes;
577     return Status::OK();
578   }
579   auto res = map_compare(*PreTensorRowShapes_ptr, CurrTensorRowShapes);
580   CHECK_FAIL_RETURN_UNEXPECTED(res,
581                                "Tensor with dynamic shape do not currently support saving. Except for the shape of "
582                                "dimension 0, the other dimension shapes must be fixed. "
583                                "You can reshape the Tensor to a fixed shape before saving.");
584   return Status::OK();
585 }
586 
FetchMetaFromTensorRow(const std::unordered_map<std::string,int32_t> & column_name_id_map,const TensorRow & row,nlohmann::json * schema,std::vector<std::string> * index_fields)587 Status SaveToDisk::FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
588                                           const TensorRow &row, nlohmann::json *schema,
589                                           std::vector<std::string> *index_fields) {
590   if (schema == nullptr) {
591     RETURN_STATUS_UNEXPECTED("schema can not be nullptr.");
592   }
593   if (index_fields == nullptr) {
594     RETURN_STATUS_UNEXPECTED("index_fields can not be nullptr.");
595   }
596   if (column_name_id_map.empty()) {
597     RETURN_STATUS_UNEXPECTED("column_name_id_map can not be nullptr..");
598   }
599   nlohmann::json dataset_schema;
600   for (auto &col : column_name_id_map) {
601     auto idx = col.second;
602     auto column_name = col.first;
603     auto &tensor = row[idx];
604     auto column_type = tensor->type();
605     auto column_shape = tensor->shape();
606 
607     std::string mr_type;
608     auto shapes = column_shape.AsVector();
609     std::vector<int> mr_shape(shapes.begin(), shapes.end());
610     std::string el = column_type.ToString();
611     dataset_schema[column_name] = el;
612     if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
613       std::string err_msg("Invalid type, unsupported data type: " + el);
614       RETURN_STATUS_UNEXPECTED(err_msg);
615     } else {
616       mr_type = mindrecord::kTypesMap.at(el);
617     }
618     if (mr_shape.empty()) {
619       (*schema)[column_name] = {{"type", mr_type}};
620     } else {
621       if (mr_type == "string") {  // mindrecord can not support string with shape.
622         std::string err_msg("Invalid data, mindrecord can not support multi-dimensional string tensor.");
623         RETURN_STATUS_UNEXPECTED(err_msg);
624       }
625       if (mr_type == "bytes") {  // ignore shape of bytes in minrecord
626         (*schema)[column_name] = {{"type", mr_type}};
627       } else {
628         mr_shape[0] = -1;  // make first dimension -1
629         (*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}};
630       }
631     }
632     if (mr_type == "bytes" || !mr_shape.empty()) {
633       continue;
634     }
635     index_fields->emplace_back(column_name);  // candidate of index fields
636   }
637   MS_LOG(DEBUG) << "Schema of dataset: " << dataset_schema.dump();
638   return Status::OK();
639 }
640 
ValidateInputParams(nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data,const std::unordered_map<std::string,int32_t> & column_name_id_map)641 inline Status ValidateInputParams(nlohmann::json *row_raw_data,
642                                   std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data,
643                                   const std::unordered_map<std::string, int32_t> &column_name_id_map) {
644   if (row_raw_data == nullptr) {
645     RETURN_STATUS_UNEXPECTED("row_raw_data can not be nullptr.");
646   }
647   if (row_bin_data == nullptr) {
648     RETURN_STATUS_UNEXPECTED("row_bin_data can not be nullptr.");
649   }
650   if (column_name_id_map.empty()) {
651     RETURN_STATUS_UNEXPECTED("column_name_id_map can not be nullptr.");
652   }
653   return Status::OK();
654 }
655 
FetchIntData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::unique_ptr<std::vector<uint8_t>> * data_ptr)656 Status SaveToDisk::FetchIntData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
657                                 std::unique_ptr<std::vector<uint8_t>> *data_ptr) {
658   RETURN_UNEXPECTED_IF_NULL(row_raw_data);
659   RETURN_UNEXPECTED_IF_NULL(data_ptr);
660   auto column_type = tensor->type();
661   Status s;
662   if (column_type == DataType::DE_INT8) {
663     std::unique_ptr<int32_t> data;
664     std::unique_ptr<int8_t> dummy;
665     RETURN_IF_NOT_OK(
666       TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
667     if (data != nullptr) {
668       (*row_raw_data)[column_name] = std::move(*data);
669     }
670   } else if (column_type == DataType::DE_UINT8) {
671     std::unique_ptr<int32_t> data;
672     std::unique_ptr<uint8_t> dummy;
673     RETURN_IF_NOT_OK(
674       TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
675     if (data != nullptr) {
676       (*row_raw_data)[column_name] = std::move(*data);
677     }
678   } else if (column_type == DataType::DE_INT16) {
679     std::unique_ptr<int32_t> data;
680     std::unique_ptr<int16_t> dummy;
681     RETURN_IF_NOT_OK(
682       TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
683     if (data != nullptr) {
684       (*row_raw_data)[column_name] = std::move(*data);
685     }
686   } else if (column_type == DataType::DE_UINT16) {
687     std::unique_ptr<int32_t> data;
688     std::unique_ptr<uint16_t> dummy;
689     RETURN_IF_NOT_OK(
690       TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
691     if (data != nullptr) {
692       (*row_raw_data)[column_name] = std::move(*data);
693     }
694   } else if (column_type == DataType::DE_INT32) {
695     std::unique_ptr<int32_t> data, dummy;
696     RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy));
697     if (data != nullptr) {
698       (*row_raw_data)[column_name] = std::move(*data);
699     }
700   } else if (column_type == DataType::DE_UINT32) {
701     std::unique_ptr<int64_t> data;
702     std::unique_ptr<uint32_t> dummy;
703     RETURN_IF_NOT_OK(
704       TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy, true));
705     if (data != nullptr) {
706       (*row_raw_data)[column_name] = std::move(*data);
707     }
708   } else if (column_type == DataType::DE_INT64 || column_type == DataType::DE_UINT64) {
709     std::unique_ptr<int64_t> data, dummy;
710     RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy));
711     if (data != nullptr) {
712       (*row_raw_data)[column_name] = std::move(*data);
713     }
714   }
715 
716   return Status::OK();
717 }
718 
FetchFloatData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::unique_ptr<std::vector<uint8_t>> * data_ptr)719 Status SaveToDisk::FetchFloatData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
720                                   std::unique_ptr<std::vector<uint8_t>> *data_ptr) {
721   RETURN_UNEXPECTED_IF_NULL(row_raw_data);
722   RETURN_UNEXPECTED_IF_NULL(data_ptr);
723   auto column_type = tensor->type();
724   Status s;
725   if (column_type == DataType::DE_FLOAT16) {
726     std::unique_ptr<float> data, dummy;
727     std::shared_ptr<Tensor> out_tensor;
728     RETURN_IF_NOT_OK(TypeCast(tensor, &out_tensor, DataType("float32")));
729     RETURN_IF_NOT_OK(
730       TransformTensor(out_tensor->GetBuffer(), out_tensor->shape(), out_tensor->Size(), &data, data_ptr, &dummy));
731     if (data != nullptr) {
732       (*row_raw_data)[column_name] = std::move(*data);
733     }
734   } else if (column_type == DataType::DE_FLOAT32) {
735     std::unique_ptr<float> data, dummy;
736     RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy));
737     if (data != nullptr) {
738       (*row_raw_data)[column_name] = std::move(*data);
739     }
740   } else if (column_type == DataType::DE_FLOAT64) {
741     std::unique_ptr<double> data, dummy;
742     RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, data_ptr, &dummy));
743     if (data != nullptr) {
744       (*row_raw_data)[column_name] = std::move(*data);
745     }
746   }
747   return Status::OK();
748 }
749 
FetchItemData(std::shared_ptr<Tensor> tensor,std::string column_name,nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data)750 Status SaveToDisk::FetchItemData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
751                                  std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
752   RETURN_UNEXPECTED_IF_NULL(tensor);
753   RETURN_UNEXPECTED_IF_NULL(row_raw_data);
754   RETURN_UNEXPECTED_IF_NULL(row_bin_data);
755   auto column_type = tensor->type();
756   Status s;
757   std::unique_ptr<std::vector<uint8_t>> data_ptr;
758   if (column_type == DataType::DE_BOOL) {
759     std::unique_ptr<int32_t> data;
760     std::unique_ptr<int8_t> dummy;
761     RETURN_IF_NOT_OK(
762       TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true));
763     if (data != nullptr) {
764       (*row_raw_data)[column_name] = std::move(*data);
765     }
766   } else if (column_type == DataType::DE_INT8 || column_type == DataType::DE_UINT8 ||
767              column_type == DataType::DE_INT16 || column_type == DataType::DE_UINT16 ||
768              column_type == DataType::DE_INT32 || column_type == DataType::DE_UINT32 ||
769              column_type == DataType::DE_INT64 || column_type == DataType::DE_UINT64) {
770     s = FetchIntData(tensor, column_name, row_raw_data, &data_ptr);
771     RETURN_IF_NOT_OK(s);
772   } else if (column_type == DataType::DE_FLOAT16 || column_type == DataType::DE_FLOAT32 ||
773              column_type == DataType::DE_FLOAT64) {
774     s = FetchFloatData(tensor, column_name, row_raw_data, &data_ptr);
775     RETURN_IF_NOT_OK(s);
776   } else if (column_type == DataType::DE_BYTES) {
777     std::unique_ptr<char> data;
778     std::unique_ptr<char> dummy;
779     CHECK_FAIL_RETURN_UNEXPECTED(tensor->shape().Rank() == 1 || tensor->shape().Rank() == 0,
780                                  "Currently, multi-dimensional bytes cannot be converted to MindRecord.");
781     if (tensor->shape().Rank() == 1) {
782       CHECK_FAIL_RETURN_UNEXPECTED(tensor->shape().AsVector()[0] == 1,
783                                    "Currently, multi-dimensional bytes cannot be converted to MindRecord.");
784     }
785     // current only support one bytes to mindrecord field
786     uint32_t string_length = 0;
787     RETURN_IF_NOT_OK(tensor->GetStringLength(&string_length));
788     RETURN_IF_NOT_OK(TransformTensor(tensor->GetBuffer() + kOffsetSize * 2, TensorShape({1}), string_length, &data,
789                                      &data_ptr, &dummy));
790   } else if (column_type.IsString()) {
791     std::string_view sv;
792     RETURN_IF_NOT_OK(tensor->GetItemAt(&sv, {}));  // assume scalar string tensor
793     std::string ss(sv);
794     (*row_raw_data)[column_name] = std::move(ss);
795   } else {
796     RETURN_STATUS_UNEXPECTED("Invalid dtype, got unexpected type when casting data: " + column_type.ToString());
797   }
798   if (data_ptr != nullptr) {
799     (*row_bin_data)[column_name] = std::move(data_ptr);
800   }
801   return Status::OK();
802 }
803 
FetchDataFromTensorRow(const TensorRow & row,const std::unordered_map<std::string,int32_t> & column_name_id_map,nlohmann::json * row_raw_data,std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> * row_bin_data)804 Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
805                                           const std::unordered_map<std::string, int32_t> &column_name_id_map,
806                                           nlohmann::json *row_raw_data,
807                                           std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
808   RETURN_UNEXPECTED_IF_NULL(row_raw_data);
809   RETURN_UNEXPECTED_IF_NULL(row_bin_data);
810   Status s;
811   s = ValidateInputParams(row_raw_data, row_bin_data, column_name_id_map);
812   if (s.IsError()) {
813     return s;
814   }
815   for (auto &col : column_name_id_map) {
816     auto idx = col.second;
817     auto column_name = col.first;
818     auto &tensor = row[idx];
819     s = FetchItemData(tensor, column_name, row_raw_data, row_bin_data);
820     RETURN_IF_NOT_OK(s);
821   }
822   return Status::OK();
823 }
824 
825 template <typename T, typename S>
TransformTensor(const unsigned char * src,const TensorShape & shape,const int64_t num_of_elements,std::unique_ptr<T> * data,std::unique_ptr<std::vector<uint8_t>> * data_ptr,std::unique_ptr<S> * s,bool need_convert)826 Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
827                                    std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
828                                    std::unique_ptr<S> *s, bool need_convert) {
829   // No need to check src since we support some scenarios that src is nullptr and num_of_elements is 0.
830   RETURN_UNEXPECTED_IF_NULL(data);
831   RETURN_UNEXPECTED_IF_NULL(data_ptr);
832   RETURN_UNEXPECTED_IF_NULL(s);
833 
834   *data_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(T));
835   if (need_convert) {
836     auto tmp_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(S));
837     (void)std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin());
838     auto s_ptr = reinterpret_cast<S *>(&(*(tmp_ptr->begin())));
839     auto el = std::make_unique<T>();
840     for (uint32_t i = 0; i < num_of_elements; ++i) {
841       *el = *(s_ptr + i);
842       auto t_ptr = reinterpret_cast<uint8_t *>(el.get());
843       for (uint32_t j = 0; j < sizeof(T); ++j) {
844         *((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j);
845       }
846     }
847   } else {
848     (void)std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin());
849   }
850   if (shape.empty()) {
851     *data = std::make_unique<T>();
852     auto t_ptr = reinterpret_cast<uint8_t *>((*data).get());
853     for (uint32_t i = 0; i < sizeof(T); ++i) {
854       *(t_ptr + i) = *((*data_ptr)->begin() + i);
855     }
856   }
857   return Status::OK();
858 }
859 #endif
860 
Init(const std::shared_ptr<DatasetNode> & root)861 Status BuildVocabConsumer::Init(const std::shared_ptr<DatasetNode> &root) { return tree_adapter_->Compile(root, 1); }
862 
Start()863 Status BuildVocabConsumer::Start() {
864   RETURN_IF_NOT_OK(CollectPipelineInfoStart("BuildVocabConsumer", "Start"));
865   // Getting one row would trigger building the vocab
866   TensorRow row;
867   RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
868   // The returned row would EOE which is an empty row
869   CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "BuildVocab: The fetched row from BuildVocab should be an EOE.");
870   RETURN_IF_NOT_OK(CollectPipelineInfoEnd("BuildVocabConsumer", "Start"));
871   return Status::OK();
872 }
GetDatasetSize(int64_t * size,bool estimate)873 Status DatasetSizeGetter::GetDatasetSize(int64_t *size, bool estimate) {
874   if (dataset_size_ == -1) {
875     RETURN_IF_NOT_OK(root_->GetDatasetSize(shared_from_this(), estimate, size));
876     dataset_size_ = *size;  // save the previous result
877   }
878 
879   *size = dataset_size_;
880   return Status::OK();
881 }
882 
Init(const std::shared_ptr<DatasetNode> & root)883 Status DatasetSizeGetter::Init(const std::shared_ptr<DatasetNode> &root) {
884   root_ = root;
885   return Status::OK();
886 }
887 
DryRun(const std::shared_ptr<DatasetNode> & ir_node,int64_t * dataset_size)888 Status DatasetSizeGetter::DryRun(const std::shared_ptr<DatasetNode> &ir_node, int64_t *dataset_size) {
889   RETURN_UNEXPECTED_IF_NULL(dataset_size);
890   std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
891   tree_adapters_.push_back(tree_adapter);
892   RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1));
893   TensorRow row;
894   RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
895   int64_t row_cnt = 0;
896   while (!row.empty()) {
897     ++row_cnt;
898     RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
899   }
900   *dataset_size = row_cnt;
901   return Status::OK();
902 }
903 
GetRow(const std::shared_ptr<TreeAdapter> & tree_adapter,TensorRow * row)904 Status DatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row) {
905   RETURN_UNEXPECTED_IF_NULL(row);
906   return tree_adapter->GetNext(row);
907 }
908 
Terminate()909 Status DatasetSizeGetter::Terminate() {
910   for (const auto &tree : tree_adapters_) {
911     RETURN_UNEXPECTED_IF_NULL(tree);
912     RETURN_UNEXPECTED_IF_NULL(tree->AllTasks());
913     RETURN_IF_NOT_OK(tree->AllTasks()->ServiceStop());
914   }
915   return Status::OK();
916 }
917 }  // namespace dataset
918 }  // namespace mindspore
919