• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
17 
18 #include <limits>
19 #include <memory>
20 
21 #include "minddata/dataset/util/random.h"
22 
23 namespace mindspore {
24 namespace dataset {
DistributedSamplerRT(int64_t num_shards,int64_t shard_id,bool shuffle,int64_t num_samples,uint32_t seed,int64_t offset,bool even_dist)25 DistributedSamplerRT::DistributedSamplerRT(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
26                                            uint32_t seed, int64_t offset, bool even_dist)
27     : SamplerRT(num_samples, std::numeric_limits<int64_t>::max()),
28       cnt_(0),
29       seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
30       device_id_(shard_id),
31       num_devices_(num_shards),
32       shuffle_(shuffle),
33       even_dist_(even_dist),
34       offset_(offset),
35       non_empty_(true) {
36   // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
37   // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
38   // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
39   // PreBuildSampler is phased out, this can be cleaned up.
40   GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_devices_);
41 }
42 
InitSampler()43 Status DistributedSamplerRT::InitSampler() {
44   if (is_initialized) {
45     return Status::OK();
46   }
47   // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
48   // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
49   if (num_samples_ == 0 || num_samples_ > num_rows_) {
50     num_samples_ = num_rows_;
51   }
52   CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "Invalid parameter, num_samples must be greater than 0, but got " +
53                                                    std::to_string(num_samples_) + ".\n");
54   CHECK_FAIL_RETURN_UNEXPECTED(
55     num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n");
56   CHECK_FAIL_RETURN_UNEXPECTED(
57     device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
58     "Invalid parameter, num_shard must be greater than shard_id and greater than 0, got num_shard: " +
59       std::to_string(num_devices_) + ", shard_id: " + std::to_string(device_id_) + ".\n");
60   rnd_.seed(seed_++);
61 
62   if (offset_ != -1 || !even_dist_) {
63     if (offset_ == -1) {
64       offset_ = 0;
65     }
66     samples_per_tensor_ = (num_rows_ + offset_) / num_devices_;
67     int64_t remainder = (num_rows_ + offset_) % num_devices_;
68     if (device_id_ < remainder) {
69       samples_per_tensor_++;
70     }
71     if (device_id_ < offset_) {
72       samples_per_tensor_--;
73     }
74   } else {
75     offset_ = 0;
76     samples_per_tensor_ = (num_rows_ + num_devices_ - 1) / num_devices_;  // equals to ceil(num_rows/num_devices)
77   }
78   samples_per_tensor_ = num_samples_ < samples_per_tensor_ ? num_samples_ : samples_per_tensor_;
79   if (shuffle_) {
80     shuffle_vec_.reserve(num_rows_);
81     for (int64_t i = 0; i < num_rows_; i++) {
82       shuffle_vec_.push_back(i);
83     }
84     std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
85   }
86   if (!samples_per_tensor_) {
87     non_empty_ = false;
88   }
89 
90   is_initialized = true;
91   return Status::OK();
92 }
93 
GetNextSample(TensorRow * out)94 Status DistributedSamplerRT::GetNextSample(TensorRow *out) {
95   if (cnt_ > samples_per_tensor_) {
96     RETURN_STATUS_UNEXPECTED(
97       "Sampler index must be less than or equal to num_samples(total rows in dataset), but got:" +
98       std::to_string(cnt_) + ", samples_per_tensor(num_samples): " + std::to_string(samples_per_tensor_));
99   } else if (cnt_ == samples_per_tensor_ && (non_empty_ || !even_dist_)) {
100     (*out) = TensorRow(TensorRow::kFlagEOE);
101     if (!samples_per_tensor_) {
102       non_empty_ = false;
103     }
104   } else if (!samples_per_tensor_ && !non_empty_) {
105     // If the Tensor is empty, we add samples with subscript 0 in the current dataset.
106     // This step is to make up for the solution that the code default Tensor is not empty before.
107     // We will remove this value in the concat phase
108     non_empty_ = true;
109     std::shared_ptr<Tensor> sample_ids;
110     RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, 1));
111     auto id_ptr = sample_ids->begin<int64_t>();
112     // add index 0
113     *id_ptr = 0;
114     (*out) = {sample_ids};
115   } else {
116     if (HasChildSampler()) {
117       RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
118     }
119 
120     std::shared_ptr<Tensor> sample_ids;
121     RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_tensor_));
122     auto id_ptr = sample_ids->begin<int64_t>();
123     bool flag_add_1 = false;
124     while (cnt_ < samples_per_tensor_ && id_ptr != sample_ids->end<int64_t>()) {
125       int64_t middle_value = num_devices_ * cnt_ + device_id_ - offset_;
126       // if index < 0, we move back one place
127       if (middle_value < 0) {
128         samples_per_tensor_++;
129         cnt_++;
130         flag_add_1 = true;
131         middle_value = num_devices_ * cnt_ + device_id_ - offset_;
132       }
133       int64_t sampled_id = middle_value % num_rows_;
134 
135       if (shuffle_) {
136         sampled_id = shuffle_vec_[static_cast<size_t>(sampled_id)];
137       }
138 
139       if (HasChildSampler()) {
140         RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
141       }
142 
143       *id_ptr = sampled_id;
144       ++id_ptr;
145       cnt_++;
146     }
147 
148     // If 1 was added before, we will cut off 1 here
149     if (flag_add_1) {
150       samples_per_tensor_--;
151       cnt_--;
152     }
153     (*out) = {sample_ids};
154   }
155   return Status::OK();
156 }
157 
ResetSampler()158 Status DistributedSamplerRT::ResetSampler() {
159   CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_tensor_, "[Internal ERROR] Reset() Sampler called early or late.");
160   cnt_ = 0;
161 
162   if (shuffle_ == true) {
163     rnd_.seed(seed_);
164     seed_++;
165     std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
166   }
167 
168   if (HasChildSampler()) {
169     RETURN_IF_NOT_OK(child_[0]->ResetSampler());
170   }
171 
172   return Status::OK();
173 }
174 
CalculateNumSamples(int64_t num_rows)175 int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
176   int64_t child_num_rows = num_rows;
177   if (!child_.empty()) {
178     child_num_rows = child_[0]->CalculateNumSamples(num_rows);
179   }
180   int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
181   int64_t remainder = (child_num_rows + offset_) % num_devices_;
182   int64_t shard_size = (child_num_rows + offset_) / num_devices_;
183   if (offset_ != -1 || !even_dist_) {
184     if (offset_ == -1) {
185       offset_ = 0;
186     }
187     if (device_id_ < remainder) {
188       shard_size++;
189     }
190     if (device_id_ < offset_) {
191       shard_size--;
192     }
193   } else {
194     shard_size = (child_num_rows + num_devices_ - 1) / num_devices_;
195   }
196   // add 1 to an empty shard
197   // this logic is needed to follow the logic in initSampler that is written for ConcatDataset
198   if (shard_size == 0) {
199     shard_size++;
200   }
201 
202   return std::min(num_samples, shard_size);
203 }
204 
SamplerPrint(std::ostream & out,bool show_all) const205 void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
206   out << "\nSampler: DistributedSampler";
207   if (show_all) {
208     SamplerRT::SamplerPrint(out, show_all);
209     out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_
210         << "\nshuffle: " << shuffle_;
211   }
212 }
213 
to_json(nlohmann::json * out_json)214 Status DistributedSamplerRT::to_json(nlohmann::json *out_json) {
215   nlohmann::json args;
216   RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
217   args["sampler_name"] = "DistributedSampler";
218   args["num_shards"] = num_devices_;
219   args["shard_id"] = device_id_;
220   args["shuffle"] = shuffle_;
221   args["offset"] = offset_;
222   *out_json = args;
223   return Status::OK();
224 }
225 
226 }  // namespace dataset
227 }  // namespace mindspore
228