• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_sample.h"
18 
19 namespace mindspore {
20 namespace mindrecord {
ShardSample(int64_t n)21 ShardSample::ShardSample(int64_t n)
22     : numerator_(0),
23       denominator_(0),
24       partition_id_(0),
25       no_of_samples_(n),
26       indices_({}),
27       sampler_type_(kCustomTopNSampler),
28       offset_(-1) {}
29 
ShardSample(int64_t num,int64_t den)30 ShardSample::ShardSample(int64_t num, int64_t den)
31     : numerator_(num),
32       denominator_(den),
33       partition_id_(0),
34       no_of_samples_(0),
35       indices_({}),
36       sampler_type_(kCustomTopPercentSampler),
37       offset_(-1) {}
38 
ShardSample(int64_t num,int64_t den,int64_t par,int64_t no_of_samples,int64_t offset)39 ShardSample::ShardSample(int64_t num, int64_t den, int64_t par, int64_t no_of_samples, int64_t offset)
40     : numerator_(num),
41       denominator_(den),
42       partition_id_(par),
43       no_of_samples_(no_of_samples),
44       indices_({}),
45       sampler_type_(kCustomTopPercentSampler),
46       offset_(offset) {}
47 
ShardSample(const std::vector<int64_t> & indices)48 ShardSample::ShardSample(const std::vector<int64_t> &indices)
49     : numerator_(0),
50       denominator_(0),
51       partition_id_(0),
52       no_of_samples_(0),
53       indices_(indices),
54       sampler_type_(kSubsetSampler) {}
55 
ShardSample(const std::vector<int64_t> & indices,uint32_t seed)56 ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed) : ShardSample(indices) {
57   sampler_type_ = kSubsetRandomSampler;
58   shuffle_op_ = std::make_shared<ShardShuffle>(seed);
59 }
60 
GetNumSamples(int64_t dataset_size,int64_t num_classes)61 int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
62   if (sampler_type_ == kCustomTopNSampler) {
63     return no_of_samples_;
64   }
65 
66   if (sampler_type_ == kCustomTopPercentSampler) {
67     if (dataset_size % denominator_ == 0) {
68       return dataset_size / denominator_ * numerator_;
69     } else {
70       return dataset_size / denominator_ * numerator_ + 1;
71     }
72   }
73   if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
74     return indices_.size();
75   }
76   return 0;
77 }
78 
UpdateTasks(ShardTaskList & tasks,int64_t taking)79 Status ShardSample::UpdateTasks(ShardTaskList &tasks, int64_t taking) {
80   if (tasks.permutation_.empty()) {
81     ShardTaskList new_tasks;
82     auto total_no = tasks.sample_ids_.size();
83     CHECK_FAIL_RETURN_UNEXPECTED_MR(
84       total_no > 0, "[Internal ERROR] 'total_no' should be positive but got: " + std::to_string(total_no));
85     if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
86       for (int64_t i = 0; i < indices_.size(); ++i) {
87         int64_t index = ((indices_[i] % total_no) + total_no) % total_no;
88         new_tasks.AssignTask(tasks, index);  // different mod result between c and python
89       }
90     } else {
91       int64_t count = 0;
92       if (nums_per_shard_.empty()) {
93         for (int64_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
94           if (no_of_samples_ != 0 && count == no_of_samples_) {
95             break;
96           }
97           new_tasks.AssignTask(tasks, i % total_no);  // rounding up. if overflow, go back to start
98           count++;
99         }
100       } else {
101         // Get samples within a specific range
102         int64_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0;
103         for (; i < nums_per_shard_[partition_id_]; i++) {
104           if (no_of_samples_ != 0 && count == no_of_samples_) {
105             break;
106           }
107           new_tasks.AssignTask(tasks, i % total_no);
108           count++;
109         }
110       }
111     }
112     ShardTaskList::TaskListSwap(tasks, new_tasks);
113   } else {
114     ShardTaskList new_tasks;
115     int64_t total_no = tasks.permutation_.size();
116     CHECK_FAIL_RETURN_UNEXPECTED_MR(
117       total_no > 0, "[Internal ERROR] 'total_no' should be positive but got: " + std::to_string(total_no));
118     int64_t cnt = 0;
119     for (int64_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
120       if (no_of_samples_ != 0 && cnt == no_of_samples_) {
121         break;
122       }
123       new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
124       cnt++;
125     }
126     ShardTaskList::TaskListSwap(tasks, new_tasks);
127   }
128   return Status::OK();
129 }
130 
UpdatePartitionWhenSlowMode(ShardTaskList & tasks)131 Status ShardSample::UpdatePartitionWhenSlowMode(ShardTaskList &tasks) {
132   // distribtued sample when load mode is slow load
133   // split shard sample
134   // 0 : 17  -  shard0 has 17 samples - pre shard 2
135   // 1 : 32  -  shard1 has 15 samples - pre shard 0
136   // 2 : 58  -  shard2 has 26 samples - pre shard 1
137   // padded_sample = 6
138   // Assuming this is an 8-card training
139   // card 0 : kCommonTask, 0, 0, 8
140   // card 1 : kCommonTask, 0, 8, 16
141   // card 2 : kCommonTask, 0, 16, 17
142   // card 2 : kCommonTask, 1, 17, 24
143   // card 3 : kCommonTask, 1, 24, 32
144   // card 4 : kCommonTask, 2, 32, 40
145   // card 5 : kCommonTask, 2, 40, 48
146   // card 6 : kCommonTask, 2, 48, 56
147   // card 7 : kCommonTask, 2, 56, 58
148   // card 7 : kPaddedTask, -1, 58, 64
149   auto tasks_shard_sample_count = tasks.shuffled_shard_sample_count_;
150   int64_t total_sample = tasks_shard_sample_count[tasks_shard_sample_count.size() - 1] + tasks.padded_sample_;
151   int64_t step = total_sample % denominator_ == 0 ? total_sample / denominator_ : total_sample / denominator_ + 1;
152   int64_t start = partition_id_ * step;
153   int64_t end = (partition_id_ + 1) * step;
154   std::vector<PartitionedShardSampleCount> vpssc;
155   int64_t tmp_start = start;
156   int64_t tmp_end = end;
157   for (int32_t shard_index = 0; shard_index < tasks_shard_sample_count.size(); shard_index++) {
158     if (tmp_start >= tasks_shard_sample_count[shard_index]) {
159       continue;
160     }
161 
162     if (tmp_end <= tasks_shard_sample_count[shard_index]) {
163       tmp_end = end;
164       // add new range to vp
165       PartitionedShardSampleCount pssc;
166       pssc.task_type = TaskType::kCommonTask;
167       pssc.shard_id = shard_index;
168       pssc.start = tmp_start;
169       pssc.end = tmp_end;
170       vpssc.push_back(pssc);
171       break;
172     } else {
173       PartitionedShardSampleCount pssc;
174       pssc.task_type = TaskType::kCommonTask;
175       pssc.shard_id = shard_index;
176       pssc.start = tmp_start;
177       pssc.end = tasks_shard_sample_count[shard_index];
178       vpssc.push_back(pssc);
179       tmp_start = tasks_shard_sample_count[shard_index];
180     }
181   }
182 
183   // retrieve from the start or padded sample
184   if (end > tasks_shard_sample_count[tasks_shard_sample_count.size() - 1]) {
185     // padded scenario
186     if (tasks.padded_sample_ > 0) {
187       if (end - tasks_shard_sample_count[tasks_shard_sample_count.size() - 1] <= tasks.padded_sample_) {
188         PartitionedShardSampleCount pssc;
189         pssc.task_type = TaskType::kPaddedTask;
190         pssc.shard_id = -1;
191         pssc.start = tmp_start;
192         pssc.end = end;
193         vpssc.push_back(pssc);
194       } else {
195         RETURN_STATUS_UNEXPECTED_MR(
196           "It's padded sample scenario, but the total sample: " + std::to_string(total_sample) +
197           " which is not divisible by " + std::to_string(denominator_));
198       }
199     } else {
200       tmp_start = 0;
201       end = end - tasks_shard_sample_count[tasks_shard_sample_count.size() - 1];
202       tmp_end = end;
203       for (int32_t shard_index = 0; shard_index < tasks_shard_sample_count.size(); shard_index++) {
204         if (tmp_start >= tasks_shard_sample_count[shard_index]) {
205           continue;
206         }
207 
208         if (tmp_end <= tasks_shard_sample_count[shard_index]) {
209           tmp_end = end;
210           // add new range to vp
211           PartitionedShardSampleCount pssc;
212           pssc.task_type = TaskType::kCommonTask;
213           pssc.shard_id = shard_index;
214           pssc.start = tmp_start;
215           pssc.end = tmp_end;
216           vpssc.push_back(pssc);
217           break;
218         } else {
219           PartitionedShardSampleCount pssc;
220           pssc.task_type = TaskType::kCommonTask;
221           pssc.shard_id = shard_index;
222           pssc.start = tmp_start;
223           pssc.end = tasks_shard_sample_count[shard_index];
224           vpssc.push_back(pssc);
225           tmp_start = tasks_shard_sample_count[shard_index];
226         }
227       }
228     }
229   }
230 
231   tasks.SetPartitionedShardSampleCount(vpssc);
232 
233   // update vpssc by no_of_samples_
234   if (no_of_samples_ != 0) {
235     tasks.UpdatePartitionedShardSampleCountByNumSamples(no_of_samples_);
236   }
237 
238   return Status::OK();
239 }
240 
Execute(ShardTaskList & tasks)241 Status ShardSample::Execute(ShardTaskList &tasks) {
242   if (tasks.load_mode_ != LoadMode::kSlow) {
243     if (offset_ != -1) {
244       int64_t old_v = 0;
245       int64_t num_rows_ = tasks.sample_ids_.size();
246       for (int64_t x = 0; x < denominator_; x++) {
247         int64_t samples_per_buffer_ = (num_rows_ + offset_) / denominator_;
248         int64_t remainder = (num_rows_ + offset_) % denominator_;
249         if (x < remainder) {
250           samples_per_buffer_++;
251         }
252         if (x < offset_) {
253           samples_per_buffer_--;
254         }
255         old_v += samples_per_buffer_;
256         // nums_per_shard_ is used to save the current shard's ending index
257         nums_per_shard_.push_back(old_v);
258       }
259     }
260     int no_of_categories = static_cast<int>(tasks.categories);
261     int64_t total_no = tasks.sample_ids_.size();
262     int64_t taking = 0;
263     if (sampler_type_ == kCustomTopNSampler) {  // non sharding case constructor #1
264       no_of_samples_ = std::min(no_of_samples_, total_no);
265       taking = no_of_samples_ - no_of_samples_ % no_of_categories;
266     } else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
267       CHECK_FAIL_RETURN_UNEXPECTED_MR(static_cast<int64_t>(indices_.size()) <= total_no,
268                                       "Invalid input, indices size: " + std::to_string(indices_.size()) +
269                                         " should be less than or equal to database size: " + std::to_string(total_no) +
270                                         ".");
271     } else {  // constructor TopPercent
272       if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
273         if (numerator_ == 1 && denominator_ > 1) {  // sharding
274           taking = (total_no + denominator_ - 1) / denominator_;
275         } else {  // non sharding
276           taking = total_no * numerator_ / denominator_;
277           taking -= (taking % no_of_categories);
278         }
279       } else {
280         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] 'numerator_': " + std::to_string(numerator_) +
281                                     " should be positive and less than denominator_: " + std::to_string(denominator_) +
282                                     ".");
283       }
284     }
285     return UpdateTasks(tasks, taking);
286   }
287 
288   return UpdatePartitionWhenSlowMode(tasks);
289 }
290 
SufExecute(ShardTaskList & tasks)291 Status ShardSample::SufExecute(ShardTaskList &tasks) {
292   if (sampler_type_ == kSubsetRandomSampler) {
293     RETURN_IF_NOT_OK_MR((*shuffle_op_)(tasks));
294   }
295   return Status::OK();
296 }
297 }  // namespace mindrecord
298 }  // namespace mindspore
299