• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
17 
18 #include "minddata/dataset/core/global_context.h"
19 #include "minddata/dataset/util/random.h"
20 
21 namespace mindspore {
22 namespace dataset {
23 //  Constructor.
WeightedRandomSamplerRT(const std::vector<double> & weights,int64_t num_samples,bool replacement,int64_t samples_per_tensor)24 WeightedRandomSamplerRT::WeightedRandomSamplerRT(const std::vector<double> &weights, int64_t num_samples,
25                                                  bool replacement, int64_t samples_per_tensor)
26     : SamplerRT(num_samples, samples_per_tensor), weights_(weights), replacement_(replacement), sample_id_(0) {}
27 
28 // Initialized this Sampler.
InitSampler()29 Status WeightedRandomSamplerRT::InitSampler() {
30   if (is_initialized) {
31     return Status::OK();
32   }
33   // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
34   // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
35   if (num_samples_ == 0 || num_samples_ > num_rows_) {
36     num_samples_ = num_rows_;
37   }
38   CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_,
39                                "[Internal ERROR] num_samples and num_rows must be greater than 0, but got num_rows: " +
40                                  std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_));
41   CHECK_FAIL_RETURN_UNEXPECTED(samples_per_tensor_ > 0,
42                                "Invalid parameter, samples_per_tensor(num_samples) must be greater than 0, but got " +
43                                  std::to_string(samples_per_tensor_) + ".\n");
44 
45   if (weights_.size() > static_cast<size_t>(num_rows_)) {
46     RETURN_STATUS_UNEXPECTED(
47       "Invalid parameter, size of sample weights must be less than or equal to num of data, "
48       "otherwise might cause generated id out of bound or other errors, but got weight size: " +
49       std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
50   }
51   if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
52     RETURN_STATUS_UNEXPECTED(
53       "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
54       "but got weight size: " +
55       std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
56   }
57 
58   // Initialize random generator with seed from config manager
59   rand_gen_.seed(GetSeed());
60 
61   samples_per_tensor_ = (samples_per_tensor_ > num_samples_) ? num_samples_ : samples_per_tensor_;
62 
63   if (!replacement_) {
64     exp_dist_ = std::make_unique<std::exponential_distribution<>>(1);
65     InitOnePassSampling();
66   } else {
67     discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end());
68   }
69 
70   is_initialized = true;
71   return Status::OK();
72 }
73 
74 // Initialized the computation for generating weighted random numbers without replacement using onepass method.
InitOnePassSampling()75 void WeightedRandomSamplerRT::InitOnePassSampling() {
76   exp_dist_->reset();
77   onepass_ids_.clear();
78   std::vector<std::pair<double, int64_t>> val_idx;
79   for (size_t i = 0; i < weights_.size(); i++) {
80     val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i));
81   }
82 
83   // Partial sort the first `numSamples` elements.
84   std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end());
85   for (int64_t i = 0; i < num_samples_; i++) {
86     onepass_ids_.push_back(val_idx[i].second);
87   }
88 }
89 
90 // Reset the internal variable to the initial state and reshuffle the indices.
ResetSampler(const bool failover_reset)91 Status WeightedRandomSamplerRT::ResetSampler(const bool failover_reset) {
92   sample_id_ = 0;
93   rand_gen_.seed(GetSeed());
94   if (!replacement_) {
95     InitOnePassSampling();
96   } else {
97     discrete_dist_->reset();
98   }
99 
100   if (HasChildSampler()) {
101     RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
102   }
103 
104   return Status::OK();
105 }
106 
107 // Get the sample ids.
GetNextSample(TensorRow * out)108 Status WeightedRandomSamplerRT::GetNextSample(TensorRow *out) {
109   RETURN_UNEXPECTED_IF_NULL(out);
110   if (weights_.size() > static_cast<size_t>(num_rows_)) {
111     RETURN_STATUS_UNEXPECTED(
112       "Invalid parameter, size of sample weights must be less than or equal to num of data, "
113       "otherwise might cause generated id out of bound or other errors, but got weight size: " +
114       std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
115   }
116 
117   if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
118     RETURN_STATUS_UNEXPECTED(
119       "Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
120       "but got weight size: " +
121       std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
122   }
123 
124   if (sample_id_ == num_samples_) {
125     (*out) = TensorRow(TensorRow::kFlagEOE);
126   } else {
127     if (HasChildSampler()) {
128       RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
129     }
130 
131     std::shared_ptr<Tensor> outputIds;
132 
133     int64_t last_id = sample_id_ + samples_per_tensor_;
134     // Handling the return all samples at once, and when last draw is not a full batch.
135     if (last_id > num_samples_) {
136       last_id = num_samples_;
137     }
138 
139     // Allocate tensor.
140     RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_));
141 
142     // Initialize tensor.
143     auto id_ptr = outputIds->begin<int64_t>();
144     // Assign the data to tensor element.
145     while (sample_id_ < last_id) {
146       int64_t genId;
147       if (replacement_) {
148         genId = (*discrete_dist_)(rand_gen_);
149       } else {
150         // Draw sample without replacement.
151         genId = onepass_ids_.front();
152         onepass_ids_.pop_front();
153       }
154 
155       if (genId >= num_rows_) {
156         RETURN_STATUS_UNEXPECTED(
157           "[Internal ERROR] Generated indice is out of bound, expect range [0, num_data-1], got indice: " +
158           std::to_string(genId) + ", num_data: " + std::to_string(num_rows_ - 1));
159       }
160 
161       if (HasChildSampler()) {
162         RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId));
163       }
164 
165       *id_ptr = genId;
166       ++id_ptr;
167       sample_id_++;
168     }
169 
170     (*out) = {outputIds};
171   }
172 
173   return Status::OK();
174 }
175 
SamplerPrint(std::ostream & out,bool show_all) const176 void WeightedRandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
177   out << "\nSampler: WeightedRandomSampler";
178   if (show_all) {
179     // Call the super class for displaying any common detailed info
180     SamplerRT::SamplerPrint(out, show_all);
181     // Then add our own info if any
182   }
183 }
184 
to_json(nlohmann::json * out_json)185 Status WeightedRandomSamplerRT::to_json(nlohmann::json *out_json) {
186   RETURN_UNEXPECTED_IF_NULL(out_json);
187   nlohmann::json args;
188   RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
189   args["sampler_name"] = "WeightedRandomSampler";
190   args["weights"] = weights_;
191   args["replacement"] = replacement_;
192   *out_json = args;
193   return Status::OK();
194 }
195 }  // namespace dataset
196 }  // namespace mindspore
197