• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <random>
21 #include <utility>
22 #include <vector>
23 
24 #include "minddata/dataset/core/global_context.h"
25 #include "minddata/dataset/util/random.h"
26 
27 namespace mindspore {
28 namespace dataset {
29 //  Constructor.
WeightedRandomSamplerRT(const std::vector<double> & weights,int64_t num_samples,bool replacement,int64_t samples_per_tensor)30 WeightedRandomSamplerRT::WeightedRandomSamplerRT(const std::vector<double> &weights, int64_t num_samples,
31                                                  bool replacement, int64_t samples_per_tensor)
32     : SamplerRT(num_samples, samples_per_tensor), weights_(weights), replacement_(replacement), sample_id_(0) {}
33 
34 // Initialized this Sampler.
InitSampler()35 Status WeightedRandomSamplerRT::InitSampler() {
36   if (is_initialized) {
37     return Status::OK();
38   }
39   // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
40   // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
41   if (num_samples_ == 0 || num_samples_ > num_rows_) {
42     num_samples_ = num_rows_;
43   }
44   CHECK_FAIL_RETURN_UNEXPECTED(
45     num_rows_ > 0 && num_samples_,
46     "Invalid parameter, num_samples and num_rows must be greater than 0, but got num_rows: " +
47       std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_));
48   CHECK_FAIL_RETURN_UNEXPECTED(samples_per_tensor_ > 0,
49                                "Invalid parameter, samples_per_tensor(num_samples) must be greater than 0, but got " +
50                                  std::to_string(samples_per_tensor_) + ".\n");
51 
52   if (weights_.size() > static_cast<size_t>(num_rows_)) {
53     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
54                   "Invalid parameter, size of sample weights must be less than or equal to num of data, "
55                   "otherwise might cause generated id out of bound or other errors, but got weight size: " +
56                     std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
57   }
58   if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
59     RETURN_STATUS_UNEXPECTED(
60       "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
61       "but got weight size: " +
62       std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
63   }
64 
65   // Initialize random generator with seed from config manager
66   rand_gen_.seed(GetSeed());
67 
68   samples_per_tensor_ = (samples_per_tensor_ > num_samples_) ? num_samples_ : samples_per_tensor_;
69 
70   if (!replacement_) {
71     exp_dist_ = std::make_unique<std::exponential_distribution<>>(1);
72     InitOnePassSampling();
73   } else {
74     discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end());
75   }
76 
77   is_initialized = true;
78   return Status::OK();
79 }
80 
81 // Initialized the computation for generating weighted random numbers without replacement using onepass method.
InitOnePassSampling()82 void WeightedRandomSamplerRT::InitOnePassSampling() {
83   exp_dist_->reset();
84   onepass_ids_.clear();
85   std::vector<std::pair<double, int64_t>> val_idx;
86   for (size_t i = 0; i < weights_.size(); i++) {
87     val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i));
88   }
89 
90   // Partial sort the first `numSamples` elements.
91   std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end());
92   for (int64_t i = 0; i < num_samples_; i++) {
93     onepass_ids_.push_back(val_idx[i].second);
94   }
95 }
96 
97 // Reset the internal variable to the initial state and reshuffle the indices.
ResetSampler()98 Status WeightedRandomSamplerRT::ResetSampler() {
99   sample_id_ = 0;
100   rand_gen_.seed(GetSeed());
101   if (!replacement_) {
102     InitOnePassSampling();
103   } else {
104     discrete_dist_->reset();
105   }
106 
107   if (HasChildSampler()) {
108     RETURN_IF_NOT_OK(child_[0]->ResetSampler());
109   }
110 
111   return Status::OK();
112 }
113 
114 // Get the sample ids.
GetNextSample(TensorRow * out)115 Status WeightedRandomSamplerRT::GetNextSample(TensorRow *out) {
116   if (weights_.size() > static_cast<size_t>(num_rows_)) {
117     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
118                   "Invalid parameter, size of sample weights must be less than or equal to num of data, "
119                   "otherwise might cause generated id out of bound or other errors, but got weight size: " +
120                     std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
121   }
122 
123   if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
124     RETURN_STATUS_UNEXPECTED(
125       "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
126       "but got weight size: " +
127       std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
128   }
129 
130   if (sample_id_ == num_samples_) {
131     (*out) = TensorRow(TensorRow::kFlagEOE);
132   } else {
133     if (HasChildSampler()) {
134       RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
135     }
136 
137     std::shared_ptr<Tensor> outputIds;
138 
139     int64_t last_id = sample_id_ + samples_per_tensor_;
140     // Handling the return all samples at once, and when last draw is not a full batch.
141     if (last_id > num_samples_) {
142       last_id = num_samples_;
143     }
144 
145     // Allocate tensor.
146     RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_));
147 
148     // Initialize tensor.
149     auto id_ptr = outputIds->begin<int64_t>();
150     // Assign the data to tensor element.
151     while (sample_id_ < last_id) {
152       int64_t genId;
153       if (replacement_) {
154         genId = (*discrete_dist_)(rand_gen_);
155       } else {
156         // Draw sample without replacement.
157         genId = onepass_ids_.front();
158         onepass_ids_.pop_front();
159       }
160 
161       if (genId >= num_rows_) {
162         RETURN_STATUS_UNEXPECTED("Generated indice is out of bound, expect range [0, num_data-1], got indice: " +
163                                  std::to_string(genId) + ", num_data: " + std::to_string(num_rows_ - 1));
164       }
165 
166       if (HasChildSampler()) {
167         RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId));
168       }
169 
170       *id_ptr = genId;
171       ++id_ptr;
172       sample_id_++;
173     }
174 
175     (*out) = {outputIds};
176   }
177 
178   return Status::OK();
179 }
180 
SamplerPrint(std::ostream & out,bool show_all) const181 void WeightedRandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
182   out << "\nSampler: WeightedRandomSampler";
183   if (show_all) {
184     // Call the super class for displaying any common detailed info
185     SamplerRT::SamplerPrint(out, show_all);
186     // Then add our own info if any
187   }
188 }
189 
to_json(nlohmann::json * out_json)190 Status WeightedRandomSamplerRT::to_json(nlohmann::json *out_json) {
191   nlohmann::json args;
192   RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
193   args["sampler_name"] = "WeightedRandomSampler";
194   args["weights"] = weights_;
195   args["replacement"] = replacement_;
196   *out_json = args;
197   return Status::OK();
198 }
199 }  // namespace dataset
200 }  // namespace mindspore
201