• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/include/dataset/datasets.h"
18 #include <algorithm>
19 #include <fstream>
20 #include <unordered_set>
21 #include <utility>
22 #include <nlohmann/json.hpp>
23 
24 #include "minddata/dataset/core/tensor.h"
25 #include "minddata/dataset/engine/runtime_context.h"
26 #include "minddata/dataset/include/dataset/constants.h"
27 #include "minddata/dataset/include/dataset/iterator.h"
28 #include "minddata/dataset/include/dataset/samplers.h"
29 #include "minddata/dataset/include/dataset/transforms.h"
30 #include "minddata/dataset/util/path.h"
31 #include "minddata/dataset/util/status.h"
32 #include "minddata/dataset/core/client.h"
33 #include "minddata/dataset/core/type_id.h"
34 #include "minddata/dataset/engine/consumers/tree_consumer.h"
35 #include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
36 
37 #include "minddata/dataset/kernels/c_func_op.h"
38 #include "minddata/dataset/kernels/tensor_op.h"
39 
40 #ifndef ENABLE_ANDROID
41 #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
42 #endif
43 
44 #ifndef ENABLE_ANDROID
45 #include "minddata/dataset/text/sentence_piece_vocab.h"
46 #include "minddata/dataset/text/vocab.h"
47 #endif
48 
49 // Sampler headers (in alphabetical order)
50 #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
51 
52 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
53 
54 // IR non-leaf nodes
55 #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
56 #ifndef ENABLE_ANDROID
57 #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
58 #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
59 #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
60 #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
61 #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
62 #endif
63 
64 #include "minddata/dataset/engine/ir/datasetops/map_node.h"
65 #include "minddata/dataset/engine/ir/datasetops/project_node.h"
66 
67 #ifndef ENABLE_ANDROID
68 #include "minddata/dataset/engine/ir/datasetops/rename_node.h"
69 #endif
70 
71 #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
72 #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
73 
74 #ifndef ENABLE_ANDROID
75 #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
76 #include "minddata/dataset/engine/ir/datasetops/take_node.h"
77 #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
78 #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
79 #endif
80 
81 #include "minddata/dataset/core/config_manager.h"
82 #include "minddata/dataset/util/random.h"
83 #include "minddata/dataset/util/services.h"
84 
85 // IR leaf nodes
86 #include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
87 #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
88 
89 // IR leaf nodes disabled for android
90 #ifndef ENABLE_ANDROID
91 #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
92 #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
93 #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
94 #include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
95 #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
96 #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
97 #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
98 #include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
99 #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
100 #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
101 #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
102 #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
103 #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
104 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
105 #include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
106 #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
107 #include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
108 #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
109 #endif
110 
111 namespace mindspore {
112 namespace dataset {
113 
114 // convert MSTensorVec to DE TensorRow, return empty if fails
VecToRow(const MSTensorVec & v)115 TensorRow VecToRow(const MSTensorVec &v) {
116   TensorRow row;
117   row.reserve(v.size());
118   for (const MSTensor &t : v) {
119     std::shared_ptr<Tensor> rt;
120     Status rc = Tensor::CreateFromMSTensor(t, &rt);
121     if (rc.IsError()) {
122       MS_LOG_ERROR << "Convert from MSTensor to DETensor failed:" << rc.ToString() << ".";
123       return {};
124     }
125     row.emplace_back(rt);
126   }
127   return row;
128 }
129 
130 // convert DE TensorRow to MSTensorVec, won't fail
RowToVec(const TensorRow & v)131 MSTensorVec RowToVec(const TensorRow &v) {
132   MSTensorVec rv;
133   rv.reserve(v.size());
134   std::transform(v.begin(), v.end(), std::back_inserter(rv), [](std::shared_ptr<Tensor> t) -> MSTensor {
135     return mindspore::MSTensor(std::make_shared<DETensor>(t));
136   });
137   return rv;
138 }
139 
140 // Convert a std::function<TensorRow(TensorRow)> to std::function<MSTensorVec(MSTensor)> with this helper
FuncPtrConverter(std::function<MSTensorVec (MSTensorVec)> func,TensorRow in_row)141 TensorRow FuncPtrConverter(std::function<MSTensorVec(MSTensorVec)> func, TensorRow in_row) {
142   return VecToRow(func(RowToVec(in_row)));
143 }
144 
145 // Function to create the iterator, which will build and launch the execution tree.
CreateIteratorCharIF(std::vector<std::vector<char>> columns,int32_t num_epochs)146 std::shared_ptr<Iterator> Dataset::CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs) {
147   std::shared_ptr<Iterator> iter;
148   try {
149     auto ds = shared_from_this();
150 
151     // The specified columns will be selected from the dataset and passed down the pipeline
152     // in the order specified, other columns will be discarded.
153     if (!VectorCharToString(columns).empty()) {
154       ds = ds->Project(VectorCharToString(columns));
155     }
156 
157     iter = std::make_shared<Iterator>();
158     Status rc = iter->BuildAndLaunchTree(ds, num_epochs);
159     if (rc.IsError()) {
160       MS_LOG(ERROR) << "CreateIterator failed." << rc;
161       return nullptr;
162     }
163   } catch (const std::exception &err) {
164     MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what();
165     return nullptr;
166   }
167 
168   return iter;
169 }
170 
171 // Function to create the iterator, which will build and launch the execution tree.
CreatePullBasedIterator(std::vector<std::vector<char>> columns)172 std::shared_ptr<PullIterator> Dataset::CreatePullBasedIterator(std::vector<std::vector<char>> columns) {
173   // The specified columns will be selected from the dataset and passed down the pipeline
174   // in the order specified, other columns will be discarded.
175   // This code is not in a try/catch block because there is no execution tree class that will be created.
176   auto ds = shared_from_this();
177   if (!VectorCharToString(columns).empty()) {
178     ds = ds->Project(VectorCharToString(columns));
179   }
180 
181   std::shared_ptr<PullIterator> iter = std::make_shared<PullIterator>();
182   Status rc = iter->BuildAndLaunchTree(ds);
183   if (rc.IsError()) {
184     MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << rc;
185   }
186   RETURN_SECOND_IF_ERROR(rc, nullptr);
187   return iter;
188 }
189 
190 #ifndef ENABLE_ANDROID
191 // Function to return a transferred Node that transfers data through a device.
DeviceQueueCharIF(const std::vector<char> & queue_name,const std::vector<char> & device_type,int32_t device_id,int32_t num_epochs,bool send_epoch_end,int32_t total_batches,bool create_data_info_queue)192 bool Dataset::DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type,
193                                 int32_t device_id, int32_t num_epochs, bool send_epoch_end, int32_t total_batches,
194                                 bool create_data_info_queue) {
195   Status rc;
196 
197   // Build and launch tree
198   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
199   rc = runtime_context->Init();
200   if (rc.IsError()) {
201     MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
202     return false;
203   }
204 
205   // Add TransferNode IR on top of dataset
206   auto ds =
207     std::make_shared<TransferNode>(shared_from_this()->IRNode(), CharToString(queue_name), CharToString(device_type),
208                                    device_id, send_epoch_end, total_batches, create_data_info_queue);
209 
210   // Get ToDevice consumer
211   auto consumer = std::make_unique<ToDevice>(num_epochs);
212   ToDevice *consumer_ptr = consumer.get();
213   if (consumer_ptr == nullptr) {
214     MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
215     return false;
216   }
217   rc = consumer->Init(ds);
218   if (rc.IsError()) {
219     MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
220     return false;
221   }
222   runtime_context->AssignConsumer(std::move(consumer));
223 
224   // Send data to device
225   rc = consumer_ptr->Send();
226   if (rc.IsError()) {
227     MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc;
228     return false;
229   }
230 
231   return true;
232 }
233 
234 // Function to create the saver, which will build and launch the execution tree and save data
SaveCharIF(const std::vector<char> & dataset_path,int32_t num_files,const std::vector<char> & dataset_type)235 bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files,
236                          const std::vector<char> &dataset_type) {
237   Status rc;
238   // Build and launch tree
239   auto ds = shared_from_this();
240   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
241   rc = runtime_context->Init();
242   if (rc.IsError()) {
243     MS_LOG(ERROR) << "CreateSaver failed." << rc;
244     return false;
245   }
246 
247   // Get SaveToDisk consumer
248   auto consumer = std::make_unique<SaveToDisk>(CharToString(dataset_path), num_files, CharToString(dataset_type));
249   rc = consumer->ValidateParams();
250   if (rc.IsError()) {
251     MS_LOG(ERROR) << "CreateSaver failed." << rc;
252     return false;
253   }
254   SaveToDisk *consumer_ptr = consumer.get();
255   if (consumer_ptr == nullptr) {
256     MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
257     return false;
258   }
259   rc = consumer->Init(ds->IRNode());
260   if (rc.IsError()) {
261     MS_LOG(ERROR) << "CreateSaver failed." << rc;
262     return false;
263   }
264 
265   runtime_context->AssignConsumer(std::move(consumer));
266 
267   // Save data into file
268   rc = consumer_ptr->Save();
269   if (rc.IsError()) {
270     MS_LOG(ERROR) << "Saver: Failed to save data into file. Error status: " << rc;
271     return false;
272   }
273 
274   // Shut down the data pipeline
275   rc = runtime_context->Terminate();
276   if (rc.IsError()) {
277     MS_LOG(ERROR) << "Saver: Failed to shut down pipeline. Error status: " << rc;
278     return false;
279   }
280 
281   return true;
282 }
283 #endif
284 
285 // Constructor
Dataset()286 Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
287 
GetDatasetSize(bool estimate)288 int64_t Dataset::GetDatasetSize(bool estimate) {
289   int64_t dataset_size = -1;
290   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
291   RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
292   std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
293   DatasetSizeGetter *consumer = size_getter.get();
294   if (consumer == nullptr) {
295     MS_LOG(ERROR) << "DatasetSizeGetter: Failed to get consumer.";
296     return -1;
297   }
298   runtime_context->AssignConsumer(size_getter);
299   RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
300   RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1);
301   return dataset_size;
302 }
303 
GetOutputTypes()304 std::vector<mindspore::DataType> Dataset::GetOutputTypes() {
305   std::vector<DataType> types;
306   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
307   RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
308   TreeGetters *consumer = tree_getters_.get();
309   if (consumer == nullptr) {
310     MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
311     return std::vector<mindspore::DataType>();
312   }
313   runtime_context->AssignConsumer(tree_getters_);
314   RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
315   RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {});
316   std::vector<mindspore::DataType> ret_types;
317   std::transform(
318     types.begin(), types.end(), std::back_inserter(ret_types),
319     [](const DataType &d) -> mindspore::DataType { return static_cast<mindspore::DataType>(DETypeToMSType(d)); });
320   return ret_types;
321 }
322 
GetOutputShapes()323 std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
324   std::vector<TensorShape> shapes;
325   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
326   RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
327   TreeGetters *consumer = tree_getters_.get();
328   if (consumer == nullptr) {
329     MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
330     return std::vector<std::vector<int64_t>>();
331   }
332   runtime_context->AssignConsumer(tree_getters_);
333   RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
334   RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {});
335   std::vector<std::vector<int64_t>> ret_shapes;
336   std::transform(shapes.begin(), shapes.end(), std::back_inserter(ret_shapes),
337                  [](const TensorShape &s) -> std::vector<int64_t> { return s.AsVector(); });
338   return ret_shapes;
339 }
340 
GetNumClasses()341 int64_t Dataset::GetNumClasses() {
342   int64_t num_classes = -1;
343   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
344   RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
345   TreeGetters *consumer = tree_getters_.get();
346   if (consumer == nullptr) {
347     MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
348     return -1;
349   }
350   runtime_context->AssignConsumer(tree_getters_);
351   RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
352   RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1);
353   return num_classes;
354 }
355 
GetColumnNamesCharIF()356 std::vector<std::vector<char>> Dataset::GetColumnNamesCharIF() {
357   std::vector<std::string> col_names;
358   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
359   RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
360   TreeGetters *consumer = tree_getters_.get();
361   if (consumer == nullptr) {
362     MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
363     return std::vector<std::vector<char>>();
364   }
365   runtime_context->AssignConsumer(tree_getters_);
366   RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
367   RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {});
368   return VectorStringToChar(col_names);
369 }
370 
GetClassIndexingCharIF()371 std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> Dataset::GetClassIndexingCharIF() {
372   std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
373   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
374   RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
375   TreeGetters *consumer = tree_getters_.get();
376   if (consumer == nullptr) {
377     MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
378     return std::vector<std::pair<std::vector<char>, std::vector<int32_t>>>();
379   }
380   runtime_context->AssignConsumer(tree_getters_);
381   RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
382   RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {});
383   return ClassIndexStringToChar(output_class_indexing);
384 }
385 
386 /// \brief Function to create a SchemaObj
387 /// \param[in] schema_file Path of schema file
388 /// \return Shared pointer to the current schema
SchemaCharIF(const std::vector<char> & schema_file)389 std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file) {
390   auto schema = std::make_shared<SchemaObj>(CharToString(schema_file));
391   return schema->Init() ? schema : nullptr;
392 }
393 
394 // FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
395 // (In alphabetical order)
396 
397 // Function to create a Batch dataset
BatchDataset(std::shared_ptr<Dataset> input,int32_t batch_size,bool drop_remainder)398 BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder) {
399   // Default values
400   auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder);
401   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
402 }
403 
404 #ifndef ENABLE_ANDROID
405 // Function to create a BucketBatchByLength dataset
BucketBatchByLengthDataset(std::shared_ptr<Dataset> input,const std::vector<std::vector<char>> & column_names,const std::vector<int32_t> & bucket_boundaries,const std::vector<int32_t> & bucket_batch_sizes,std::function<MSTensorVec (MSTensorVec)> element_length_function,const std::map<std::vector<char>,std::pair<std::vector<int64_t>,MSTensor>> & pad_info,bool pad_to_bucket_boundary,bool drop_remainder)406 BucketBatchByLengthDataset::BucketBatchByLengthDataset(
407   std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &column_names,
408   const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
409   std::function<MSTensorVec(MSTensorVec)> element_length_function,
410   const std::map<std::vector<char>, std::pair<std::vector<int64_t>, MSTensor>> &pad_info, bool pad_to_bucket_boundary,
411   bool drop_remainder) {
412   std::shared_ptr<TensorOp> c_func = nullptr;
413   if (element_length_function != nullptr) {
414     c_func = std::make_shared<CFuncOp>(std::bind(FuncPtrConverter, element_length_function, std::placeholders::_1));
415   }
416 
417   std::map<std::vector<char>, std::pair<TensorShape, std::shared_ptr<Tensor>>> map;
418   for (auto const &p : pad_info) {
419     const MSTensor &t = p.second.second;
420     std::shared_ptr<Tensor> rt;
421     Status rc = Tensor::CreateFromMemory(TensorShape(t.Shape()), MSTypeToDEType(static_cast<TypeId>(t.DataType())),
422                                          (const uchar *)(t.Data().get()), t.DataSize(), &rt);
423     if (rc.IsError()) {
424       MS_LOG_ERROR << "Fail to create DETensor from MSTensor for pad_info: " << rc.ToString() << ".";
425       map.clear();
426       break;
427     }
428     map.insert({p.first, {TensorShape(p.second.first), rt}});
429   }
430 
431   auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), VectorCharToString(column_names),
432                                                       bucket_boundaries, bucket_batch_sizes, c_func,
433                                                       PadInfoCharToString(map), pad_to_bucket_boundary, drop_remainder);
434 
435   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
436 }
437 
ConcatDataset(const std::vector<std::shared_ptr<Dataset>> & datasets)438 ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
439   std::vector<std::shared_ptr<DatasetNode>> all_datasets;
440   (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
441                        [](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
442                          return (dataset != nullptr) ? dataset->IRNode() : nullptr;
443                        });
444 
445   auto ds = std::make_shared<ConcatNode>(all_datasets);
446 
447   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
448 }
449 
FilterDataset(std::shared_ptr<Dataset> input,std::function<MSTensorVec (MSTensorVec)> predicate,const std::vector<std::vector<char>> & input_columns)450 FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<MSTensorVec(MSTensorVec)> predicate,
451                              const std::vector<std::vector<char>> &input_columns) {
452   std::shared_ptr<TensorOp> c_func = nullptr;
453   if (predicate) {
454     c_func = std::make_shared<CFuncOp>(std::bind(FuncPtrConverter, predicate, std::placeholders::_1));
455   }
456   auto ds = std::make_shared<FilterNode>(input->IRNode(), c_func, VectorCharToString(input_columns));
457 
458   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
459 }
460 #endif
461 
MapDataset(std::shared_ptr<Dataset> input,std::vector<std::shared_ptr<TensorOperation>> operations,const std::vector<std::vector<char>> & input_columns,const std::vector<std::vector<char>> & output_columns,const std::vector<std::vector<char>> & project_columns,const std::shared_ptr<DatasetCache> & cache,std::vector<std::shared_ptr<DSCallback>> callbacks)462 MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
463                        const std::vector<std::vector<char>> &input_columns,
464                        const std::vector<std::vector<char>> &output_columns,
465                        const std::vector<std::vector<char>> &project_columns,
466                        const std::shared_ptr<DatasetCache> &cache, std::vector<std::shared_ptr<DSCallback>> callbacks) {
467   auto ds = std::make_shared<MapNode>(input->IRNode(), operations, VectorCharToString(input_columns),
468                                       VectorCharToString(output_columns), VectorCharToString(project_columns), cache,
469                                       callbacks);
470 
471   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
472 }
473 
ProjectDataset(std::shared_ptr<Dataset> input,const std::vector<std::vector<char>> & columns)474 ProjectDataset::ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns) {
475   auto ds = std::make_shared<ProjectNode>(input->IRNode(), VectorCharToString(columns));
476 
477   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
478 }
479 #ifndef ENABLE_ANDROID
RenameDataset(std::shared_ptr<Dataset> input,const std::vector<std::vector<char>> & input_columns,const std::vector<std::vector<char>> & output_columns)480 RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &input_columns,
481                              const std::vector<std::vector<char>> &output_columns) {
482   auto ds = std::make_shared<RenameNode>(input->IRNode(), VectorCharToString(input_columns),
483                                          VectorCharToString(output_columns));
484 
485   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
486 }
487 #endif
488 
RepeatDataset(std::shared_ptr<Dataset> input,int32_t count)489 RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) {
490   auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);
491 
492   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
493 }
494 
ShuffleDataset(std::shared_ptr<Dataset> input,int32_t buffer_size)495 ShuffleDataset::ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size) {
496   // Pass in reshuffle_each_epoch with true
497   auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true);
498 
499   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
500 }
501 
502 #ifndef ENABLE_ANDROID
SkipDataset(std::shared_ptr<Dataset> input,int32_t count)503 SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) {
504   auto ds = std::make_shared<SkipNode>(input->IRNode(), count);
505 
506   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
507 }
508 
TakeDataset(std::shared_ptr<Dataset> input,int32_t count)509 TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
510   auto ds = std::make_shared<TakeNode>(input->IRNode(), count);
511 
512   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
513 }
514 
ZipDataset(const std::vector<std::shared_ptr<Dataset>> & datasets)515 ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
516   std::vector<std::shared_ptr<DatasetNode>> all_datasets;
517   (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
518                        [](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
519                          return (dataset != nullptr) ? dataset->IRNode() : nullptr;
520                        });
521   auto ds = std::make_shared<ZipNode>(all_datasets);
522 
523   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
524 }
525 #endif
GetBatchSize()526 int64_t Dataset::GetBatchSize() {
527   int64_t batch_size = -1;
528   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
529   RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
530   RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
531   RETURN_SECOND_IF_ERROR(tree_getters_->GetBatchSize(&batch_size), -1);
532   return batch_size;
533 }
534 
GetRepeatCount()535 int64_t Dataset::GetRepeatCount() {
536   int64_t repeat_count = 0;
537   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
538   RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
539   RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0);
540   RETURN_SECOND_IF_ERROR(tree_getters_->GetRepeatCount(&repeat_count), 0);
541   return repeat_count;
542 }
543 
SetNumWorkers(int32_t num_workers)544 std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
545   if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) {
546     return nullptr;
547   }
548   return shared_from_this();
549 }
550 
551 #ifndef ENABLE_ANDROID
BuildSentencePieceVocabCharIF(const std::vector<std::vector<char>> & col_names,int32_t vocab_size,float character_coverage,SentencePieceModel model_type,const std::map<std::vector<char>,std::vector<char>> & params)552 std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocabCharIF(
553   const std::vector<std::vector<char>> &col_names, int32_t vocab_size, float character_coverage,
554   SentencePieceModel model_type, const std::map<std::vector<char>, std::vector<char>> &params) {
555   auto vocab = std::make_shared<SentencePieceVocab>();
556   auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, VectorCharToString(col_names), vocab_size,
557                                                      character_coverage, model_type, UnorderedMapCharToString(params));
558 
559   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
560   Status rc = runtime_context->Init();
561   if (rc.IsError()) {
562     MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
563     return nullptr;
564   }
565 
566   auto consumer = std::make_unique<BuildVocabConsumer>();
567   BuildVocabConsumer *bv_consumer = consumer.get();
568   if (bv_consumer == nullptr) {
569     MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
570     return nullptr;
571   }
572   rc = consumer->Init(ds);
573   if (rc.IsError()) {
574     MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
575     return nullptr;
576   }
577   runtime_context->AssignConsumer(std::move(consumer));
578 
579   // Run tree here to starting building SentencePieceVocab
580   rc = bv_consumer->Start();
581   if (rc.IsError()) {
582     MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to start consumer. Error status: " << rc;
583     return nullptr;
584   }
585   return vocab;
586 }
587 
BuildVocabCharIF(const std::vector<std::vector<char>> & columns,const std::pair<int64_t,int64_t> & freq_range,int64_t top_k,const std::vector<std::vector<char>> & special_tokens,bool special_first)588 std::shared_ptr<Vocab> Dataset::BuildVocabCharIF(const std::vector<std::vector<char>> &columns,
589                                                  const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
590                                                  const std::vector<std::vector<char>> &special_tokens,
591                                                  bool special_first) {
592   auto vocab = std::make_shared<Vocab>();
593   auto ds = std::make_shared<BuildVocabNode>(IRNode(), vocab, VectorCharToString(columns), freq_range, top_k,
594                                              VectorCharToString(special_tokens), special_first);
595 
596   std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
597   Status rc = runtime_context->Init();
598   if (rc.IsError()) {
599     MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;
600     return nullptr;
601   }
602 
603   auto consumer = std::make_unique<BuildVocabConsumer>();
604   BuildVocabConsumer *bv_consumer = consumer.get();
605   if (bv_consumer == nullptr) {
606     MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
607     return nullptr;
608   }
609   rc = consumer->Init(ds);
610   if (rc.IsError()) {
611     MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;
612     return nullptr;
613   }
614   runtime_context->AssignConsumer(std::move(consumer));
615 
616   // Run tree here to starting building vocab
617   rc = bv_consumer->Start();
618   if (rc.IsError()) {
619     MS_LOG(ERROR) << "BuildVocab: Failed to start consumer. Error status: " << rc;
620     return nullptr;
621   }
622   return vocab;
623 }
624 #endif
625 
Batch(int32_t batch_size,bool drop_remainder)626 std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
627   return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder);
628 }
629 
630 struct SchemaObj::Data {
631   int32_t num_rows_;
632   std::string dataset_type_;
633   std::string schema_file_;
634   nlohmann::json columns_;
635 };
636 
SchemaObj(const std::vector<char> & schema_file)637 SchemaObj::SchemaObj(const std::vector<char> &schema_file) : data_(std::make_shared<Data>()) {
638   data_->schema_file_ = CharToString(schema_file);
639   data_->dataset_type_ = "";
640   data_->num_rows_ = 0;
641 }
642 
643 // SchemaObj Init function
Init()644 Status SchemaObj::Init() {
645   if (data_ != nullptr && !data_->schema_file_.empty()) {
646     std::string real_path;
647     RETURN_IF_NOT_OK(Path::RealPath(data_->schema_file_, real_path));
648     Path schema_file(real_path);
649     CHECK_FAIL_RETURN_UNEXPECTED(schema_file.Exists(),
650                                  "The file " + data_->schema_file_ + " does not exist or permission denied!");
651 
652     nlohmann::json js;
653     try {
654       std::ifstream in(real_path);
655       in >> js;
656       CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
657                                    "\"columns\" node is required in the schema json file.");
658     } catch (const std::exception &err) {
659       std::string err_msg = "Schema file failed to load: ";
660       RETURN_STATUS_SYNTAX_ERROR(err_msg);
661     }
662     return from_json(js);
663   }
664   return Status::OK();
665 }
666 
667 // Function to add a column to schema with a mstype de_type and known shape
add_column_char(const std::vector<char> & name,mindspore::DataType de_type,const std::vector<int32_t> & shape)668 Status SchemaObj::add_column_char(const std::vector<char> &name, mindspore::DataType de_type,
669                                   const std::vector<int32_t> &shape) {
670   DataType data_type = dataset::MSTypeToDEType(static_cast<TypeId>(de_type));
671   return add_column_char(name, StringToChar(data_type.ToString()), shape);
672 }
673 
674 // Function to add a column to schema with a string de_type and known shape
add_column_char(const std::vector<char> & name,const std::vector<char> & de_type,const std::vector<int32_t> & shape)675 Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vector<char> &de_type,
676                                   const std::vector<int32_t> &shape) {
677   DataType data_type(CharToString(de_type));
678   CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
679 
680   nlohmann::json new_column;
681   new_column["name"] = CharToString(name);
682   new_column["type"] = data_type.ToString();
683   new_column["shape"] = shape;
684   new_column["rank"] = shape.size();
685 
686   data_->columns_.push_back(new_column);
687   return Status::OK();
688 }
689 
690 // Function to add a column to schema with a mstype de_type and without shape
add_column_char(const std::vector<char> & name,mindspore::DataType de_type)691 Status SchemaObj::add_column_char(const std::vector<char> &name, mindspore::DataType de_type) {
692   DataType data_type = dataset::MSTypeToDEType(static_cast<TypeId>(de_type));
693   return add_column_char(name, StringToChar(data_type.ToString()));
694 }
695 
696 // Function to add a column to schema with a string de_type and without shape
add_column_char(const std::vector<char> & name,const std::vector<char> & de_type)697 Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vector<char> &de_type) {
698   DataType data_type(CharToString(de_type));
699   CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
700 
701   nlohmann::json new_column;
702   new_column["name"] = CharToString(name);
703   new_column["type"] = data_type.ToString();
704   new_column["rank"] = 1;
705 
706   data_->columns_.push_back(new_column);
707   return Status::OK();
708 }
709 
schema_to_json(nlohmann::json * out_json)710 Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
711   nlohmann::json json_file;
712   json_file["columns"] = data_->columns_;
713   std::string str_dataset_type_(data_->dataset_type_);
714   if (str_dataset_type_ != "") {
715     json_file["datasetType"] = str_dataset_type_;
716   }
717 
718   if (data_->num_rows_ > 0) {
719     json_file["numRows"] = data_->num_rows_;
720   }
721   *out_json = json_file;
722   return Status::OK();
723 }
724 
to_json_char()725 const std::vector<char> SchemaObj::to_json_char() {
726   nlohmann::json json_file;
727   this->schema_to_json(&json_file);
728   return StringToChar(json_file.dump(2));
729 }
730 
set_dataset_type(std::string dataset_type)731 void SchemaObj::set_dataset_type(std::string dataset_type) { data_->dataset_type_ = dataset_type.data(); }
732 
set_num_rows(int32_t num_rows)733 void SchemaObj::set_num_rows(int32_t num_rows) { data_->num_rows_ = num_rows; }
734 
get_num_rows() const735 int32_t SchemaObj::get_num_rows() const { return data_->num_rows_; }
736 
parse_column(nlohmann::json columns)737 Status SchemaObj::parse_column(nlohmann::json columns) {
738   std::string name, de_type;
739   std::vector<int32_t> shape;
740 
741   data_->columns_.clear();
742   if (columns.type() == nlohmann::json::value_t::array) {
743     // reference to python list
744     for (auto column : columns) {
745       auto key_name = column.find("name");
746       if (key_name == column.end()) {
747         RETURN_STATUS_SYNTAX_ERROR("Column's name is missing");
748       }
749       name = *key_name;
750 
751       auto key_type = column.find("type");
752       if (key_type == column.end()) {
753         RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
754       }
755       de_type = *key_type;
756 
757       shape.clear();
758       auto key_shape = column.find("shape");
759       if (key_shape != column.end()) {
760         shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
761       }
762       RETURN_IF_NOT_OK(add_column(name, de_type, shape));
763     }
764   } else if (columns.type() == nlohmann::json::value_t::object) {
765     for (const auto &it_child : columns.items()) {
766       name = it_child.key();
767       auto key_type = it_child.value().find("type");
768       if (key_type == it_child.value().end()) {
769         RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
770       }
771       de_type = *key_type;
772 
773       shape.clear();
774       auto key_shape = it_child.value().find("shape");
775       if (key_shape != it_child.value().end()) {
776         shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
777       }
778 
779       RETURN_IF_NOT_OK(add_column(name, de_type, shape));
780     }
781   } else {
782     RETURN_STATUS_SYNTAX_ERROR("columns must be dict or list, columns contain name, type, shape(optional).");
783   }
784   return Status::OK();
785 }
786 
from_json(nlohmann::json json_obj)787 Status SchemaObj::from_json(nlohmann::json json_obj) {
788   for (const auto &it_child : json_obj.items()) {
789     if (it_child.key() == "datasetType") {
790       std::string str_dataset_type_ = it_child.value();
791       data_->dataset_type_ = str_dataset_type_.data();
792     } else if (it_child.key() == "numRows") {
793       data_->num_rows_ = it_child.value();
794     } else if (it_child.key() == "columns") {
795       RETURN_IF_NOT_OK(parse_column(it_child.value()));
796     } else {
797       RETURN_STATUS_SYNTAX_ERROR("Unknown field " + it_child.key());
798     }
799   }
800   if (data_->columns_.empty()) {
801     RETURN_STATUS_SYNTAX_ERROR("Columns are missing.");
802   }
803   if (data_->num_rows_ < 0) {
804     RETURN_STATUS_SYNTAX_ERROR("numRows must be greater than or equal to 0");
805   }
806 
807   return Status::OK();
808 }
809 
FromJSONStringCharIF(const std::vector<char> & json_string)810 Status SchemaObj::FromJSONStringCharIF(const std::vector<char> &json_string) {
811   try {
812     nlohmann::json js = nlohmann::json::parse(CharToString(json_string));
813     CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
814                                  "\"columns\" node is required in the schema json JSON.");
815     RETURN_IF_NOT_OK(from_json(js));
816   } catch (const std::exception &err) {
817     std::string err_msg = "FromJSONString: JSON string failed to parse: ";
818     err_msg += err.what();
819     RETURN_STATUS_SYNTAX_ERROR(err_msg);
820   }
821   return Status::OK();
822 }
823 
ParseColumnStringCharIF(const std::vector<char> & json_string)824 Status SchemaObj::ParseColumnStringCharIF(const std::vector<char> &json_string) {
825   try {
826     nlohmann::json js = nlohmann::json::parse(CharToString(json_string));
827     RETURN_IF_NOT_OK(parse_column(js));
828   } catch (const std::exception &err) {
829     std::string err_msg = "ParseColumnString: JSON string failed to parse: ";
830     err_msg += err.what();
831     RETURN_STATUS_SYNTAX_ERROR(err_msg);
832   }
833   return Status::OK();
834 }
835 
836 // OTHER FUNCTIONS
837 
838 #ifndef ENABLE_ANDROID
839 
CreateDatasetCacheCharIF(session_id_type id,uint64_t mem_sz,bool spill,std::optional<std::vector<char>> hostname,std::optional<int32_t> port,std::optional<int32_t> num_connections,std::optional<int32_t> prefetch_sz)840 std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(session_id_type id, uint64_t mem_sz, bool spill,
841                                                        std::optional<std::vector<char>> hostname,
842                                                        std::optional<int32_t> port,
843                                                        std::optional<int32_t> num_connections,
844                                                        std::optional<int32_t> prefetch_sz) {
845   auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
846   return cache;
847 }
848 #endif
849 
AlbumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & data_schema,const std::vector<std::vector<char>> & column_names,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)850 AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
851                            const std::vector<std::vector<char>> &column_names, bool decode,
852                            const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
853   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
854   auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
855                                         VectorCharToString(column_names), decode, sampler_obj, cache);
856   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
857 }
858 
AlbumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & data_schema,const std::vector<std::vector<char>> & column_names,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)859 AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
860                            const std::vector<std::vector<char>> &column_names, bool decode, const Sampler *sampler,
861                            const std::shared_ptr<DatasetCache> &cache) {
862   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
863   auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
864                                         VectorCharToString(column_names), decode, sampler_obj, cache);
865   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
866 }
AlbumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & data_schema,const std::vector<std::vector<char>> & column_names,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)867 AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
868                            const std::vector<std::vector<char>> &column_names, bool decode,
869                            const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
870   auto sampler_obj = sampler.get().Parse();
871   auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
872                                         VectorCharToString(column_names), decode, sampler_obj, cache);
873   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
874 }
875 
876 #ifndef ENABLE_ANDROID
CelebADataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,bool decode,const std::set<std::vector<char>> & extensions,const std::shared_ptr<DatasetCache> & cache)877 CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
878                              const std::shared_ptr<Sampler> &sampler, bool decode,
879                              const std::set<std::vector<char>> &extensions,
880                              const std::shared_ptr<DatasetCache> &cache) {
881   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
882   auto ds = std::make_shared<CelebANode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, decode,
883                                          SetCharToString(extensions), cache);
884   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
885 }
CelebADataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,bool decode,const std::set<std::vector<char>> & extensions,const std::shared_ptr<DatasetCache> & cache)886 CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
887                              const Sampler *sampler, bool decode, const std::set<std::vector<char>> &extensions,
888                              const std::shared_ptr<DatasetCache> &cache) {
889   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
890   auto ds = std::make_shared<CelebANode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, decode,
891                                          SetCharToString(extensions), cache);
892   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
893 }
CelebADataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,bool decode,const std::set<std::vector<char>> & extensions,const std::shared_ptr<DatasetCache> & cache)894 CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
895                              const std::reference_wrapper<Sampler> sampler, bool decode,
896                              const std::set<std::vector<char>> &extensions,
897                              const std::shared_ptr<DatasetCache> &cache) {
898   auto sampler_obj = sampler.get().Parse();
899   auto ds = std::make_shared<CelebANode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, decode,
900                                          SetCharToString(extensions), cache);
901   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
902 }
903 
Cifar10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)904 Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
905                                const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
906   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
907   auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
908   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
909 }
Cifar10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)910 Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
911                                const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
912   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
913   auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
914   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
915 }
Cifar10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)916 Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
917                                const std::reference_wrapper<Sampler> sampler,
918                                const std::shared_ptr<DatasetCache> &cache) {
919   auto sampler_obj = sampler.get().Parse();
920   auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
921   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
922 }
923 
Cifar100Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)924 Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
925                                  const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
926   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
927   auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
928   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
929 }
Cifar100Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)930 Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
931                                  const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
932   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
933   auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
934   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
935 }
Cifar100Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)936 Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
937                                  const std::reference_wrapper<Sampler> sampler,
938                                  const std::shared_ptr<DatasetCache> &cache) {
939   auto sampler_obj = sampler.get().Parse();
940   auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
941   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
942 }
943 
CityscapesDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & quality_mode,const std::vector<char> & task,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)944 CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
945                                      const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
946                                      const std::shared_ptr<Sampler> &sampler,
947                                      const std::shared_ptr<DatasetCache> &cache) {
948   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
949   auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
950                                              CharToString(task), decode, sampler_obj, cache);
951   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
952 }
953 
CityscapesDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & quality_mode,const std::vector<char> & task,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)954 CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
955                                      const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
956                                      const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
957   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
958   auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
959                                              CharToString(task), decode, sampler_obj, cache);
960   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
961 }
962 
CityscapesDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & quality_mode,const std::vector<char> & task,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)963 CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
964                                      const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
965                                      const std::reference_wrapper<Sampler> sampler,
966                                      const std::shared_ptr<DatasetCache> &cache) {
967   auto sampler_obj = sampler.get().Parse();
968   auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
969                                              CharToString(task), decode, sampler_obj, cache);
970   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
971 }
972 
CLUEDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<char> & task,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)973 CLUEDataset::CLUEDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &task,
974                          const std::vector<char> &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
975                          int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
976   auto ds = std::make_shared<CLUENode>(VectorCharToString(dataset_files), CharToString(task), CharToString(usage),
977                                        num_samples, shuffle, num_shards, shard_id, cache);
978   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
979 }
980 
CocoDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,const std::vector<char> & task,const bool & decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache,const bool & extra_metadata)981 CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
982                          const std::vector<char> &task, const bool &decode, const std::shared_ptr<Sampler> &sampler,
983                          const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
984   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
985   auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
986                                        decode, sampler_obj, cache, extra_metadata);
987   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
988 }
CocoDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,const std::vector<char> & task,const bool & decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache,const bool & extra_metadata)989 CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
990                          const std::vector<char> &task, const bool &decode, const Sampler *sampler,
991                          const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
992   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
993   auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
994                                        decode, sampler_obj, cache, extra_metadata);
995   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
996 }
CocoDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,const std::vector<char> & task,const bool & decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache,const bool & extra_metadata)997 CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
998                          const std::vector<char> &task, const bool &decode,
999                          const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache,
1000                          const bool &extra_metadata) {
1001   auto sampler_obj = sampler.get().Parse();
1002   auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
1003                                        decode, sampler_obj, cache, extra_metadata);
1004   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1005 }
1006 
CSVDataset(const std::vector<std::vector<char>> & dataset_files,char field_delim,const std::vector<std::shared_ptr<CsvBase>> & column_defaults,const std::vector<std::vector<char>> & column_names,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1007 CSVDataset::CSVDataset(const std::vector<std::vector<char>> &dataset_files, char field_delim,
1008                        const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
1009                        const std::vector<std::vector<char>> &column_names, int64_t num_samples, ShuffleMode shuffle,
1010                        int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
1011   auto ds =
1012     std::make_shared<CSVNode>(VectorCharToString(dataset_files), field_delim, column_defaults,
1013                               VectorCharToString(column_names), num_samples, shuffle, num_shards, shard_id, cache);
1014   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1015 }
1016 
DIV2KDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & downgrade,int32_t scale,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1017 DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1018                            const std::vector<char> &downgrade, int32_t scale, bool decode,
1019                            const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1020   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1021   auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
1022                                         decode, sampler_obj, cache);
1023   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1024 }
1025 
DIV2KDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & downgrade,int32_t scale,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1026 DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1027                            const std::vector<char> &downgrade, int32_t scale, bool decode, const Sampler *sampler,
1028                            const std::shared_ptr<DatasetCache> &cache) {
1029   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1030   auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
1031                                         decode, sampler_obj, cache);
1032   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1033 }
1034 
DIV2KDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & downgrade,int32_t scale,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)1035 DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1036                            const std::vector<char> &downgrade, int32_t scale, bool decode,
1037                            const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
1038   auto sampler_obj = sampler.get().Parse();
1039   auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
1040                                         decode, sampler_obj, cache);
1041   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1042 }
1043 
FlickrDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1044 FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1045                              bool decode, const std::shared_ptr<Sampler> &sampler,
1046                              const std::shared_ptr<DatasetCache> &cache) {
1047   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1048   auto ds =
1049     std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
1050   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1051 }
1052 
FlickrDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1053 FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1054                              bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1055   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1056   auto ds =
1057     std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
1058   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1059 }
1060 
FlickrDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)1061 FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1062                              bool decode, const std::reference_wrapper<Sampler> sampler,
1063                              const std::shared_ptr<DatasetCache> &cache) {
1064   auto sampler_obj = sampler.get().Parse();
1065   auto ds =
1066     std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
1067   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1068 }
1069 
ImageFolderDataset(const std::vector<char> & dataset_dir,bool decode,const std::shared_ptr<Sampler> & sampler,const std::set<std::vector<char>> & extensions,const std::map<std::vector<char>,int32_t> & class_indexing,const std::shared_ptr<DatasetCache> & cache)1070 ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
1071                                        const std::shared_ptr<Sampler> &sampler,
1072                                        const std::set<std::vector<char>> &extensions,
1073                                        const std::map<std::vector<char>, int32_t> &class_indexing,
1074                                        const std::shared_ptr<DatasetCache> &cache) {
1075   // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
1076   bool recursive = false;
1077 
1078   // Create logical representation of ImageFolderDataset.
1079   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1080   auto ds = std::make_shared<ImageFolderNode>(CharToString(dataset_dir), decode, sampler_obj, recursive,
1081                                               SetCharToString(extensions), MapCharToString(class_indexing), cache);
1082   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1083 }
1084 
ImageFolderDataset(const std::vector<char> & dataset_dir,bool decode,const Sampler * sampler,const std::set<std::vector<char>> & extensions,const std::map<std::vector<char>,int32_t> & class_indexing,const std::shared_ptr<DatasetCache> & cache)1085 ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
1086                                        const std::set<std::vector<char>> &extensions,
1087                                        const std::map<std::vector<char>, int32_t> &class_indexing,
1088                                        const std::shared_ptr<DatasetCache> &cache) {
1089   // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
1090   bool recursive = false;
1091 
1092   // Create logical representation of ImageFolderDataset.
1093   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1094   auto ds = std::make_shared<ImageFolderNode>(CharToString(dataset_dir), decode, sampler_obj, recursive,
1095                                               SetCharToString(extensions), MapCharToString(class_indexing), cache);
1096   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1097 }
1098 
ImageFolderDataset(const std::vector<char> & dataset_dir,bool decode,const std::reference_wrapper<Sampler> sampler,const std::set<std::vector<char>> & extensions,const std::map<std::vector<char>,int32_t> & class_indexing,const std::shared_ptr<DatasetCache> & cache)1099 ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
1100                                        const std::reference_wrapper<Sampler> sampler,
1101                                        const std::set<std::vector<char>> &extensions,
1102                                        const std::map<std::vector<char>, int32_t> &class_indexing,
1103                                        const std::shared_ptr<DatasetCache> &cache) {
1104   // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
1105   bool recursive = false;
1106 
1107   // Create logical representation of ImageFolderDataset.
1108   auto sampler_obj = sampler.get().Parse();
1109   auto ds = std::make_shared<ImageFolderNode>(CharToString(dataset_dir), decode, sampler_obj, recursive,
1110                                               SetCharToString(extensions), MapCharToString(class_indexing), cache);
1111   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1112 }
1113 
ManifestDataset(const std::vector<char> & dataset_file,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<DatasetCache> & cache)1114 ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
1115                                  const std::shared_ptr<Sampler> &sampler,
1116                                  const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
1117                                  const std::shared_ptr<DatasetCache> &cache) {
1118   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1119   auto ds = std::make_shared<ManifestNode>(CharToString(dataset_file), CharToString(usage), sampler_obj,
1120                                            MapCharToString(class_indexing), decode, cache);
1121   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1122 }
ManifestDataset(const std::vector<char> & dataset_file,const std::vector<char> & usage,const Sampler * sampler,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<DatasetCache> & cache)1123 ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
1124                                  const Sampler *sampler, const std::map<std::vector<char>, int32_t> &class_indexing,
1125                                  bool decode, const std::shared_ptr<DatasetCache> &cache) {
1126   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1127   auto ds = std::make_shared<ManifestNode>(CharToString(dataset_file), CharToString(usage), sampler_obj,
1128                                            MapCharToString(class_indexing), decode, cache);
1129   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1130 }
ManifestDataset(const std::vector<char> & dataset_file,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<DatasetCache> & cache)1131 ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
1132                                  const std::reference_wrapper<Sampler> sampler,
1133                                  const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
1134                                  const std::shared_ptr<DatasetCache> &cache) {
1135   auto sampler_obj = sampler.get().Parse();
1136   auto ds = std::make_shared<ManifestNode>(CharToString(dataset_file), CharToString(usage), sampler_obj,
1137                                            MapCharToString(class_indexing), decode, cache);
1138   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1139 }
1140 
MindDataDataset(const std::vector<char> & dataset_file,const std::vector<std::vector<char>> & columns_list,const std::shared_ptr<Sampler> & sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1141 MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
1142                                  const std::vector<std::vector<char>> &columns_list,
1143                                  const std::shared_ptr<Sampler> &sampler, const nlohmann::json *padded_sample,
1144                                  int64_t num_padded, ShuffleMode shuffle_mode,
1145                                  const std::shared_ptr<DatasetCache> &cache) {
1146   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1147   nlohmann::json sample = nullptr;
1148   if (padded_sample) {
1149     sample = *padded_sample;
1150   }
1151   auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
1152                                            sample, num_padded, shuffle_mode, cache);
1153   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1154 }
1155 
MindDataDataset(const std::vector<char> & dataset_file,const std::vector<std::vector<char>> & columns_list,const Sampler * sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1156 MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
1157                                  const std::vector<std::vector<char>> &columns_list, const Sampler *sampler,
1158                                  const nlohmann::json *padded_sample, int64_t num_padded, ShuffleMode shuffle_mode,
1159                                  const std::shared_ptr<DatasetCache> &cache) {
1160   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1161   nlohmann::json sample = nullptr;
1162   if (padded_sample) {
1163     sample = *padded_sample;
1164   }
1165   auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
1166                                            sample, num_padded, shuffle_mode, cache);
1167   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1168 }
1169 
MindDataDataset(const std::vector<char> & dataset_file,const std::vector<std::vector<char>> & columns_list,const std::reference_wrapper<Sampler> sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1170 MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
1171                                  const std::vector<std::vector<char>> &columns_list,
1172                                  const std::reference_wrapper<Sampler> sampler, const nlohmann::json *padded_sample,
1173                                  int64_t num_padded, ShuffleMode shuffle_mode,
1174                                  const std::shared_ptr<DatasetCache> &cache) {
1175   auto sampler_obj = sampler.get().Parse();
1176   nlohmann::json sample = nullptr;
1177   if (padded_sample) {
1178     sample = *padded_sample;
1179   }
1180 
1181   auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
1182                                            sample, num_padded, shuffle_mode, cache);
1183   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1184 }
1185 
MindDataDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<std::vector<char>> & columns_list,const std::shared_ptr<Sampler> & sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1186 MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
1187                                  const std::vector<std::vector<char>> &columns_list,
1188                                  const std::shared_ptr<Sampler> &sampler, const nlohmann::json *padded_sample,
1189                                  int64_t num_padded, ShuffleMode shuffle_mode,
1190                                  const std::shared_ptr<DatasetCache> &cache) {
1191   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1192   nlohmann::json sample = nullptr;
1193   if (padded_sample) {
1194     sample = *padded_sample;
1195   }
1196 
1197   auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
1198                                            sampler_obj, sample, num_padded, shuffle_mode, cache);
1199   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1200 }
1201 
MindDataDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<std::vector<char>> & columns_list,const Sampler * sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1202 MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
1203                                  const std::vector<std::vector<char>> &columns_list, const Sampler *sampler,
1204                                  const nlohmann::json *padded_sample, int64_t num_padded, ShuffleMode shuffle_mode,
1205                                  const std::shared_ptr<DatasetCache> &cache) {
1206   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1207   nlohmann::json sample = nullptr;
1208   if (padded_sample) {
1209     sample = *padded_sample;
1210   }
1211 
1212   auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
1213                                            sampler_obj, sample, num_padded, shuffle_mode, cache);
1214   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1215 }
1216 
MindDataDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<std::vector<char>> & columns_list,const std::reference_wrapper<Sampler> sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1217 MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
1218                                  const std::vector<std::vector<char>> &columns_list,
1219                                  const std::reference_wrapper<Sampler> sampler, const nlohmann::json *padded_sample,
1220                                  int64_t num_padded, ShuffleMode shuffle_mode,
1221                                  const std::shared_ptr<DatasetCache> &cache) {
1222   auto sampler_obj = sampler.get().Parse();
1223   nlohmann::json sample = nullptr;
1224   if (padded_sample) {
1225     sample = *padded_sample;
1226   }
1227   auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
1228                                            sampler_obj, sample, num_padded, shuffle_mode, cache);
1229   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1230 }
1231 #endif
1232 
MnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1233 MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1234                            const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1235   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1236   auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1237   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1238 }
MnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1239 MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
1240                            const std::shared_ptr<DatasetCache> &cache) {
1241   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1242   auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1243   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1244 }
MnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)1245 MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1246                            const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
1247   auto sampler_obj = sampler.get().Parse();
1248   auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1249   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1250 }
1251 
1252 #ifndef ENABLE_ANDROID
TextFileDataset(const std::vector<std::vector<char>> & dataset_files,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1253 TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples,
1254                                  ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1255                                  const std::shared_ptr<DatasetCache> &cache) {
1256   auto ds = std::make_shared<TextFileNode>(VectorCharToString(dataset_files), num_samples, shuffle, num_shards,
1257                                            shard_id, cache);
1258   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1259 }
1260 
USPSDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1261 USPSDataset::USPSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
1262                          ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1263                          const std::shared_ptr<DatasetCache> &cache) {
1264   auto ds = std::make_shared<USPSNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle, num_shards,
1265                                        shard_id, cache);
1266   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1267 }
1268 
VOCDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache,bool extra_metadata)1269 VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1270                        const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
1271                        bool decode, const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache,
1272                        bool extra_metadata) {
1273   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1274   auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1275                                       MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
1276   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1277 }
VOCDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache,bool extra_metadata)1278 VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1279                        const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
1280                        bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache,
1281                        bool extra_metadata) {
1282   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1283   auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1284                                       MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
1285   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1286 }
VOCDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache,bool extra_metadata)1287 VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1288                        const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
1289                        bool decode, const std::reference_wrapper<Sampler> sampler,
1290                        const std::shared_ptr<DatasetCache> &cache, bool extra_metadata) {
1291   auto sampler_obj = sampler.get().Parse();
1292   auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1293                                       MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
1294   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1295 }  // namespace dataset
1296 
RandomDataDataset(const int32_t & total_rows,std::shared_ptr<SchemaObj> schema,const std::vector<std::vector<char>> & columns_list,std::shared_ptr<DatasetCache> cache)1297 RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
1298                                      const std::vector<std::vector<char>> &columns_list,
1299                                      std::shared_ptr<DatasetCache> cache) {
1300   auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), VectorCharToString(columns_list), cache);
1301   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1302 }
RandomDataDataset(const int32_t & total_rows,const std::vector<char> & schema_path,const std::vector<std::vector<char>> & columns_list,std::shared_ptr<DatasetCache> cache)1303 RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vector<char> &schema_path,
1304                                      const std::vector<std::vector<char>> &columns_list,
1305                                      std::shared_ptr<DatasetCache> cache) {
1306   auto ds =
1307     std::make_shared<RandomNode>(total_rows, CharToString(schema_path), VectorCharToString(columns_list), cache);
1308   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1309 }
1310 
SBUDataset(const std::vector<char> & dataset_dir,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1311 SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
1312                        const std::shared_ptr<DatasetCache> &cache) {
1313   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1314   auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
1315   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1316 }
1317 
SBUDataset(const std::vector<char> & dataset_dir,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1318 SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
1319                        const std::shared_ptr<DatasetCache> &cache) {
1320   auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1321   auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
1322   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1323 }
1324 
SBUDataset(const std::vector<char> & dataset_dir,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)1325 SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler,
1326                        const std::shared_ptr<DatasetCache> &cache) {
1327   auto sampler_obj = sampler.get().Parse();
1328   auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
1329   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1330 }
1331 
TFRecordDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<char> & schema,const std::vector<std::vector<char>> & columns_list,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,bool shard_equal_rows,std::shared_ptr<DatasetCache> cache)1332 TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
1333                                  const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
1334                                  ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
1335                                  std::shared_ptr<DatasetCache> cache) {
1336   auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), CharToString(schema),
1337                                            VectorCharToString(columns_list), num_samples, shuffle, num_shards, shard_id,
1338                                            shard_equal_rows, cache);
1339   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1340 }
TFRecordDataset(const std::vector<std::vector<char>> & dataset_files,std::shared_ptr<SchemaObj> schema,const std::vector<std::vector<char>> & columns_list,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,bool shard_equal_rows,std::shared_ptr<DatasetCache> cache)1341 TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, std::shared_ptr<SchemaObj> schema,
1342                                  const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
1343                                  ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
1344                                  std::shared_ptr<DatasetCache> cache) {
1345   auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), schema, VectorCharToString(columns_list),
1346                                            num_samples, shuffle, num_shards, shard_id, shard_equal_rows, cache);
1347   ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1348 }
1349 
1350 #endif
1351 }  // namespace dataset
1352 }  // namespace mindspore
1353