• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_task_list.h"
18 #include "minddata/mindrecord/include/common/shard_utils.h"
19 
20 namespace mindspore {
21 namespace mindrecord {
22 // when mindrecord is slow load mode, the shuffle size is 1000,000
23 const int64_t ShuffleSize = 1000000;
24 
GeneratorIds()25 GeneratorIds::GeneratorIds() : partitioned_shard_sample_count_(), partition_index_(0), partition_sample_index_(0) {}
26 
SetShardSampleCount(const std::vector<PartitionedShardSampleCount> & partitioned_shard_sample_count)27 void GeneratorIds::SetShardSampleCount(const std::vector<PartitionedShardSampleCount> &partitioned_shard_sample_count) {
28   partitioned_shard_sample_count_ = partitioned_shard_sample_count;
29   partition_index_ = 0;
30   partition_sample_index_ = 0;
31 }
32 
ResetShardIndexAndID()33 void GeneratorIds::ResetShardIndexAndID() {
34   partition_index_ = 0;
35   partition_sample_index_ = 0;
36 }
37 
GetNextSampleIds(const bool & need_shuffle,const uint32_t & seed)38 std::vector<int64_t> GeneratorIds::GetNextSampleIds(const bool &need_shuffle, const uint32_t &seed) {
39   // partitioned_shard_sample_count_ is:
40   // CommonTask, 0, 16680, 17777
41   // CommonTask, 0, 0, 15
42   std::vector<int64_t> ids;
43   for (int32_t i = partition_index_; i < partitioned_shard_sample_count_.size(); i++) {
44     for (int64_t j = partitioned_shard_sample_count_[i].start + partition_sample_index_;
45          j < partitioned_shard_sample_count_[i].end; j++) {
46       ids.push_back(j);
47       partition_sample_index_++;
48       if (ids.size() >= ShuffleSize) {
49         if (need_shuffle) {
50           std::shuffle(ids.begin(), ids.end(), std::default_random_engine(seed));
51         }
52         return ids;
53       }
54     }
55     partition_index_++;
56     partition_sample_index_ = 0;
57   }
58   if (need_shuffle) {
59     std::shuffle(ids.begin(), ids.end(), std::default_random_engine(seed));
60   }
61   return ids;
62 }
63 
ShardTaskList()64 ShardTaskList::ShardTaskList()
65     : categories(1), padded_sample_(0), need_shuffle_(false), shuffle_seed_(0), load_mode_(LoadMode::kFast) {}
66 
ShardTaskList(const ShardTaskList & other)67 ShardTaskList::ShardTaskList(const ShardTaskList &other)
68     : categories(other.categories),
69       permutation_(other.permutation_),
70       sample_ids_(other.sample_ids_),
71       task_list_(other.task_list_),
72       sample_meta_list_(other.sample_meta_list_),
73       shard_sample_count_(other.shard_sample_count_),
74       padded_sample_(other.padded_sample_),
75       file_ids_(other.file_ids_),
76       shuffled_shard_sample_count_(other.shuffled_shard_sample_count_),
77       partitioned_shard_sample_count_(other.partitioned_shard_sample_count_),
78       need_shuffle_(other.need_shuffle_),
79       shuffle_seed_(other.shuffle_seed_),
80       generator_ids_(other.generator_ids_),
81       load_mode_(other.load_mode_) {}
82 
operator =(const ShardTaskList & other)83 ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
84   ShardTaskList tmp(other);
85   std::swap(categories, tmp.categories);
86   permutation_.swap(tmp.permutation_);
87   sample_ids_.swap(tmp.sample_ids_);
88   task_list_.swap(tmp.task_list_);
89   sample_meta_list_.swap(tmp.sample_meta_list_);
90   shard_sample_count_.swap(tmp.shard_sample_count_);
91   padded_sample_ = tmp.padded_sample_;
92   file_ids_.swap(tmp.file_ids_);
93   shuffled_shard_sample_count_.swap(tmp.shuffled_shard_sample_count_);
94   partitioned_shard_sample_count_.swap(tmp.partitioned_shard_sample_count_);
95   need_shuffle_ = tmp.need_shuffle_;
96   shuffle_seed_ = tmp.shuffle_seed_;
97   generator_ids_ = tmp.generator_ids_;
98   load_mode_ = tmp.load_mode_;
99   return *this;
100 }
101 
InitSampleIds()102 void ShardTaskList::InitSampleIds() {
103   // no-op if there already exists sample ids.  Do not clobber previous list
104   if (sample_ids_.empty()) {
105     sample_ids_ = std::vector<int64_t>(task_list_.size());
106     for (auto i = 0; i < task_list_.size(); i++) {
107       sample_ids_[i] = i;
108     }
109   }
110 }
111 
MakePerm()112 void ShardTaskList::MakePerm() {
113   int64_t perm_size = sample_ids_.size();
114   permutation_ = std::vector<int64_t>(perm_size);
115   for (int64_t i = 0; i < perm_size; i++) {
116     permutation_[i] = i;
117   }
118 }
119 
120 // Swap the new_tasks with orig_tasks
TaskListSwap(ShardTaskList & orig_tasks,ShardTaskList & new_tasks)121 void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks) {
122   // When swapping, if the orig_tasks contains fields that need to be preserved after the swap, then swapping with a
123   // new_tasks that does not have those fields will result in clobbering/losing the data after the swap.
124   // The task_list_ should not be lost/clobbered.
125   // This function can be called in the middle of mindrecord's epoch, when orig_tasks.task_list_ is still being
126   // used by mindrecord op's worker threads. So don't touch its task_list_ since this field should be preserved anyways.
127 
128   std::swap(orig_tasks.categories, new_tasks.categories);
129   std::swap(orig_tasks.permutation_, new_tasks.permutation_);
130   std::swap(orig_tasks.sample_ids_, new_tasks.sample_ids_);
131 }
132 
PopBack()133 void ShardTaskList::PopBack() {
134   task_list_.pop_back();
135   if (load_mode_ == LoadMode::kFast) {
136     sample_meta_list_.pop_back();
137   }
138 }
139 
Size() const140 int64_t ShardTaskList::Size() const {
141   if (load_mode_ != LoadMode::kSlow) {
142     return static_cast<int64_t>(task_list_.size());
143   }
144 
145   // slow load mode
146   return shard_sample_count_[shard_sample_count_.size() - 1] + padded_sample_;
147 }
148 
SizeAfterSampling() const149 int64_t ShardTaskList::SizeAfterSampling() const {
150   if (load_mode_ != LoadMode::kSlow) {
151     return static_cast<int64_t>(sample_ids_.size());
152   }
153 
154   // slow load mode
155   int64_t count = 0;
156   for (int32_t i = 0; i < partitioned_shard_sample_count_.size(); i++) {
157     count += partitioned_shard_sample_count_[i].end - partitioned_shard_sample_count_[i].start;
158   }
159   return count;
160 }
161 
SizeOfRows() const162 int64_t ShardTaskList::SizeOfRows() const {
163   int64_t size_of_rows = 0;
164   if (task_list_.size() == 0) {
165     return size_of_rows;
166   }
167 
168   if (load_mode_ == LoadMode::kFast) {
169     // 1 task is 1 page,blob index start from 2
170     auto sum_num_rows = [](int64_t x, SampleMeta y) { return x + std::get<0>(y)[0]; };
171     size_of_rows = std::accumulate(sample_meta_list_.begin(), sample_meta_list_.end(), 0, sum_num_rows);
172   } else {
173     MS_LOG(WARNING) << "In lazy load mode, size of rows will be " << size_of_rows << " which is not correctly.";
174   }
175   return size_of_rows;
176 }
177 
GetTaskByID(int64_t id)178 ShardTask ShardTaskList::GetTaskByID(int64_t id) {
179   if (load_mode_ == LoadMode::kFast) {
180     return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), std::get<0>(sample_meta_list_[id]),
181             std::get<1>(sample_meta_list_[id])};
182   } else if (load_mode_ == LoadMode::kLazy) {
183     return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), {}, json()};
184   }
185 
186   TaskType task_type = TaskType::kCommonTask;
187   // get the partitioned shard id
188   int32_t shard_id = 0;
189   int32_t row_id = 0;
190   for (int32_t i = 0; i < partitioned_shard_sample_count_.size(); i++) {
191     if (id >= partitioned_shard_sample_count_[i].start && id < partitioned_shard_sample_count_[i].end) {
192       task_type = partitioned_shard_sample_count_[i].task_type;
193       shard_id = partitioned_shard_sample_count_[i].shard_id;
194       break;
195     }
196   }
197 
198   if (shard_id == -1) {
199     return {TaskType::kPaddedTask, std::make_tuple(shard_id, row_id), {}, json()};
200   }
201 
202   // get the original shard_id which is in order with mindrecord files
203   shard_id = file_ids_[shard_id];
204 
205   // get the row id in the shard
206   row_id = id;
207   for (int32_t i = 0; i < shuffled_shard_sample_count_.size(); i++) {
208     if (id < shuffled_shard_sample_count_[i]) {
209       if (i > 0) {
210         row_id = id - shuffled_shard_sample_count_[i - 1];
211       }
212       break;
213     }
214   }
215 
216   return {task_type, std::make_tuple(shard_id, row_id), {}, json()};
217 }
218 
GetTaskSampleByID(int64_t id)219 int64_t ShardTaskList::GetTaskSampleByID(int64_t id) { return sample_ids_[id]; }
220 
GetRandomTaskID()221 int64_t ShardTaskList::GetRandomTaskID() {
222   std::mt19937 gen = GetRandomDevice();
223   std::uniform_int_distribution<> dis(0, sample_ids_.size() - 1);
224   return dis(gen);
225 }
226 
GetRandomTask()227 ShardTask ShardTaskList::GetRandomTask() {
228   std::mt19937 gen = GetRandomDevice();
229   std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
230   size_t random = dis(gen);
231   if (load_mode_ == LoadMode::kFast) {
232     return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), std::get<0>(sample_meta_list_[random]),
233             std::get<1>(sample_meta_list_[random])};
234   } else {
235     return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), {}, json()};
236   }
237 }
238 
Combine(std::vector<ShardTaskList> & category_tasks,bool replacement,int64_t num_elements,int64_t num_samples)239 ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
240                                      int64_t num_samples) {
241   ShardTaskList res;
242   if (category_tasks.empty()) {
243     return res;
244   }
245   auto total_categories = category_tasks.size();
246   res.categories = static_cast<int64_t>(total_categories);
247   if (!replacement) {
248     auto minTasks = category_tasks[0].Size();
249     for (int64_t i = 1; i < total_categories; i++) {
250       minTasks = std::min(minTasks, category_tasks[i].Size());
251     }
252     int64_t count = 0;
253     for (int64_t task_no = 0; task_no < minTasks; task_no++) {
254       for (int64_t i = 0; i < total_categories; i++) {
255         if (num_samples != 0 && count == num_samples) {
256           break;
257         }
258         res.InsertTask(std::move(category_tasks[i].GetTaskByID(task_no)));
259         count++;
260       }
261     }
262   } else {
263     auto maxTasks = category_tasks[0].Size();
264     for (int64_t i = 1; i < total_categories; i++) {
265       maxTasks = std::max(maxTasks, category_tasks[i].Size());
266     }
267     if (num_elements != std::numeric_limits<int64_t>::max()) {
268       maxTasks = static_cast<decltype(maxTasks)>(num_elements);
269     }
270     int64_t count = 0;
271     for (int64_t i = 0; i < total_categories; i++) {
272       for (int64_t j = 0; j < maxTasks; j++) {
273         if (num_samples != 0 && count == num_samples) {
274           break;
275         }
276         res.InsertTask(category_tasks[i].GetRandomTask());
277         count++;
278       }
279     }
280   }
281 
282   return res;
283 }
284 
SetShardSampleCount(const std::vector<int64_t> & shard_sample_count)285 void ShardTaskList::SetShardSampleCount(const std::vector<int64_t> &shard_sample_count) {
286   // original shard sample count like:
287   // indicate shard_id : inc_count
288   // 0 : 15  -  shard0 has 15 samples
289   // 1 : 41  -  shard1 has 26 samples
290   // 2 : 58  -  shard2 has 17 samples
291   shard_sample_count_ = shard_sample_count;
292 
293   // generate new file_ids
294   std::vector<int32_t> file_ids;
295   for (int32_t i = 0; i < shard_sample_count_.size(); i++) {
296     file_ids.push_back(i);
297   }
298   SetFileIds(file_ids);
299 }
300 
SetPaddedSample(const int32_t & padded_sample)301 void ShardTaskList::SetPaddedSample(const int32_t &padded_sample) { padded_sample_ = padded_sample; }
302 
SetFileIds(const std::vector<int32_t> & file_ids)303 void ShardTaskList::SetFileIds(const std::vector<int32_t> &file_ids) {
304   file_ids_ = file_ids;
305 
306   // original shard sample count like:
307   // indicate shard_id : inc_count
308   // 0 : 15  -  shard0 has 15 samples
309   // 1 : 41  -  shard1 has 26 samples
310   // 2 : 58  -  shard2 has 17 samples
311   // create shuffled_shard_sample_count_
312   // after shuffle
313   // 0 : 17  -  shard0 has 17 samples - pre shard2
314   // 1 : 32  -  shard1 has 15 samples - pre shard0
315   // 2 : 58  -  shard2 has 26 samples - pre shard1
316   std::vector<int64_t> shuffled_shard_sample_count;
317   int64_t count = 0;
318   int64_t start;
319   for (int32_t i = 0; i < file_ids_.size(); i++) {
320     if (file_ids[i] == 0) {
321       start = 0;
322     } else {
323       start = shard_sample_count_[file_ids[i] - 1];
324     }
325     shuffled_shard_sample_count.push_back(shard_sample_count_[file_ids[i]] - start + count);
326     count += shard_sample_count_[file_ids[i]] - start;
327   }
328   SetShuffledShardSampleCount(shuffled_shard_sample_count);
329 }
330 
SetShuffledShardSampleCount(const std::vector<int64_t> & shuffled_shard_sample_count)331 void ShardTaskList::SetShuffledShardSampleCount(const std::vector<int64_t> &shuffled_shard_sample_count) {
332   shuffled_shard_sample_count_ = shuffled_shard_sample_count;
333 
334   // generate new partitioned_shard_sample_count
335   std::vector<PartitionedShardSampleCount> vpssc;
336   int64_t start = 0;
337   for (int32_t shard_index = 0; shard_index < shuffled_shard_sample_count_.size(); shard_index++) {
338     // add new range to vp
339     PartitionedShardSampleCount pssc;
340     pssc.task_type = TaskType::kCommonTask;
341     pssc.shard_id = shard_index;
342     pssc.start = start;
343     pssc.end = shuffled_shard_sample_count_[shard_index];
344     vpssc.push_back(pssc);
345     start = shuffled_shard_sample_count_[shard_index];
346   }
347 
348   // padded scenario
349   if (padded_sample_ > 0) {
350     PartitionedShardSampleCount pssc;
351     pssc.task_type = TaskType::kPaddedTask;
352     pssc.shard_id = -1;
353     pssc.start = start;
354     pssc.end = start + padded_sample_;
355     vpssc.push_back(pssc);
356   }
357 
358   SetPartitionedShardSampleCount(vpssc);
359 }
360 
SetPartitionedShardSampleCount(const std::vector<PartitionedShardSampleCount> & partitioned_shard_sample_count)361 void ShardTaskList::SetPartitionedShardSampleCount(
362   const std::vector<PartitionedShardSampleCount> &partitioned_shard_sample_count) {
363   partitioned_shard_sample_count_ = partitioned_shard_sample_count;
364   generator_ids_.SetShardSampleCount(partitioned_shard_sample_count_);
365 }
366 
UpdatePartitionedShardSampleCountByNumSamples(const int64_t & num_samples)367 void ShardTaskList::UpdatePartitionedShardSampleCountByNumSamples(const int64_t &num_samples) {
368   auto count = num_samples;
369   std::vector<PartitionedShardSampleCount> new_partitioned_shard_sample_count = {};
370   for (int32_t i = 0; i < partitioned_shard_sample_count_.size(); i++) {
371     auto start = partitioned_shard_sample_count_[i].start;
372     if (partitioned_shard_sample_count_[i].end - start <= count) {
373       new_partitioned_shard_sample_count.push_back(partitioned_shard_sample_count_[i]);
374       count = count - (partitioned_shard_sample_count_[i].end - start);
375     } else {
376       PartitionedShardSampleCount pssc;
377       pssc.task_type = partitioned_shard_sample_count_[i].task_type;
378       pssc.shard_id = partitioned_shard_sample_count_[i].shard_id;
379       pssc.start = start;
380       pssc.end = start + count;
381       new_partitioned_shard_sample_count.push_back(pssc);
382       break;
383     }
384   }
385 
386   SetPartitionedShardSampleCount(new_partitioned_shard_sample_count);
387 }
388 
GetNextSampleIds()389 std::vector<int64_t> ShardTaskList::GetNextSampleIds() {
390   return generator_ids_.GetNextSampleIds(need_shuffle_, shuffle_seed_);
391 }
392 }  // namespace mindrecord
393 }  // namespace mindspore
394