1 /*
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_sample.h"
18
19 namespace mindspore {
20 namespace mindrecord {
ShardSample(int64_t n)21 ShardSample::ShardSample(int64_t n)
22 : numerator_(0),
23 denominator_(0),
24 partition_id_(0),
25 no_of_samples_(n),
26 indices_({}),
27 sampler_type_(kCustomTopNSampler),
28 offset_(-1) {}
29
ShardSample(int64_t num,int64_t den)30 ShardSample::ShardSample(int64_t num, int64_t den)
31 : numerator_(num),
32 denominator_(den),
33 partition_id_(0),
34 no_of_samples_(0),
35 indices_({}),
36 sampler_type_(kCustomTopPercentSampler),
37 offset_(-1) {}
38
ShardSample(int64_t num,int64_t den,int64_t par,int64_t no_of_samples,int64_t offset)39 ShardSample::ShardSample(int64_t num, int64_t den, int64_t par, int64_t no_of_samples, int64_t offset)
40 : numerator_(num),
41 denominator_(den),
42 partition_id_(par),
43 no_of_samples_(no_of_samples),
44 indices_({}),
45 sampler_type_(kCustomTopPercentSampler),
46 offset_(offset) {}
47
ShardSample(const std::vector<int64_t> & indices)48 ShardSample::ShardSample(const std::vector<int64_t> &indices)
49 : numerator_(0),
50 denominator_(0),
51 partition_id_(0),
52 no_of_samples_(0),
53 indices_(indices),
54 sampler_type_(kSubsetSampler) {}
55
ShardSample(const std::vector<int64_t> & indices,uint32_t seed)56 ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed) : ShardSample(indices) {
57 sampler_type_ = kSubsetRandomSampler;
58 shuffle_op_ = std::make_shared<ShardShuffle>(seed);
59 }
60
GetNumSamples(int64_t dataset_size,int64_t num_classes)61 int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
62 if (sampler_type_ == kCustomTopNSampler) {
63 return no_of_samples_;
64 }
65
66 if (sampler_type_ == kCustomTopPercentSampler) {
67 if (dataset_size % denominator_ == 0) {
68 return dataset_size / denominator_ * numerator_;
69 } else {
70 return dataset_size / denominator_ * numerator_ + 1;
71 }
72 }
73 if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
74 return indices_.size();
75 }
76 return 0;
77 }
78
UpdateTasks(ShardTaskList & tasks,int64_t taking)79 Status ShardSample::UpdateTasks(ShardTaskList &tasks, int64_t taking) {
80 if (tasks.permutation_.empty()) {
81 ShardTaskList new_tasks;
82 auto total_no = tasks.sample_ids_.size();
83 CHECK_FAIL_RETURN_UNEXPECTED_MR(
84 total_no > 0, "[Internal ERROR] 'total_no' should be positive but got: " + std::to_string(total_no));
85 if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
86 for (int64_t i = 0; i < indices_.size(); ++i) {
87 int64_t index = ((indices_[i] % total_no) + total_no) % total_no;
88 new_tasks.AssignTask(tasks, index); // different mod result between c and python
89 }
90 } else {
91 int64_t count = 0;
92 if (nums_per_shard_.empty()) {
93 for (int64_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
94 if (no_of_samples_ != 0 && count == no_of_samples_) {
95 break;
96 }
97 new_tasks.AssignTask(tasks, i % total_no); // rounding up. if overflow, go back to start
98 count++;
99 }
100 } else {
101 // Get samples within a specific range
102 int64_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0;
103 for (; i < nums_per_shard_[partition_id_]; i++) {
104 if (no_of_samples_ != 0 && count == no_of_samples_) {
105 break;
106 }
107 new_tasks.AssignTask(tasks, i % total_no);
108 count++;
109 }
110 }
111 }
112 ShardTaskList::TaskListSwap(tasks, new_tasks);
113 } else {
114 ShardTaskList new_tasks;
115 int64_t total_no = tasks.permutation_.size();
116 CHECK_FAIL_RETURN_UNEXPECTED_MR(
117 total_no > 0, "[Internal ERROR] 'total_no' should be positive but got: " + std::to_string(total_no));
118 int64_t cnt = 0;
119 for (int64_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
120 if (no_of_samples_ != 0 && cnt == no_of_samples_) {
121 break;
122 }
123 new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
124 cnt++;
125 }
126 ShardTaskList::TaskListSwap(tasks, new_tasks);
127 }
128 return Status::OK();
129 }
130
UpdatePartitionWhenSlowMode(ShardTaskList & tasks)131 Status ShardSample::UpdatePartitionWhenSlowMode(ShardTaskList &tasks) {
132 // distribtued sample when load mode is slow load
133 // split shard sample
134 // 0 : 17 - shard0 has 17 samples - pre shard 2
135 // 1 : 32 - shard1 has 15 samples - pre shard 0
136 // 2 : 58 - shard2 has 26 samples - pre shard 1
137 // padded_sample = 6
138 // Assuming this is an 8-card training
139 // card 0 : kCommonTask, 0, 0, 8
140 // card 1 : kCommonTask, 0, 8, 16
141 // card 2 : kCommonTask, 0, 16, 17
142 // card 2 : kCommonTask, 1, 17, 24
143 // card 3 : kCommonTask, 1, 24, 32
144 // card 4 : kCommonTask, 2, 32, 40
145 // card 5 : kCommonTask, 2, 40, 48
146 // card 6 : kCommonTask, 2, 48, 56
147 // card 7 : kCommonTask, 2, 56, 58
148 // card 7 : kPaddedTask, -1, 58, 64
149 auto tasks_shard_sample_count = tasks.shuffled_shard_sample_count_;
150 int64_t total_sample = tasks_shard_sample_count[tasks_shard_sample_count.size() - 1] + tasks.padded_sample_;
151 int64_t step = total_sample % denominator_ == 0 ? total_sample / denominator_ : total_sample / denominator_ + 1;
152 int64_t start = partition_id_ * step;
153 int64_t end = (partition_id_ + 1) * step;
154 std::vector<PartitionedShardSampleCount> vpssc;
155 int64_t tmp_start = start;
156 int64_t tmp_end = end;
157 for (int32_t shard_index = 0; shard_index < tasks_shard_sample_count.size(); shard_index++) {
158 if (tmp_start >= tasks_shard_sample_count[shard_index]) {
159 continue;
160 }
161
162 if (tmp_end <= tasks_shard_sample_count[shard_index]) {
163 tmp_end = end;
164 // add new range to vp
165 PartitionedShardSampleCount pssc;
166 pssc.task_type = TaskType::kCommonTask;
167 pssc.shard_id = shard_index;
168 pssc.start = tmp_start;
169 pssc.end = tmp_end;
170 vpssc.push_back(pssc);
171 break;
172 } else {
173 PartitionedShardSampleCount pssc;
174 pssc.task_type = TaskType::kCommonTask;
175 pssc.shard_id = shard_index;
176 pssc.start = tmp_start;
177 pssc.end = tasks_shard_sample_count[shard_index];
178 vpssc.push_back(pssc);
179 tmp_start = tasks_shard_sample_count[shard_index];
180 }
181 }
182
183 // retrieve from the start or padded sample
184 if (end > tasks_shard_sample_count[tasks_shard_sample_count.size() - 1]) {
185 // padded scenario
186 if (tasks.padded_sample_ > 0) {
187 if (end - tasks_shard_sample_count[tasks_shard_sample_count.size() - 1] <= tasks.padded_sample_) {
188 PartitionedShardSampleCount pssc;
189 pssc.task_type = TaskType::kPaddedTask;
190 pssc.shard_id = -1;
191 pssc.start = tmp_start;
192 pssc.end = end;
193 vpssc.push_back(pssc);
194 } else {
195 RETURN_STATUS_UNEXPECTED_MR(
196 "It's padded sample scenario, but the total sample: " + std::to_string(total_sample) +
197 " which is not divisible by " + std::to_string(denominator_));
198 }
199 } else {
200 tmp_start = 0;
201 end = end - tasks_shard_sample_count[tasks_shard_sample_count.size() - 1];
202 tmp_end = end;
203 for (int32_t shard_index = 0; shard_index < tasks_shard_sample_count.size(); shard_index++) {
204 if (tmp_start >= tasks_shard_sample_count[shard_index]) {
205 continue;
206 }
207
208 if (tmp_end <= tasks_shard_sample_count[shard_index]) {
209 tmp_end = end;
210 // add new range to vp
211 PartitionedShardSampleCount pssc;
212 pssc.task_type = TaskType::kCommonTask;
213 pssc.shard_id = shard_index;
214 pssc.start = tmp_start;
215 pssc.end = tmp_end;
216 vpssc.push_back(pssc);
217 break;
218 } else {
219 PartitionedShardSampleCount pssc;
220 pssc.task_type = TaskType::kCommonTask;
221 pssc.shard_id = shard_index;
222 pssc.start = tmp_start;
223 pssc.end = tasks_shard_sample_count[shard_index];
224 vpssc.push_back(pssc);
225 tmp_start = tasks_shard_sample_count[shard_index];
226 }
227 }
228 }
229 }
230
231 tasks.SetPartitionedShardSampleCount(vpssc);
232
233 // update vpssc by no_of_samples_
234 if (no_of_samples_ != 0) {
235 tasks.UpdatePartitionedShardSampleCountByNumSamples(no_of_samples_);
236 }
237
238 return Status::OK();
239 }
240
Execute(ShardTaskList & tasks)241 Status ShardSample::Execute(ShardTaskList &tasks) {
242 if (tasks.load_mode_ != LoadMode::kSlow) {
243 if (offset_ != -1) {
244 int64_t old_v = 0;
245 int64_t num_rows_ = tasks.sample_ids_.size();
246 for (int64_t x = 0; x < denominator_; x++) {
247 int64_t samples_per_buffer_ = (num_rows_ + offset_) / denominator_;
248 int64_t remainder = (num_rows_ + offset_) % denominator_;
249 if (x < remainder) {
250 samples_per_buffer_++;
251 }
252 if (x < offset_) {
253 samples_per_buffer_--;
254 }
255 old_v += samples_per_buffer_;
256 // nums_per_shard_ is used to save the current shard's ending index
257 nums_per_shard_.push_back(old_v);
258 }
259 }
260 int no_of_categories = static_cast<int>(tasks.categories);
261 int64_t total_no = tasks.sample_ids_.size();
262 int64_t taking = 0;
263 if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1
264 no_of_samples_ = std::min(no_of_samples_, total_no);
265 taking = no_of_samples_ - no_of_samples_ % no_of_categories;
266 } else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
267 CHECK_FAIL_RETURN_UNEXPECTED_MR(static_cast<int64_t>(indices_.size()) <= total_no,
268 "Invalid input, indices size: " + std::to_string(indices_.size()) +
269 " should be less than or equal to database size: " + std::to_string(total_no) +
270 ".");
271 } else { // constructor TopPercent
272 if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
273 if (numerator_ == 1 && denominator_ > 1) { // sharding
274 taking = (total_no + denominator_ - 1) / denominator_;
275 } else { // non sharding
276 taking = total_no * numerator_ / denominator_;
277 taking -= (taking % no_of_categories);
278 }
279 } else {
280 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] 'numerator_': " + std::to_string(numerator_) +
281 " should be positive and less than denominator_: " + std::to_string(denominator_) +
282 ".");
283 }
284 }
285 return UpdateTasks(tasks, taking);
286 }
287
288 return UpdatePartitionWhenSlowMode(tasks);
289 }
290
SufExecute(ShardTaskList & tasks)291 Status ShardSample::SufExecute(ShardTaskList &tasks) {
292 if (sampler_type_ == kSubsetRandomSampler) {
293 RETURN_IF_NOT_OK_MR((*shuffle_op_)(tasks));
294 }
295 return Status::OK();
296 }
297 } // namespace mindrecord
298 } // namespace mindspore
299