1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
17
18 #include <limits>
19 #include <memory>
20
21 #include "minddata/dataset/util/random.h"
22
23 namespace mindspore {
24 namespace dataset {
DistributedSamplerRT(int64_t num_shards,int64_t shard_id,bool shuffle,int64_t num_samples,uint32_t seed,int64_t offset,bool even_dist)25 DistributedSamplerRT::DistributedSamplerRT(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
26 uint32_t seed, int64_t offset, bool even_dist)
27 : SamplerRT(num_samples, std::numeric_limits<int64_t>::max()),
28 cnt_(0),
29 seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
30 device_id_(shard_id),
31 num_devices_(num_shards),
32 shuffle_(shuffle),
33 even_dist_(even_dist),
34 offset_(offset),
35 non_empty_(true) {
36 // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
37 // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
38 // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
39 // PreBuildSampler is phased out, this can be cleaned up.
40 GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_devices_);
41 }
42
InitSampler()43 Status DistributedSamplerRT::InitSampler() {
44 if (is_initialized) {
45 return Status::OK();
46 }
47 // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
48 // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
49 if (num_samples_ == 0 || num_samples_ > num_rows_) {
50 num_samples_ = num_rows_;
51 }
52 CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "Invalid parameter, num_samples must be greater than 0, but got " +
53 std::to_string(num_samples_) + ".\n");
54 CHECK_FAIL_RETURN_UNEXPECTED(
55 num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n");
56 CHECK_FAIL_RETURN_UNEXPECTED(
57 device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
58 "Invalid parameter, num_shard must be greater than shard_id and greater than 0, got num_shard: " +
59 std::to_string(num_devices_) + ", shard_id: " + std::to_string(device_id_) + ".\n");
60 rnd_.seed(seed_++);
61
62 if (offset_ != -1 || !even_dist_) {
63 if (offset_ == -1) {
64 offset_ = 0;
65 }
66 samples_per_tensor_ = (num_rows_ + offset_) / num_devices_;
67 int64_t remainder = (num_rows_ + offset_) % num_devices_;
68 if (device_id_ < remainder) {
69 samples_per_tensor_++;
70 }
71 if (device_id_ < offset_) {
72 samples_per_tensor_--;
73 }
74 } else {
75 offset_ = 0;
76 samples_per_tensor_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
77 }
78 samples_per_tensor_ = num_samples_ < samples_per_tensor_ ? num_samples_ : samples_per_tensor_;
79 if (shuffle_) {
80 shuffle_vec_.reserve(num_rows_);
81 for (int64_t i = 0; i < num_rows_; i++) {
82 shuffle_vec_.push_back(i);
83 }
84 std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
85 }
86 if (!samples_per_tensor_) {
87 non_empty_ = false;
88 }
89
90 is_initialized = true;
91 return Status::OK();
92 }
93
GetNextSample(TensorRow * out)94 Status DistributedSamplerRT::GetNextSample(TensorRow *out) {
95 if (cnt_ > samples_per_tensor_) {
96 RETURN_STATUS_UNEXPECTED(
97 "Sampler index must be less than or equal to num_samples(total rows in dataset), but got:" +
98 std::to_string(cnt_) + ", samples_per_tensor(num_samples): " + std::to_string(samples_per_tensor_));
99 } else if (cnt_ == samples_per_tensor_ && (non_empty_ || !even_dist_)) {
100 (*out) = TensorRow(TensorRow::kFlagEOE);
101 if (!samples_per_tensor_) {
102 non_empty_ = false;
103 }
104 } else if (!samples_per_tensor_ && !non_empty_) {
105 // If the Tensor is empty, we add samples with subscript 0 in the current dataset.
106 // This step is to make up for the solution that the code default Tensor is not empty before.
107 // We will remove this value in the concat phase
108 non_empty_ = true;
109 std::shared_ptr<Tensor> sample_ids;
110 RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, 1));
111 auto id_ptr = sample_ids->begin<int64_t>();
112 // add index 0
113 *id_ptr = 0;
114 (*out) = {sample_ids};
115 } else {
116 if (HasChildSampler()) {
117 RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
118 }
119
120 std::shared_ptr<Tensor> sample_ids;
121 RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_tensor_));
122 auto id_ptr = sample_ids->begin<int64_t>();
123 bool flag_add_1 = false;
124 while (cnt_ < samples_per_tensor_ && id_ptr != sample_ids->end<int64_t>()) {
125 int64_t middle_value = num_devices_ * cnt_ + device_id_ - offset_;
126 // if index < 0, we move back one place
127 if (middle_value < 0) {
128 samples_per_tensor_++;
129 cnt_++;
130 flag_add_1 = true;
131 middle_value = num_devices_ * cnt_ + device_id_ - offset_;
132 }
133 int64_t sampled_id = middle_value % num_rows_;
134
135 if (shuffle_) {
136 sampled_id = shuffle_vec_[static_cast<size_t>(sampled_id)];
137 }
138
139 if (HasChildSampler()) {
140 RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
141 }
142
143 *id_ptr = sampled_id;
144 ++id_ptr;
145 cnt_++;
146 }
147
148 // If 1 was added before, we will cut off 1 here
149 if (flag_add_1) {
150 samples_per_tensor_--;
151 cnt_--;
152 }
153 (*out) = {sample_ids};
154 }
155 return Status::OK();
156 }
157
ResetSampler()158 Status DistributedSamplerRT::ResetSampler() {
159 CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_tensor_, "[Internal ERROR] Reset() Sampler called early or late.");
160 cnt_ = 0;
161
162 if (shuffle_ == true) {
163 rnd_.seed(seed_);
164 seed_++;
165 std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
166 }
167
168 if (HasChildSampler()) {
169 RETURN_IF_NOT_OK(child_[0]->ResetSampler());
170 }
171
172 return Status::OK();
173 }
174
CalculateNumSamples(int64_t num_rows)175 int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
176 int64_t child_num_rows = num_rows;
177 if (!child_.empty()) {
178 child_num_rows = child_[0]->CalculateNumSamples(num_rows);
179 }
180 int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
181 int64_t remainder = (child_num_rows + offset_) % num_devices_;
182 int64_t shard_size = (child_num_rows + offset_) / num_devices_;
183 if (offset_ != -1 || !even_dist_) {
184 if (offset_ == -1) {
185 offset_ = 0;
186 }
187 if (device_id_ < remainder) {
188 shard_size++;
189 }
190 if (device_id_ < offset_) {
191 shard_size--;
192 }
193 } else {
194 shard_size = (child_num_rows + num_devices_ - 1) / num_devices_;
195 }
196 // add 1 to an empty shard
197 // this logic is needed to follow the logic in initSampler that is written for ConcatDataset
198 if (shard_size == 0) {
199 shard_size++;
200 }
201
202 return std::min(num_samples, shard_size);
203 }
204
SamplerPrint(std::ostream & out,bool show_all) const205 void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
206 out << "\nSampler: DistributedSampler";
207 if (show_all) {
208 SamplerRT::SamplerPrint(out, show_all);
209 out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_
210 << "\nshuffle: " << shuffle_;
211 }
212 }
213
to_json(nlohmann::json * out_json)214 Status DistributedSamplerRT::to_json(nlohmann::json *out_json) {
215 nlohmann::json args;
216 RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
217 args["sampler_name"] = "DistributedSampler";
218 args["num_shards"] = num_devices_;
219 args["shard_id"] = device_id_;
220 args["shuffle"] = shuffle_;
221 args["offset"] = offset_;
222 *out_json = args;
223 return Status::OK();
224 }
225
226 } // namespace dataset
227 } // namespace mindspore
228