1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_task_list.h"
18 #include "minddata/mindrecord/include/common/shard_utils.h"
19
20 namespace mindspore {
21 namespace mindrecord {
22 // when mindrecord is slow load mode, the shuffle size is 1000,000
23 const int64_t ShuffleSize = 1000000;
24
GeneratorIds()25 GeneratorIds::GeneratorIds() : partitioned_shard_sample_count_(), partition_index_(0), partition_sample_index_(0) {}
26
SetShardSampleCount(const std::vector<PartitionedShardSampleCount> & partitioned_shard_sample_count)27 void GeneratorIds::SetShardSampleCount(const std::vector<PartitionedShardSampleCount> &partitioned_shard_sample_count) {
28 partitioned_shard_sample_count_ = partitioned_shard_sample_count;
29 partition_index_ = 0;
30 partition_sample_index_ = 0;
31 }
32
ResetShardIndexAndID()33 void GeneratorIds::ResetShardIndexAndID() {
34 partition_index_ = 0;
35 partition_sample_index_ = 0;
36 }
37
GetNextSampleIds(const bool & need_shuffle,const uint32_t & seed)38 std::vector<int64_t> GeneratorIds::GetNextSampleIds(const bool &need_shuffle, const uint32_t &seed) {
39 // partitioned_shard_sample_count_ is:
40 // CommonTask, 0, 16680, 17777
41 // CommonTask, 0, 0, 15
42 std::vector<int64_t> ids;
43 for (int32_t i = partition_index_; i < partitioned_shard_sample_count_.size(); i++) {
44 for (int64_t j = partitioned_shard_sample_count_[i].start + partition_sample_index_;
45 j < partitioned_shard_sample_count_[i].end; j++) {
46 ids.push_back(j);
47 partition_sample_index_++;
48 if (ids.size() >= ShuffleSize) {
49 if (need_shuffle) {
50 std::shuffle(ids.begin(), ids.end(), std::default_random_engine(seed));
51 }
52 return ids;
53 }
54 }
55 partition_index_++;
56 partition_sample_index_ = 0;
57 }
58 if (need_shuffle) {
59 std::shuffle(ids.begin(), ids.end(), std::default_random_engine(seed));
60 }
61 return ids;
62 }
63
ShardTaskList()64 ShardTaskList::ShardTaskList()
65 : categories(1), padded_sample_(0), need_shuffle_(false), shuffle_seed_(0), load_mode_(LoadMode::kFast) {}
66
ShardTaskList(const ShardTaskList & other)67 ShardTaskList::ShardTaskList(const ShardTaskList &other)
68 : categories(other.categories),
69 permutation_(other.permutation_),
70 sample_ids_(other.sample_ids_),
71 task_list_(other.task_list_),
72 sample_meta_list_(other.sample_meta_list_),
73 shard_sample_count_(other.shard_sample_count_),
74 padded_sample_(other.padded_sample_),
75 file_ids_(other.file_ids_),
76 shuffled_shard_sample_count_(other.shuffled_shard_sample_count_),
77 partitioned_shard_sample_count_(other.partitioned_shard_sample_count_),
78 need_shuffle_(other.need_shuffle_),
79 shuffle_seed_(other.shuffle_seed_),
80 generator_ids_(other.generator_ids_),
81 load_mode_(other.load_mode_) {}
82
operator =(const ShardTaskList & other)83 ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
84 ShardTaskList tmp(other);
85 std::swap(categories, tmp.categories);
86 permutation_.swap(tmp.permutation_);
87 sample_ids_.swap(tmp.sample_ids_);
88 task_list_.swap(tmp.task_list_);
89 sample_meta_list_.swap(tmp.sample_meta_list_);
90 shard_sample_count_.swap(tmp.shard_sample_count_);
91 padded_sample_ = tmp.padded_sample_;
92 file_ids_.swap(tmp.file_ids_);
93 shuffled_shard_sample_count_.swap(tmp.shuffled_shard_sample_count_);
94 partitioned_shard_sample_count_.swap(tmp.partitioned_shard_sample_count_);
95 need_shuffle_ = tmp.need_shuffle_;
96 shuffle_seed_ = tmp.shuffle_seed_;
97 generator_ids_ = tmp.generator_ids_;
98 load_mode_ = tmp.load_mode_;
99 return *this;
100 }
101
InitSampleIds()102 void ShardTaskList::InitSampleIds() {
103 // no-op if there already exists sample ids. Do not clobber previous list
104 if (sample_ids_.empty()) {
105 sample_ids_ = std::vector<int64_t>(task_list_.size());
106 for (auto i = 0; i < task_list_.size(); i++) {
107 sample_ids_[i] = i;
108 }
109 }
110 }
111
MakePerm()112 void ShardTaskList::MakePerm() {
113 int64_t perm_size = sample_ids_.size();
114 permutation_ = std::vector<int64_t>(perm_size);
115 for (int64_t i = 0; i < perm_size; i++) {
116 permutation_[i] = i;
117 }
118 }
119
120 // Swap the new_tasks with orig_tasks
TaskListSwap(ShardTaskList & orig_tasks,ShardTaskList & new_tasks)121 void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks) {
122 // When swapping, if the orig_tasks contains fields that need to be preserved after the swap, then swapping with a
123 // new_tasks that does not have those fields will result in clobbering/losing the data after the swap.
124 // The task_list_ should not be lost/clobbered.
125 // This function can be called in the middle of mindrecord's epoch, when orig_tasks.task_list_ is still being
126 // used by mindrecord op's worker threads. So don't touch its task_list_ since this field should be preserved anyways.
127
128 std::swap(orig_tasks.categories, new_tasks.categories);
129 std::swap(orig_tasks.permutation_, new_tasks.permutation_);
130 std::swap(orig_tasks.sample_ids_, new_tasks.sample_ids_);
131 }
132
PopBack()133 void ShardTaskList::PopBack() {
134 task_list_.pop_back();
135 if (load_mode_ == LoadMode::kFast) {
136 sample_meta_list_.pop_back();
137 }
138 }
139
Size() const140 int64_t ShardTaskList::Size() const {
141 if (load_mode_ != LoadMode::kSlow) {
142 return static_cast<int64_t>(task_list_.size());
143 }
144
145 // slow load mode
146 return shard_sample_count_[shard_sample_count_.size() - 1] + padded_sample_;
147 }
148
SizeAfterSampling() const149 int64_t ShardTaskList::SizeAfterSampling() const {
150 if (load_mode_ != LoadMode::kSlow) {
151 return static_cast<int64_t>(sample_ids_.size());
152 }
153
154 // slow load mode
155 int64_t count = 0;
156 for (int32_t i = 0; i < partitioned_shard_sample_count_.size(); i++) {
157 count += partitioned_shard_sample_count_[i].end - partitioned_shard_sample_count_[i].start;
158 }
159 return count;
160 }
161
SizeOfRows() const162 int64_t ShardTaskList::SizeOfRows() const {
163 int64_t size_of_rows = 0;
164 if (task_list_.size() == 0) {
165 return size_of_rows;
166 }
167
168 if (load_mode_ == LoadMode::kFast) {
169 // 1 task is 1 page,blob index start from 2
170 auto sum_num_rows = [](int64_t x, SampleMeta y) { return x + std::get<0>(y)[0]; };
171 size_of_rows = std::accumulate(sample_meta_list_.begin(), sample_meta_list_.end(), 0, sum_num_rows);
172 } else {
173 MS_LOG(WARNING) << "In lazy load mode, size of rows will be " << size_of_rows << " which is not correctly.";
174 }
175 return size_of_rows;
176 }
177
GetTaskByID(int64_t id)178 ShardTask ShardTaskList::GetTaskByID(int64_t id) {
179 if (load_mode_ == LoadMode::kFast) {
180 return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), std::get<0>(sample_meta_list_[id]),
181 std::get<1>(sample_meta_list_[id])};
182 } else if (load_mode_ == LoadMode::kLazy) {
183 return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), {}, json()};
184 }
185
186 TaskType task_type = TaskType::kCommonTask;
187 // get the partitioned shard id
188 int32_t shard_id = 0;
189 int32_t row_id = 0;
190 for (int32_t i = 0; i < partitioned_shard_sample_count_.size(); i++) {
191 if (id >= partitioned_shard_sample_count_[i].start && id < partitioned_shard_sample_count_[i].end) {
192 task_type = partitioned_shard_sample_count_[i].task_type;
193 shard_id = partitioned_shard_sample_count_[i].shard_id;
194 break;
195 }
196 }
197
198 if (shard_id == -1) {
199 return {TaskType::kPaddedTask, std::make_tuple(shard_id, row_id), {}, json()};
200 }
201
202 // get the original shard_id which is in order with mindrecord files
203 shard_id = file_ids_[shard_id];
204
205 // get the row id in the shard
206 row_id = id;
207 for (int32_t i = 0; i < shuffled_shard_sample_count_.size(); i++) {
208 if (id < shuffled_shard_sample_count_[i]) {
209 if (i > 0) {
210 row_id = id - shuffled_shard_sample_count_[i - 1];
211 }
212 break;
213 }
214 }
215
216 return {task_type, std::make_tuple(shard_id, row_id), {}, json()};
217 }
218
GetTaskSampleByID(int64_t id)219 int64_t ShardTaskList::GetTaskSampleByID(int64_t id) { return sample_ids_[id]; }
220
GetRandomTaskID()221 int64_t ShardTaskList::GetRandomTaskID() {
222 std::mt19937 gen = GetRandomDevice();
223 std::uniform_int_distribution<> dis(0, sample_ids_.size() - 1);
224 return dis(gen);
225 }
226
GetRandomTask()227 ShardTask ShardTaskList::GetRandomTask() {
228 std::mt19937 gen = GetRandomDevice();
229 std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
230 size_t random = dis(gen);
231 if (load_mode_ == LoadMode::kFast) {
232 return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), std::get<0>(sample_meta_list_[random]),
233 std::get<1>(sample_meta_list_[random])};
234 } else {
235 return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), {}, json()};
236 }
237 }
238
Combine(std::vector<ShardTaskList> & category_tasks,bool replacement,int64_t num_elements,int64_t num_samples)239 ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
240 int64_t num_samples) {
241 ShardTaskList res;
242 if (category_tasks.empty()) {
243 return res;
244 }
245 auto total_categories = category_tasks.size();
246 res.categories = static_cast<int64_t>(total_categories);
247 if (!replacement) {
248 auto minTasks = category_tasks[0].Size();
249 for (int64_t i = 1; i < total_categories; i++) {
250 minTasks = std::min(minTasks, category_tasks[i].Size());
251 }
252 int64_t count = 0;
253 for (int64_t task_no = 0; task_no < minTasks; task_no++) {
254 for (int64_t i = 0; i < total_categories; i++) {
255 if (num_samples != 0 && count == num_samples) {
256 break;
257 }
258 res.InsertTask(std::move(category_tasks[i].GetTaskByID(task_no)));
259 count++;
260 }
261 }
262 } else {
263 auto maxTasks = category_tasks[0].Size();
264 for (int64_t i = 1; i < total_categories; i++) {
265 maxTasks = std::max(maxTasks, category_tasks[i].Size());
266 }
267 if (num_elements != std::numeric_limits<int64_t>::max()) {
268 maxTasks = static_cast<decltype(maxTasks)>(num_elements);
269 }
270 int64_t count = 0;
271 for (int64_t i = 0; i < total_categories; i++) {
272 for (int64_t j = 0; j < maxTasks; j++) {
273 if (num_samples != 0 && count == num_samples) {
274 break;
275 }
276 res.InsertTask(category_tasks[i].GetRandomTask());
277 count++;
278 }
279 }
280 }
281
282 return res;
283 }
284
SetShardSampleCount(const std::vector<int64_t> & shard_sample_count)285 void ShardTaskList::SetShardSampleCount(const std::vector<int64_t> &shard_sample_count) {
286 // original shard sample count like:
287 // indicate shard_id : inc_count
288 // 0 : 15 - shard0 has 15 samples
289 // 1 : 41 - shard1 has 26 samples
290 // 2 : 58 - shard2 has 17 samples
291 shard_sample_count_ = shard_sample_count;
292
293 // generate new file_ids
294 std::vector<int32_t> file_ids;
295 for (int32_t i = 0; i < shard_sample_count_.size(); i++) {
296 file_ids.push_back(i);
297 }
298 SetFileIds(file_ids);
299 }
300
SetPaddedSample(const int32_t & padded_sample)301 void ShardTaskList::SetPaddedSample(const int32_t &padded_sample) { padded_sample_ = padded_sample; }
302
SetFileIds(const std::vector<int32_t> & file_ids)303 void ShardTaskList::SetFileIds(const std::vector<int32_t> &file_ids) {
304 file_ids_ = file_ids;
305
306 // original shard sample count like:
307 // indicate shard_id : inc_count
308 // 0 : 15 - shard0 has 15 samples
309 // 1 : 41 - shard1 has 26 samples
310 // 2 : 58 - shard2 has 17 samples
311 // create shuffled_shard_sample_count_
312 // after shuffle
313 // 0 : 17 - shard0 has 17 samples - pre shard2
314 // 1 : 32 - shard1 has 15 samples - pre shard0
315 // 2 : 58 - shard2 has 26 samples - pre shard1
316 std::vector<int64_t> shuffled_shard_sample_count;
317 int64_t count = 0;
318 int64_t start;
319 for (int32_t i = 0; i < file_ids_.size(); i++) {
320 if (file_ids[i] == 0) {
321 start = 0;
322 } else {
323 start = shard_sample_count_[file_ids[i] - 1];
324 }
325 shuffled_shard_sample_count.push_back(shard_sample_count_[file_ids[i]] - start + count);
326 count += shard_sample_count_[file_ids[i]] - start;
327 }
328 SetShuffledShardSampleCount(shuffled_shard_sample_count);
329 }
330
SetShuffledShardSampleCount(const std::vector<int64_t> & shuffled_shard_sample_count)331 void ShardTaskList::SetShuffledShardSampleCount(const std::vector<int64_t> &shuffled_shard_sample_count) {
332 shuffled_shard_sample_count_ = shuffled_shard_sample_count;
333
334 // generate new partitioned_shard_sample_count
335 std::vector<PartitionedShardSampleCount> vpssc;
336 int64_t start = 0;
337 for (int32_t shard_index = 0; shard_index < shuffled_shard_sample_count_.size(); shard_index++) {
338 // add new range to vp
339 PartitionedShardSampleCount pssc;
340 pssc.task_type = TaskType::kCommonTask;
341 pssc.shard_id = shard_index;
342 pssc.start = start;
343 pssc.end = shuffled_shard_sample_count_[shard_index];
344 vpssc.push_back(pssc);
345 start = shuffled_shard_sample_count_[shard_index];
346 }
347
348 // padded scenario
349 if (padded_sample_ > 0) {
350 PartitionedShardSampleCount pssc;
351 pssc.task_type = TaskType::kPaddedTask;
352 pssc.shard_id = -1;
353 pssc.start = start;
354 pssc.end = start + padded_sample_;
355 vpssc.push_back(pssc);
356 }
357
358 SetPartitionedShardSampleCount(vpssc);
359 }
360
SetPartitionedShardSampleCount(const std::vector<PartitionedShardSampleCount> & partitioned_shard_sample_count)361 void ShardTaskList::SetPartitionedShardSampleCount(
362 const std::vector<PartitionedShardSampleCount> &partitioned_shard_sample_count) {
363 partitioned_shard_sample_count_ = partitioned_shard_sample_count;
364 generator_ids_.SetShardSampleCount(partitioned_shard_sample_count_);
365 }
366
UpdatePartitionedShardSampleCountByNumSamples(const int64_t & num_samples)367 void ShardTaskList::UpdatePartitionedShardSampleCountByNumSamples(const int64_t &num_samples) {
368 auto count = num_samples;
369 std::vector<PartitionedShardSampleCount> new_partitioned_shard_sample_count = {};
370 for (int32_t i = 0; i < partitioned_shard_sample_count_.size(); i++) {
371 auto start = partitioned_shard_sample_count_[i].start;
372 if (partitioned_shard_sample_count_[i].end - start <= count) {
373 new_partitioned_shard_sample_count.push_back(partitioned_shard_sample_count_[i]);
374 count = count - (partitioned_shard_sample_count_[i].end - start);
375 } else {
376 PartitionedShardSampleCount pssc;
377 pssc.task_type = partitioned_shard_sample_count_[i].task_type;
378 pssc.shard_id = partitioned_shard_sample_count_[i].shard_id;
379 pssc.start = start;
380 pssc.end = start + count;
381 new_partitioned_shard_sample_count.push_back(pssc);
382 break;
383 }
384 }
385
386 SetPartitionedShardSampleCount(new_partitioned_shard_sample_count);
387 }
388
GetNextSampleIds()389 std::vector<int64_t> ShardTaskList::GetNextSampleIds() {
390 return generator_ids_.GetNextSampleIds(need_shuffle_, shuffle_seed_);
391 }
392 } // namespace mindrecord
393 } // namespace mindspore
394