1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <random>
21 #include <utility>
22 #include <vector>
23
24 #include "minddata/dataset/core/global_context.h"
25 #include "minddata/dataset/util/random.h"
26
27 namespace mindspore {
28 namespace dataset {
29 // Constructor.
WeightedRandomSamplerRT(const std::vector<double> & weights,int64_t num_samples,bool replacement,int64_t samples_per_tensor)30 WeightedRandomSamplerRT::WeightedRandomSamplerRT(const std::vector<double> &weights, int64_t num_samples,
31 bool replacement, int64_t samples_per_tensor)
32 : SamplerRT(num_samples, samples_per_tensor), weights_(weights), replacement_(replacement), sample_id_(0) {}
33
34 // Initialized this Sampler.
InitSampler()35 Status WeightedRandomSamplerRT::InitSampler() {
36 if (is_initialized) {
37 return Status::OK();
38 }
39 // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
40 // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
41 if (num_samples_ == 0 || num_samples_ > num_rows_) {
42 num_samples_ = num_rows_;
43 }
44 CHECK_FAIL_RETURN_UNEXPECTED(
45 num_rows_ > 0 && num_samples_,
46 "Invalid parameter, num_samples and num_rows must be greater than 0, but got num_rows: " +
47 std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_));
48 CHECK_FAIL_RETURN_UNEXPECTED(samples_per_tensor_ > 0,
49 "Invalid parameter, samples_per_tensor(num_samples) must be greater than 0, but got " +
50 std::to_string(samples_per_tensor_) + ".\n");
51
52 if (weights_.size() > static_cast<size_t>(num_rows_)) {
53 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
54 "Invalid parameter, size of sample weights must be less than or equal to num of data, "
55 "otherwise might cause generated id out of bound or other errors, but got weight size: " +
56 std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
57 }
58 if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
59 RETURN_STATUS_UNEXPECTED(
60 "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
61 "but got weight size: " +
62 std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
63 }
64
65 // Initialize random generator with seed from config manager
66 rand_gen_.seed(GetSeed());
67
68 samples_per_tensor_ = (samples_per_tensor_ > num_samples_) ? num_samples_ : samples_per_tensor_;
69
70 if (!replacement_) {
71 exp_dist_ = std::make_unique<std::exponential_distribution<>>(1);
72 InitOnePassSampling();
73 } else {
74 discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end());
75 }
76
77 is_initialized = true;
78 return Status::OK();
79 }
80
81 // Initialized the computation for generating weighted random numbers without replacement using onepass method.
InitOnePassSampling()82 void WeightedRandomSamplerRT::InitOnePassSampling() {
83 exp_dist_->reset();
84 onepass_ids_.clear();
85 std::vector<std::pair<double, int64_t>> val_idx;
86 for (size_t i = 0; i < weights_.size(); i++) {
87 val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i));
88 }
89
90 // Partial sort the first `numSamples` elements.
91 std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end());
92 for (int64_t i = 0; i < num_samples_; i++) {
93 onepass_ids_.push_back(val_idx[i].second);
94 }
95 }
96
97 // Reset the internal variable to the initial state and reshuffle the indices.
ResetSampler()98 Status WeightedRandomSamplerRT::ResetSampler() {
99 sample_id_ = 0;
100 rand_gen_.seed(GetSeed());
101 if (!replacement_) {
102 InitOnePassSampling();
103 } else {
104 discrete_dist_->reset();
105 }
106
107 if (HasChildSampler()) {
108 RETURN_IF_NOT_OK(child_[0]->ResetSampler());
109 }
110
111 return Status::OK();
112 }
113
114 // Get the sample ids.
GetNextSample(TensorRow * out)115 Status WeightedRandomSamplerRT::GetNextSample(TensorRow *out) {
116 if (weights_.size() > static_cast<size_t>(num_rows_)) {
117 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
118 "Invalid parameter, size of sample weights must be less than or equal to num of data, "
119 "otherwise might cause generated id out of bound or other errors, but got weight size: " +
120 std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
121 }
122
123 if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
124 RETURN_STATUS_UNEXPECTED(
125 "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
126 "but got weight size: " +
127 std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
128 }
129
130 if (sample_id_ == num_samples_) {
131 (*out) = TensorRow(TensorRow::kFlagEOE);
132 } else {
133 if (HasChildSampler()) {
134 RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
135 }
136
137 std::shared_ptr<Tensor> outputIds;
138
139 int64_t last_id = sample_id_ + samples_per_tensor_;
140 // Handling the return all samples at once, and when last draw is not a full batch.
141 if (last_id > num_samples_) {
142 last_id = num_samples_;
143 }
144
145 // Allocate tensor.
146 RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_));
147
148 // Initialize tensor.
149 auto id_ptr = outputIds->begin<int64_t>();
150 // Assign the data to tensor element.
151 while (sample_id_ < last_id) {
152 int64_t genId;
153 if (replacement_) {
154 genId = (*discrete_dist_)(rand_gen_);
155 } else {
156 // Draw sample without replacement.
157 genId = onepass_ids_.front();
158 onepass_ids_.pop_front();
159 }
160
161 if (genId >= num_rows_) {
162 RETURN_STATUS_UNEXPECTED("Generated indice is out of bound, expect range [0, num_data-1], got indice: " +
163 std::to_string(genId) + ", num_data: " + std::to_string(num_rows_ - 1));
164 }
165
166 if (HasChildSampler()) {
167 RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId));
168 }
169
170 *id_ptr = genId;
171 ++id_ptr;
172 sample_id_++;
173 }
174
175 (*out) = {outputIds};
176 }
177
178 return Status::OK();
179 }
180
SamplerPrint(std::ostream & out,bool show_all) const181 void WeightedRandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
182 out << "\nSampler: WeightedRandomSampler";
183 if (show_all) {
184 // Call the super class for displaying any common detailed info
185 SamplerRT::SamplerPrint(out, show_all);
186 // Then add our own info if any
187 }
188 }
189
to_json(nlohmann::json * out_json)190 Status WeightedRandomSamplerRT::to_json(nlohmann::json *out_json) {
191 nlohmann::json args;
192 RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
193 args["sampler_name"] = "WeightedRandomSampler";
194 args["weights"] = weights_;
195 args["replacement"] = replacement_;
196 *out_json = args;
197 return Status::OK();
198 }
199 } // namespace dataset
200 } // namespace mindspore
201