• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pybind11/pybind11.h"
17 #include "pybind11/stl_bind.h"
18 
19 #include "minddata/dataset/api/python/pybind_register.h"
20 
21 #include "minddata/dataset/include/dataset/constants.h"
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "minddata/mindrecord/include/shard_operator.h"
25 #include "minddata/mindrecord/include/shard_pk_sample.h"
26 #include "minddata/mindrecord/include/shard_sample.h"
27 #include "minddata/mindrecord/include/shard_sequential_sample.h"
28 #include "minddata/mindrecord/include/shard_shuffle.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 
__anon4d523be80102(const py::module *m) 33 PYBIND_REGISTER(ShardOperator, 0, ([](const py::module *m) {
34                   (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(
35                     *m, "ShardOperator")
36                     .def("add_child",
37                          [](std::shared_ptr<mindrecord::ShardOperator> self,
38                             std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
39                 }));
40 
__anon4d523be80302(const py::module *m) 41 PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) {
42                   (void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
43                                    std::shared_ptr<mindrecord::ShardDistributedSample>>(*m,
44                                                                                         "MindrecordDistributedSampler")
45                     .def(py::init<int64_t, int64_t, bool, uint32_t, int64_t, int64_t>());
46                 }));
47 
48 PYBIND_REGISTER(
__anon4d523be80402(const py::module *m) 49   ShardPkSample, 1, ([](const py::module *m) {
50     (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
51       *m, "MindrecordPkSampler")
52       .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle, int64_t num_samples) {
53         if (shuffle == true) {
54           return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
55                                                              GetSeed(), num_samples);
56         } else {
57           return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples);
58         }
59       }));
60   }));
61 
62 PYBIND_REGISTER(
__anon4d523be80602(const py::module *m) 63   ShardSample, 0, ([](const py::module *m) {
64     (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
65       *m, "MindrecordSubsetSampler")
66       .def(py::init<std::vector<int64_t>, uint32_t>())
67       .def(py::init<std::vector<int64_t>>());
68   }));
69 
__anon4d523be80702(const py::module *m) 70 PYBIND_REGISTER(ShardSequentialSample, 0, ([](const py::module *m) {
71                   (void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
72                                    std::shared_ptr<mindrecord::ShardSequentialSample>>(*m,
73                                                                                        "MindrecordSequentialSampler")
74                     .def(py::init([](int64_t num_samples, int64_t start_index) {
75                       return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
76                     }));
77                 }));
78 
79 PYBIND_REGISTER(
__anon4d523be80902(const py::module *m) 80   ShardShuffle, 1, ([](const py::module *m) {
81     (void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
82       *m, "MindrecordRandomSampler")
83       .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
84         return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
85       }));
86   }));
87 
__anon4d523be80b02(const py::module *m) 88 PYBIND_REGISTER(ShuffleMode, 1, ([](const py::module *m) {
89                   (void)py::enum_<ShuffleMode>(*m, "ShuffleMode", py::arithmetic())
90                     .value("FALSE", ShuffleMode::kFalse)
91                     .value("FILES", ShuffleMode::kFiles)
92                     .value("GLOBAL", ShuffleMode::kGlobal)
93                     .value("INFILE", ShuffleMode::kInfile)
94                     .export_values();
95                 }));
96 
97 }  // namespace dataset
98 }  // namespace mindspore
99