1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/include/dataset/datasets.h"
18 #include <algorithm>
19 #include <fstream>
20 #include <unordered_set>
21 #include <utility>
22 #include <nlohmann/json.hpp>
23
24 #include "minddata/dataset/core/tensor.h"
25 #include "minddata/dataset/engine/runtime_context.h"
26 #include "minddata/dataset/include/dataset/constants.h"
27 #include "minddata/dataset/include/dataset/iterator.h"
28 #include "minddata/dataset/include/dataset/samplers.h"
29 #include "minddata/dataset/include/dataset/transforms.h"
30 #include "minddata/dataset/util/path.h"
31 #include "minddata/dataset/util/status.h"
32 #include "minddata/dataset/core/client.h"
33 #include "minddata/dataset/core/type_id.h"
34 #include "minddata/dataset/engine/consumers/tree_consumer.h"
35 #include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
36
37 #include "minddata/dataset/kernels/c_func_op.h"
38 #include "minddata/dataset/kernels/tensor_op.h"
39
40 #ifndef ENABLE_ANDROID
41 #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
42 #endif
43
44 #ifndef ENABLE_ANDROID
45 #include "minddata/dataset/text/sentence_piece_vocab.h"
46 #include "minddata/dataset/text/vocab.h"
47 #endif
48
49 // Sampler headers (in alphabetical order)
50 #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
51
52 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
53
54 // IR non-leaf nodes
55 #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
56 #ifndef ENABLE_ANDROID
57 #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
58 #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
59 #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
60 #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
61 #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
62 #endif
63
64 #include "minddata/dataset/engine/ir/datasetops/map_node.h"
65 #include "minddata/dataset/engine/ir/datasetops/project_node.h"
66
67 #ifndef ENABLE_ANDROID
68 #include "minddata/dataset/engine/ir/datasetops/rename_node.h"
69 #endif
70
71 #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
72 #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
73
74 #ifndef ENABLE_ANDROID
75 #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
76 #include "minddata/dataset/engine/ir/datasetops/take_node.h"
77 #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
78 #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
79 #endif
80
81 #include "minddata/dataset/core/config_manager.h"
82 #include "minddata/dataset/util/random.h"
83 #include "minddata/dataset/util/services.h"
84
85 // IR leaf nodes
86 #include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
87 #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
88
89 // IR leaf nodes disabled for android
90 #ifndef ENABLE_ANDROID
91 #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
92 #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
93 #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
94 #include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
95 #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
96 #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
97 #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
98 #include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
99 #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
100 #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
101 #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
102 #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
103 #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
104 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
105 #include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
106 #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
107 #include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
108 #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
109 #endif
110
111 namespace mindspore {
112 namespace dataset {
113
114 // convert MSTensorVec to DE TensorRow, return empty if fails
VecToRow(const MSTensorVec & v)115 TensorRow VecToRow(const MSTensorVec &v) {
116 TensorRow row;
117 row.reserve(v.size());
118 for (const MSTensor &t : v) {
119 std::shared_ptr<Tensor> rt;
120 Status rc = Tensor::CreateFromMSTensor(t, &rt);
121 if (rc.IsError()) {
122 MS_LOG_ERROR << "Convert from MSTensor to DETensor failed:" << rc.ToString() << ".";
123 return {};
124 }
125 row.emplace_back(rt);
126 }
127 return row;
128 }
129
130 // convert DE TensorRow to MSTensorVec, won't fail
RowToVec(const TensorRow & v)131 MSTensorVec RowToVec(const TensorRow &v) {
132 MSTensorVec rv;
133 rv.reserve(v.size());
134 std::transform(v.begin(), v.end(), std::back_inserter(rv), [](std::shared_ptr<Tensor> t) -> MSTensor {
135 return mindspore::MSTensor(std::make_shared<DETensor>(t));
136 });
137 return rv;
138 }
139
140 // Convert a std::function<TensorRow(TensorRow)> to std::function<MSTensorVec(MSTensor)> with this helper
FuncPtrConverter(std::function<MSTensorVec (MSTensorVec)> func,TensorRow in_row)141 TensorRow FuncPtrConverter(std::function<MSTensorVec(MSTensorVec)> func, TensorRow in_row) {
142 return VecToRow(func(RowToVec(in_row)));
143 }
144
145 // Function to create the iterator, which will build and launch the execution tree.
CreateIteratorCharIF(std::vector<std::vector<char>> columns,int32_t num_epochs)146 std::shared_ptr<Iterator> Dataset::CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs) {
147 std::shared_ptr<Iterator> iter;
148 try {
149 auto ds = shared_from_this();
150
151 // The specified columns will be selected from the dataset and passed down the pipeline
152 // in the order specified, other columns will be discarded.
153 if (!VectorCharToString(columns).empty()) {
154 ds = ds->Project(VectorCharToString(columns));
155 }
156
157 iter = std::make_shared<Iterator>();
158 Status rc = iter->BuildAndLaunchTree(ds, num_epochs);
159 if (rc.IsError()) {
160 MS_LOG(ERROR) << "CreateIterator failed." << rc;
161 return nullptr;
162 }
163 } catch (const std::exception &err) {
164 MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what();
165 return nullptr;
166 }
167
168 return iter;
169 }
170
171 // Function to create the iterator, which will build and launch the execution tree.
CreatePullBasedIterator(std::vector<std::vector<char>> columns)172 std::shared_ptr<PullIterator> Dataset::CreatePullBasedIterator(std::vector<std::vector<char>> columns) {
173 // The specified columns will be selected from the dataset and passed down the pipeline
174 // in the order specified, other columns will be discarded.
175 // This code is not in a try/catch block because there is no execution tree class that will be created.
176 auto ds = shared_from_this();
177 if (!VectorCharToString(columns).empty()) {
178 ds = ds->Project(VectorCharToString(columns));
179 }
180
181 std::shared_ptr<PullIterator> iter = std::make_shared<PullIterator>();
182 Status rc = iter->BuildAndLaunchTree(ds);
183 if (rc.IsError()) {
184 MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << rc;
185 }
186 RETURN_SECOND_IF_ERROR(rc, nullptr);
187 return iter;
188 }
189
190 #ifndef ENABLE_ANDROID
191 // Function to return a transferred Node that transfers data through a device.
DeviceQueueCharIF(const std::vector<char> & queue_name,const std::vector<char> & device_type,int32_t device_id,int32_t num_epochs,bool send_epoch_end,int32_t total_batches,bool create_data_info_queue)192 bool Dataset::DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type,
193 int32_t device_id, int32_t num_epochs, bool send_epoch_end, int32_t total_batches,
194 bool create_data_info_queue) {
195 Status rc;
196
197 // Build and launch tree
198 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
199 rc = runtime_context->Init();
200 if (rc.IsError()) {
201 MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
202 return false;
203 }
204
205 // Add TransferNode IR on top of dataset
206 auto ds =
207 std::make_shared<TransferNode>(shared_from_this()->IRNode(), CharToString(queue_name), CharToString(device_type),
208 device_id, send_epoch_end, total_batches, create_data_info_queue);
209
210 // Get ToDevice consumer
211 auto consumer = std::make_unique<ToDevice>(num_epochs);
212 ToDevice *consumer_ptr = consumer.get();
213 if (consumer_ptr == nullptr) {
214 MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
215 return false;
216 }
217 rc = consumer->Init(ds);
218 if (rc.IsError()) {
219 MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
220 return false;
221 }
222 runtime_context->AssignConsumer(std::move(consumer));
223
224 // Send data to device
225 rc = consumer_ptr->Send();
226 if (rc.IsError()) {
227 MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc;
228 return false;
229 }
230
231 return true;
232 }
233
234 // Function to create the saver, which will build and launch the execution tree and save data
SaveCharIF(const std::vector<char> & dataset_path,int32_t num_files,const std::vector<char> & dataset_type)235 bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files,
236 const std::vector<char> &dataset_type) {
237 Status rc;
238 // Build and launch tree
239 auto ds = shared_from_this();
240 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
241 rc = runtime_context->Init();
242 if (rc.IsError()) {
243 MS_LOG(ERROR) << "CreateSaver failed." << rc;
244 return false;
245 }
246
247 // Get SaveToDisk consumer
248 auto consumer = std::make_unique<SaveToDisk>(CharToString(dataset_path), num_files, CharToString(dataset_type));
249 rc = consumer->ValidateParams();
250 if (rc.IsError()) {
251 MS_LOG(ERROR) << "CreateSaver failed." << rc;
252 return false;
253 }
254 SaveToDisk *consumer_ptr = consumer.get();
255 if (consumer_ptr == nullptr) {
256 MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
257 return false;
258 }
259 rc = consumer->Init(ds->IRNode());
260 if (rc.IsError()) {
261 MS_LOG(ERROR) << "CreateSaver failed." << rc;
262 return false;
263 }
264
265 runtime_context->AssignConsumer(std::move(consumer));
266
267 // Save data into file
268 rc = consumer_ptr->Save();
269 if (rc.IsError()) {
270 MS_LOG(ERROR) << "Saver: Failed to save data into file. Error status: " << rc;
271 return false;
272 }
273
274 // Shut down the data pipeline
275 rc = runtime_context->Terminate();
276 if (rc.IsError()) {
277 MS_LOG(ERROR) << "Saver: Failed to shut down pipeline. Error status: " << rc;
278 return false;
279 }
280
281 return true;
282 }
283 #endif
284
285 // Constructor
Dataset()286 Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
287
GetDatasetSize(bool estimate)288 int64_t Dataset::GetDatasetSize(bool estimate) {
289 int64_t dataset_size = -1;
290 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
291 RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
292 std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
293 DatasetSizeGetter *consumer = size_getter.get();
294 if (consumer == nullptr) {
295 MS_LOG(ERROR) << "DatasetSizeGetter: Failed to get consumer.";
296 return -1;
297 }
298 runtime_context->AssignConsumer(size_getter);
299 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
300 RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1);
301 return dataset_size;
302 }
303
GetOutputTypes()304 std::vector<mindspore::DataType> Dataset::GetOutputTypes() {
305 std::vector<DataType> types;
306 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
307 RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
308 TreeGetters *consumer = tree_getters_.get();
309 if (consumer == nullptr) {
310 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
311 return std::vector<mindspore::DataType>();
312 }
313 runtime_context->AssignConsumer(tree_getters_);
314 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
315 RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {});
316 std::vector<mindspore::DataType> ret_types;
317 std::transform(
318 types.begin(), types.end(), std::back_inserter(ret_types),
319 [](const DataType &d) -> mindspore::DataType { return static_cast<mindspore::DataType>(DETypeToMSType(d)); });
320 return ret_types;
321 }
322
GetOutputShapes()323 std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
324 std::vector<TensorShape> shapes;
325 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
326 RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
327 TreeGetters *consumer = tree_getters_.get();
328 if (consumer == nullptr) {
329 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
330 return std::vector<std::vector<int64_t>>();
331 }
332 runtime_context->AssignConsumer(tree_getters_);
333 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
334 RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {});
335 std::vector<std::vector<int64_t>> ret_shapes;
336 std::transform(shapes.begin(), shapes.end(), std::back_inserter(ret_shapes),
337 [](const TensorShape &s) -> std::vector<int64_t> { return s.AsVector(); });
338 return ret_shapes;
339 }
340
GetNumClasses()341 int64_t Dataset::GetNumClasses() {
342 int64_t num_classes = -1;
343 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
344 RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
345 TreeGetters *consumer = tree_getters_.get();
346 if (consumer == nullptr) {
347 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
348 return -1;
349 }
350 runtime_context->AssignConsumer(tree_getters_);
351 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
352 RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1);
353 return num_classes;
354 }
355
GetColumnNamesCharIF()356 std::vector<std::vector<char>> Dataset::GetColumnNamesCharIF() {
357 std::vector<std::string> col_names;
358 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
359 RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
360 TreeGetters *consumer = tree_getters_.get();
361 if (consumer == nullptr) {
362 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
363 return std::vector<std::vector<char>>();
364 }
365 runtime_context->AssignConsumer(tree_getters_);
366 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
367 RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {});
368 return VectorStringToChar(col_names);
369 }
370
GetClassIndexingCharIF()371 std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> Dataset::GetClassIndexingCharIF() {
372 std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
373 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
374 RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
375 TreeGetters *consumer = tree_getters_.get();
376 if (consumer == nullptr) {
377 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
378 return std::vector<std::pair<std::vector<char>, std::vector<int32_t>>>();
379 }
380 runtime_context->AssignConsumer(tree_getters_);
381 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
382 RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {});
383 return ClassIndexStringToChar(output_class_indexing);
384 }
385
386 /// \brief Function to create a SchemaObj
387 /// \param[in] schema_file Path of schema file
388 /// \return Shared pointer to the current schema
SchemaCharIF(const std::vector<char> & schema_file)389 std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file) {
390 auto schema = std::make_shared<SchemaObj>(CharToString(schema_file));
391 return schema->Init() ? schema : nullptr;
392 }
393
394 // FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
395 // (In alphabetical order)
396
397 // Function to create a Batch dataset
BatchDataset(std::shared_ptr<Dataset> input,int32_t batch_size,bool drop_remainder)398 BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder) {
399 // Default values
400 auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder);
401 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
402 }
403
404 #ifndef ENABLE_ANDROID
405 // Function to create a BucketBatchByLength dataset
BucketBatchByLengthDataset(std::shared_ptr<Dataset> input,const std::vector<std::vector<char>> & column_names,const std::vector<int32_t> & bucket_boundaries,const std::vector<int32_t> & bucket_batch_sizes,std::function<MSTensorVec (MSTensorVec)> element_length_function,const std::map<std::vector<char>,std::pair<std::vector<int64_t>,MSTensor>> & pad_info,bool pad_to_bucket_boundary,bool drop_remainder)406 BucketBatchByLengthDataset::BucketBatchByLengthDataset(
407 std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &column_names,
408 const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
409 std::function<MSTensorVec(MSTensorVec)> element_length_function,
410 const std::map<std::vector<char>, std::pair<std::vector<int64_t>, MSTensor>> &pad_info, bool pad_to_bucket_boundary,
411 bool drop_remainder) {
412 std::shared_ptr<TensorOp> c_func = nullptr;
413 if (element_length_function != nullptr) {
414 c_func = std::make_shared<CFuncOp>(std::bind(FuncPtrConverter, element_length_function, std::placeholders::_1));
415 }
416
417 std::map<std::vector<char>, std::pair<TensorShape, std::shared_ptr<Tensor>>> map;
418 for (auto const &p : pad_info) {
419 const MSTensor &t = p.second.second;
420 std::shared_ptr<Tensor> rt;
421 Status rc = Tensor::CreateFromMemory(TensorShape(t.Shape()), MSTypeToDEType(static_cast<TypeId>(t.DataType())),
422 (const uchar *)(t.Data().get()), t.DataSize(), &rt);
423 if (rc.IsError()) {
424 MS_LOG_ERROR << "Fail to create DETensor from MSTensor for pad_info: " << rc.ToString() << ".";
425 map.clear();
426 break;
427 }
428 map.insert({p.first, {TensorShape(p.second.first), rt}});
429 }
430
431 auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), VectorCharToString(column_names),
432 bucket_boundaries, bucket_batch_sizes, c_func,
433 PadInfoCharToString(map), pad_to_bucket_boundary, drop_remainder);
434
435 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
436 }
437
ConcatDataset(const std::vector<std::shared_ptr<Dataset>> & datasets)438 ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
439 std::vector<std::shared_ptr<DatasetNode>> all_datasets;
440 (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
441 [](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
442 return (dataset != nullptr) ? dataset->IRNode() : nullptr;
443 });
444
445 auto ds = std::make_shared<ConcatNode>(all_datasets);
446
447 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
448 }
449
FilterDataset(std::shared_ptr<Dataset> input,std::function<MSTensorVec (MSTensorVec)> predicate,const std::vector<std::vector<char>> & input_columns)450 FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<MSTensorVec(MSTensorVec)> predicate,
451 const std::vector<std::vector<char>> &input_columns) {
452 std::shared_ptr<TensorOp> c_func = nullptr;
453 if (predicate) {
454 c_func = std::make_shared<CFuncOp>(std::bind(FuncPtrConverter, predicate, std::placeholders::_1));
455 }
456 auto ds = std::make_shared<FilterNode>(input->IRNode(), c_func, VectorCharToString(input_columns));
457
458 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
459 }
460 #endif
461
MapDataset(std::shared_ptr<Dataset> input,std::vector<std::shared_ptr<TensorOperation>> operations,const std::vector<std::vector<char>> & input_columns,const std::vector<std::vector<char>> & output_columns,const std::vector<std::vector<char>> & project_columns,const std::shared_ptr<DatasetCache> & cache,std::vector<std::shared_ptr<DSCallback>> callbacks)462 MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
463 const std::vector<std::vector<char>> &input_columns,
464 const std::vector<std::vector<char>> &output_columns,
465 const std::vector<std::vector<char>> &project_columns,
466 const std::shared_ptr<DatasetCache> &cache, std::vector<std::shared_ptr<DSCallback>> callbacks) {
467 auto ds = std::make_shared<MapNode>(input->IRNode(), operations, VectorCharToString(input_columns),
468 VectorCharToString(output_columns), VectorCharToString(project_columns), cache,
469 callbacks);
470
471 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
472 }
473
ProjectDataset(std::shared_ptr<Dataset> input,const std::vector<std::vector<char>> & columns)474 ProjectDataset::ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns) {
475 auto ds = std::make_shared<ProjectNode>(input->IRNode(), VectorCharToString(columns));
476
477 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
478 }
479 #ifndef ENABLE_ANDROID
RenameDataset(std::shared_ptr<Dataset> input,const std::vector<std::vector<char>> & input_columns,const std::vector<std::vector<char>> & output_columns)480 RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &input_columns,
481 const std::vector<std::vector<char>> &output_columns) {
482 auto ds = std::make_shared<RenameNode>(input->IRNode(), VectorCharToString(input_columns),
483 VectorCharToString(output_columns));
484
485 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
486 }
487 #endif
488
RepeatDataset(std::shared_ptr<Dataset> input,int32_t count)489 RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) {
490 auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);
491
492 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
493 }
494
ShuffleDataset(std::shared_ptr<Dataset> input,int32_t buffer_size)495 ShuffleDataset::ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size) {
496 // Pass in reshuffle_each_epoch with true
497 auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true);
498
499 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
500 }
501
502 #ifndef ENABLE_ANDROID
SkipDataset(std::shared_ptr<Dataset> input,int32_t count)503 SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) {
504 auto ds = std::make_shared<SkipNode>(input->IRNode(), count);
505
506 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
507 }
508
TakeDataset(std::shared_ptr<Dataset> input,int32_t count)509 TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
510 auto ds = std::make_shared<TakeNode>(input->IRNode(), count);
511
512 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
513 }
514
ZipDataset(const std::vector<std::shared_ptr<Dataset>> & datasets)515 ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
516 std::vector<std::shared_ptr<DatasetNode>> all_datasets;
517 (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
518 [](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
519 return (dataset != nullptr) ? dataset->IRNode() : nullptr;
520 });
521 auto ds = std::make_shared<ZipNode>(all_datasets);
522
523 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
524 }
525 #endif
GetBatchSize()526 int64_t Dataset::GetBatchSize() {
527 int64_t batch_size = -1;
528 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
529 RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
530 RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
531 RETURN_SECOND_IF_ERROR(tree_getters_->GetBatchSize(&batch_size), -1);
532 return batch_size;
533 }
534
GetRepeatCount()535 int64_t Dataset::GetRepeatCount() {
536 int64_t repeat_count = 0;
537 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
538 RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
539 RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0);
540 RETURN_SECOND_IF_ERROR(tree_getters_->GetRepeatCount(&repeat_count), 0);
541 return repeat_count;
542 }
543
SetNumWorkers(int32_t num_workers)544 std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
545 if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) {
546 return nullptr;
547 }
548 return shared_from_this();
549 }
550
551 #ifndef ENABLE_ANDROID
BuildSentencePieceVocabCharIF(const std::vector<std::vector<char>> & col_names,int32_t vocab_size,float character_coverage,SentencePieceModel model_type,const std::map<std::vector<char>,std::vector<char>> & params)552 std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocabCharIF(
553 const std::vector<std::vector<char>> &col_names, int32_t vocab_size, float character_coverage,
554 SentencePieceModel model_type, const std::map<std::vector<char>, std::vector<char>> ¶ms) {
555 auto vocab = std::make_shared<SentencePieceVocab>();
556 auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, VectorCharToString(col_names), vocab_size,
557 character_coverage, model_type, UnorderedMapCharToString(params));
558
559 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
560 Status rc = runtime_context->Init();
561 if (rc.IsError()) {
562 MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
563 return nullptr;
564 }
565
566 auto consumer = std::make_unique<BuildVocabConsumer>();
567 BuildVocabConsumer *bv_consumer = consumer.get();
568 if (bv_consumer == nullptr) {
569 MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
570 return nullptr;
571 }
572 rc = consumer->Init(ds);
573 if (rc.IsError()) {
574 MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
575 return nullptr;
576 }
577 runtime_context->AssignConsumer(std::move(consumer));
578
579 // Run tree here to starting building SentencePieceVocab
580 rc = bv_consumer->Start();
581 if (rc.IsError()) {
582 MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to start consumer. Error status: " << rc;
583 return nullptr;
584 }
585 return vocab;
586 }
587
BuildVocabCharIF(const std::vector<std::vector<char>> & columns,const std::pair<int64_t,int64_t> & freq_range,int64_t top_k,const std::vector<std::vector<char>> & special_tokens,bool special_first)588 std::shared_ptr<Vocab> Dataset::BuildVocabCharIF(const std::vector<std::vector<char>> &columns,
589 const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
590 const std::vector<std::vector<char>> &special_tokens,
591 bool special_first) {
592 auto vocab = std::make_shared<Vocab>();
593 auto ds = std::make_shared<BuildVocabNode>(IRNode(), vocab, VectorCharToString(columns), freq_range, top_k,
594 VectorCharToString(special_tokens), special_first);
595
596 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
597 Status rc = runtime_context->Init();
598 if (rc.IsError()) {
599 MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;
600 return nullptr;
601 }
602
603 auto consumer = std::make_unique<BuildVocabConsumer>();
604 BuildVocabConsumer *bv_consumer = consumer.get();
605 if (bv_consumer == nullptr) {
606 MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
607 return nullptr;
608 }
609 rc = consumer->Init(ds);
610 if (rc.IsError()) {
611 MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;
612 return nullptr;
613 }
614 runtime_context->AssignConsumer(std::move(consumer));
615
616 // Run tree here to starting building vocab
617 rc = bv_consumer->Start();
618 if (rc.IsError()) {
619 MS_LOG(ERROR) << "BuildVocab: Failed to start consumer. Error status: " << rc;
620 return nullptr;
621 }
622 return vocab;
623 }
624 #endif
625
Batch(int32_t batch_size,bool drop_remainder)626 std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
627 return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder);
628 }
629
630 struct SchemaObj::Data {
631 int32_t num_rows_;
632 std::string dataset_type_;
633 std::string schema_file_;
634 nlohmann::json columns_;
635 };
636
SchemaObj(const std::vector<char> & schema_file)637 SchemaObj::SchemaObj(const std::vector<char> &schema_file) : data_(std::make_shared<Data>()) {
638 data_->schema_file_ = CharToString(schema_file);
639 data_->dataset_type_ = "";
640 data_->num_rows_ = 0;
641 }
642
643 // SchemaObj Init function
Init()644 Status SchemaObj::Init() {
645 if (data_ != nullptr && !data_->schema_file_.empty()) {
646 std::string real_path;
647 RETURN_IF_NOT_OK(Path::RealPath(data_->schema_file_, real_path));
648 Path schema_file(real_path);
649 CHECK_FAIL_RETURN_UNEXPECTED(schema_file.Exists(),
650 "The file " + data_->schema_file_ + " does not exist or permission denied!");
651
652 nlohmann::json js;
653 try {
654 std::ifstream in(real_path);
655 in >> js;
656 CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
657 "\"columns\" node is required in the schema json file.");
658 } catch (const std::exception &err) {
659 std::string err_msg = "Schema file failed to load: ";
660 RETURN_STATUS_SYNTAX_ERROR(err_msg);
661 }
662 return from_json(js);
663 }
664 return Status::OK();
665 }
666
667 // Function to add a column to schema with a mstype de_type and known shape
add_column_char(const std::vector<char> & name,mindspore::DataType de_type,const std::vector<int32_t> & shape)668 Status SchemaObj::add_column_char(const std::vector<char> &name, mindspore::DataType de_type,
669 const std::vector<int32_t> &shape) {
670 DataType data_type = dataset::MSTypeToDEType(static_cast<TypeId>(de_type));
671 return add_column_char(name, StringToChar(data_type.ToString()), shape);
672 }
673
674 // Function to add a column to schema with a string de_type and known shape
add_column_char(const std::vector<char> & name,const std::vector<char> & de_type,const std::vector<int32_t> & shape)675 Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vector<char> &de_type,
676 const std::vector<int32_t> &shape) {
677 DataType data_type(CharToString(de_type));
678 CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
679
680 nlohmann::json new_column;
681 new_column["name"] = CharToString(name);
682 new_column["type"] = data_type.ToString();
683 new_column["shape"] = shape;
684 new_column["rank"] = shape.size();
685
686 data_->columns_.push_back(new_column);
687 return Status::OK();
688 }
689
690 // Function to add a column to schema with a mstype de_type and without shape
add_column_char(const std::vector<char> & name,mindspore::DataType de_type)691 Status SchemaObj::add_column_char(const std::vector<char> &name, mindspore::DataType de_type) {
692 DataType data_type = dataset::MSTypeToDEType(static_cast<TypeId>(de_type));
693 return add_column_char(name, StringToChar(data_type.ToString()));
694 }
695
696 // Function to add a column to schema with a string de_type and without shape
add_column_char(const std::vector<char> & name,const std::vector<char> & de_type)697 Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vector<char> &de_type) {
698 DataType data_type(CharToString(de_type));
699 CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
700
701 nlohmann::json new_column;
702 new_column["name"] = CharToString(name);
703 new_column["type"] = data_type.ToString();
704 new_column["rank"] = 1;
705
706 data_->columns_.push_back(new_column);
707 return Status::OK();
708 }
709
schema_to_json(nlohmann::json * out_json)710 Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
711 nlohmann::json json_file;
712 json_file["columns"] = data_->columns_;
713 std::string str_dataset_type_(data_->dataset_type_);
714 if (str_dataset_type_ != "") {
715 json_file["datasetType"] = str_dataset_type_;
716 }
717
718 if (data_->num_rows_ > 0) {
719 json_file["numRows"] = data_->num_rows_;
720 }
721 *out_json = json_file;
722 return Status::OK();
723 }
724
to_json_char()725 const std::vector<char> SchemaObj::to_json_char() {
726 nlohmann::json json_file;
727 this->schema_to_json(&json_file);
728 return StringToChar(json_file.dump(2));
729 }
730
set_dataset_type(std::string dataset_type)731 void SchemaObj::set_dataset_type(std::string dataset_type) { data_->dataset_type_ = dataset_type.data(); }
732
set_num_rows(int32_t num_rows)733 void SchemaObj::set_num_rows(int32_t num_rows) { data_->num_rows_ = num_rows; }
734
get_num_rows() const735 int32_t SchemaObj::get_num_rows() const { return data_->num_rows_; }
736
parse_column(nlohmann::json columns)737 Status SchemaObj::parse_column(nlohmann::json columns) {
738 std::string name, de_type;
739 std::vector<int32_t> shape;
740
741 data_->columns_.clear();
742 if (columns.type() == nlohmann::json::value_t::array) {
743 // reference to python list
744 for (auto column : columns) {
745 auto key_name = column.find("name");
746 if (key_name == column.end()) {
747 RETURN_STATUS_SYNTAX_ERROR("Column's name is missing");
748 }
749 name = *key_name;
750
751 auto key_type = column.find("type");
752 if (key_type == column.end()) {
753 RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
754 }
755 de_type = *key_type;
756
757 shape.clear();
758 auto key_shape = column.find("shape");
759 if (key_shape != column.end()) {
760 shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
761 }
762 RETURN_IF_NOT_OK(add_column(name, de_type, shape));
763 }
764 } else if (columns.type() == nlohmann::json::value_t::object) {
765 for (const auto &it_child : columns.items()) {
766 name = it_child.key();
767 auto key_type = it_child.value().find("type");
768 if (key_type == it_child.value().end()) {
769 RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
770 }
771 de_type = *key_type;
772
773 shape.clear();
774 auto key_shape = it_child.value().find("shape");
775 if (key_shape != it_child.value().end()) {
776 shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
777 }
778
779 RETURN_IF_NOT_OK(add_column(name, de_type, shape));
780 }
781 } else {
782 RETURN_STATUS_SYNTAX_ERROR("columns must be dict or list, columns contain name, type, shape(optional).");
783 }
784 return Status::OK();
785 }
786
from_json(nlohmann::json json_obj)787 Status SchemaObj::from_json(nlohmann::json json_obj) {
788 for (const auto &it_child : json_obj.items()) {
789 if (it_child.key() == "datasetType") {
790 std::string str_dataset_type_ = it_child.value();
791 data_->dataset_type_ = str_dataset_type_.data();
792 } else if (it_child.key() == "numRows") {
793 data_->num_rows_ = it_child.value();
794 } else if (it_child.key() == "columns") {
795 RETURN_IF_NOT_OK(parse_column(it_child.value()));
796 } else {
797 RETURN_STATUS_SYNTAX_ERROR("Unknown field " + it_child.key());
798 }
799 }
800 if (data_->columns_.empty()) {
801 RETURN_STATUS_SYNTAX_ERROR("Columns are missing.");
802 }
803 if (data_->num_rows_ < 0) {
804 RETURN_STATUS_SYNTAX_ERROR("numRows must be greater than or equal to 0");
805 }
806
807 return Status::OK();
808 }
809
FromJSONStringCharIF(const std::vector<char> & json_string)810 Status SchemaObj::FromJSONStringCharIF(const std::vector<char> &json_string) {
811 try {
812 nlohmann::json js = nlohmann::json::parse(CharToString(json_string));
813 CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
814 "\"columns\" node is required in the schema json JSON.");
815 RETURN_IF_NOT_OK(from_json(js));
816 } catch (const std::exception &err) {
817 std::string err_msg = "FromJSONString: JSON string failed to parse: ";
818 err_msg += err.what();
819 RETURN_STATUS_SYNTAX_ERROR(err_msg);
820 }
821 return Status::OK();
822 }
823
ParseColumnStringCharIF(const std::vector<char> & json_string)824 Status SchemaObj::ParseColumnStringCharIF(const std::vector<char> &json_string) {
825 try {
826 nlohmann::json js = nlohmann::json::parse(CharToString(json_string));
827 RETURN_IF_NOT_OK(parse_column(js));
828 } catch (const std::exception &err) {
829 std::string err_msg = "ParseColumnString: JSON string failed to parse: ";
830 err_msg += err.what();
831 RETURN_STATUS_SYNTAX_ERROR(err_msg);
832 }
833 return Status::OK();
834 }
835
836 // OTHER FUNCTIONS
837
838 #ifndef ENABLE_ANDROID
839
CreateDatasetCacheCharIF(session_id_type id,uint64_t mem_sz,bool spill,std::optional<std::vector<char>> hostname,std::optional<int32_t> port,std::optional<int32_t> num_connections,std::optional<int32_t> prefetch_sz)840 std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(session_id_type id, uint64_t mem_sz, bool spill,
841 std::optional<std::vector<char>> hostname,
842 std::optional<int32_t> port,
843 std::optional<int32_t> num_connections,
844 std::optional<int32_t> prefetch_sz) {
845 auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
846 return cache;
847 }
848 #endif
849
AlbumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & data_schema,const std::vector<std::vector<char>> & column_names,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)850 AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
851 const std::vector<std::vector<char>> &column_names, bool decode,
852 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
853 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
854 auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
855 VectorCharToString(column_names), decode, sampler_obj, cache);
856 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
857 }
858
AlbumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & data_schema,const std::vector<std::vector<char>> & column_names,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)859 AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
860 const std::vector<std::vector<char>> &column_names, bool decode, const Sampler *sampler,
861 const std::shared_ptr<DatasetCache> &cache) {
862 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
863 auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
864 VectorCharToString(column_names), decode, sampler_obj, cache);
865 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
866 }
AlbumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & data_schema,const std::vector<std::vector<char>> & column_names,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)867 AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
868 const std::vector<std::vector<char>> &column_names, bool decode,
869 const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
870 auto sampler_obj = sampler.get().Parse();
871 auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
872 VectorCharToString(column_names), decode, sampler_obj, cache);
873 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
874 }
875
876 #ifndef ENABLE_ANDROID
CelebADataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,bool decode,const std::set<std::vector<char>> & extensions,const std::shared_ptr<DatasetCache> & cache)877 CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
878 const std::shared_ptr<Sampler> &sampler, bool decode,
879 const std::set<std::vector<char>> &extensions,
880 const std::shared_ptr<DatasetCache> &cache) {
881 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
882 auto ds = std::make_shared<CelebANode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, decode,
883 SetCharToString(extensions), cache);
884 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
885 }
CelebADataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,bool decode,const std::set<std::vector<char>> & extensions,const std::shared_ptr<DatasetCache> & cache)886 CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
887 const Sampler *sampler, bool decode, const std::set<std::vector<char>> &extensions,
888 const std::shared_ptr<DatasetCache> &cache) {
889 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
890 auto ds = std::make_shared<CelebANode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, decode,
891 SetCharToString(extensions), cache);
892 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
893 }
CelebADataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,bool decode,const std::set<std::vector<char>> & extensions,const std::shared_ptr<DatasetCache> & cache)894 CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
895 const std::reference_wrapper<Sampler> sampler, bool decode,
896 const std::set<std::vector<char>> &extensions,
897 const std::shared_ptr<DatasetCache> &cache) {
898 auto sampler_obj = sampler.get().Parse();
899 auto ds = std::make_shared<CelebANode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, decode,
900 SetCharToString(extensions), cache);
901 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
902 }
903
Cifar10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)904 Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
905 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
906 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
907 auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
908 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
909 }
Cifar10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)910 Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
911 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
912 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
913 auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
914 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
915 }
Cifar10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)916 Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
917 const std::reference_wrapper<Sampler> sampler,
918 const std::shared_ptr<DatasetCache> &cache) {
919 auto sampler_obj = sampler.get().Parse();
920 auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
921 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
922 }
923
Cifar100Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)924 Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
925 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
926 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
927 auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
928 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
929 }
Cifar100Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)930 Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
931 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
932 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
933 auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
934 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
935 }
Cifar100Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)936 Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
937 const std::reference_wrapper<Sampler> sampler,
938 const std::shared_ptr<DatasetCache> &cache) {
939 auto sampler_obj = sampler.get().Parse();
940 auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
941 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
942 }
943
CityscapesDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & quality_mode,const std::vector<char> & task,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)944 CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
945 const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
946 const std::shared_ptr<Sampler> &sampler,
947 const std::shared_ptr<DatasetCache> &cache) {
948 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
949 auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
950 CharToString(task), decode, sampler_obj, cache);
951 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
952 }
953
CityscapesDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & quality_mode,const std::vector<char> & task,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)954 CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
955 const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
956 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
957 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
958 auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
959 CharToString(task), decode, sampler_obj, cache);
960 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
961 }
962
CityscapesDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & quality_mode,const std::vector<char> & task,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)963 CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
964 const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
965 const std::reference_wrapper<Sampler> sampler,
966 const std::shared_ptr<DatasetCache> &cache) {
967 auto sampler_obj = sampler.get().Parse();
968 auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
969 CharToString(task), decode, sampler_obj, cache);
970 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
971 }
972
CLUEDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<char> & task,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)973 CLUEDataset::CLUEDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &task,
974 const std::vector<char> &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
975 int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
976 auto ds = std::make_shared<CLUENode>(VectorCharToString(dataset_files), CharToString(task), CharToString(usage),
977 num_samples, shuffle, num_shards, shard_id, cache);
978 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
979 }
980
CocoDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,const std::vector<char> & task,const bool & decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache,const bool & extra_metadata)981 CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
982 const std::vector<char> &task, const bool &decode, const std::shared_ptr<Sampler> &sampler,
983 const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
984 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
985 auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
986 decode, sampler_obj, cache, extra_metadata);
987 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
988 }
CocoDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,const std::vector<char> & task,const bool & decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache,const bool & extra_metadata)989 CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
990 const std::vector<char> &task, const bool &decode, const Sampler *sampler,
991 const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
992 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
993 auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
994 decode, sampler_obj, cache, extra_metadata);
995 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
996 }
CocoDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,const std::vector<char> & task,const bool & decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache,const bool & extra_metadata)997 CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
998 const std::vector<char> &task, const bool &decode,
999 const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache,
1000 const bool &extra_metadata) {
1001 auto sampler_obj = sampler.get().Parse();
1002 auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
1003 decode, sampler_obj, cache, extra_metadata);
1004 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1005 }
1006
CSVDataset(const std::vector<std::vector<char>> & dataset_files,char field_delim,const std::vector<std::shared_ptr<CsvBase>> & column_defaults,const std::vector<std::vector<char>> & column_names,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1007 CSVDataset::CSVDataset(const std::vector<std::vector<char>> &dataset_files, char field_delim,
1008 const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
1009 const std::vector<std::vector<char>> &column_names, int64_t num_samples, ShuffleMode shuffle,
1010 int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
1011 auto ds =
1012 std::make_shared<CSVNode>(VectorCharToString(dataset_files), field_delim, column_defaults,
1013 VectorCharToString(column_names), num_samples, shuffle, num_shards, shard_id, cache);
1014 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1015 }
1016
DIV2KDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & downgrade,int32_t scale,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1017 DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1018 const std::vector<char> &downgrade, int32_t scale, bool decode,
1019 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1020 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1021 auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
1022 decode, sampler_obj, cache);
1023 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1024 }
1025
DIV2KDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & downgrade,int32_t scale,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1026 DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1027 const std::vector<char> &downgrade, int32_t scale, bool decode, const Sampler *sampler,
1028 const std::shared_ptr<DatasetCache> &cache) {
1029 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1030 auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
1031 decode, sampler_obj, cache);
1032 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1033 }
1034
DIV2KDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & downgrade,int32_t scale,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)1035 DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1036 const std::vector<char> &downgrade, int32_t scale, bool decode,
1037 const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
1038 auto sampler_obj = sampler.get().Parse();
1039 auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
1040 decode, sampler_obj, cache);
1041 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1042 }
1043
FlickrDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1044 FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1045 bool decode, const std::shared_ptr<Sampler> &sampler,
1046 const std::shared_ptr<DatasetCache> &cache) {
1047 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1048 auto ds =
1049 std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
1050 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1051 }
1052
FlickrDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1053 FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1054 bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1055 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1056 auto ds =
1057 std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
1058 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1059 }
1060
FlickrDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)1061 FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1062 bool decode, const std::reference_wrapper<Sampler> sampler,
1063 const std::shared_ptr<DatasetCache> &cache) {
1064 auto sampler_obj = sampler.get().Parse();
1065 auto ds =
1066 std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
1067 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1068 }
1069
ImageFolderDataset(const std::vector<char> & dataset_dir,bool decode,const std::shared_ptr<Sampler> & sampler,const std::set<std::vector<char>> & extensions,const std::map<std::vector<char>,int32_t> & class_indexing,const std::shared_ptr<DatasetCache> & cache)1070 ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
1071 const std::shared_ptr<Sampler> &sampler,
1072 const std::set<std::vector<char>> &extensions,
1073 const std::map<std::vector<char>, int32_t> &class_indexing,
1074 const std::shared_ptr<DatasetCache> &cache) {
1075 // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
1076 bool recursive = false;
1077
1078 // Create logical representation of ImageFolderDataset.
1079 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1080 auto ds = std::make_shared<ImageFolderNode>(CharToString(dataset_dir), decode, sampler_obj, recursive,
1081 SetCharToString(extensions), MapCharToString(class_indexing), cache);
1082 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1083 }
1084
ImageFolderDataset(const std::vector<char> & dataset_dir,bool decode,const Sampler * sampler,const std::set<std::vector<char>> & extensions,const std::map<std::vector<char>,int32_t> & class_indexing,const std::shared_ptr<DatasetCache> & cache)1085 ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
1086 const std::set<std::vector<char>> &extensions,
1087 const std::map<std::vector<char>, int32_t> &class_indexing,
1088 const std::shared_ptr<DatasetCache> &cache) {
1089 // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
1090 bool recursive = false;
1091
1092 // Create logical representation of ImageFolderDataset.
1093 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1094 auto ds = std::make_shared<ImageFolderNode>(CharToString(dataset_dir), decode, sampler_obj, recursive,
1095 SetCharToString(extensions), MapCharToString(class_indexing), cache);
1096 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1097 }
1098
ImageFolderDataset(const std::vector<char> & dataset_dir,bool decode,const std::reference_wrapper<Sampler> sampler,const std::set<std::vector<char>> & extensions,const std::map<std::vector<char>,int32_t> & class_indexing,const std::shared_ptr<DatasetCache> & cache)1099 ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
1100 const std::reference_wrapper<Sampler> sampler,
1101 const std::set<std::vector<char>> &extensions,
1102 const std::map<std::vector<char>, int32_t> &class_indexing,
1103 const std::shared_ptr<DatasetCache> &cache) {
1104 // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
1105 bool recursive = false;
1106
1107 // Create logical representation of ImageFolderDataset.
1108 auto sampler_obj = sampler.get().Parse();
1109 auto ds = std::make_shared<ImageFolderNode>(CharToString(dataset_dir), decode, sampler_obj, recursive,
1110 SetCharToString(extensions), MapCharToString(class_indexing), cache);
1111 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1112 }
1113
ManifestDataset(const std::vector<char> & dataset_file,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<DatasetCache> & cache)1114 ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
1115 const std::shared_ptr<Sampler> &sampler,
1116 const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
1117 const std::shared_ptr<DatasetCache> &cache) {
1118 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1119 auto ds = std::make_shared<ManifestNode>(CharToString(dataset_file), CharToString(usage), sampler_obj,
1120 MapCharToString(class_indexing), decode, cache);
1121 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1122 }
ManifestDataset(const std::vector<char> & dataset_file,const std::vector<char> & usage,const Sampler * sampler,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<DatasetCache> & cache)1123 ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
1124 const Sampler *sampler, const std::map<std::vector<char>, int32_t> &class_indexing,
1125 bool decode, const std::shared_ptr<DatasetCache> &cache) {
1126 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1127 auto ds = std::make_shared<ManifestNode>(CharToString(dataset_file), CharToString(usage), sampler_obj,
1128 MapCharToString(class_indexing), decode, cache);
1129 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1130 }
ManifestDataset(const std::vector<char> & dataset_file,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<DatasetCache> & cache)1131 ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
1132 const std::reference_wrapper<Sampler> sampler,
1133 const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
1134 const std::shared_ptr<DatasetCache> &cache) {
1135 auto sampler_obj = sampler.get().Parse();
1136 auto ds = std::make_shared<ManifestNode>(CharToString(dataset_file), CharToString(usage), sampler_obj,
1137 MapCharToString(class_indexing), decode, cache);
1138 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1139 }
1140
MindDataDataset(const std::vector<char> & dataset_file,const std::vector<std::vector<char>> & columns_list,const std::shared_ptr<Sampler> & sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1141 MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
1142 const std::vector<std::vector<char>> &columns_list,
1143 const std::shared_ptr<Sampler> &sampler, const nlohmann::json *padded_sample,
1144 int64_t num_padded, ShuffleMode shuffle_mode,
1145 const std::shared_ptr<DatasetCache> &cache) {
1146 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1147 nlohmann::json sample = nullptr;
1148 if (padded_sample) {
1149 sample = *padded_sample;
1150 }
1151 auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
1152 sample, num_padded, shuffle_mode, cache);
1153 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1154 }
1155
MindDataDataset(const std::vector<char> & dataset_file,const std::vector<std::vector<char>> & columns_list,const Sampler * sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1156 MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
1157 const std::vector<std::vector<char>> &columns_list, const Sampler *sampler,
1158 const nlohmann::json *padded_sample, int64_t num_padded, ShuffleMode shuffle_mode,
1159 const std::shared_ptr<DatasetCache> &cache) {
1160 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1161 nlohmann::json sample = nullptr;
1162 if (padded_sample) {
1163 sample = *padded_sample;
1164 }
1165 auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
1166 sample, num_padded, shuffle_mode, cache);
1167 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1168 }
1169
MindDataDataset(const std::vector<char> & dataset_file,const std::vector<std::vector<char>> & columns_list,const std::reference_wrapper<Sampler> sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1170 MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
1171 const std::vector<std::vector<char>> &columns_list,
1172 const std::reference_wrapper<Sampler> sampler, const nlohmann::json *padded_sample,
1173 int64_t num_padded, ShuffleMode shuffle_mode,
1174 const std::shared_ptr<DatasetCache> &cache) {
1175 auto sampler_obj = sampler.get().Parse();
1176 nlohmann::json sample = nullptr;
1177 if (padded_sample) {
1178 sample = *padded_sample;
1179 }
1180
1181 auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
1182 sample, num_padded, shuffle_mode, cache);
1183 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1184 }
1185
MindDataDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<std::vector<char>> & columns_list,const std::shared_ptr<Sampler> & sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1186 MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
1187 const std::vector<std::vector<char>> &columns_list,
1188 const std::shared_ptr<Sampler> &sampler, const nlohmann::json *padded_sample,
1189 int64_t num_padded, ShuffleMode shuffle_mode,
1190 const std::shared_ptr<DatasetCache> &cache) {
1191 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1192 nlohmann::json sample = nullptr;
1193 if (padded_sample) {
1194 sample = *padded_sample;
1195 }
1196
1197 auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
1198 sampler_obj, sample, num_padded, shuffle_mode, cache);
1199 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1200 }
1201
MindDataDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<std::vector<char>> & columns_list,const Sampler * sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1202 MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
1203 const std::vector<std::vector<char>> &columns_list, const Sampler *sampler,
1204 const nlohmann::json *padded_sample, int64_t num_padded, ShuffleMode shuffle_mode,
1205 const std::shared_ptr<DatasetCache> &cache) {
1206 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1207 nlohmann::json sample = nullptr;
1208 if (padded_sample) {
1209 sample = *padded_sample;
1210 }
1211
1212 auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
1213 sampler_obj, sample, num_padded, shuffle_mode, cache);
1214 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1215 }
1216
MindDataDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<std::vector<char>> & columns_list,const std::reference_wrapper<Sampler> sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1217 MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
1218 const std::vector<std::vector<char>> &columns_list,
1219 const std::reference_wrapper<Sampler> sampler, const nlohmann::json *padded_sample,
1220 int64_t num_padded, ShuffleMode shuffle_mode,
1221 const std::shared_ptr<DatasetCache> &cache) {
1222 auto sampler_obj = sampler.get().Parse();
1223 nlohmann::json sample = nullptr;
1224 if (padded_sample) {
1225 sample = *padded_sample;
1226 }
1227 auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
1228 sampler_obj, sample, num_padded, shuffle_mode, cache);
1229 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1230 }
1231 #endif
1232
MnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1233 MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1234 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1235 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1236 auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1237 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1238 }
MnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1239 MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
1240 const std::shared_ptr<DatasetCache> &cache) {
1241 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1242 auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1243 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1244 }
MnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)1245 MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1246 const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
1247 auto sampler_obj = sampler.get().Parse();
1248 auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1249 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1250 }
1251
1252 #ifndef ENABLE_ANDROID
TextFileDataset(const std::vector<std::vector<char>> & dataset_files,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1253 TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples,
1254 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1255 const std::shared_ptr<DatasetCache> &cache) {
1256 auto ds = std::make_shared<TextFileNode>(VectorCharToString(dataset_files), num_samples, shuffle, num_shards,
1257 shard_id, cache);
1258 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1259 }
1260
USPSDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1261 USPSDataset::USPSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
1262 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1263 const std::shared_ptr<DatasetCache> &cache) {
1264 auto ds = std::make_shared<USPSNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle, num_shards,
1265 shard_id, cache);
1266 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1267 }
1268
VOCDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache,bool extra_metadata)1269 VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1270 const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
1271 bool decode, const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache,
1272 bool extra_metadata) {
1273 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1274 auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1275 MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
1276 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1277 }
VOCDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache,bool extra_metadata)1278 VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1279 const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
1280 bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache,
1281 bool extra_metadata) {
1282 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1283 auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1284 MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
1285 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1286 }
VOCDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache,bool extra_metadata)1287 VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1288 const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
1289 bool decode, const std::reference_wrapper<Sampler> sampler,
1290 const std::shared_ptr<DatasetCache> &cache, bool extra_metadata) {
1291 auto sampler_obj = sampler.get().Parse();
1292 auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1293 MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
1294 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1295 } // namespace dataset
1296
RandomDataDataset(const int32_t & total_rows,std::shared_ptr<SchemaObj> schema,const std::vector<std::vector<char>> & columns_list,std::shared_ptr<DatasetCache> cache)1297 RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
1298 const std::vector<std::vector<char>> &columns_list,
1299 std::shared_ptr<DatasetCache> cache) {
1300 auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), VectorCharToString(columns_list), cache);
1301 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1302 }
RandomDataDataset(const int32_t & total_rows,const std::vector<char> & schema_path,const std::vector<std::vector<char>> & columns_list,std::shared_ptr<DatasetCache> cache)1303 RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vector<char> &schema_path,
1304 const std::vector<std::vector<char>> &columns_list,
1305 std::shared_ptr<DatasetCache> cache) {
1306 auto ds =
1307 std::make_shared<RandomNode>(total_rows, CharToString(schema_path), VectorCharToString(columns_list), cache);
1308 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1309 }
1310
SBUDataset(const std::vector<char> & dataset_dir,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1311 SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
1312 const std::shared_ptr<DatasetCache> &cache) {
1313 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1314 auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
1315 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1316 }
1317
SBUDataset(const std::vector<char> & dataset_dir,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1318 SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
1319 const std::shared_ptr<DatasetCache> &cache) {
1320 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1321 auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
1322 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1323 }
1324
SBUDataset(const std::vector<char> & dataset_dir,bool decode,const std::reference_wrapper<Sampler> sampler,const std::shared_ptr<DatasetCache> & cache)1325 SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler,
1326 const std::shared_ptr<DatasetCache> &cache) {
1327 auto sampler_obj = sampler.get().Parse();
1328 auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
1329 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1330 }
1331
TFRecordDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<char> & schema,const std::vector<std::vector<char>> & columns_list,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,bool shard_equal_rows,std::shared_ptr<DatasetCache> cache)1332 TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
1333 const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
1334 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
1335 std::shared_ptr<DatasetCache> cache) {
1336 auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), CharToString(schema),
1337 VectorCharToString(columns_list), num_samples, shuffle, num_shards, shard_id,
1338 shard_equal_rows, cache);
1339 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1340 }
TFRecordDataset(const std::vector<std::vector<char>> & dataset_files,std::shared_ptr<SchemaObj> schema,const std::vector<std::vector<char>> & columns_list,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,bool shard_equal_rows,std::shared_ptr<DatasetCache> cache)1341 TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, std::shared_ptr<SchemaObj> schema,
1342 const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
1343 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
1344 std::shared_ptr<DatasetCache> cache) {
1345 auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), schema, VectorCharToString(columns_list),
1346 num_samples, shuffle, num_shards, shard_id, shard_equal_rows, cache);
1347 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1348 }
1349
1350 #endif
1351 } // namespace dataset
1352 } // namespace mindspore
1353