• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_
18 
19 #include <deque>
20 #include <limits>
21 #include <memory>
22 #include <vector>
23 
24 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 // Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights).
29 class WeightedRandomSamplerRT : public SamplerRT {
30  public:
31   // Constructor.
32   // @param weights A lift of sample weights.
33   // @param num_samples Number of samples to be drawn.
34   // @param replacement Determine if samples are drawn with/without replacement.
35   // @param samples_per_tensor The number of ids we draw on each call to GetNextSample().
36   // When samples_per_tensor=0, GetNextSample() will draw all the sample ids and return them at once.
37   WeightedRandomSamplerRT(const std::vector<double> &weights, int64_t num_samples, bool replacement,
38                           int64_t samples_per_tensor = std::numeric_limits<int64_t>::max());
39 
40   // Destructor.
41   ~WeightedRandomSamplerRT() = default;
42 
43   // Initialize the sampler.
44   // @param op (Not used in this sampler)
45   // @return Status
46   Status InitSampler() override;
47 
48   // Reset the internal variable to the initial state and reshuffle the indices.
49   Status ResetSampler() override;
50 
51   // Get the sample ids.
52   // @param[out] TensorRow where the sample ids will be placed.
53   // @note the sample ids (int64_t) will be placed in one Tensor
54   Status GetNextSample(TensorRow *out) override;
55 
56   // Printer for debugging purposes.
57   // @param out - output stream to write to
58   // @param show_all - bool to show detailed vs summary
59   void SamplerPrint(std::ostream &out, bool show_all) const override;
60 
61   /// \brief Get the arguments of node
62   /// \param[out] out_json JSON string of all attributes
63   /// \return Status of the function
64   Status to_json(nlohmann::json *out_json) override;
65 
66  private:
67   // A list of weights for each sample.
68   std::vector<double> weights_;
69 
70   // A flag indicating if samples are drawn with/without replacement.
71   bool replacement_;
72 
73   // Current sample id.
74   int64_t sample_id_;
75 
76   // Random engine and device
77   std::mt19937 rand_gen_;
78 
79   // Discrete distribution for generating weighted random numbers with replacement.
80   std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_;
81 
82   // Exponential distribution for generating weighted random numbers without replacement.
83   // based on "Accelerating weighted random sampling without replacement" by Kirill Muller.
84   std::unique_ptr<std::exponential_distribution<>> exp_dist_;
85 
86   // Initialized the computation for generating weighted random numbers without replacement
87   // using onepass method.
88   void InitOnePassSampling();
89 
90   // Store the random weighted ids generated by onepass method in `InitOnePassSampling`
91   std::deque<int64_t> onepass_ids_;
92 };
93 }  // namespace dataset
94 }  // namespace mindspore
95 
96 #endif
97