• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_shuffle.h"
18 
19 #include <algorithm>
20 
21 namespace mindspore {
22 namespace mindrecord {
ShardShuffle(uint32_t seed,ShuffleType shuffle_type)23 ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
24     : shuffle_seed_(seed),
25       no_of_samples_(0),
26       replacement_(false),
27       reshuffle_each_epoch_(true),
28       shuffle_type_(shuffle_type) {}
29 
ShardShuffle(uint32_t seed,int64_t no_of_samples,bool replacement,bool reshuffle_each_epoch,ShuffleType shuffle_type)30 ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
31                            ShuffleType shuffle_type)
32     : shuffle_seed_(seed),
33       no_of_samples_(no_of_samples),
34       replacement_(replacement),
35       reshuffle_each_epoch_(reshuffle_each_epoch),
36       shuffle_type_(shuffle_type) {}
37 
GetNumSamples(int64_t dataset_size,int64_t num_classes)38 int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
39   if (replacement_) {
40     return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
41   }
42   return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
43 }
44 
CategoryShuffle(ShardTaskList & tasks)45 Status ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
46   int64_t individual_size = tasks.sample_ids_.size() / tasks.categories;
47   std::vector<std::vector<int64_t>> new_permutations(tasks.categories, std::vector<int64_t>(individual_size));
48   for (int64_t i = 0; i < tasks.categories; i++) {
49     for (int64_t j = 0; j < individual_size; j++) {
50       new_permutations[i][j] = j;
51     }
52     std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
53   }
54   tasks.permutation_.clear();
55   for (int64_t j = 0; j < individual_size; j++) {
56     for (int64_t i = 0; i < tasks.categories; i++) {
57       tasks.permutation_.push_back(new_permutations[i][j] * tasks.categories + i);
58     }
59   }
60 
61   ShardTaskList new_tasks;
62   for (int64_t i = 0; i < individual_size; ++i) {
63     new_tasks.AssignTask(tasks, tasks.permutation_[i]);
64   }
65   ShardTaskList::TaskListSwap(tasks, new_tasks);
66 
67   return Status::OK();
68 }
69 
ShuffleFiles(ShardTaskList & tasks)70 Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
71   if (no_of_samples_ == 0) {
72     no_of_samples_ = tasks.Size();
73   }
74   CHECK_FAIL_RETURN_UNEXPECTED_MR(
75     no_of_samples_ > 0, "Invalid input, 'num_samples' should be positive but got: " + std::to_string(no_of_samples_));
76   auto shard_sample_cout = GetShardSampleCount();
77 
78   // shuffle the files index
79   std::vector<int64_t> shuffle_files;
80   for (int64_t i = 0; i < shard_sample_cout.size(); i++) {
81     shuffle_files.push_back(i);
82   }
83   std::shuffle(shuffle_files.begin(), shuffle_files.end(), std::default_random_engine(shuffle_seed_));
84 
85   // reconstruct the permutation between files
86   // -- before --
87   // file1: [0, 1, 2]
88   // file2: [3, 4, 5, 6]
89   // file3: [7, 8]
90   // file4: [9, 10]
91   // files: [file1, file2, file3, file4]
92   // permutation: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
93   // -- after --
94   // files: [file4, file1, file3, file2]
95   // permutation : [9, 10, 0, 1, 2, 7, 8, 3, 4, 5, 6]
96   auto original_permutation = tasks.permutation_;
97   int64_t whole_index = 0;
98   for (int64_t i = 0; i < shuffle_files.size(); i++) {
99     int64_t start_index = 0;
100     int64_t current_size = 0;
101     if (shuffle_files[i] == 0) {
102       start_index = 0;
103       current_size = shard_sample_cout[shuffle_files[i]];
104     } else {
105       start_index = shard_sample_cout[shuffle_files[i] - 1];
106       current_size = shard_sample_cout[shuffle_files[i]] - start_index;
107     }
108     (void)std::copy(original_permutation.begin() + start_index,
109                     original_permutation.begin() + start_index + current_size,
110                     tasks.permutation_.begin() + whole_index);
111     whole_index += current_size;
112   }
113 
114   auto total_no = tasks.Size();
115   int64_t samples_to_assign =
116     (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
117   ShardTaskList new_tasks;
118   for (int64_t i = 0; i < samples_to_assign; ++i) {
119     new_tasks.AssignTask(tasks, tasks.permutation_[i]);
120   }
121   ShardTaskList::TaskListSwap(tasks, new_tasks);
122   return Status::OK();
123 }
124 
ShuffleInfile(ShardTaskList & tasks)125 Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
126   if (no_of_samples_ == 0) {
127     no_of_samples_ = tasks.Size();
128   }
129   CHECK_FAIL_RETURN_UNEXPECTED_MR(
130     no_of_samples_ > 0, "Invalid input, 'num_samples' should be positive but got: " + std::to_string(no_of_samples_));
131   // reconstruct the permutation in file
132   // -- before --
133   // file1: [0, 1, 2]
134   // file2: [3, 4, 5, 6]
135   // file3: [7, 8]
136   // file4: [9, 10]
137   // files: [file1, file2, file3, file4]
138   // permutation: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
139   // -- after --
140   // permutation: [2, 0, 1, 4, 6, 3, 5, 8, 7, 9, 10]
141   auto shard_sample_cout = GetShardSampleCount();
142   int64_t start_index = 0;
143   for (int64_t i = 0; i < shard_sample_cout.size(); i++) {
144     auto current_size = shard_sample_cout[i] - start_index;
145     std::shuffle(tasks.permutation_.begin() + start_index, tasks.permutation_.begin() + start_index + current_size,
146                  std::default_random_engine(shuffle_seed_));
147     start_index = shard_sample_cout[i];
148   }
149   auto total_no = tasks.Size();
150   ShardTaskList new_tasks;
151   int64_t samples_to_assign =
152     (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
153   for (int64_t i = 0; i < samples_to_assign; ++i) {
154     new_tasks.AssignTask(tasks, tasks.permutation_[i]);
155   }
156   ShardTaskList::TaskListSwap(tasks, new_tasks);
157   return Status::OK();
158 }
159 
Execute(ShardTaskList & tasks)160 Status ShardShuffle::Execute(ShardTaskList &tasks) {
161   if (reshuffle_each_epoch_) {
162     shuffle_seed_++;
163   }
164   CHECK_FAIL_RETURN_UNEXPECTED_MR(tasks.categories >= 1,
165                                   "[Internal ERROR] task categories should be greater than or equal to 1 but got: " +
166                                     std::to_string(tasks.categories));
167   if (tasks.load_mode_ != LoadMode::kSlow) {
168     if (shuffle_type_ == kShuffleSample) {  // shuffle each sample
169       if (tasks.permutation_.empty() == true) {
170         tasks.MakePerm();
171       }
172       if (GetShuffleMode() == dataset::ShuffleMode::kGlobal) {
173         if (replacement_ == true) {
174           ShardTaskList new_tasks;
175           if (no_of_samples_ == 0) {
176             no_of_samples_ = tasks.sample_ids_.size();
177           }
178           CHECK_FAIL_RETURN_UNEXPECTED_MR(
179             no_of_samples_ > 0,
180             "Invalid input, 'num_samples' should be positive but got: " + std::to_string(no_of_samples_));
181           for (uint32_t i = 0; i < no_of_samples_; ++i) {
182             new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
183           }
184 
185           ShardTaskList::TaskListSwap(tasks, new_tasks);
186         } else {
187           std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
188           auto total_no = tasks.Size();
189           ShardTaskList new_tasks;
190           int64_t samples_to_assign =
191             (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
192           for (int64_t i = 0; i < samples_to_assign; ++i) {
193             new_tasks.AssignTask(tasks, tasks.permutation_[i]);
194           }
195           ShardTaskList::TaskListSwap(tasks, new_tasks);
196         }
197       } else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) {
198         RETURN_IF_NOT_OK_MR(ShuffleInfile(tasks));
199       } else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) {
200         RETURN_IF_NOT_OK_MR(ShuffleFiles(tasks));
201       }
202     } else {  // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
203       return this->CategoryShuffle(tasks);
204     }
205   } else {
206     // just shuffle file names
207     CHECK_FAIL_RETURN_UNEXPECTED_MR(
208       GetShuffleMode() != dataset::ShuffleMode::kInfile && GetShuffleMode() != dataset::ShuffleMode::kFiles,
209       "The shuffle mode Shuffle.FILES and Shuffle.INFILE are not supported because "
210       "the number of samples in the dataset is greater than " +
211         std::to_string(SLOW_LOAD_THRESHOLD) + ".");
212     auto shard_sample_count = GetShardSampleCount();
213     std::vector<int32_t> file_ids;
214     for (int32_t i = 0; i < shard_sample_count.size(); i++) {
215       file_ids.push_back(i);
216     }
217     std::shuffle(file_ids.begin(), file_ids.end(), std::default_random_engine(shuffle_seed_));
218     tasks.SetFileIds(file_ids);
219     tasks.need_shuffle_ = true;
220     tasks.shuffle_seed_ = shuffle_seed_;
221     int64_t samples_to_assign =
222       (no_of_samples_ > 0 && no_of_samples_ < tasks.SizeAfterSampling()) ? no_of_samples_ : tasks.SizeAfterSampling();
223     tasks.UpdatePartitionedShardSampleCountByNumSamples(samples_to_assign);
224   }
225   return Status::OK();
226 }
227 }  // namespace mindrecord
228 }  // namespace mindspore
229