• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_
18 
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
24 #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 /// Randomly samples elements from a given list of indices, without replacement.
29 class SubsetRandomSamplerRT : public SubsetSamplerRT {
30  public:
31   /// Constructor.
32   /// \param indices List of indices from where we will randomly draw samples.
33   /// \param num_samples The number of samples to draw. 0 for the full amount.
34   /// \param samples_per_tensor The number of ids we draw on each call to GetNextSample().
35   /// When samples_per_tensor=0, GetNextSample() will draw all the sample ids and return them at once.
36   SubsetRandomSamplerRT(const std::vector<int64_t> &indices, int64_t num_samples,
37                         std::int64_t samples_per_tensor = std::numeric_limits<int64_t>::max());
38 
39   /// Destructor.
40   ~SubsetRandomSamplerRT() = default;
41 
42   /// Initialize the sampler.
43   /// \return Status
44   Status InitSampler() override;
45 
46   /// Reset the internal variable to the initial state and reshuffle the indices.
47   /// \return Status
48   Status ResetSampler() override;
49 
50   /// Printer for debugging purposes.
51   /// \param out - output stream to write to
52   /// \param show_all - bool to show detailed vs summary
53   void SamplerPrint(std::ostream &out, bool show_all) const override;
54 
55   /// \brief Get the arguments of node
56   /// \param[out] out_json JSON string of all attributes
57   /// \return Status of the function
58   Status to_json(nlohmann::json *out_json) override;
59 
60  private:
61   // A random number generator.
62   std::mt19937 rand_gen_;
63 };
64 }  // namespace dataset
65 }  // namespace mindspore
66 
67 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_
68