• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pybind11/pybind11.h"
17 
18 #include "minddata/dataset/api/python/pybind_conversion.h"
19 #include "minddata/dataset/api/python/pybind_register.h"
20 #include "minddata/dataset/include/dataset/constants.h"
21 #include "minddata/dataset/include/dataset/datasets.h"
22 
23 #include "minddata/dataset/core/config_manager.h"
24 #include "minddata/dataset/core/data_type.h"
25 #include "minddata/dataset/util/path.h"
26 
27 // IR leaf nodes
28 #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
29 #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
30 #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
31 #include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
32 #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
33 #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
34 #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
35 #include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
36 #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
37 #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
38 #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
39 #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
40 #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
41 #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
42 
43 // IR leaf nodes disabled for android
44 #ifndef ENABLE_ANDROID
45 #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
46 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
47 #include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
48 #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
49 #include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
50 #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
51 #endif
52 
53 namespace mindspore {
54 namespace dataset {
55 
56 // PYBIND FOR LEAF NODES
57 // (In alphabetical order)
58 
__anon9c4d722b0102(const py::module *m) 59 PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
60                   (void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
61                                                                                          "to create a CelebANode")
62                     .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode,
63                                      py::list extensions) {
64                       auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
65                                                                  toStringSet(extensions), nullptr);
66                       THROW_IF_ERROR(celebA->ValidateParams());
67                       return celebA;
68                     }));
69                 }));
70 
__anon9c4d722b0302(const py::module *m) 71 PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
72                   (void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
73                                                                                            "to create a Cifar10Node")
74                     .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
75                       auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
76                       THROW_IF_ERROR(cifar10->ValidateParams());
77                       return cifar10;
78                     }));
79                 }));
80 
__anon9c4d722b0502(const py::module *m) 81 PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
82                   (void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
83                                                                                              "to create a Cifar100Node")
84                     .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
85                       auto cifar100 =
86                         std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
87                       THROW_IF_ERROR(cifar100->ValidateParams());
88                       return cifar100;
89                     }));
90                 }));
91 
__anon9c4d722b0702(const py::module *m) 92 PYBIND_REGISTER(CityscapesNode, 2, ([](const py::module *m) {
93                   (void)py::class_<CityscapesNode, DatasetNode, std::shared_ptr<CityscapesNode>>(
94                     *m, "CityscapesNode", "to create a CityscapesNode")
95                     .def(py::init([](std::string dataset_dir, std::string usage, std::string quality_mode,
96                                      std::string task, bool decode, const py::handle &sampler) {
97                       auto cityscapes = std::make_shared<CityscapesNode>(dataset_dir, usage, quality_mode, task, decode,
98                                                                          toSamplerObj(sampler), nullptr);
99                       THROW_IF_ERROR(cityscapes->ValidateParams());
100                       return cityscapes;
101                     }));
102                 }));
103 
__anon9c4d722b0902(const py::module *m) 104 PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) {
105                   (void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode",
106                                                                                      "to create a CLUENode")
107                     .def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples,
108                                      int32_t shuffle, int32_t num_shards, int32_t shard_id) {
109                       std::shared_ptr<CLUENode> clue_node =
110                         std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples,
111                                                             toShuffleMode(shuffle), num_shards, shard_id, nullptr);
112                       THROW_IF_ERROR(clue_node->ValidateParams());
113                       return clue_node;
114                     }));
115                 }));
116 
__anon9c4d722b0b02(const py::module *m) 117 PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
118                   (void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
119                                                                                      "to create a CocoNode")
120                     .def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
121                                      bool decode, const py::handle &sampler, bool extra_metadata) {
122                       std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
123                         dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr, extra_metadata);
124                       THROW_IF_ERROR(coco->ValidateParams());
125                       return coco;
126                     }));
127                 }));
128 
__anon9c4d722b0d02(const py::module *m) 129 PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
130                   (void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
131                     .def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
132                                      std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
133                                      int32_t num_shards, int32_t shard_id) {
134                       auto csv =
135                         std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), column_names,
136                                                   num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
137                       THROW_IF_ERROR(csv->ValidateParams());
138                       return csv;
139                     }));
140                 }));
141 
__anon9c4d722b0f02(const py::module *m) 142 PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
143                   (void)py::class_<DIV2KNode, DatasetNode, std::shared_ptr<DIV2KNode>>(*m, "DIV2KNode",
144                                                                                        "to create a DIV2KNode")
145                     .def(py::init([](std::string dataset_dir, std::string usage, std::string downgrade, int32_t scale,
146                                      bool decode, py::handle sampler) {
147                       auto div2k = std::make_shared<DIV2KNode>(dataset_dir, usage, downgrade, scale, decode,
148                                                                toSamplerObj(sampler), nullptr);
149                       THROW_IF_ERROR(div2k->ValidateParams());
150                       return div2k;
151                     }));
152                 }));
153 
154 PYBIND_REGISTER(
__anon9c4d722b1102(const py::module *m) 155   FlickrNode, 2, ([](const py::module *m) {
156     (void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")
157       .def(py::init([](std::string dataset_dir, std::string annotation_file, bool decode, const py::handle &sampler) {
158         auto flickr =
159           std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, toSamplerObj(sampler), nullptr);
160         THROW_IF_ERROR(flickr->ValidateParams());
161         return flickr;
162       }));
163   }));
164 
__anon9c4d722b1302(const py::module *m) 165 PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
166                   (void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
167                     *m, "GeneratorNode", "to create a GeneratorNode")
168                     .def(py::init([](py::function generator_function, const std::vector<std::string> &column_names,
169                                      const std::vector<DataType> &column_types, int64_t dataset_len, py::handle sampler,
170                                      uint32_t num_parallel_workers) {
171                       auto gen =
172                         std::make_shared<GeneratorNode>(generator_function, column_names, column_types, dataset_len,
173                                                         toSamplerObj(sampler), num_parallel_workers);
174                       THROW_IF_ERROR(gen->ValidateParams());
175                       return gen;
176                     }))
177                     .def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema,
178                                      int64_t dataset_len, py::handle sampler, uint32_t num_parallel_workers) {
179                       auto gen = std::make_shared<GeneratorNode>(generator_function, schema, dataset_len,
180                                                                  toSamplerObj(sampler), num_parallel_workers);
181                       THROW_IF_ERROR(gen->ValidateParams());
182                       return gen;
183                     }));
184                 }));
185 
__anon9c4d722b1602(const py::module *m) 186 PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
187                   (void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
188                     *m, "ImageFolderNode", "to create an ImageFolderNode")
189                     .def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions,
190                                      py::dict class_indexing) {
191                       // Don't update recursive to true
192                       bool recursive = false;  // Will be removed in future PR
193                       auto imagefolder = std::make_shared<ImageFolderNode>(dataset_dir, decode, toSamplerObj(sampler),
194                                                                            recursive, toStringSet(extensions),
195                                                                            toStringMap(class_indexing), nullptr);
196                       THROW_IF_ERROR(imagefolder->ValidateParams());
197                       return imagefolder;
198                     }));
199                 }));
200 
__anon9c4d722b1802(const py::module *m) 201 PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
202                   (void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
203                                                                                              "to create a ManifestNode")
204                     .def(py::init([](std::string dataset_file, std::string usage, py::handle sampler,
205                                      py::dict class_indexing, bool decode) {
206                       auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
207                                                                      toStringMap(class_indexing), decode, nullptr);
208                       THROW_IF_ERROR(manifest->ValidateParams());
209                       return manifest;
210                     }));
211                 }));
212 
__anon9c4d722b1a02(const py::module *m) 213 PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
214                   (void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
215                                                                                              "to create a MindDataNode")
216                     .def(py::init([](std::string dataset_file, py::list columns_list, py::handle sampler,
217                                      const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
218                       nlohmann::json padded_sample_json;
219                       std::map<std::string, std::string> sample_bytes;
220                       THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
221                       auto minddata = std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list),
222                                                                      toSamplerObj(sampler, true), padded_sample_json,
223                                                                      num_padded, shuffle_mode, nullptr);
224                       minddata->SetSampleBytes(&sample_bytes);
225                       THROW_IF_ERROR(minddata->ValidateParams());
226                       return minddata;
227                     }))
228                     .def(py::init([](py::list dataset_file, py::list columns_list, py::handle sampler,
229                                      const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
230                       nlohmann::json padded_sample_json;
231                       std::map<std::string, std::string> sample_bytes;
232                       THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
233                       auto minddata = std::make_shared<MindDataNode>(
234                         toStringVector(dataset_file), toStringVector(columns_list), toSamplerObj(sampler, true),
235                         padded_sample_json, num_padded, shuffle_mode, nullptr);
236                       minddata->SetSampleBytes(&sample_bytes);
237                       THROW_IF_ERROR(minddata->ValidateParams());
238                       return minddata;
239                     }));
240                 }));
241 
__anon9c4d722b1d02(const py::module *m) 242 PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
243                   (void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
244                                                                                        "to create an MnistNode")
245                     .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
246                       auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
247                       THROW_IF_ERROR(mnist->ValidateParams());
248                       return mnist;
249                     }));
250                 }));
251 
__anon9c4d722b1f02(const py::module *m) 252 PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
253                   (void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode",
254                                                                                          "to create a RandomNode")
255                     .def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list) {
256                       auto random_node =
257                         std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
258                       THROW_IF_ERROR(random_node->ValidateParams());
259                       return random_node;
260                     }))
261                     .def(py::init([](int32_t total_rows, std::string schema, py::list columns_list) {
262                       auto random_node =
263                         std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
264                       THROW_IF_ERROR(random_node->ValidateParams());
265                       return random_node;
266                     }));
267                 }));
268 
__anon9c4d722b2202(const py::module *m) 269 PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
270                   (void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
271                                                                                    "to create an SBUNode")
272                     .def(py::init([](std::string dataset_dir, bool decode, const py::handle &sampler) {
273                       auto sbu = std::make_shared<SBUNode>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
274                       THROW_IF_ERROR(sbu->ValidateParams());
275                       return sbu;
276                     }));
277                 }));
278 
__anon9c4d722b2402(const py::module *m) 279 PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
280                   (void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
281                                                                                              "to create a TextFileNode")
282                     .def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
283                                      int32_t shard_id) {
284                       std::shared_ptr<TextFileNode> textfile_node =
285                         std::make_shared<TextFileNode>(toStringVector(dataset_files), num_samples,
286                                                        toShuffleMode(shuffle), num_shards, shard_id, nullptr);
287                       THROW_IF_ERROR(textfile_node->ValidateParams());
288                       return textfile_node;
289                     }));
290                 }));
291 
__anon9c4d722b2602(const py::module *m) 292 PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
293                   (void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
294                                                                                              "to create a TFRecordNode")
295                     .def(py::init([](const py::list dataset_files, std::shared_ptr<SchemaObj> schema,
296                                      const py::list columns_list, int64_t num_samples, int32_t shuffle,
297                                      int32_t num_shards, int32_t shard_id, bool shard_equal_rows) {
298                       std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
299                         toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
300                         toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
301                       THROW_IF_ERROR(tfrecord->ValidateParams());
302                       return tfrecord;
303                     }))
304                     .def(py::init([](const py::list dataset_files, std::string schema, const py::list columns_list,
305                                      int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
306                                      bool shard_equal_rows) {
307                       std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
308                         toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
309                         toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
310                       THROW_IF_ERROR(tfrecord->ValidateParams());
311                       return tfrecord;
312                     }));
313                 }));
314 
__anon9c4d722b2902(const py::module *m) 315 PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
316                   (void)py::class_<USPSNode, DatasetNode, std::shared_ptr<USPSNode>>(*m, "USPSNode",
317                                                                                      "to create an USPSNode")
318                     .def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle,
319                                      int32_t num_shards, int32_t shard_id) {
320                       auto usps = std::make_shared<USPSNode>(dataset_dir, usage, num_samples, toShuffleMode(shuffle),
321                                                              num_shards, shard_id, nullptr);
322                       THROW_IF_ERROR(usps->ValidateParams());
323                       return usps;
324                     }));
325                 }));
326 
__anon9c4d722b2b02(const py::module *m) 327 PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
328                   (void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
329                     .def(py::init([](std::string dataset_dir, std::string task, std::string usage,
330                                      const py::dict &class_indexing, bool decode, const py::handle &sampler,
331                                      bool extra_metadata) {
332                       std::shared_ptr<VOCNode> voc =
333                         std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
334                                                   toSamplerObj(sampler), nullptr, extra_metadata);
335                       THROW_IF_ERROR(voc->ValidateParams());
336                       return voc;
337                     }));
338                 }));
339 
340 }  // namespace dataset
341 }  // namespace mindspore
342