1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pybind11/pybind11.h"
17 #include "pybind11/stl_bind.h"
18
19 #include "minddata/dataset/api/python/pybind_register.h"
20
21 #include "minddata/dataset/include/dataset/constants.h"
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "minddata/mindrecord/include/shard_operator.h"
25 #include "minddata/mindrecord/include/shard_pk_sample.h"
26 #include "minddata/mindrecord/include/shard_sample.h"
27 #include "minddata/mindrecord/include/shard_sequential_sample.h"
28 #include "minddata/mindrecord/include/shard_shuffle.h"
29
30 namespace mindspore {
31 namespace dataset {
__anonaf517c3c0102(const py::module *m) 32 PYBIND_REGISTER(ShardOperator, 0, ([](const py::module *m) {
33 (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(
34 *m, "ShardOperator")
35 .def("add_child",
36 [](const std::shared_ptr<mindrecord::ShardOperator> &self,
37 const std::shared_ptr<mindrecord::ShardOperator> &child) {
38 THROW_IF_ERROR(self->SetChildOp(child));
39 })
40 .def("set_num_samples", [](const std::shared_ptr<mindrecord::ShardOperator> &self,
41 int64_t num_samples) { self->SetNumSamples(num_samples); });
42 }));
43
__anonaf517c3c0402(const py::module *m) 44 PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) {
45 (void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
46 std::shared_ptr<mindrecord::ShardDistributedSample>>(*m,
47 "MindrecordDistributedSampler")
48 .def(py::init<int64_t, int64_t, bool, uint32_t, int64_t, int64_t>());
49 }));
50
51 PYBIND_REGISTER(
__anonaf517c3c0502(const py::module *m) 52 ShardPkSample, 1, ([](const py::module *m) {
53 (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
54 *m, "MindrecordPkSampler")
55 .def(py::init([](int64_t kVal, const std::string &kColumn, bool shuffle, int64_t num_samples) {
56 if (shuffle == true) {
57 return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
58 GetSeed(), num_samples);
59 } else {
60 return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples);
61 }
62 }));
63 }));
64
65 PYBIND_REGISTER(
__anonaf517c3c0702(const py::module *m) 66 ShardSample, 0, ([](const py::module *m) {
67 (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
68 *m, "MindrecordSubsetSampler")
69 .def(py::init<std::vector<int64_t>, uint32_t>())
70 .def(py::init<std::vector<int64_t>>());
71 }));
72
__anonaf517c3c0802(const py::module *m) 73 PYBIND_REGISTER(ShardSequentialSample, 0, ([](const py::module *m) {
74 (void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
75 std::shared_ptr<mindrecord::ShardSequentialSample>>(*m,
76 "MindrecordSequentialSampler")
77 .def(py::init([](int64_t num_samples, int64_t start_index) {
78 return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
79 }));
80 }));
81
82 PYBIND_REGISTER(
__anonaf517c3c0a02(const py::module *m) 83 ShardShuffle, 1, ([](const py::module *m) {
84 (void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
85 *m, "MindrecordRandomSampler")
86 .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
87 return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
88 }));
89 }));
90
__anonaf517c3c0c02(const py::module *m) 91 PYBIND_REGISTER(ShuffleMode, 1, ([](const py::module *m) {
92 (void)py::enum_<ShuffleMode>(*m, "ShuffleMode", py::arithmetic())
93 .value("FALSE", ShuffleMode::kFalse)
94 .value("FILES", ShuffleMode::kFiles)
95 .value("GLOBAL", ShuffleMode::kGlobal)
96 .value("INFILE", ShuffleMode::kInfile)
97 .export_values();
98 }));
99 } // namespace dataset
100 } // namespace mindspore
101