1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pybind11/pybind11.h"
17 #include "pybind11/stl_bind.h"
18
19 #include "minddata/dataset/api/python/pybind_register.h"
20
21 #include "minddata/dataset/include/dataset/constants.h"
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "minddata/mindrecord/include/shard_operator.h"
25 #include "minddata/mindrecord/include/shard_pk_sample.h"
26 #include "minddata/mindrecord/include/shard_sample.h"
27 #include "minddata/mindrecord/include/shard_sequential_sample.h"
28 #include "minddata/mindrecord/include/shard_shuffle.h"
29
30 namespace mindspore {
31 namespace dataset {
32
__anon4d523be80102(const py::module *m) 33 PYBIND_REGISTER(ShardOperator, 0, ([](const py::module *m) {
34 (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(
35 *m, "ShardOperator")
36 .def("add_child",
37 [](std::shared_ptr<mindrecord::ShardOperator> self,
38 std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
39 }));
40
__anon4d523be80302(const py::module *m) 41 PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) {
42 (void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
43 std::shared_ptr<mindrecord::ShardDistributedSample>>(*m,
44 "MindrecordDistributedSampler")
45 .def(py::init<int64_t, int64_t, bool, uint32_t, int64_t, int64_t>());
46 }));
47
48 PYBIND_REGISTER(
__anon4d523be80402(const py::module *m) 49 ShardPkSample, 1, ([](const py::module *m) {
50 (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
51 *m, "MindrecordPkSampler")
52 .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle, int64_t num_samples) {
53 if (shuffle == true) {
54 return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
55 GetSeed(), num_samples);
56 } else {
57 return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples);
58 }
59 }));
60 }));
61
62 PYBIND_REGISTER(
__anon4d523be80602(const py::module *m) 63 ShardSample, 0, ([](const py::module *m) {
64 (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
65 *m, "MindrecordSubsetSampler")
66 .def(py::init<std::vector<int64_t>, uint32_t>())
67 .def(py::init<std::vector<int64_t>>());
68 }));
69
__anon4d523be80702(const py::module *m) 70 PYBIND_REGISTER(ShardSequentialSample, 0, ([](const py::module *m) {
71 (void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
72 std::shared_ptr<mindrecord::ShardSequentialSample>>(*m,
73 "MindrecordSequentialSampler")
74 .def(py::init([](int64_t num_samples, int64_t start_index) {
75 return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
76 }));
77 }));
78
79 PYBIND_REGISTER(
__anon4d523be80902(const py::module *m) 80 ShardShuffle, 1, ([](const py::module *m) {
81 (void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
82 *m, "MindrecordRandomSampler")
83 .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
84 return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
85 }));
86 }));
87
__anon4d523be80b02(const py::module *m) 88 PYBIND_REGISTER(ShuffleMode, 1, ([](const py::module *m) {
89 (void)py::enum_<ShuffleMode>(*m, "ShuffleMode", py::arithmetic())
90 .value("FALSE", ShuffleMode::kFalse)
91 .value("FILES", ShuffleMode::kFiles)
92 .value("GLOBAL", ShuffleMode::kGlobal)
93 .value("INFILE", ShuffleMode::kInfile)
94 .export_values();
95 }));
96
97 } // namespace dataset
98 } // namespace mindspore
99