1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
17
18 #include "minddata/dataset/core/global_context.h"
19 #include "minddata/dataset/util/random.h"
20
21 namespace mindspore {
22 namespace dataset {
23 // Constructor.
WeightedRandomSamplerRT(const std::vector<double> & weights,int64_t num_samples,bool replacement,int64_t samples_per_tensor)24 WeightedRandomSamplerRT::WeightedRandomSamplerRT(const std::vector<double> &weights, int64_t num_samples,
25 bool replacement, int64_t samples_per_tensor)
26 : SamplerRT(num_samples, samples_per_tensor), weights_(weights), replacement_(replacement), sample_id_(0) {}
27
28 // Initialized this Sampler.
InitSampler()29 Status WeightedRandomSamplerRT::InitSampler() {
30 if (is_initialized) {
31 return Status::OK();
32 }
33 // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
34 // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
35 if (num_samples_ == 0 || num_samples_ > num_rows_) {
36 num_samples_ = num_rows_;
37 }
38 CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_,
39 "[Internal ERROR] num_samples and num_rows must be greater than 0, but got num_rows: " +
40 std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_));
41 CHECK_FAIL_RETURN_UNEXPECTED(samples_per_tensor_ > 0,
42 "Invalid parameter, samples_per_tensor(num_samples) must be greater than 0, but got " +
43 std::to_string(samples_per_tensor_) + ".\n");
44
45 if (weights_.size() > static_cast<size_t>(num_rows_)) {
46 RETURN_STATUS_UNEXPECTED(
47 "Invalid parameter, size of sample weights must be less than or equal to num of data, "
48 "otherwise might cause generated id out of bound or other errors, but got weight size: " +
49 std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
50 }
51 if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
52 RETURN_STATUS_UNEXPECTED(
53 "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
54 "but got weight size: " +
55 std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
56 }
57
58 // Initialize random generator with seed from config manager
59 rand_gen_.seed(GetSeed());
60
61 samples_per_tensor_ = (samples_per_tensor_ > num_samples_) ? num_samples_ : samples_per_tensor_;
62
63 if (!replacement_) {
64 exp_dist_ = std::make_unique<std::exponential_distribution<>>(1);
65 InitOnePassSampling();
66 } else {
67 discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end());
68 }
69
70 is_initialized = true;
71 return Status::OK();
72 }
73
74 // Initialized the computation for generating weighted random numbers without replacement using onepass method.
InitOnePassSampling()75 void WeightedRandomSamplerRT::InitOnePassSampling() {
76 exp_dist_->reset();
77 onepass_ids_.clear();
78 std::vector<std::pair<double, int64_t>> val_idx;
79 for (size_t i = 0; i < weights_.size(); i++) {
80 val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i));
81 }
82
83 // Partial sort the first `numSamples` elements.
84 std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end());
85 for (int64_t i = 0; i < num_samples_; i++) {
86 onepass_ids_.push_back(val_idx[i].second);
87 }
88 }
89
90 // Reset the internal variable to the initial state and reshuffle the indices.
ResetSampler(const bool failover_reset)91 Status WeightedRandomSamplerRT::ResetSampler(const bool failover_reset) {
92 sample_id_ = 0;
93 rand_gen_.seed(GetSeed());
94 if (!replacement_) {
95 InitOnePassSampling();
96 } else {
97 discrete_dist_->reset();
98 }
99
100 if (HasChildSampler()) {
101 RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
102 }
103
104 return Status::OK();
105 }
106
107 // Get the sample ids.
GetNextSample(TensorRow * out)108 Status WeightedRandomSamplerRT::GetNextSample(TensorRow *out) {
109 RETURN_UNEXPECTED_IF_NULL(out);
110 if (weights_.size() > static_cast<size_t>(num_rows_)) {
111 RETURN_STATUS_UNEXPECTED(
112 "Invalid parameter, size of sample weights must be less than or equal to num of data, "
113 "otherwise might cause generated id out of bound or other errors, but got weight size: " +
114 std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
115 }
116
117 if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
118 RETURN_STATUS_UNEXPECTED(
119 "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
120 "but got weight size: " +
121 std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
122 }
123
124 if (sample_id_ == num_samples_) {
125 (*out) = TensorRow(TensorRow::kFlagEOE);
126 } else {
127 if (HasChildSampler()) {
128 RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
129 }
130
131 std::shared_ptr<Tensor> outputIds;
132
133 int64_t last_id = sample_id_ + samples_per_tensor_;
134 // Handling the return all samples at once, and when last draw is not a full batch.
135 if (last_id > num_samples_) {
136 last_id = num_samples_;
137 }
138
139 // Allocate tensor.
140 RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_));
141
142 // Initialize tensor.
143 auto id_ptr = outputIds->begin<int64_t>();
144 // Assign the data to tensor element.
145 while (sample_id_ < last_id) {
146 int64_t genId;
147 if (replacement_) {
148 genId = (*discrete_dist_)(rand_gen_);
149 } else {
150 // Draw sample without replacement.
151 genId = onepass_ids_.front();
152 onepass_ids_.pop_front();
153 }
154
155 if (genId >= num_rows_) {
156 RETURN_STATUS_UNEXPECTED(
157 "[Internal ERROR] Generated indice is out of bound, expect range [0, num_data-1], got indice: " +
158 std::to_string(genId) + ", num_data: " + std::to_string(num_rows_ - 1));
159 }
160
161 if (HasChildSampler()) {
162 RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId));
163 }
164
165 *id_ptr = genId;
166 ++id_ptr;
167 sample_id_++;
168 }
169
170 (*out) = {outputIds};
171 }
172
173 return Status::OK();
174 }
175
SamplerPrint(std::ostream & out,bool show_all) const176 void WeightedRandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
177 out << "\nSampler: WeightedRandomSampler";
178 if (show_all) {
179 // Call the super class for displaying any common detailed info
180 SamplerRT::SamplerPrint(out, show_all);
181 // Then add our own info if any
182 }
183 }
184
to_json(nlohmann::json * out_json)185 Status WeightedRandomSamplerRT::to_json(nlohmann::json *out_json) {
186 RETURN_UNEXPECTED_IF_NULL(out_json);
187 nlohmann::json args;
188 RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
189 args["sampler_name"] = "WeightedRandomSampler";
190 args["weights"] = weights_;
191 args["replacement"] = replacement_;
192 *out_json = args;
193 return Status::OK();
194 }
195 } // namespace dataset
196 } // namespace mindspore
197