1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pybind11/pybind11.h"
17
18 #include "minddata/dataset/api/python/pybind_conversion.h"
19 #include "minddata/dataset/api/python/pybind_register.h"
20 #include "minddata/dataset/include/dataset/constants.h"
21 #include "minddata/dataset/include/dataset/datasets.h"
22
23 #include "minddata/dataset/core/config_manager.h"
24 #include "minddata/dataset/core/data_type.h"
25 #include "minddata/dataset/util/path.h"
26
27 // IR leaf nodes
28 #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
29 #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
30 #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
31 #include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
32 #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
33 #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
34 #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
35 #include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
36 #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
37 #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
38 #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
39 #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
40 #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
41 #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
42
43 // IR leaf nodes disabled for android
44 #ifndef ENABLE_ANDROID
45 #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
46 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
47 #include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
48 #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
49 #include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
50 #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
51 #endif
52
53 namespace mindspore {
54 namespace dataset {
55
56 // PYBIND FOR LEAF NODES
57 // (In alphabetical order)
58
__anon9c4d722b0102(const py::module *m) 59 PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
60 (void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
61 "to create a CelebANode")
62 .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode,
63 py::list extensions) {
64 auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
65 toStringSet(extensions), nullptr);
66 THROW_IF_ERROR(celebA->ValidateParams());
67 return celebA;
68 }));
69 }));
70
__anon9c4d722b0302(const py::module *m) 71 PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
72 (void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
73 "to create a Cifar10Node")
74 .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
75 auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
76 THROW_IF_ERROR(cifar10->ValidateParams());
77 return cifar10;
78 }));
79 }));
80
__anon9c4d722b0502(const py::module *m) 81 PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
82 (void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
83 "to create a Cifar100Node")
84 .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
85 auto cifar100 =
86 std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
87 THROW_IF_ERROR(cifar100->ValidateParams());
88 return cifar100;
89 }));
90 }));
91
__anon9c4d722b0702(const py::module *m) 92 PYBIND_REGISTER(CityscapesNode, 2, ([](const py::module *m) {
93 (void)py::class_<CityscapesNode, DatasetNode, std::shared_ptr<CityscapesNode>>(
94 *m, "CityscapesNode", "to create a CityscapesNode")
95 .def(py::init([](std::string dataset_dir, std::string usage, std::string quality_mode,
96 std::string task, bool decode, const py::handle &sampler) {
97 auto cityscapes = std::make_shared<CityscapesNode>(dataset_dir, usage, quality_mode, task, decode,
98 toSamplerObj(sampler), nullptr);
99 THROW_IF_ERROR(cityscapes->ValidateParams());
100 return cityscapes;
101 }));
102 }));
103
__anon9c4d722b0902(const py::module *m) 104 PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) {
105 (void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode",
106 "to create a CLUENode")
107 .def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples,
108 int32_t shuffle, int32_t num_shards, int32_t shard_id) {
109 std::shared_ptr<CLUENode> clue_node =
110 std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples,
111 toShuffleMode(shuffle), num_shards, shard_id, nullptr);
112 THROW_IF_ERROR(clue_node->ValidateParams());
113 return clue_node;
114 }));
115 }));
116
__anon9c4d722b0b02(const py::module *m) 117 PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
118 (void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
119 "to create a CocoNode")
120 .def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
121 bool decode, const py::handle &sampler, bool extra_metadata) {
122 std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
123 dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr, extra_metadata);
124 THROW_IF_ERROR(coco->ValidateParams());
125 return coco;
126 }));
127 }));
128
__anon9c4d722b0d02(const py::module *m) 129 PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
130 (void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
131 .def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
132 std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
133 int32_t num_shards, int32_t shard_id) {
134 auto csv =
135 std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), column_names,
136 num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
137 THROW_IF_ERROR(csv->ValidateParams());
138 return csv;
139 }));
140 }));
141
__anon9c4d722b0f02(const py::module *m) 142 PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
143 (void)py::class_<DIV2KNode, DatasetNode, std::shared_ptr<DIV2KNode>>(*m, "DIV2KNode",
144 "to create a DIV2KNode")
145 .def(py::init([](std::string dataset_dir, std::string usage, std::string downgrade, int32_t scale,
146 bool decode, py::handle sampler) {
147 auto div2k = std::make_shared<DIV2KNode>(dataset_dir, usage, downgrade, scale, decode,
148 toSamplerObj(sampler), nullptr);
149 THROW_IF_ERROR(div2k->ValidateParams());
150 return div2k;
151 }));
152 }));
153
154 PYBIND_REGISTER(
__anon9c4d722b1102(const py::module *m) 155 FlickrNode, 2, ([](const py::module *m) {
156 (void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")
157 .def(py::init([](std::string dataset_dir, std::string annotation_file, bool decode, const py::handle &sampler) {
158 auto flickr =
159 std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, toSamplerObj(sampler), nullptr);
160 THROW_IF_ERROR(flickr->ValidateParams());
161 return flickr;
162 }));
163 }));
164
__anon9c4d722b1302(const py::module *m) 165 PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
166 (void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
167 *m, "GeneratorNode", "to create a GeneratorNode")
168 .def(py::init([](py::function generator_function, const std::vector<std::string> &column_names,
169 const std::vector<DataType> &column_types, int64_t dataset_len, py::handle sampler,
170 uint32_t num_parallel_workers) {
171 auto gen =
172 std::make_shared<GeneratorNode>(generator_function, column_names, column_types, dataset_len,
173 toSamplerObj(sampler), num_parallel_workers);
174 THROW_IF_ERROR(gen->ValidateParams());
175 return gen;
176 }))
177 .def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema,
178 int64_t dataset_len, py::handle sampler, uint32_t num_parallel_workers) {
179 auto gen = std::make_shared<GeneratorNode>(generator_function, schema, dataset_len,
180 toSamplerObj(sampler), num_parallel_workers);
181 THROW_IF_ERROR(gen->ValidateParams());
182 return gen;
183 }));
184 }));
185
__anon9c4d722b1602(const py::module *m) 186 PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
187 (void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
188 *m, "ImageFolderNode", "to create an ImageFolderNode")
189 .def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions,
190 py::dict class_indexing) {
191 // Don't update recursive to true
192 bool recursive = false; // Will be removed in future PR
193 auto imagefolder = std::make_shared<ImageFolderNode>(dataset_dir, decode, toSamplerObj(sampler),
194 recursive, toStringSet(extensions),
195 toStringMap(class_indexing), nullptr);
196 THROW_IF_ERROR(imagefolder->ValidateParams());
197 return imagefolder;
198 }));
199 }));
200
__anon9c4d722b1802(const py::module *m) 201 PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
202 (void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
203 "to create a ManifestNode")
204 .def(py::init([](std::string dataset_file, std::string usage, py::handle sampler,
205 py::dict class_indexing, bool decode) {
206 auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
207 toStringMap(class_indexing), decode, nullptr);
208 THROW_IF_ERROR(manifest->ValidateParams());
209 return manifest;
210 }));
211 }));
212
__anon9c4d722b1a02(const py::module *m) 213 PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
214 (void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
215 "to create a MindDataNode")
216 .def(py::init([](std::string dataset_file, py::list columns_list, py::handle sampler,
217 const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
218 nlohmann::json padded_sample_json;
219 std::map<std::string, std::string> sample_bytes;
220 THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
221 auto minddata = std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list),
222 toSamplerObj(sampler, true), padded_sample_json,
223 num_padded, shuffle_mode, nullptr);
224 minddata->SetSampleBytes(&sample_bytes);
225 THROW_IF_ERROR(minddata->ValidateParams());
226 return minddata;
227 }))
228 .def(py::init([](py::list dataset_file, py::list columns_list, py::handle sampler,
229 const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
230 nlohmann::json padded_sample_json;
231 std::map<std::string, std::string> sample_bytes;
232 THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
233 auto minddata = std::make_shared<MindDataNode>(
234 toStringVector(dataset_file), toStringVector(columns_list), toSamplerObj(sampler, true),
235 padded_sample_json, num_padded, shuffle_mode, nullptr);
236 minddata->SetSampleBytes(&sample_bytes);
237 THROW_IF_ERROR(minddata->ValidateParams());
238 return minddata;
239 }));
240 }));
241
__anon9c4d722b1d02(const py::module *m) 242 PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
243 (void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
244 "to create an MnistNode")
245 .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
246 auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
247 THROW_IF_ERROR(mnist->ValidateParams());
248 return mnist;
249 }));
250 }));
251
__anon9c4d722b1f02(const py::module *m) 252 PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
253 (void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode",
254 "to create a RandomNode")
255 .def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list) {
256 auto random_node =
257 std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
258 THROW_IF_ERROR(random_node->ValidateParams());
259 return random_node;
260 }))
261 .def(py::init([](int32_t total_rows, std::string schema, py::list columns_list) {
262 auto random_node =
263 std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
264 THROW_IF_ERROR(random_node->ValidateParams());
265 return random_node;
266 }));
267 }));
268
__anon9c4d722b2202(const py::module *m) 269 PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
270 (void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
271 "to create an SBUNode")
272 .def(py::init([](std::string dataset_dir, bool decode, const py::handle &sampler) {
273 auto sbu = std::make_shared<SBUNode>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
274 THROW_IF_ERROR(sbu->ValidateParams());
275 return sbu;
276 }));
277 }));
278
__anon9c4d722b2402(const py::module *m) 279 PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
280 (void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
281 "to create a TextFileNode")
282 .def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
283 int32_t shard_id) {
284 std::shared_ptr<TextFileNode> textfile_node =
285 std::make_shared<TextFileNode>(toStringVector(dataset_files), num_samples,
286 toShuffleMode(shuffle), num_shards, shard_id, nullptr);
287 THROW_IF_ERROR(textfile_node->ValidateParams());
288 return textfile_node;
289 }));
290 }));
291
__anon9c4d722b2602(const py::module *m) 292 PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
293 (void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
294 "to create a TFRecordNode")
295 .def(py::init([](const py::list dataset_files, std::shared_ptr<SchemaObj> schema,
296 const py::list columns_list, int64_t num_samples, int32_t shuffle,
297 int32_t num_shards, int32_t shard_id, bool shard_equal_rows) {
298 std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
299 toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
300 toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
301 THROW_IF_ERROR(tfrecord->ValidateParams());
302 return tfrecord;
303 }))
304 .def(py::init([](const py::list dataset_files, std::string schema, const py::list columns_list,
305 int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
306 bool shard_equal_rows) {
307 std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
308 toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
309 toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
310 THROW_IF_ERROR(tfrecord->ValidateParams());
311 return tfrecord;
312 }));
313 }));
314
__anon9c4d722b2902(const py::module *m) 315 PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
316 (void)py::class_<USPSNode, DatasetNode, std::shared_ptr<USPSNode>>(*m, "USPSNode",
317 "to create an USPSNode")
318 .def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle,
319 int32_t num_shards, int32_t shard_id) {
320 auto usps = std::make_shared<USPSNode>(dataset_dir, usage, num_samples, toShuffleMode(shuffle),
321 num_shards, shard_id, nullptr);
322 THROW_IF_ERROR(usps->ValidateParams());
323 return usps;
324 }));
325 }));
326
__anon9c4d722b2b02(const py::module *m) 327 PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
328 (void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
329 .def(py::init([](std::string dataset_dir, std::string task, std::string usage,
330 const py::dict &class_indexing, bool decode, const py::handle &sampler,
331 bool extra_metadata) {
332 std::shared_ptr<VOCNode> voc =
333 std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
334 toSamplerObj(sampler), nullptr, extra_metadata);
335 THROW_IF_ERROR(voc->ValidateParams());
336 return voc;
337 }));
338 }));
339
340 } // namespace dataset
341 } // namespace mindspore
342