1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/include/dataset/datasets.h"
18
19 #include <algorithm>
20 #include <fstream>
21 #include <unordered_set>
22 #include <utility>
23
24 #include <nlohmann/json.hpp>
25
26 #include "minddata/dataset/core/tensor.h"
27 #include "minddata/dataset/core/type_id.h"
28 #include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
29 #include "minddata/dataset/engine/consumers/tree_consumer.h"
30 #include "minddata/dataset/engine/runtime_context.h"
31 #include "minddata/dataset/include/dataset/constants.h"
32 #include "minddata/dataset/include/dataset/iterator.h"
33 #include "minddata/dataset/include/dataset/samplers.h"
34 #include "minddata/dataset/kernels/c_func_op.h"
35 #include "minddata/dataset/kernels/tensor_op.h"
36 #include "minddata/dataset/util/path.h"
37 #include "minddata/dataset/util/random.h"
38 #include "minddata/dataset/util/status.h"
39 #ifndef ENABLE_ANDROID
40 #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
41 #include "minddata/dataset/include/dataset/text.h"
42 #endif
43
44 // Sampler headers (in alphabetical order)
45 #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
46
47 // IR dataset node
48 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
49
50 // IR non-leaf nodes
51 #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
52 #ifndef ENABLE_ANDROID
53 #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
54 #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
55 #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
56 #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
57 #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
58 #endif
59 #include "minddata/dataset/engine/ir/datasetops/map_node.h"
60 #include "minddata/dataset/engine/ir/datasetops/project_node.h"
61 #ifndef ENABLE_ANDROID
62 #include "minddata/dataset/engine/ir/datasetops/rename_node.h"
63 #endif
64 #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
65 #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
66 #ifndef ENABLE_ANDROID
67 #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
68 #include "minddata/dataset/engine/ir/datasetops/take_node.h"
69 #include "minddata/dataset/engine/ir/datasetops/data_queue_node.h"
70 #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
71 #endif
72
73 // IR leaf nodes
74 #include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
75 #include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
76 #ifndef ENABLE_ANDROID
77 #include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h"
78 #include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h"
79 #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
80 #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
81 #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
82 #include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
83 #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
84 #include "minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.h"
85 #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
86 #include "minddata/dataset/engine/ir/datasetops/source/conll2000_node.h"
87 #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
88 #include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
89 #include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
90 #include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
91 #include "minddata/dataset/engine/ir/datasetops/source/en_wik9_node.h"
92 #include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
93 #include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
94 #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
95 #include "minddata/dataset/engine/ir/datasetops/source/food101_node.h"
96 #include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h"
97 #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
98 #include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
99 #include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
100 #include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h"
101 #include "minddata/dataset/engine/ir/datasetops/source/kitti_node.h"
102 #include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
103 #include "minddata/dataset/engine/ir/datasetops/source/lfw_node.h"
104 #include "minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h"
105 #include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
106 #include "minddata/dataset/engine/ir/datasetops/source/lsun_node.h"
107 #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
108 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
109 #endif
110 #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
111 #ifndef ENABLE_ANDROID
112 #include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h"
113 #include "minddata/dataset/engine/ir/datasetops/source/omniglot_node.h"
114 #include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
115 #include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
116 #include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
117 #include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
118 #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
119 #include "minddata/dataset/engine/ir/datasetops/source/rendered_sst2_node.h"
120 #include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
121 #include "minddata/dataset/engine/ir/datasetops/source/semeion_node.h"
122 #include "minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h"
123 #include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
124 #include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
125 #include "minddata/dataset/engine/ir/datasetops/source/sst2_node.h"
126 #include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
127 #include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
128 #include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
129 #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
130 #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
131 #include "minddata/dataset/engine/ir/datasetops/source/udpos_node.h"
132 #include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
133 #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
134 #include "minddata/dataset/engine/ir/datasetops/source/wider_face_node.h"
135 #include "minddata/dataset/engine/ir/datasetops/source/wiki_text_node.h"
136 #include "minddata/dataset/engine/ir/datasetops/source/yahoo_answers_node.h"
137 #include "minddata/dataset/engine/ir/datasetops/source/yelp_review_node.h"
138 #include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h"
139 #endif
140
141 namespace mindspore {
142 namespace dataset {
143 // convert MSTensorVec to DE TensorRow, return empty if fails
VecToRow(const MSTensorVec & v)144 TensorRow VecToRow(const MSTensorVec &v) {
145 TensorRow row;
146 row.reserve(v.size());
147 for (const MSTensor &t : v) {
148 std::shared_ptr<Tensor> rt;
149 Status rc = Tensor::CreateFromMSTensor(t, &rt);
150 if (rc.IsError()) {
151 MS_LOG(ERROR) << "Convert from MSTensor to DETensor failed:" << rc.ToString() << ".";
152 return {};
153 }
154 row.emplace_back(rt);
155 }
156 return row;
157 }
158
159 // convert DE TensorRow to MSTensorVec, won't fail
RowToVec(const TensorRow & v)160 MSTensorVec RowToVec(const TensorRow &v) {
161 MSTensorVec rv;
162 rv.reserve(v.size());
163 std::transform(v.begin(), v.end(), std::back_inserter(rv), [](const std::shared_ptr<Tensor> &t) -> MSTensor {
164 return mindspore::MSTensor(std::make_shared<DETensor>(t));
165 });
166 return rv;
167 }
168
169 // Convert a std::function<TensorRow(TensorRow)> to std::function<MSTensorVec(MSTensor)> with this helper
FuncPtrConverter(const std::function<MSTensorVec (MSTensorVec)> & func,const TensorRow & in_row)170 TensorRow FuncPtrConverter(const std::function<MSTensorVec(MSTensorVec)> &func, const TensorRow &in_row) {
171 return VecToRow(func(RowToVec(in_row)));
172 }
173
174 // Function to create the iterator, which will build and launch the execution tree.
CreateIteratorCharIF(int32_t num_epochs)175 std::shared_ptr<Iterator> Dataset::CreateIteratorCharIF(int32_t num_epochs) {
176 std::shared_ptr<Iterator> iter;
177 try {
178 auto ds = shared_from_this();
179 iter = std::make_shared<Iterator>();
180 Status rc = iter->BuildAndLaunchTree(ds, num_epochs);
181 if (rc.IsError()) {
182 MS_LOG(ERROR) << "CreateIterator failed." << rc;
183 return nullptr;
184 }
185 } catch (const std::exception &err) {
186 MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what();
187 return nullptr;
188 }
189
190 return iter;
191 }
192
193 // Function to create the iterator, which will build and launch the execution tree.
CreatePullBasedIterator()194 std::shared_ptr<PullIterator> Dataset::CreatePullBasedIterator() {
195 auto ds = shared_from_this();
196 std::shared_ptr<PullIterator> iter = std::make_shared<PullIterator>();
197 Status rc = iter->BuildAndLaunchTree(ds, 1);
198 if (rc.IsError()) {
199 MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << rc;
200 }
201 RETURN_SECOND_IF_ERROR(rc, nullptr);
202 return iter;
203 }
204
205 // Function to return a transferred Node that transfers data through a device.
DeviceQueueCharIF(const std::vector<char> & queue_name,const std::vector<char> & device_type,int32_t device_id,int32_t num_epochs,bool send_epoch_end,int32_t total_batches,bool create_data_info_queue)206 bool Dataset::DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type,
207 int32_t device_id, int32_t num_epochs, bool send_epoch_end, int32_t total_batches,
208 bool create_data_info_queue) {
209 #ifndef ENABLE_ANDROID
210 Status rc;
211
212 // Build and launch tree
213 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
214 rc = runtime_context->Init();
215 if (rc.IsError()) {
216 MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
217 return false;
218 }
219
220 // Add DataQueueNode IR on top of dataset
221 auto ds =
222 std::make_shared<DataQueueNode>(shared_from_this()->IRNode(), CharToString(queue_name), CharToString(device_type),
223 device_id, send_epoch_end, total_batches, create_data_info_queue);
224
225 // Get ToDevice consumer
226 auto consumer = std::make_unique<ToDevice>(num_epochs);
227 ToDevice *consumer_ptr = consumer.get();
228 if (consumer_ptr == nullptr) {
229 MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
230 return false;
231 }
232 rc = consumer->Init(ds);
233 if (rc.IsError()) {
234 MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
235 return false;
236 }
237 runtime_context->AssignConsumer(std::move(consumer));
238
239 // Send data to device
240 rc = consumer_ptr->Send();
241 if (rc.IsError()) {
242 MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc;
243 return false;
244 }
245
246 return true;
247 #else
248 MS_LOG(ERROR) << "DeviceQueueCharIF is not support for Android.";
249 return false;
250 #endif
251 }
252
253 #ifndef ENABLE_ANDROID
254 // Function to create the saver, which will build and launch the execution tree and save data
SaveCharIF(const std::vector<char> & dataset_path,int32_t num_files,const std::vector<char> & dataset_type)255 bool Dataset::SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files,
256 const std::vector<char> &dataset_type) {
257 Status rc;
258 // Build and launch tree
259 auto ds = shared_from_this();
260 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
261 rc = runtime_context->Init();
262 if (rc.IsError()) {
263 MS_LOG(ERROR) << "CreateSaver failed." << rc;
264 return false;
265 }
266
267 // Get SaveToDisk consumer
268 auto consumer = std::make_unique<SaveToDisk>(CharToString(dataset_path), num_files, CharToString(dataset_type));
269 rc = consumer->ValidateParams();
270 if (rc.IsError()) {
271 MS_LOG(ERROR) << "CreateSaver failed." << rc;
272 return false;
273 }
274 SaveToDisk *consumer_ptr = consumer.get();
275 if (consumer_ptr == nullptr) {
276 MS_LOG(ERROR) << "ToDevice: Failed to get consumer.";
277 return false;
278 }
279 rc = consumer->Init(ds->IRNode());
280 if (rc.IsError()) {
281 MS_LOG(ERROR) << "CreateSaver failed." << rc;
282 return false;
283 }
284
285 runtime_context->AssignConsumer(std::move(consumer));
286
287 // Save data into file
288 rc = consumer_ptr->Save();
289 if (rc.IsError()) {
290 MS_LOG(ERROR) << "Saver: Failed to save data into file. Error status: " << rc;
291 return false;
292 }
293
294 // Shut down the data pipeline
295 rc = runtime_context->Terminate();
296 if (rc.IsError()) {
297 MS_LOG(ERROR) << "Saver: Failed to shut down pipeline. Error status: " << rc;
298 return false;
299 }
300
301 return true;
302 }
303 #endif
304
305 // Constructor
Dataset()306 Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
307
GetDatasetSize(bool estimate)308 int64_t Dataset::GetDatasetSize(bool estimate) {
309 int64_t dataset_size = -1;
310 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
311 RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
312 std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
313 DatasetSizeGetter *consumer = size_getter.get();
314 if (consumer == nullptr) {
315 MS_LOG(ERROR) << "DatasetSizeGetter: Failed to get consumer.";
316 return -1;
317 }
318 runtime_context->AssignConsumer(size_getter);
319 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
320 RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1);
321 return dataset_size;
322 }
323
GetOutputTypes()324 std::vector<mindspore::DataType> Dataset::GetOutputTypes() {
325 std::vector<DataType> types;
326 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
327 RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
328 TreeGetters *consumer = tree_getters_.get();
329 if (consumer == nullptr) {
330 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
331 return std::vector<mindspore::DataType>();
332 }
333 runtime_context->AssignConsumer(tree_getters_);
334 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
335 RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {});
336 std::vector<mindspore::DataType> ret_types;
337 std::transform(
338 types.begin(), types.end(), std::back_inserter(ret_types),
339 [](const DataType &d) -> mindspore::DataType { return static_cast<mindspore::DataType>(DETypeToMSType(d)); });
340 return ret_types;
341 }
342
GetOutputShapes()343 std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
344 std::vector<TensorShape> shapes;
345 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
346 RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
347 TreeGetters *consumer = tree_getters_.get();
348 if (consumer == nullptr) {
349 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
350 return std::vector<std::vector<int64_t>>();
351 }
352 runtime_context->AssignConsumer(tree_getters_);
353 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
354 RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {});
355 std::vector<std::vector<int64_t>> ret_shapes;
356 std::transform(shapes.begin(), shapes.end(), std::back_inserter(ret_shapes),
357 [](const TensorShape &s) -> std::vector<int64_t> { return s.AsVector(); });
358 return ret_shapes;
359 }
360
GetNumClasses()361 int64_t Dataset::GetNumClasses() {
362 int64_t num_classes = -1;
363 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
364 RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
365 TreeGetters *consumer = tree_getters_.get();
366 if (consumer == nullptr) {
367 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
368 return -1;
369 }
370 runtime_context->AssignConsumer(tree_getters_);
371 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
372 RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1);
373 return num_classes;
374 }
375
GetColumnNamesCharIF()376 std::vector<std::vector<char>> Dataset::GetColumnNamesCharIF() {
377 std::vector<std::string> col_names;
378 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
379 RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
380 TreeGetters *consumer = tree_getters_.get();
381 if (consumer == nullptr) {
382 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
383 return std::vector<std::vector<char>>();
384 }
385 runtime_context->AssignConsumer(tree_getters_);
386 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
387 RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {});
388 return VectorStringToChar(col_names);
389 }
390
GetClassIndexingCharIF()391 std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> Dataset::GetClassIndexingCharIF() {
392 std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
393 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
394 RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
395 TreeGetters *consumer = tree_getters_.get();
396 if (consumer == nullptr) {
397 MS_LOG(ERROR) << "TreeGetters: Failed to get consumer.";
398 return std::vector<std::pair<std::vector<char>, std::vector<int32_t>>>();
399 }
400 runtime_context->AssignConsumer(tree_getters_);
401 RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
402 RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {});
403 return ClassIndexStringToChar(output_class_indexing);
404 }
405
406 /// \brief Function to create a SchemaObj
407 /// \param[in] schema_file Path of schema file
408 /// \return Shared pointer to the current schema
SchemaCharIF(const std::vector<char> & schema_file)409 std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file) {
410 auto schema = std::make_shared<SchemaObj>(CharToString(schema_file));
411 return schema->Init() ? schema : nullptr;
412 }
413
414 // FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
415 // (In alphabetical order)
416
417 // Function to create a Batch dataset
BatchDataset(const std::shared_ptr<Dataset> & input,int32_t batch_size,bool drop_remainder)418 BatchDataset::BatchDataset(const std::shared_ptr<Dataset> &input, int32_t batch_size, bool drop_remainder) {
419 // Default values
420 if (input == nullptr) {
421 ir_node_ = nullptr;
422 } else {
423 auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder);
424 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
425 }
426 }
427
428 #ifndef ENABLE_ANDROID
429 // Function to create a BucketBatchByLength dataset
BucketBatchByLengthDataset(const std::shared_ptr<Dataset> & input,const std::vector<std::vector<char>> & column_names,const std::vector<int32_t> & bucket_boundaries,const std::vector<int32_t> & bucket_batch_sizes,const std::function<MSTensorVec (MSTensorVec)> & element_length_function,const std::map<std::vector<char>,std::pair<std::vector<int64_t>,MSTensor>> & pad_info,bool pad_to_bucket_boundary,bool drop_remainder)430 BucketBatchByLengthDataset::BucketBatchByLengthDataset(
431 const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &column_names,
432 const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
433 const std::function<MSTensorVec(MSTensorVec)> &element_length_function,
434 const std::map<std::vector<char>, std::pair<std::vector<int64_t>, MSTensor>> &pad_info, bool pad_to_bucket_boundary,
435 bool drop_remainder) {
436 std::shared_ptr<TensorOp> c_func = nullptr;
437 if (element_length_function != nullptr) {
438 c_func = std::make_shared<CFuncOp>(std::bind(FuncPtrConverter, element_length_function, std::placeholders::_1));
439 }
440
441 std::map<std::vector<char>, std::pair<TensorShape, std::shared_ptr<Tensor>>> map;
442 for (auto const &p : pad_info) {
443 const MSTensor &t = p.second.second;
444 std::shared_ptr<Tensor> rt;
445 Status rc = Tensor::CreateFromMemory(TensorShape(t.Shape()), MSTypeToDEType(static_cast<TypeId>(t.DataType())),
446 (const uchar *)(t.Data().get()), t.DataSize(), &rt);
447 if (rc.IsError()) {
448 MS_LOG(ERROR) << "Fail to create DETensor from MSTensor for pad_info: " << rc.ToString() << ".";
449 map.clear();
450 break;
451 }
452 map.insert({p.first, {TensorShape(p.second.first), rt}});
453 }
454 if (input == nullptr) {
455 ir_node_ = nullptr;
456 } else {
457 auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), VectorCharToString(column_names),
458 bucket_boundaries, bucket_batch_sizes, c_func,
459 MapCharToString(map), pad_to_bucket_boundary, drop_remainder);
460
461 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
462 }
463 }
464
ConcatDataset(const std::vector<std::shared_ptr<Dataset>> & datasets)465 ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
466 std::vector<std::shared_ptr<DatasetNode>> all_datasets;
467 (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
468 [](const std::shared_ptr<Dataset> &dataset) -> std::shared_ptr<DatasetNode> {
469 return (dataset != nullptr) ? dataset->IRNode() : nullptr;
470 });
471 auto ds = std::make_shared<ConcatNode>(all_datasets);
472
473 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
474 }
475
FilterDataset(const std::shared_ptr<Dataset> & input,const std::function<MSTensorVec (MSTensorVec)> & predicate,const std::vector<std::vector<char>> & input_columns)476 FilterDataset::FilterDataset(const std::shared_ptr<Dataset> &input,
477 const std::function<MSTensorVec(MSTensorVec)> &predicate,
478 const std::vector<std::vector<char>> &input_columns) {
479 std::shared_ptr<TensorOp> c_func = nullptr;
480 if (predicate) {
481 c_func = std::make_shared<CFuncOp>(std::bind(FuncPtrConverter, predicate, std::placeholders::_1));
482 }
483 if (input == nullptr) {
484 ir_node_ = nullptr;
485 } else {
486 auto ds = std::make_shared<FilterNode>(input->IRNode(), c_func, VectorCharToString(input_columns));
487 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
488 }
489 }
490 #endif
491
MapDataset(const std::shared_ptr<Dataset> & input,const std::vector<std::shared_ptr<TensorOperation>> & operations,const std::vector<std::vector<char>> & input_columns,const std::vector<std::vector<char>> & output_columns,const std::shared_ptr<DatasetCache> & cache,const std::vector<std::shared_ptr<DSCallback>> & callbacks)492 MapDataset::MapDataset(const std::shared_ptr<Dataset> &input,
493 const std::vector<std::shared_ptr<TensorOperation>> &operations,
494 const std::vector<std::vector<char>> &input_columns,
495 const std::vector<std::vector<char>> &output_columns, const std::shared_ptr<DatasetCache> &cache,
496 const std::vector<std::shared_ptr<DSCallback>> &callbacks) {
497 if (input == nullptr) {
498 ir_node_ = nullptr;
499 } else {
500 auto ds = std::make_shared<MapNode>(input->IRNode(), operations, VectorCharToString(input_columns),
501 VectorCharToString(output_columns), cache, callbacks);
502
503 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
504 }
505 }
506
ProjectDataset(const std::shared_ptr<Dataset> & input,const std::vector<std::vector<char>> & columns)507 ProjectDataset::ProjectDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &columns) {
508 if (input == nullptr) {
509 ir_node_ = nullptr;
510 } else {
511 auto ds = std::make_shared<ProjectNode>(input->IRNode(), VectorCharToString(columns));
512
513 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
514 }
515 }
516
517 #ifndef ENABLE_ANDROID
RenameDataset(const std::shared_ptr<Dataset> & input,const std::vector<std::vector<char>> & input_columns,const std::vector<std::vector<char>> & output_columns)518 RenameDataset::RenameDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &input_columns,
519 const std::vector<std::vector<char>> &output_columns) {
520 if (input == nullptr) {
521 ir_node_ = nullptr;
522 } else {
523 auto ds = std::make_shared<RenameNode>(input->IRNode(), VectorCharToString(input_columns),
524 VectorCharToString(output_columns));
525
526 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
527 }
528 }
529 #endif
530
RepeatDataset(const std::shared_ptr<Dataset> & input,int32_t count)531 RepeatDataset::RepeatDataset(const std::shared_ptr<Dataset> &input, int32_t count) {
532 if (input == nullptr) {
533 ir_node_ = nullptr;
534 } else {
535 auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);
536
537 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
538 }
539 }
540
ShuffleDataset(const std::shared_ptr<Dataset> & input,int32_t buffer_size)541 ShuffleDataset::ShuffleDataset(const std::shared_ptr<Dataset> &input, int32_t buffer_size) {
542 if (input == nullptr) {
543 ir_node_ = nullptr;
544 } else {
545 // Pass in reshuffle_each_epoch with true
546 auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true);
547
548 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
549 }
550 }
551
552 #ifndef ENABLE_ANDROID
SkipDataset(const std::shared_ptr<Dataset> & input,int32_t count)553 SkipDataset::SkipDataset(const std::shared_ptr<Dataset> &input, int32_t count) {
554 if (input == nullptr) {
555 ir_node_ = nullptr;
556 } else {
557 auto ds = std::make_shared<SkipNode>(input->IRNode(), count);
558
559 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
560 }
561 }
562
TakeDataset(const std::shared_ptr<Dataset> & input,int32_t count)563 TakeDataset::TakeDataset(const std::shared_ptr<Dataset> &input, int32_t count) {
564 if (input == nullptr) {
565 ir_node_ = nullptr;
566 } else {
567 auto ds = std::make_shared<TakeNode>(input->IRNode(), count);
568
569 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
570 }
571 }
572
ZipDataset(const std::vector<std::shared_ptr<Dataset>> & datasets)573 ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
574 std::vector<std::shared_ptr<DatasetNode>> all_datasets;
575 (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
576 [](const std::shared_ptr<Dataset> &dataset) -> std::shared_ptr<DatasetNode> {
577 return (dataset != nullptr) ? dataset->IRNode() : nullptr;
578 });
579 auto ds = std::make_shared<ZipNode>(all_datasets);
580
581 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
582 }
583 #endif
584
GetBatchSize()585 int64_t Dataset::GetBatchSize() {
586 int64_t batch_size = -1;
587 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
588 RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
589 RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
590 RETURN_SECOND_IF_ERROR(tree_getters_->GetBatchSize(&batch_size), -1);
591 return batch_size;
592 }
593
GetRepeatCount()594 int64_t Dataset::GetRepeatCount() {
595 int64_t repeat_count = 0;
596 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
597 RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
598 RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0);
599 RETURN_SECOND_IF_ERROR(tree_getters_->GetRepeatCount(&repeat_count), 0);
600 return repeat_count;
601 }
602
SetNumWorkers(int32_t num_workers)603 std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
604 if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) {
605 return nullptr;
606 }
607 return shared_from_this();
608 }
609
610 #ifndef ENABLE_ANDROID
BuildSentencePieceVocabCharIF(const std::vector<std::vector<char>> & col_names,int32_t vocab_size,float character_coverage,SentencePieceModel model_type,const std::map<std::vector<char>,std::vector<char>> & params)611 std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocabCharIF(
612 const std::vector<std::vector<char>> &col_names, int32_t vocab_size, float character_coverage,
613 SentencePieceModel model_type, const std::map<std::vector<char>, std::vector<char>> ¶ms) {
614 auto vocab = std::make_shared<SentencePieceVocab>();
615 auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, VectorCharToString(col_names), vocab_size,
616 character_coverage, model_type, UnorderedMapCharToString(params));
617
618 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
619 Status rc = runtime_context->Init();
620 if (rc.IsError()) {
621 MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
622 return nullptr;
623 }
624
625 auto consumer = std::make_unique<BuildVocabConsumer>();
626 BuildVocabConsumer *bv_consumer = consumer.get();
627 if (bv_consumer == nullptr) {
628 MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
629 return nullptr;
630 }
631 rc = consumer->Init(ds);
632 if (rc.IsError()) {
633 MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
634 return nullptr;
635 }
636 runtime_context->AssignConsumer(std::move(consumer));
637
638 // Run tree here to starting building SentencePieceVocab
639 rc = bv_consumer->Start();
640 if (rc.IsError()) {
641 MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to start consumer. Error status: " << rc;
642 return nullptr;
643 }
644 return vocab;
645 }
646
BuildVocabCharIF(const std::vector<std::vector<char>> & columns,const std::pair<int64_t,int64_t> & freq_range,int64_t top_k,const std::vector<std::vector<char>> & special_tokens,bool special_first)647 std::shared_ptr<Vocab> Dataset::BuildVocabCharIF(const std::vector<std::vector<char>> &columns,
648 const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
649 const std::vector<std::vector<char>> &special_tokens,
650 bool special_first) {
651 auto vocab = std::make_shared<Vocab>();
652 auto ds = std::make_shared<BuildVocabNode>(IRNode(), vocab, VectorCharToString(columns), freq_range, top_k,
653 VectorCharToString(special_tokens), special_first);
654
655 std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
656 Status rc = runtime_context->Init();
657 if (rc.IsError()) {
658 MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;
659 return nullptr;
660 }
661
662 auto consumer = std::make_unique<BuildVocabConsumer>();
663 BuildVocabConsumer *bv_consumer = consumer.get();
664 if (bv_consumer == nullptr) {
665 MS_LOG(ERROR) << "BuildVocabConsumer: Failed to get bv_consumer.";
666 return nullptr;
667 }
668 rc = consumer->Init(ds);
669 if (rc.IsError()) {
670 MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;
671 return nullptr;
672 }
673 runtime_context->AssignConsumer(std::move(consumer));
674
675 // Run tree here to starting building vocab
676 rc = bv_consumer->Start();
677 if (rc.IsError()) {
678 MS_LOG(ERROR) << "BuildVocab: Failed to start consumer. Error status: " << rc;
679 return nullptr;
680 }
681 return vocab;
682 }
683 #endif
684
Batch(int32_t batch_size,bool drop_remainder)685 std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
686 return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder);
687 }
688
689 struct SchemaObj::Data {
690 int32_t num_rows_;
691 std::string dataset_type_;
692 std::string schema_file_;
693 nlohmann::json columns_;
694 };
695
SchemaObj(const std::vector<char> & schema_file)696 SchemaObj::SchemaObj(const std::vector<char> &schema_file) : data_(std::make_shared<Data>()) {
697 data_->schema_file_ = CharToString(schema_file);
698 data_->dataset_type_ = "";
699 data_->num_rows_ = 0;
700 }
701
702 // SchemaObj Init function
Init()703 Status SchemaObj::Init() {
704 if (data_ != nullptr && !data_->schema_file_.empty()) {
705 std::string real_path;
706 RETURN_IF_NOT_OK(Path::RealPath(data_->schema_file_, real_path));
707 Path schema_file(real_path);
708 CHECK_FAIL_RETURN_UNEXPECTED(schema_file.Exists(),
709 "The file " + data_->schema_file_ + " does not exist or permission denied!");
710
711 nlohmann::json js;
712 try {
713 std::ifstream in(real_path, std::ifstream::in);
714 in >> js;
715 CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
716 "\"columns\" node is required in the schema json file.");
717 in.close();
718 } catch (const std::exception &err) {
719 std::string err_msg = "Schema file failed to load: ";
720 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
721 }
722 return from_json(js);
723 }
724 return Status::OK();
725 }
726
727 // Function to add a column to schema with a mstype de_type and known shape
add_column_char(const std::vector<char> & name,mindspore::DataType de_type,const std::vector<int32_t> & shape)728 Status SchemaObj::add_column_char(const std::vector<char> &name, mindspore::DataType de_type,
729 const std::vector<int32_t> &shape) {
730 DataType data_type = dataset::MSTypeToDEType(static_cast<TypeId>(de_type));
731 return add_column_char(name, StringToChar(data_type.ToString()), shape);
732 }
733
734 // Function to add a column to schema with a string de_type and known shape
add_column_char(const std::vector<char> & name,const std::vector<char> & de_type,const std::vector<int32_t> & shape)735 Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vector<char> &de_type,
736 const std::vector<int32_t> &shape) {
737 DataType data_type(CharToString(de_type));
738 CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
739
740 nlohmann::json new_column;
741 new_column["name"] = CharToString(name);
742 new_column["type"] = data_type.ToString();
743 new_column["shape"] = shape;
744 new_column["rank"] = shape.size();
745
746 data_->columns_.push_back(new_column);
747 return Status::OK();
748 }
749
750 // Function to add a column to schema with a mstype de_type and without shape
add_column_char(const std::vector<char> & name,mindspore::DataType de_type)751 Status SchemaObj::add_column_char(const std::vector<char> &name, mindspore::DataType de_type) {
752 DataType data_type = dataset::MSTypeToDEType(static_cast<TypeId>(de_type));
753 return add_column_char(name, StringToChar(data_type.ToString()));
754 }
755
756 // Function to add a column to schema with a string de_type and without shape
add_column_char(const std::vector<char> & name,const std::vector<char> & de_type)757 Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vector<char> &de_type) {
758 DataType data_type(CharToString(de_type));
759 CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
760
761 nlohmann::json new_column;
762 new_column["name"] = CharToString(name);
763 new_column["type"] = data_type.ToString();
764 new_column["rank"] = 1;
765
766 data_->columns_.push_back(new_column);
767 return Status::OK();
768 }
769
schema_to_json(nlohmann::json * out_json)770 Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
771 RETURN_UNEXPECTED_IF_NULL(out_json);
772 nlohmann::json json_file;
773 json_file["columns"] = data_->columns_;
774 std::string str_dataset_type_(data_->dataset_type_);
775 if (!str_dataset_type_.empty()) {
776 json_file["datasetType"] = str_dataset_type_;
777 }
778
779 if (data_->num_rows_ > 0) {
780 json_file["numRows"] = data_->num_rows_;
781 }
782 *out_json = json_file;
783 return Status::OK();
784 }
785
to_json_char()786 std::vector<char> SchemaObj::to_json_char() {
787 nlohmann::json json_file;
788 this->schema_to_json(&json_file);
789 return StringToChar(json_file.dump(2));
790 }
791
set_dataset_type(const std::string & dataset_type)792 void SchemaObj::set_dataset_type(const std::string &dataset_type) { data_->dataset_type_ = dataset_type; }
793
set_num_rows(int32_t num_rows)794 void SchemaObj::set_num_rows(int32_t num_rows) { data_->num_rows_ = num_rows; }
795
get_num_rows() const796 int32_t SchemaObj::get_num_rows() const { return data_->num_rows_; }
797
parse_column(nlohmann::json columns)798 Status SchemaObj::parse_column(nlohmann::json columns) {
799 std::string name, de_type;
800 std::vector<int32_t> shape;
801
802 data_->columns_.clear();
803 if (columns.type() == nlohmann::json::value_t::array) {
804 // reference to python list
805 for (auto column : columns) {
806 auto key_name = column.find("name");
807 if (key_name == column.end()) {
808 LOG_AND_RETURN_STATUS_SYNTAX_ERROR("Column's name is missing");
809 }
810 name = *key_name;
811
812 auto key_type = column.find("type");
813 if (key_type == column.end()) {
814 LOG_AND_RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
815 }
816 de_type = *key_type;
817
818 shape.clear();
819 auto key_shape = column.find("shape");
820 if (key_shape != column.end()) {
821 shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
822 }
823 RETURN_IF_NOT_OK(add_column(name, de_type, shape));
824 }
825 } else if (columns.type() == nlohmann::json::value_t::object) {
826 for (const auto &it_child : columns.items()) {
827 name = it_child.key();
828 auto key_type = it_child.value().find("type");
829 if (key_type == it_child.value().end()) {
830 LOG_AND_RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
831 }
832 de_type = *key_type;
833
834 shape.clear();
835 auto key_shape = it_child.value().find("shape");
836 if (key_shape != it_child.value().end()) {
837 shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
838 }
839
840 RETURN_IF_NOT_OK(add_column(name, de_type, shape));
841 }
842 } else {
843 LOG_AND_RETURN_STATUS_SYNTAX_ERROR("columns must be dict or list, columns contain name, type, shape(optional).");
844 }
845 return Status::OK();
846 }
847
from_json(nlohmann::json json_obj)848 Status SchemaObj::from_json(nlohmann::json json_obj) {
849 for (const auto &it_child : json_obj.items()) {
850 if (it_child.key() == "datasetType") {
851 std::string str_dataset_type_ = it_child.value();
852 data_->dataset_type_ = str_dataset_type_;
853 } else if (it_child.key() == "numRows") {
854 data_->num_rows_ = it_child.value();
855 } else if (it_child.key() == "columns") {
856 RETURN_IF_NOT_OK(parse_column(it_child.value()));
857 } else {
858 LOG_AND_RETURN_STATUS_SYNTAX_ERROR("Unknown field " + it_child.key());
859 }
860 }
861 if (data_->columns_.empty()) {
862 LOG_AND_RETURN_STATUS_SYNTAX_ERROR("Columns are missing.");
863 }
864 if (data_->num_rows_ < 0) {
865 LOG_AND_RETURN_STATUS_SYNTAX_ERROR("numRows must be greater than or equal to 0");
866 }
867
868 return Status::OK();
869 }
870
FromJSONStringCharIF(const std::vector<char> & json_string)871 Status SchemaObj::FromJSONStringCharIF(const std::vector<char> &json_string) {
872 try {
873 nlohmann::json js = nlohmann::json::parse(CharToString(json_string));
874 CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
875 "\"columns\" node is required in the schema json JSON.");
876 RETURN_IF_NOT_OK(from_json(js));
877 } catch (const std::exception &err) {
878 std::string err_msg = "FromJSONString: JSON string failed to parse: ";
879 err_msg += err.what();
880 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
881 }
882 return Status::OK();
883 }
884
ParseColumnStringCharIF(const std::vector<char> & json_string)885 Status SchemaObj::ParseColumnStringCharIF(const std::vector<char> &json_string) {
886 try {
887 nlohmann::json js = nlohmann::json::parse(CharToString(json_string));
888 RETURN_IF_NOT_OK(parse_column(js));
889 } catch (const std::exception &err) {
890 std::string err_msg = "ParseColumnString: JSON string failed to parse: ";
891 err_msg += err.what();
892 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
893 }
894 return Status::OK();
895 }
896
897 // OTHER FUNCTIONS
898
899 #ifndef ENABLE_ANDROID
900
CreateDatasetCacheCharIF(session_id_type id,uint64_t mem_sz,bool spill,const std::optional<std::vector<char>> & hostname,const std::optional<int32_t> & port,const std::optional<int32_t> & num_connections,const std::optional<int32_t> & prefetch_sz)901 std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(session_id_type id, uint64_t mem_sz, bool spill,
902 const std::optional<std::vector<char>> &hostname,
903 const std::optional<int32_t> &port,
904 const std::optional<int32_t> &num_connections,
905 const std::optional<int32_t> &prefetch_sz) {
906 auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
907 return cache;
908 }
909
AGNewsDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)910 AGNewsDataset::AGNewsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
911 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
912 const std::shared_ptr<DatasetCache> &cache) {
913 auto ds = std::make_shared<AGNewsNode>(CharToString(dataset_dir), num_samples, shuffle, CharToString(usage),
914 num_shards, shard_id, cache);
915 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
916 }
917 #endif
918
AlbumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & data_schema,const std::vector<std::vector<char>> & column_names,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)919 AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
920 const std::vector<std::vector<char>> &column_names, bool decode,
921 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
922 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
923 auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
924 VectorCharToString(column_names), decode, sampler_obj, cache);
925 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
926 }
927
AlbumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & data_schema,const std::vector<std::vector<char>> & column_names,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)928 AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
929 const std::vector<std::vector<char>> &column_names, bool decode, const Sampler *sampler,
930 const std::shared_ptr<DatasetCache> &cache) {
931 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
932 auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
933 VectorCharToString(column_names), decode, sampler_obj, cache);
934 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
935 }
936
AlbumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & data_schema,const std::vector<std::vector<char>> & column_names,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)937 AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
938 const std::vector<std::vector<char>> &column_names, bool decode,
939 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
940 auto sampler_obj = sampler.get().Parse();
941 auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
942 VectorCharToString(column_names), decode, sampler_obj, cache);
943 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
944 }
945
946 #ifndef ENABLE_ANDROID
AmazonReviewDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)947 AmazonReviewDataset::AmazonReviewDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
948 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
949 const std::shared_ptr<DatasetCache> &cache) {
950 auto ds = std::make_shared<AmazonReviewNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
951 num_shards, shard_id, cache);
952 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
953 }
954
Caltech256Dataset(const std::vector<char> & dataset_dir,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)955 Caltech256Dataset::Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode,
956 const std::shared_ptr<Sampler> &sampler,
957 const std::shared_ptr<DatasetCache> &cache) {
958 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
959 auto ds = std::make_shared<Caltech256Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
960 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
961 }
962
Caltech256Dataset(const std::vector<char> & dataset_dir,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)963 Caltech256Dataset::Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
964 const std::shared_ptr<DatasetCache> &cache) {
965 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
966 auto ds = std::make_shared<Caltech256Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
967 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
968 }
969
Caltech256Dataset(const std::vector<char> & dataset_dir,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)970 Caltech256Dataset::Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode,
971 const std::reference_wrapper<Sampler> &sampler,
972 const std::shared_ptr<DatasetCache> &cache) {
973 auto sampler_obj = sampler.get().Parse();
974 auto ds = std::make_shared<Caltech256Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
975 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
976 }
977
CelebADataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,bool decode,const std::set<std::vector<char>> & extensions,const std::shared_ptr<DatasetCache> & cache)978 CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
979 const std::shared_ptr<Sampler> &sampler, bool decode,
980 const std::set<std::vector<char>> &extensions,
981 const std::shared_ptr<DatasetCache> &cache) {
982 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
983 auto ds = std::make_shared<CelebANode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, decode,
984 SetCharToString(extensions), cache);
985 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
986 }
987
CelebADataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,bool decode,const std::set<std::vector<char>> & extensions,const std::shared_ptr<DatasetCache> & cache)988 CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
989 const Sampler *sampler, bool decode, const std::set<std::vector<char>> &extensions,
990 const std::shared_ptr<DatasetCache> &cache) {
991 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
992 auto ds = std::make_shared<CelebANode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, decode,
993 SetCharToString(extensions), cache);
994 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
995 }
996
CelebADataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,bool decode,const std::set<std::vector<char>> & extensions,const std::shared_ptr<DatasetCache> & cache)997 CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
998 const std::reference_wrapper<Sampler> &sampler, bool decode,
999 const std::set<std::vector<char>> &extensions,
1000 const std::shared_ptr<DatasetCache> &cache) {
1001 auto sampler_obj = sampler.get().Parse();
1002 auto ds = std::make_shared<CelebANode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, decode,
1003 SetCharToString(extensions), cache);
1004 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1005 }
1006
Cifar10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1007 Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1008 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1009 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1010 auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1011 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1012 }
1013
Cifar10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1014 Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1015 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1016 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1017 auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1018 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1019 }
1020
Cifar10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1021 Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1022 const std::reference_wrapper<Sampler> &sampler,
1023 const std::shared_ptr<DatasetCache> &cache) {
1024 auto sampler_obj = sampler.get().Parse();
1025 auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1026 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1027 }
1028
Cifar100Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1029 Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1030 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1031 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1032 auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1033 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1034 }
1035
Cifar100Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1036 Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1037 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1038 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1039 auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1040 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1041 }
1042
Cifar100Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1043 Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1044 const std::reference_wrapper<Sampler> &sampler,
1045 const std::shared_ptr<DatasetCache> &cache) {
1046 auto sampler_obj = sampler.get().Parse();
1047 auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1048 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1049 }
1050
CityscapesDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & quality_mode,const std::vector<char> & task,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1051 CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1052 const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
1053 const std::shared_ptr<Sampler> &sampler,
1054 const std::shared_ptr<DatasetCache> &cache) {
1055 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1056 auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
1057 CharToString(task), decode, sampler_obj, cache);
1058 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1059 }
1060
CityscapesDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & quality_mode,const std::vector<char> & task,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1061 CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1062 const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
1063 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1064 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1065 auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
1066 CharToString(task), decode, sampler_obj, cache);
1067 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1068 }
1069
CityscapesDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & quality_mode,const std::vector<char> & task,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1070 CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1071 const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
1072 const std::reference_wrapper<Sampler> &sampler,
1073 const std::shared_ptr<DatasetCache> &cache) {
1074 auto sampler_obj = sampler.get().Parse();
1075 auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
1076 CharToString(task), decode, sampler_obj, cache);
1077 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1078 }
1079
CLUEDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<char> & task,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1080 CLUEDataset::CLUEDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &task,
1081 const std::vector<char> &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
1082 int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
1083 auto ds = std::make_shared<CLUENode>(VectorCharToString(dataset_files), CharToString(task), CharToString(usage),
1084 num_samples, shuffle, num_shards, shard_id, cache);
1085 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1086 }
1087
CMUArcticDataset(const std::vector<char> & dataset_dir,const std::vector<char> & name,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1088 CMUArcticDataset::CMUArcticDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
1089 const std::shared_ptr<Sampler> &sampler,
1090 const std::shared_ptr<DatasetCache> &cache) {
1091 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1092 auto ds = std::make_shared<CMUArcticNode>(CharToString(dataset_dir), CharToString(name), sampler_obj, cache);
1093 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1094 }
1095
CMUArcticDataset(const std::vector<char> & dataset_dir,const std::vector<char> & name,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1096 CMUArcticDataset::CMUArcticDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
1097 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1098 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1099 auto ds = std::make_shared<CMUArcticNode>(CharToString(dataset_dir), CharToString(name), sampler_obj, cache);
1100 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1101 }
1102
CMUArcticDataset(const std::vector<char> & dataset_dir,const std::vector<char> & name,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1103 CMUArcticDataset::CMUArcticDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
1104 const std::reference_wrapper<Sampler> &sampler,
1105 const std::shared_ptr<DatasetCache> &cache) {
1106 auto sampler_obj = sampler.get().Parse();
1107 auto ds = std::make_shared<CMUArcticNode>(CharToString(dataset_dir), CharToString(name), sampler_obj, cache);
1108 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1109 }
1110
CocoDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,const std::vector<char> & task,const bool & decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache,const bool & extra_metadata)1111 CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1112 const std::vector<char> &task, const bool &decode, const std::shared_ptr<Sampler> &sampler,
1113 const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
1114 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1115 auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
1116 decode, sampler_obj, cache, extra_metadata);
1117 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1118 }
1119
CocoDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,const std::vector<char> & task,const bool & decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache,const bool & extra_metadata)1120 CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1121 const std::vector<char> &task, const bool &decode, const Sampler *sampler,
1122 const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
1123 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1124 auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
1125 decode, sampler_obj, cache, extra_metadata);
1126 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1127 }
1128
CocoDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,const std::vector<char> & task,const bool & decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache,const bool & extra_metadata)1129 CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1130 const std::vector<char> &task, const bool &decode,
1131 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache,
1132 const bool &extra_metadata) {
1133 auto sampler_obj = sampler.get().Parse();
1134 auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
1135 decode, sampler_obj, cache, extra_metadata);
1136 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1137 }
1138
CoNLL2000Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1139 CoNLL2000Dataset::CoNLL2000Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1140 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1141 const std::shared_ptr<DatasetCache> &cache) {
1142 auto ds = std::make_shared<CoNLL2000Node>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
1143 num_shards, shard_id, cache);
1144 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1145 }
1146
CSVDataset(const std::vector<std::vector<char>> & dataset_files,char field_delim,const std::vector<std::shared_ptr<CsvBase>> & column_defaults,const std::vector<std::vector<char>> & column_names,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1147 CSVDataset::CSVDataset(const std::vector<std::vector<char>> &dataset_files, char field_delim,
1148 const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
1149 const std::vector<std::vector<char>> &column_names, int64_t num_samples, ShuffleMode shuffle,
1150 int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
1151 auto ds =
1152 std::make_shared<CSVNode>(VectorCharToString(dataset_files), field_delim, column_defaults,
1153 VectorCharToString(column_names), num_samples, shuffle, num_shards, shard_id, cache);
1154 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1155 }
1156
DBpediaDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1157 DBpediaDataset::DBpediaDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1158 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1159 const std::shared_ptr<DatasetCache> &cache) {
1160 auto ds = std::make_shared<DBpediaNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
1161 num_shards, shard_id, cache);
1162 ir_node_ = std::static_pointer_cast<DBpediaNode>(ds);
1163 }
1164
DIV2KDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & downgrade,int32_t scale,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1165 DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1166 const std::vector<char> &downgrade, int32_t scale, bool decode,
1167 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1168 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1169 auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
1170 decode, sampler_obj, cache);
1171 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1172 }
1173
DIV2KDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & downgrade,int32_t scale,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1174 DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1175 const std::vector<char> &downgrade, int32_t scale, bool decode, const Sampler *sampler,
1176 const std::shared_ptr<DatasetCache> &cache) {
1177 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1178 auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
1179 decode, sampler_obj, cache);
1180 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1181 }
1182
DIV2KDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<char> & downgrade,int32_t scale,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1183 DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1184 const std::vector<char> &downgrade, int32_t scale, bool decode,
1185 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1186 auto sampler_obj = sampler.get().Parse();
1187 auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
1188 decode, sampler_obj, cache);
1189 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1190 }
1191
EMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & name,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1192 EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
1193 const std::vector<char> &usage, const std::shared_ptr<Sampler> &sampler,
1194 const std::shared_ptr<DatasetCache> &cache) {
1195 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1196 auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
1197 sampler_obj, cache);
1198 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1199 }
1200
EMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & name,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1201 EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
1202 const std::vector<char> &usage, const Sampler *sampler,
1203 const std::shared_ptr<DatasetCache> &cache) {
1204 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1205 auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
1206 sampler_obj, cache);
1207 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1208 }
1209
EMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & name,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1210 EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
1211 const std::vector<char> &usage, const std::reference_wrapper<Sampler> &sampler,
1212 const std::shared_ptr<DatasetCache> &cache) {
1213 auto sampler_obj = sampler.get().Parse();
1214 auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
1215 sampler_obj, cache);
1216 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1217 }
1218
EnWik9Dataset(const std::vector<char> & dataset_dir,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1219 EnWik9Dataset::EnWik9Dataset(const std::vector<char> &dataset_dir, int64_t num_samples, ShuffleMode shuffle,
1220 int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
1221 auto ds = std::make_shared<EnWik9Node>(CharToString(dataset_dir), num_samples, shuffle, num_shards, shard_id, cache);
1222 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1223 }
1224
FakeImageDataset(int32_t num_images,const std::vector<int32_t> & image_size,int32_t num_classes,int32_t base_seed,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1225 FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
1226 int32_t base_seed, const std::shared_ptr<Sampler> &sampler,
1227 const std::shared_ptr<DatasetCache> &cache) {
1228 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1229 auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache);
1230 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1231 }
1232
FakeImageDataset(int32_t num_images,const std::vector<int32_t> & image_size,int32_t num_classes,int32_t base_seed,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1233 FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
1234 int32_t base_seed, const Sampler *sampler,
1235 const std::shared_ptr<DatasetCache> &cache) {
1236 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1237 auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache);
1238 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1239 }
1240
FakeImageDataset(int32_t num_images,const std::vector<int32_t> & image_size,int32_t num_classes,int32_t base_seed,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1241 FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
1242 int32_t base_seed, const std::reference_wrapper<Sampler> &sampler,
1243 const std::shared_ptr<DatasetCache> &cache) {
1244 auto sampler_obj = sampler.get().Parse();
1245 auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache);
1246 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1247 }
1248
FashionMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1249 FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1250 const std::shared_ptr<Sampler> &sampler,
1251 const std::shared_ptr<DatasetCache> &cache) {
1252 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1253 auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1254 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1255 }
1256
FashionMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1257 FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1258 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1259 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1260 auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1261 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1262 }
1263
FashionMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1264 FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1265 const std::reference_wrapper<Sampler> &sampler,
1266 const std::shared_ptr<DatasetCache> &cache) {
1267 auto sampler_obj = sampler.get().Parse();
1268 auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1269 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1270 }
1271
FlickrDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1272 FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1273 bool decode, const std::shared_ptr<Sampler> &sampler,
1274 const std::shared_ptr<DatasetCache> &cache) {
1275 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1276 auto ds =
1277 std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
1278 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1279 }
1280
FlickrDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1281 FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1282 bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1283 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1284 auto ds =
1285 std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
1286 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1287 }
1288
FlickrDataset(const std::vector<char> & dataset_dir,const std::vector<char> & annotation_file,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1289 FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
1290 bool decode, const std::reference_wrapper<Sampler> &sampler,
1291 const std::shared_ptr<DatasetCache> &cache) {
1292 auto sampler_obj = sampler.get().Parse();
1293 auto ds =
1294 std::make_shared<FlickrNode>(CharToString(dataset_dir), CharToString(annotation_file), decode, sampler_obj, cache);
1295 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1296 }
1297
Food101Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1298 Food101Dataset::Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
1299 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1300 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1301 auto ds = std::make_shared<Food101Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
1302 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1303 }
1304
Food101Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1305 Food101Dataset::Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
1306 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1307 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1308 auto ds = std::make_shared<Food101Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
1309 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1310 }
1311
Food101Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1312 Food101Dataset::Food101Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
1313 const std::reference_wrapper<Sampler> &sampler,
1314 const std::shared_ptr<DatasetCache> &cache) {
1315 auto sampler_obj = sampler.get().Parse();
1316 auto ds = std::make_shared<Food101Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
1317 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1318 }
1319
GTZANDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1320 GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1321 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1322 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1323 auto ds = std::make_shared<GTZANNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1324 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1325 }
1326
GTZANDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1327 GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
1328 const std::shared_ptr<DatasetCache> &cache) {
1329 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1330 auto ds = std::make_shared<GTZANNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1331 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1332 }
1333
GTZANDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1334 GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1335 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1336 auto sampler_obj = sampler.get().Parse();
1337 auto ds = std::make_shared<GTZANNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1338 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1339 }
1340
ImageFolderDataset(const std::vector<char> & dataset_dir,bool decode,const std::shared_ptr<Sampler> & sampler,const std::set<std::vector<char>> & extensions,const std::map<std::vector<char>,int32_t> & class_indexing,const std::shared_ptr<DatasetCache> & cache)1341 ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
1342 const std::shared_ptr<Sampler> &sampler,
1343 const std::set<std::vector<char>> &extensions,
1344 const std::map<std::vector<char>, int32_t> &class_indexing,
1345 const std::shared_ptr<DatasetCache> &cache) {
1346 // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
1347 bool recursive = false;
1348
1349 // Create logical representation of ImageFolderDataset.
1350 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1351 auto ds = std::make_shared<ImageFolderNode>(CharToString(dataset_dir), decode, sampler_obj, recursive,
1352 SetCharToString(extensions), MapCharToString(class_indexing), cache);
1353 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1354 }
1355
ImageFolderDataset(const std::vector<char> & dataset_dir,bool decode,const Sampler * sampler,const std::set<std::vector<char>> & extensions,const std::map<std::vector<char>,int32_t> & class_indexing,const std::shared_ptr<DatasetCache> & cache)1356 ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
1357 const std::set<std::vector<char>> &extensions,
1358 const std::map<std::vector<char>, int32_t> &class_indexing,
1359 const std::shared_ptr<DatasetCache> &cache) {
1360 // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
1361 bool recursive = false;
1362
1363 // Create logical representation of ImageFolderDataset.
1364 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1365 auto ds = std::make_shared<ImageFolderNode>(CharToString(dataset_dir), decode, sampler_obj, recursive,
1366 SetCharToString(extensions), MapCharToString(class_indexing), cache);
1367 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1368 }
1369
ImageFolderDataset(const std::vector<char> & dataset_dir,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::set<std::vector<char>> & extensions,const std::map<std::vector<char>,int32_t> & class_indexing,const std::shared_ptr<DatasetCache> & cache)1370 ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
1371 const std::reference_wrapper<Sampler> &sampler,
1372 const std::set<std::vector<char>> &extensions,
1373 const std::map<std::vector<char>, int32_t> &class_indexing,
1374 const std::shared_ptr<DatasetCache> &cache) {
1375 // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
1376 bool recursive = false;
1377
1378 // Create logical representation of ImageFolderDataset.
1379 auto sampler_obj = sampler.get().Parse();
1380 auto ds = std::make_shared<ImageFolderNode>(CharToString(dataset_dir), decode, sampler_obj, recursive,
1381 SetCharToString(extensions), MapCharToString(class_indexing), cache);
1382 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1383 }
1384
IMDBDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1385 IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1386 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1387 // Create logical representation of IMDBDataset.
1388 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1389 auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1390 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1391 }
1392
IMDBDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1393 IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
1394 const std::shared_ptr<DatasetCache> &cache) {
1395 // Create logical representation of IMDBDataset.
1396 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1397 auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1398 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1399 }
1400
IMDBDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1401 IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1402 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1403 // Create logical representation of IMDBDataset.
1404 auto sampler_obj = sampler.get().Parse();
1405 auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1406 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1407 }
1408
IWSLT2016Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<std::vector<char>> & language_pair,const std::vector<char> & valid_set,const std::vector<char> & test_set,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1409 IWSLT2016Dataset::IWSLT2016Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1410 const std::vector<std::vector<char>> &language_pair,
1411 const std::vector<char> &valid_set, const std::vector<char> &test_set,
1412 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1413 const std::shared_ptr<DatasetCache> &cache) {
1414 auto ds = std::make_shared<IWSLT2016Node>(CharToString(dataset_dir), CharToString(usage),
1415 VectorCharToString(language_pair), CharToString(valid_set),
1416 CharToString(test_set), num_samples, shuffle, num_shards, shard_id, cache);
1417 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1418 }
1419
IWSLT2017Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<std::vector<char>> & language_pair,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1420 IWSLT2017Dataset::IWSLT2017Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1421 const std::vector<std::vector<char>> &language_pair, int64_t num_samples,
1422 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1423 const std::shared_ptr<DatasetCache> &cache) {
1424 auto ds =
1425 std::make_shared<IWSLT2017Node>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(language_pair),
1426 num_samples, shuffle, num_shards, shard_id, cache);
1427 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1428 }
1429
KITTIDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1430 KITTIDataset::KITTIDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
1431 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1432 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1433 auto ds = std::make_shared<KITTINode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
1434 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1435 }
1436
KITTIDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1437 KITTIDataset::KITTIDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
1438 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1439 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1440 auto ds = std::make_shared<KITTINode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
1441 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1442 }
1443
KITTIDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1444 KITTIDataset::KITTIDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
1445 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1446 auto sampler_obj = sampler.get().Parse();
1447 auto ds = std::make_shared<KITTINode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
1448 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1449 }
1450
KMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1451 KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1452 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1453 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1454 auto ds = std::make_shared<KMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1455 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1456 }
1457
KMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1458 KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1459 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1460 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1461 auto ds = std::make_shared<KMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1462 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1463 }
1464
KMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1465 KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1466 const std::reference_wrapper<Sampler> &sampler,
1467 const std::shared_ptr<DatasetCache> &cache) {
1468 auto sampler_obj = sampler.get().Parse();
1469 auto ds = std::make_shared<KMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1470 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1471 }
1472
LFWDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::vector<char> & image_set,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1473 LFWDataset::LFWDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1474 const std::vector<char> &usage, const std::vector<char> &image_set, bool decode,
1475 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1476 // Create logical representation of LFWDataset.
1477 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1478 auto ds = std::make_shared<LFWNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1479 CharToString(image_set), decode, sampler_obj, cache);
1480 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1481 }
1482
LFWDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::vector<char> & image_set,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1483 LFWDataset::LFWDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1484 const std::vector<char> &usage, const std::vector<char> &image_set, bool decode,
1485 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1486 // Create logical representation of LFWDataset.
1487 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1488 auto ds = std::make_shared<LFWNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1489 CharToString(image_set), decode, sampler_obj, cache);
1490 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1491 }
1492
LFWDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::vector<char> & image_set,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1493 LFWDataset::LFWDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1494 const std::vector<char> &usage, const std::vector<char> &image_set, bool decode,
1495 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1496 // Create logical representation of LFWDataset.
1497 auto sampler_obj = sampler.get().Parse();
1498 auto ds = std::make_shared<LFWNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1499 CharToString(image_set), decode, sampler_obj, cache);
1500 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1501 }
1502
LibriTTSDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1503 LibriTTSDataset::LibriTTSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1504 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1505 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1506 auto ds = std::make_shared<LibriTTSNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1507 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1508 }
1509
LibriTTSDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1510 LibriTTSDataset::LibriTTSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1511 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1512 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1513 auto ds = std::make_shared<LibriTTSNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1514 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1515 }
1516
LibriTTSDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1517 LibriTTSDataset::LibriTTSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1518 const std::reference_wrapper<Sampler> &sampler,
1519 const std::shared_ptr<DatasetCache> &cache) {
1520 auto sampler_obj = sampler.get().Parse();
1521 auto ds = std::make_shared<LibriTTSNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1522 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1523 }
1524
LJSpeechDataset(const std::vector<char> & dataset_dir,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1525 LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
1526 const std::shared_ptr<DatasetCache> &cache) {
1527 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1528 auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache);
1529 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1530 }
1531
LJSpeechDataset(const std::vector<char> & dataset_dir,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1532 LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
1533 const std::shared_ptr<DatasetCache> &cache) {
1534 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1535 auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache);
1536 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1537 }
1538
LJSpeechDataset(const std::vector<char> & dataset_dir,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1539 LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> &sampler,
1540 const std::shared_ptr<DatasetCache> &cache) {
1541 auto sampler_obj = sampler.get().Parse();
1542 auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache);
1543 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1544 }
1545
LSUNDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<std::vector<char>> & classes,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1546 LSUNDataset::LSUNDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1547 const std::vector<std::vector<char>> &classes, bool decode,
1548 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1549 // Create logical representation of LSUNDataset.
1550 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1551 auto ds = std::make_shared<LSUNNode>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(classes),
1552 decode, sampler_obj, cache);
1553 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1554 }
1555
LSUNDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<std::vector<char>> & classes,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1556 LSUNDataset::LSUNDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1557 const std::vector<std::vector<char>> &classes, bool decode, const Sampler *sampler,
1558 const std::shared_ptr<DatasetCache> &cache) {
1559 // Create logical representation of LSUNDataset.
1560 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1561 auto ds = std::make_shared<LSUNNode>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(classes),
1562 decode, sampler_obj, cache);
1563 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1564 }
1565
LSUNDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<std::vector<char>> & classes,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1566 LSUNDataset::LSUNDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1567 const std::vector<std::vector<char>> &classes, bool decode,
1568 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1569 // Create logical representation of LSUNDataset.
1570 auto sampler_obj = sampler.get().Parse();
1571 auto ds = std::make_shared<LSUNNode>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(classes),
1572 decode, sampler_obj, cache);
1573 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1574 }
1575
ManifestDataset(const std::vector<char> & dataset_file,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<DatasetCache> & cache)1576 ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
1577 const std::shared_ptr<Sampler> &sampler,
1578 const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
1579 const std::shared_ptr<DatasetCache> &cache) {
1580 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1581 auto ds = std::make_shared<ManifestNode>(CharToString(dataset_file), CharToString(usage), sampler_obj,
1582 MapCharToString(class_indexing), decode, cache);
1583 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1584 }
1585
ManifestDataset(const std::vector<char> & dataset_file,const std::vector<char> & usage,const Sampler * sampler,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<DatasetCache> & cache)1586 ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
1587 const Sampler *sampler, const std::map<std::vector<char>, int32_t> &class_indexing,
1588 bool decode, const std::shared_ptr<DatasetCache> &cache) {
1589 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1590 auto ds = std::make_shared<ManifestNode>(CharToString(dataset_file), CharToString(usage), sampler_obj,
1591 MapCharToString(class_indexing), decode, cache);
1592 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1593 }
1594
ManifestDataset(const std::vector<char> & dataset_file,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<DatasetCache> & cache)1595 ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
1596 const std::reference_wrapper<Sampler> &sampler,
1597 const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
1598 const std::shared_ptr<DatasetCache> &cache) {
1599 auto sampler_obj = sampler.get().Parse();
1600 auto ds = std::make_shared<ManifestNode>(CharToString(dataset_file), CharToString(usage), sampler_obj,
1601 MapCharToString(class_indexing), decode, cache);
1602 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1603 }
1604
MindDataDataset(const std::vector<char> & dataset_file,const std::vector<std::vector<char>> & columns_list,const std::shared_ptr<Sampler> & sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1605 MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
1606 const std::vector<std::vector<char>> &columns_list,
1607 const std::shared_ptr<Sampler> &sampler, const nlohmann::json *padded_sample,
1608 int64_t num_padded, ShuffleMode shuffle_mode,
1609 const std::shared_ptr<DatasetCache> &cache) {
1610 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1611 nlohmann::json sample = nullptr;
1612 if (padded_sample) {
1613 sample = *padded_sample;
1614 }
1615 auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
1616 sample, num_padded, shuffle_mode, cache);
1617 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1618 }
1619
MindDataDataset(const std::vector<char> & dataset_file,const std::vector<std::vector<char>> & columns_list,const Sampler * sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1620 MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
1621 const std::vector<std::vector<char>> &columns_list, const Sampler *sampler,
1622 const nlohmann::json *padded_sample, int64_t num_padded, ShuffleMode shuffle_mode,
1623 const std::shared_ptr<DatasetCache> &cache) {
1624 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1625 nlohmann::json sample = nullptr;
1626 if (padded_sample) {
1627 sample = *padded_sample;
1628 }
1629 auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
1630 sample, num_padded, shuffle_mode, cache);
1631 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1632 }
1633
MindDataDataset(const std::vector<char> & dataset_file,const std::vector<std::vector<char>> & columns_list,const std::reference_wrapper<Sampler> & sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1634 MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
1635 const std::vector<std::vector<char>> &columns_list,
1636 const std::reference_wrapper<Sampler> &sampler, const nlohmann::json *padded_sample,
1637 int64_t num_padded, ShuffleMode shuffle_mode,
1638 const std::shared_ptr<DatasetCache> &cache) {
1639 auto sampler_obj = sampler.get().Parse();
1640 nlohmann::json sample = nullptr;
1641 if (padded_sample) {
1642 sample = *padded_sample;
1643 }
1644
1645 auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
1646 sample, num_padded, shuffle_mode, cache);
1647 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1648 }
1649
MindDataDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<std::vector<char>> & columns_list,const std::shared_ptr<Sampler> & sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1650 MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
1651 const std::vector<std::vector<char>> &columns_list,
1652 const std::shared_ptr<Sampler> &sampler, const nlohmann::json *padded_sample,
1653 int64_t num_padded, ShuffleMode shuffle_mode,
1654 const std::shared_ptr<DatasetCache> &cache) {
1655 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1656 nlohmann::json sample = nullptr;
1657 if (padded_sample) {
1658 sample = *padded_sample;
1659 }
1660
1661 auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
1662 sampler_obj, sample, num_padded, shuffle_mode, cache);
1663 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1664 }
1665
MindDataDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<std::vector<char>> & columns_list,const Sampler * sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1666 MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
1667 const std::vector<std::vector<char>> &columns_list, const Sampler *sampler,
1668 const nlohmann::json *padded_sample, int64_t num_padded, ShuffleMode shuffle_mode,
1669 const std::shared_ptr<DatasetCache> &cache) {
1670 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1671 nlohmann::json sample = nullptr;
1672 if (padded_sample) {
1673 sample = *padded_sample;
1674 }
1675
1676 auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
1677 sampler_obj, sample, num_padded, shuffle_mode, cache);
1678 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1679 }
1680
MindDataDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<std::vector<char>> & columns_list,const std::reference_wrapper<Sampler> & sampler,const nlohmann::json * padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,const std::shared_ptr<DatasetCache> & cache)1681 MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
1682 const std::vector<std::vector<char>> &columns_list,
1683 const std::reference_wrapper<Sampler> &sampler, const nlohmann::json *padded_sample,
1684 int64_t num_padded, ShuffleMode shuffle_mode,
1685 const std::shared_ptr<DatasetCache> &cache) {
1686 auto sampler_obj = sampler.get().Parse();
1687 nlohmann::json sample = nullptr;
1688 if (padded_sample) {
1689 sample = *padded_sample;
1690 }
1691 auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
1692 sampler_obj, sample, num_padded, shuffle_mode, cache);
1693 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1694 }
1695 #endif
1696
MnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1697 MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1698 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1699 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1700 auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1701 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1702 }
1703
MnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1704 MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
1705 const std::shared_ptr<DatasetCache> &cache) {
1706 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1707 auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1708 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1709 }
1710
MnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1711 MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1712 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1713 auto sampler_obj = sampler.get().Parse();
1714 auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1715 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1716 }
1717
1718 #ifndef ENABLE_ANDROID
Multi30kDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::vector<std::vector<char>> & language_pair,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1719 Multi30kDataset::Multi30kDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1720 const std::vector<std::vector<char>> &language_pair, int64_t num_samples,
1721 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1722 const std::shared_ptr<DatasetCache> &cache) {
1723 auto ds =
1724 std::make_shared<Multi30kNode>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(language_pair),
1725 num_samples, shuffle, num_shards, shard_id, cache);
1726 ir_node_ = std::static_pointer_cast<Multi30kNode>(ds);
1727 }
1728
OmniglotDataset(const std::vector<char> & dataset_dir,bool background,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1729 OmniglotDataset::OmniglotDataset(const std::vector<char> &dataset_dir, bool background, bool decode,
1730 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1731 // Create logical representation of OmniglotDataset.
1732 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1733 auto ds = std::make_shared<OmniglotNode>(CharToString(dataset_dir), background, decode, sampler_obj, cache);
1734 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1735 }
1736
OmniglotDataset(const std::vector<char> & dataset_dir,bool background,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1737 OmniglotDataset::OmniglotDataset(const std::vector<char> &dataset_dir, bool background, bool decode,
1738 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1739 // Create logical representation of OmniglotDataset.
1740 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1741 auto ds = std::make_shared<OmniglotNode>(CharToString(dataset_dir), background, decode, sampler_obj, cache);
1742 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1743 }
1744
OmniglotDataset(const std::vector<char> & dataset_dir,bool background,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1745 OmniglotDataset::OmniglotDataset(const std::vector<char> &dataset_dir, bool background, bool decode,
1746 const std::reference_wrapper<Sampler> &sampler,
1747 const std::shared_ptr<DatasetCache> &cache) {
1748 // Create logical representation of OmniglotDataset.
1749 auto sampler_obj = sampler.get().Parse();
1750 auto ds = std::make_shared<OmniglotNode>(CharToString(dataset_dir), background, decode, sampler_obj, cache);
1751 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1752 }
1753
PennTreebankDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1754 PennTreebankDataset::PennTreebankDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1755 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1756 const std::shared_ptr<DatasetCache> &cache) {
1757 auto ds = std::make_shared<PennTreebankNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
1758 num_shards, shard_id, cache);
1759 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1760 }
1761
PhotoTourDataset(const std::vector<char> & dataset_dir,const std::vector<char> & name,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1762 PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
1763 const std::vector<char> &usage, const std::shared_ptr<Sampler> &sampler,
1764 const std::shared_ptr<DatasetCache> &cache) {
1765 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1766 auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
1767 sampler_obj, cache);
1768 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1769 }
1770
PhotoTourDataset(const std::vector<char> & dataset_dir,const std::vector<char> & name,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1771 PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
1772 const std::vector<char> &usage, const Sampler *sampler,
1773 const std::shared_ptr<DatasetCache> &cache) {
1774 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1775 auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
1776 sampler_obj, cache);
1777 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1778 }
1779
PhotoTourDataset(const std::vector<char> & dataset_dir,const std::vector<char> & name,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1780 PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
1781 const std::vector<char> &usage, const std::reference_wrapper<Sampler> &sampler,
1782 const std::shared_ptr<DatasetCache> &cache) {
1783 auto sampler_obj = sampler.get().Parse();
1784 auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
1785 sampler_obj, cache);
1786 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1787 }
1788
Places365Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const bool small,const bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1789 Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1790 const bool small, const bool decode, const std::shared_ptr<Sampler> &sampler,
1791 const std::shared_ptr<DatasetCache> &cache) {
1792 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1793 auto ds =
1794 std::make_shared<Places365Node>(CharToString(dataset_dir), CharToString(usage), small, decode, sampler_obj, cache);
1795 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1796 }
1797
Places365Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const bool small,const bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1798 Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1799 const bool small, const bool decode, const Sampler *sampler,
1800 const std::shared_ptr<DatasetCache> &cache) {
1801 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1802 auto ds =
1803 std::make_shared<Places365Node>(CharToString(dataset_dir), CharToString(usage), small, decode, sampler_obj, cache);
1804 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1805 }
1806
Places365Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const bool small,const bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1807 Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1808 const bool small, const bool decode, const std::reference_wrapper<Sampler> &sampler,
1809 const std::shared_ptr<DatasetCache> &cache) {
1810 auto sampler_obj = sampler.get().Parse();
1811 auto ds =
1812 std::make_shared<Places365Node>(CharToString(dataset_dir), CharToString(usage), small, decode, sampler_obj, cache);
1813 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1814 }
1815
QMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool compat,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1816 QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
1817 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1818 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1819 auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
1820 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1821 }
1822
QMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool compat,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1823 QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
1824 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1825 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1826 auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
1827 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1828 }
1829
QMnistDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool compat,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1830 QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
1831 const std::reference_wrapper<Sampler> &sampler,
1832 const std::shared_ptr<DatasetCache> &cache) {
1833 auto sampler_obj = sampler.get().Parse();
1834 auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
1835 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1836 }
1837
SemeionDataset(const std::vector<char> & dataset_dir,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1838 SemeionDataset::SemeionDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
1839 const std::shared_ptr<DatasetCache> &cache) {
1840 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1841 auto ds = std::make_shared<SemeionNode>(CharToString(dataset_dir), sampler_obj, cache);
1842 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1843 }
1844
SemeionDataset(const std::vector<char> & dataset_dir,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1845 SemeionDataset::SemeionDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
1846 const std::shared_ptr<DatasetCache> &cache) {
1847 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1848 auto ds = std::make_shared<SemeionNode>(CharToString(dataset_dir), sampler_obj, cache);
1849 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1850 }
1851
SemeionDataset(const std::vector<char> & dataset_dir,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1852 SemeionDataset::SemeionDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> &sampler,
1853 const std::shared_ptr<DatasetCache> &cache) {
1854 auto sampler_obj = sampler.get().Parse();
1855 auto ds = std::make_shared<SemeionNode>(CharToString(dataset_dir), sampler_obj, cache);
1856 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1857 }
1858
SQuADDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1859 SQuADDataset::SQuADDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
1860 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1861 const std::shared_ptr<DatasetCache> &cache) {
1862 auto ds = std::make_shared<SQuADNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
1863 num_shards, shard_id, cache);
1864 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1865 }
1866
SST2Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1867 SST2Dataset::SST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
1868 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1869 const std::shared_ptr<DatasetCache> &cache) {
1870 auto ds = std::make_shared<SST2Node>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle, num_shards,
1871 shard_id, cache);
1872 ir_node_ = std::static_pointer_cast<SST2Node>(ds);
1873 }
1874
TedliumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & release,const std::vector<char> & usage,const std::vector<char> & extensions,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1875 TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
1876 const std::vector<char> &usage, const std::vector<char> &extensions,
1877 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1878 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1879 auto ds = std::make_shared<TedliumNode>(CharToString(dataset_dir), CharToString(release), CharToString(usage),
1880 CharToString(extensions), sampler_obj, cache);
1881 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1882 }
1883
TedliumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & release,const std::vector<char> & usage,const std::vector<char> & extensions,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1884 TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
1885 const std::vector<char> &usage, const std::vector<char> &extensions,
1886 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
1887 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1888 auto ds = std::make_shared<TedliumNode>(CharToString(dataset_dir), CharToString(release), CharToString(usage),
1889 CharToString(extensions), sampler_obj, cache);
1890 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1891 }
1892
TedliumDataset(const std::vector<char> & dataset_dir,const std::vector<char> & release,const std::vector<char> & usage,const std::vector<char> & extensions,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1893 TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
1894 const std::vector<char> &usage, const std::vector<char> &extensions,
1895 const std::reference_wrapper<Sampler> &sampler,
1896 const std::shared_ptr<DatasetCache> &cache) {
1897 auto sampler_obj = sampler.get().Parse();
1898 auto ds = std::make_shared<TedliumNode>(CharToString(dataset_dir), CharToString(release), CharToString(usage),
1899 CharToString(extensions), sampler_obj, cache);
1900 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1901 }
1902
STL10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1903 STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1904 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1905 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1906 auto ds = std::make_shared<STL10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1907 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1908 }
1909
STL10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1910 STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
1911 const std::shared_ptr<DatasetCache> &cache) {
1912 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1913 auto ds = std::make_shared<STL10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1914 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1915 }
1916
STL10Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1917 STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1918 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
1919 auto sampler_obj = sampler.get().Parse();
1920 auto ds = std::make_shared<STL10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
1921 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1922 }
1923
SUN397Dataset(const std::vector<char> & dataset_dir,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1924 SUN397Dataset::SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
1925 const std::shared_ptr<DatasetCache> &cache) {
1926 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1927 auto ds = std::make_shared<SUN397Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
1928 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1929 }
1930
SUN397Dataset(const std::vector<char> & dataset_dir,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)1931 SUN397Dataset::SUN397Dataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
1932 const std::shared_ptr<DatasetCache> &cache) {
1933 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1934 auto ds = std::make_shared<SUN397Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
1935 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1936 }
1937
SUN397Dataset(const std::vector<char> & dataset_dir,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)1938 SUN397Dataset::SUN397Dataset(const std::vector<char> &dataset_dir, bool decode,
1939 const std::reference_wrapper<Sampler> &sampler,
1940 const std::shared_ptr<DatasetCache> &cache) {
1941 auto sampler_obj = sampler.get().Parse();
1942 auto ds = std::make_shared<SUN397Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
1943 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1944 }
1945
TextFileDataset(const std::vector<std::vector<char>> & dataset_files,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1946 TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples,
1947 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1948 const std::shared_ptr<DatasetCache> &cache) {
1949 auto ds = std::make_shared<TextFileNode>(VectorCharToString(dataset_files), num_samples, shuffle, num_shards,
1950 shard_id, cache);
1951 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1952 }
1953
USPSDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1954 USPSDataset::USPSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
1955 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1956 const std::shared_ptr<DatasetCache> &cache) {
1957 auto ds = std::make_shared<USPSNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle, num_shards,
1958 shard_id, cache);
1959 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1960 }
1961
VOCDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache,bool extra_metadata)1962 VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1963 const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
1964 bool decode, const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache,
1965 bool extra_metadata) {
1966 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1967 auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1968 MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
1969 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1970 }
1971
VOCDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache,bool extra_metadata)1972 VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1973 const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
1974 bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache,
1975 bool extra_metadata) {
1976 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
1977 auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1978 MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
1979 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1980 }
1981
VOCDataset(const std::vector<char> & dataset_dir,const std::vector<char> & task,const std::vector<char> & usage,const std::map<std::vector<char>,int32_t> & class_indexing,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache,bool extra_metadata)1982 VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
1983 const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
1984 bool decode, const std::reference_wrapper<Sampler> &sampler,
1985 const std::shared_ptr<DatasetCache> &cache, bool extra_metadata) {
1986 auto sampler_obj = sampler.get().Parse();
1987 auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
1988 MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
1989 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1990 }
1991
WikiTextDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)1992 WikiTextDataset::WikiTextDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
1993 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
1994 const std::shared_ptr<DatasetCache> &cache) {
1995 auto ds = std::make_shared<WikiTextNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
1996 num_shards, shard_id, cache);
1997 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
1998 }
1999
RandomDataDataset(const int32_t & total_rows,std::shared_ptr<SchemaObj> schema,const std::vector<std::vector<char>> & columns_list,const std::shared_ptr<DatasetCache> & cache)2000 RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
2001 const std::vector<std::vector<char>> &columns_list,
2002 const std::shared_ptr<DatasetCache> &cache) {
2003 auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), VectorCharToString(columns_list), cache);
2004 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2005 }
2006
RandomDataDataset(const int32_t & total_rows,const std::vector<char> & schema_path,const std::vector<std::vector<char>> & columns_list,const std::shared_ptr<DatasetCache> & cache)2007 RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vector<char> &schema_path,
2008 const std::vector<std::vector<char>> &columns_list,
2009 const std::shared_ptr<DatasetCache> &cache) {
2010 auto ds =
2011 std::make_shared<RandomNode>(total_rows, CharToString(schema_path), VectorCharToString(columns_list), cache);
2012 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2013 }
2014
RenderedSST2Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2015 RenderedSST2Dataset::RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
2016 bool decode, const std::shared_ptr<Sampler> &sampler,
2017 const std::shared_ptr<DatasetCache> &cache) {
2018 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2019 auto ds =
2020 std::make_shared<RenderedSST2Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
2021 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2022 }
2023
RenderedSST2Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)2024 RenderedSST2Dataset::RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
2025 bool decode, const Sampler *sampler,
2026 const std::shared_ptr<DatasetCache> &cache) {
2027 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2028 auto ds =
2029 std::make_shared<RenderedSST2Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
2030 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2031 }
2032
RenderedSST2Dataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2033 RenderedSST2Dataset::RenderedSST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
2034 bool decode, const std::reference_wrapper<Sampler> &sampler,
2035 const std::shared_ptr<DatasetCache> &cache) {
2036 auto sampler_obj = sampler.get().Parse();
2037 auto ds =
2038 std::make_shared<RenderedSST2Node>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
2039 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2040 }
2041
SBUDataset(const std::vector<char> & dataset_dir,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2042 SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
2043 const std::shared_ptr<DatasetCache> &cache) {
2044 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2045 auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
2046 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2047 }
2048
SBUDataset(const std::vector<char> & dataset_dir,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)2049 SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
2050 const std::shared_ptr<DatasetCache> &cache) {
2051 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2052 auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
2053 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2054 }
2055
SBUDataset(const std::vector<char> & dataset_dir,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2056 SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode,
2057 const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
2058 auto sampler_obj = sampler.get().Parse();
2059 auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
2060 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2061 }
2062
SogouNewsDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)2063 SogouNewsDataset::SogouNewsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
2064 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
2065 const std::shared_ptr<DatasetCache> &cache) {
2066 auto ds = std::make_shared<SogouNewsNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
2067 num_shards, shard_id, cache);
2068 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2069 }
2070
SpeechCommandsDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2071 SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
2072 const std::shared_ptr<Sampler> &sampler,
2073 const std::shared_ptr<DatasetCache> &cache) {
2074 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2075 auto ds = std::make_shared<SpeechCommandsNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
2076 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2077 }
2078
SpeechCommandsDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)2079 SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
2080 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
2081 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2082 auto ds = std::make_shared<SpeechCommandsNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
2083 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2084 }
2085
SpeechCommandsDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2086 SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
2087 const std::reference_wrapper<Sampler> &sampler,
2088 const std::shared_ptr<DatasetCache> &cache) {
2089 auto sampler_obj = sampler.get().Parse();
2090 auto ds = std::make_shared<SpeechCommandsNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
2091 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2092 }
2093
TFRecordDataset(const std::vector<std::vector<char>> & dataset_files,const std::vector<char> & schema,const std::vector<std::vector<char>> & columns_list,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,bool shard_equal_rows,const std::shared_ptr<DatasetCache> & cache,const std::vector<char> & compression_type)2094 TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
2095 const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
2096 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
2097 const std::shared_ptr<DatasetCache> &cache,
2098 const std::vector<char> &compression_type) {
2099 auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), CharToString(schema),
2100 VectorCharToString(columns_list), num_samples, shuffle, num_shards, shard_id,
2101 shard_equal_rows, cache, CharToString(compression_type));
2102 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2103 }
2104
TFRecordDataset(const std::vector<std::vector<char>> & dataset_files,const std::shared_ptr<SchemaObj> & schema,const std::vector<std::vector<char>> & columns_list,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,bool shard_equal_rows,const std::shared_ptr<DatasetCache> & cache,const std::vector<char> & compression_type)2105 TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files,
2106 const std::shared_ptr<SchemaObj> &schema,
2107 const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
2108 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
2109 const std::shared_ptr<DatasetCache> &cache,
2110 const std::vector<char> &compression_type) {
2111 auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), schema, VectorCharToString(columns_list),
2112 num_samples, shuffle, num_shards, shard_id, shard_equal_rows, cache,
2113 CharToString(compression_type));
2114 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2115 }
2116
UDPOSDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)2117 UDPOSDataset::UDPOSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
2118 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
2119 const std::shared_ptr<DatasetCache> &cache) {
2120 auto ds = std::make_shared<UDPOSNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
2121 num_shards, shard_id, cache);
2122 ir_node_ = std::static_pointer_cast<UDPOSNode>(ds);
2123 }
2124
WIDERFaceDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2125 WIDERFaceDataset::WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
2126 const std::shared_ptr<Sampler> &sampler,
2127 const std::shared_ptr<DatasetCache> &cache) {
2128 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2129 auto ds = std::make_shared<WIDERFaceNode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
2130 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2131 }
2132
WIDERFaceDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)2133 WIDERFaceDataset::WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
2134 const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
2135 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2136 auto ds = std::make_shared<WIDERFaceNode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
2137 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2138 }
2139
WIDERFaceDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,bool decode,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2140 WIDERFaceDataset::WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
2141 const std::reference_wrapper<Sampler> &sampler,
2142 const std::shared_ptr<DatasetCache> &cache) {
2143 auto sampler_obj = sampler.get().Parse();
2144 auto ds = std::make_shared<WIDERFaceNode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
2145 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2146 }
2147
YahooAnswersDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)2148 YahooAnswersDataset::YahooAnswersDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
2149 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
2150 const std::shared_ptr<DatasetCache> &cache) {
2151 auto ds = std::make_shared<YahooAnswersNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
2152 num_shards, shard_id, cache);
2153 ir_node_ = std::static_pointer_cast<YahooAnswersNode>(ds);
2154 }
2155
YelpReviewDataset(const std::vector<char> & dataset_dir,const std::vector<char> & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)2156 YelpReviewDataset::YelpReviewDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
2157 int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
2158 const std::shared_ptr<DatasetCache> &cache) {
2159 auto ds = std::make_shared<YelpReviewNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
2160 num_shards, shard_id, cache);
2161 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2162 }
2163
YesNoDataset(const std::vector<char> & dataset_dir,const std::shared_ptr<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2164 YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
2165 const std::shared_ptr<DatasetCache> &cache) {
2166 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2167 auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache);
2168 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2169 }
2170
YesNoDataset(const std::vector<char> & dataset_dir,const Sampler * sampler,const std::shared_ptr<DatasetCache> & cache)2171 YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
2172 const std::shared_ptr<DatasetCache> &cache) {
2173 auto sampler_obj = sampler ? sampler->Parse() : nullptr;
2174 auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache);
2175 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2176 }
2177
YesNoDataset(const std::vector<char> & dataset_dir,const std::reference_wrapper<Sampler> & sampler,const std::shared_ptr<DatasetCache> & cache)2178 YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> &sampler,
2179 const std::shared_ptr<DatasetCache> &cache) {
2180 auto sampler_obj = sampler.get().Parse();
2181 auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache);
2182 ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
2183 }
2184 #endif
2185 } // namespace dataset
2186 } // namespace mindspore
2187