• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pybind11/pybind11.h"
17 #include "pybind11/stl_bind.h"
18 
19 #include "minddata/dataset/api/python/pybind_register.h"
20 
21 #include "minddata/dataset/include/dataset/constants.h"
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "minddata/mindrecord/include/shard_operator.h"
25 #include "minddata/mindrecord/include/shard_pk_sample.h"
26 #include "minddata/mindrecord/include/shard_sample.h"
27 #include "minddata/mindrecord/include/shard_sequential_sample.h"
28 #include "minddata/mindrecord/include/shard_shuffle.h"
29 
30 namespace mindspore {
31 namespace dataset {
__anonaf517c3c0102(const py::module *m) 32 PYBIND_REGISTER(ShardOperator, 0, ([](const py::module *m) {
33                   (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(
34                     *m, "ShardOperator")
35                     .def("add_child",
36                          [](const std::shared_ptr<mindrecord::ShardOperator> &self,
37                             const std::shared_ptr<mindrecord::ShardOperator> &child) {
38                            THROW_IF_ERROR(self->SetChildOp(child));
39                          })
40                     .def("set_num_samples", [](const std::shared_ptr<mindrecord::ShardOperator> &self,
41                                                int64_t num_samples) { self->SetNumSamples(num_samples); });
42                 }));
43 
__anonaf517c3c0402(const py::module *m) 44 PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) {
45                   (void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
46                                    std::shared_ptr<mindrecord::ShardDistributedSample>>(*m,
47                                                                                         "MindrecordDistributedSampler")
48                     .def(py::init<int64_t, int64_t, bool, uint32_t, int64_t, int64_t>());
49                 }));
50 
51 PYBIND_REGISTER(
__anonaf517c3c0502(const py::module *m) 52   ShardPkSample, 1, ([](const py::module *m) {
53     (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
54       *m, "MindrecordPkSampler")
55       .def(py::init([](int64_t kVal, const std::string &kColumn, bool shuffle, int64_t num_samples) {
56         if (shuffle == true) {
57           return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
58                                                              GetSeed(), num_samples);
59         } else {
60           return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples);
61         }
62       }));
63   }));
64 
65 PYBIND_REGISTER(
__anonaf517c3c0702(const py::module *m) 66   ShardSample, 0, ([](const py::module *m) {
67     (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
68       *m, "MindrecordSubsetSampler")
69       .def(py::init<std::vector<int64_t>, uint32_t>())
70       .def(py::init<std::vector<int64_t>>());
71   }));
72 
__anonaf517c3c0802(const py::module *m) 73 PYBIND_REGISTER(ShardSequentialSample, 0, ([](const py::module *m) {
74                   (void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
75                                    std::shared_ptr<mindrecord::ShardSequentialSample>>(*m,
76                                                                                        "MindrecordSequentialSampler")
77                     .def(py::init([](int64_t num_samples, int64_t start_index) {
78                       return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
79                     }));
80                 }));
81 
82 PYBIND_REGISTER(
__anonaf517c3c0a02(const py::module *m) 83   ShardShuffle, 1, ([](const py::module *m) {
84     (void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
85       *m, "MindrecordRandomSampler")
86       .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
87         return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
88       }));
89   }));
90 
__anonaf517c3c0c02(const py::module *m) 91 PYBIND_REGISTER(ShuffleMode, 1, ([](const py::module *m) {
92                   (void)py::enum_<ShuffleMode>(*m, "ShuffleMode", py::arithmetic())
93                     .value("FALSE", ShuffleMode::kFalse)
94                     .value("FILES", ShuffleMode::kFiles)
95                     .value("GLOBAL", ShuffleMode::kGlobal)
96                     .value("INFILE", ShuffleMode::kInfile)
97                     .export_values();
98                 }));
99 }  // namespace dataset
100 }  // namespace mindspore
101